Files
ieemoo-ai-contrast/tools/getpairs.py
2025-07-02 14:41:12 +08:00

199 lines
6.7 KiB
Python

import os
import random
import json
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import logging
class PairGenerator:
"""Generate positive and negative image pairs for contrastive learning."""
def __init__(self):
self._setup_logging()
def _setup_logging(self):
"""Configure logging settings."""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
"""Scan directory and return dict of {folder: [image_paths]}."""
root = Path(root_dir)
if not root.is_dir():
raise ValueError(f"Invalid directory: {root_dir}")
return {
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
for folder in root.iterdir() if folder.is_dir()
}
def _generate_same_pairs(
self,
files_dict: Dict[str, List[str]],
num_pairs: int,
min_size: int, # min_size is the minimum number of images per folder
group_size: Optional[int] = None
) -> List[Tuple[str, str, int]]:
"""Generate positive pairs from same folder."""
pairs = []
for folder, files in files_dict.items():
if len(files) < 2:
continue
if group_size:
# Group mode: generate all possible pairs within group
for i in range(0, len(files), group_size):
group = files[i:i + group_size]
pairs.extend([
(group[i], group[j], 1)
for i in range(len(group))
for j in range(i + 1, len(group))
])
else:
# Individual mode: random pairs
try:
pairs.extend(self._random_pairs(files, min(min_size, len(files) // 2)))
except ValueError as e:
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
random.shuffle(pairs)
return pairs[:num_pairs]
def _generate_cross_pairs(
self,
files_dict: Dict[str, List[str]],
num_pairs: int
) -> List[Tuple[str, str, int]]:
"""Generate negative pairs from different folders."""
folders = list(files_dict.keys())
pairs = []
existing_pairs = set()
while len(pairs) < num_pairs and len(folders) >= 2:
folder1, folder2 = random.sample(folders, 2)
file1 = random.choice(files_dict[folder1])
file2 = random.choice(files_dict[folder2])
pair_key = (file1, file2)
reverse_key = (file2, file1)
if pair_key not in existing_pairs and reverse_key not in existing_pairs:
pairs.append((file1, file2, 0))
existing_pairs.add(pair_key)
existing_pairs.add(reverse_key)
return pairs
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
"""Generate random pairs from file list."""
max_possible = len(files) // 2
if max_possible == 0:
return []
actual_pairs = min(num_pairs, max_possible)
indices = random.sample(range(len(files)), 2 * actual_pairs)
indices.sort()
return [(files[i], files[i + 1], 1) for i in range(0, len(indices), 2)]
def get_pairs(self,
root_dir: str,
num_pairs: int = 6000,
output_txt: Optional[str] = None
) -> List[Tuple[str, str, int]]:
"""
Generate individual image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
num_pairs: Number of pairs to generate
output_txt: Optional path to save pairs as txt file
Returns:
List of (path1, path2, label) tuples
"""
files_dict = self._get_image_files(root_dir)
same_pairs = self._generate_same_pairs(files_dict, num_pairs, min_size=30)
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
if output_txt:
try:
with open(output_txt, 'w') as f:
for file1, file2, label in pairs:
f.write(f"{file1} {file2} {label}\n")
self.logger.info(f"Saved pairs to {output_txt}")
except IOError as e:
self.logger.warning(f"Failed to write pairs to {output_txt}: {str(e)}")
return pairs
def get_group_pairs(
self,
root_dir: str,
img_num: int = 20,
group_num: int = 10,
num_pairs: int = 5000
) -> List[Tuple[str, str, int]]:
"""
Generate grouped image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
img_num: Minimum images required per folder
group_num: Number of images per group
num_pairs: Number of pairs to generate
Returns:
List of (path1, path2, label) tuples
"""
# Filter folders with enough images
files_dict = {
k: v for k, v in self._get_image_files(root_dir).items()
if len(v) >= img_num
}
# Split into groups
grouped_files = {}
for folder, files in files_dict.items():
random.shuffle(files)
grouped_files[folder] = [
files[i:i + group_num]
for i in range(0, len(files), group_num)
]
# Generate pairs
same_pairs = self._generate_same_pairs(
grouped_files, num_pairs, min_size=30, group_size=group_num
)
cross_pairs = self._generate_cross_pairs(
grouped_files, len(same_pairs)
)
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} group pairs")
# Save to JSON
with open("cross_same.json", 'w') as f:
json.dump(pairs, f)
return pairs
if __name__ == "__main__":
original_path = '/home/lc/data_center/scatter/val_extar'
parent_dir = str(Path(original_path).parent)
generator = PairGenerator()
# Example usage:
pairs = generator.get_pairs(original_path,
output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
# groups = generator.get_group_pairs('val') # Group pairs