import os import random import json from pathlib import Path from typing import List, Tuple, Dict, Optional import logging class PairGenerator: """Generate positive and negative image pairs for contrastive learning.""" def __init__(self): self._setup_logging() def _setup_logging(self): """Configure logging settings.""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) self.logger = logging.getLogger(__name__) def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]: """Scan directory and return dict of {folder: [image_paths]}.""" root = Path(root_dir) if not root.is_dir(): raise ValueError(f"Invalid directory: {root_dir}") return { str(folder): [str(f) for f in folder.iterdir() if f.is_file()] for folder in root.iterdir() if folder.is_dir() } def _generate_same_pairs( self, files_dict: Dict[str, List[str]], num_pairs: int, group_size: Optional[int] = None ) -> List[Tuple[str, str, int]]: """Generate positive pairs from same folder.""" pairs = [] for folder, files in files_dict.items(): if len(files) < 2: continue if group_size: # Group mode: generate all possible pairs within group for i in range(0, len(files), group_size): group = files[i:i+group_size] pairs.extend([ (group[i], group[j], 1) for i in range(len(group)) for j in range(i+1, len(group)) ]) else: # Individual mode: random pairs try: pairs.extend(self._random_pairs(files, min(3, len(files)//2))) except ValueError as e: self.logger.warning(f"Skipping folder {folder}: {str(e)}") random.shuffle(pairs) return pairs[:num_pairs] def _generate_cross_pairs( self, files_dict: Dict[str, List[str]], num_pairs: int ) -> List[Tuple[str, str, int]]: """Generate negative pairs from different folders.""" folders = list(files_dict.keys()) pairs = [] while len(pairs) < num_pairs and len(folders) >= 2: folder1, folder2 = random.sample(folders, 2) file1 = random.choice(files_dict[folder1]) file2 = random.choice(files_dict[folder2]) if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1) for f1, f2, _ in pairs): pairs.append((file1, file2, 0)) return pairs def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]: """Generate random pairs from file list.""" if len(files) < 2 * num_pairs: raise ValueError("Not enough files for requested pairs") indices = random.sample(range(len(files)), 2 * num_pairs) indices.sort() return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)] def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]: """ Generate individual image pairs with labels (1=same, 0=different). Args: root_dir: Directory containing subfolders of images num_pairs: Number of pairs to generate Returns: List of (path1, path2, label) tuples """ files_dict = self._get_image_files(root_dir) same_pairs = self._generate_same_pairs(files_dict, num_pairs) cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs)) pairs = same_pairs + cross_pairs self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)") return pairs def get_group_pairs( self, root_dir: str, img_num: int = 20, group_num: int = 10, num_pairs: int = 5000 ) -> List[Tuple[str, str, int]]: """ Generate grouped image pairs with labels (1=same, 0=different). Args: root_dir: Directory containing subfolders of images img_num: Minimum images required per folder group_num: Number of images per group num_pairs: Number of pairs to generate Returns: List of (path1, path2, label) tuples """ # Filter folders with enough images files_dict = { k: v for k, v in self._get_image_files(root_dir).items() if len(v) >= img_num } # Split into groups grouped_files = {} for folder, files in files_dict.items(): random.shuffle(files) grouped_files[folder] = [ files[i:i+group_num] for i in range(0, len(files), group_num) ] # Generate pairs same_pairs = self._generate_same_pairs( grouped_files, num_pairs, group_size=group_num ) cross_pairs = self._generate_cross_pairs( grouped_files, len(same_pairs) ) pairs = same_pairs + cross_pairs self.logger.info(f"Generated {len(pairs)} group pairs") # Save to JSON with open("cross_same.json", 'w') as f: json.dump(pairs, f) return pairs if __name__ == "__main__": generator = PairGenerator() # Example usage: pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs # groups = generator.get_group_pairs('val') # Group pairs