Files
ieemoo-ai-contrast/tools/getpairs.py
2025-06-13 10:57:02 +08:00

172 lines
5.7 KiB
Python

import os
import random
import json
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import logging
class PairGenerator:
"""Generate positive and negative image pairs for contrastive learning."""
def __init__(self):
self._setup_logging()
def _setup_logging(self):
"""Configure logging settings."""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
"""Scan directory and return dict of {folder: [image_paths]}."""
root = Path(root_dir)
if not root.is_dir():
raise ValueError(f"Invalid directory: {root_dir}")
return {
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
for folder in root.iterdir() if folder.is_dir()
}
def _generate_same_pairs(
self,
files_dict: Dict[str, List[str]],
num_pairs: int,
group_size: Optional[int] = None
) -> List[Tuple[str, str, int]]:
"""Generate positive pairs from same folder."""
pairs = []
for folder, files in files_dict.items():
if len(files) < 2:
continue
if group_size:
# Group mode: generate all possible pairs within group
for i in range(0, len(files), group_size):
group = files[i:i+group_size]
pairs.extend([
(group[i], group[j], 1)
for i in range(len(group))
for j in range(i+1, len(group))
])
else:
# Individual mode: random pairs
try:
pairs.extend(self._random_pairs(files, min(3, len(files)//2)))
except ValueError as e:
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
random.shuffle(pairs)
return pairs[:num_pairs]
def _generate_cross_pairs(
self,
files_dict: Dict[str, List[str]],
num_pairs: int
) -> List[Tuple[str, str, int]]:
"""Generate negative pairs from different folders."""
folders = list(files_dict.keys())
pairs = []
while len(pairs) < num_pairs and len(folders) >= 2:
folder1, folder2 = random.sample(folders, 2)
file1 = random.choice(files_dict[folder1])
file2 = random.choice(files_dict[folder2])
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1)
for f1, f2, _ in pairs):
pairs.append((file1, file2, 0))
return pairs
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
"""Generate random pairs from file list."""
if len(files) < 2 * num_pairs:
raise ValueError("Not enough files for requested pairs")
indices = random.sample(range(len(files)), 2 * num_pairs)
indices.sort()
return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)]
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]:
"""
Generate individual image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
num_pairs: Number of pairs to generate
Returns:
List of (path1, path2, label) tuples
"""
files_dict = self._get_image_files(root_dir)
same_pairs = self._generate_same_pairs(files_dict, num_pairs)
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
return pairs
def get_group_pairs(
self,
root_dir: str,
img_num: int = 20,
group_num: int = 10,
num_pairs: int = 5000
) -> List[Tuple[str, str, int]]:
"""
Generate grouped image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
img_num: Minimum images required per folder
group_num: Number of images per group
num_pairs: Number of pairs to generate
Returns:
List of (path1, path2, label) tuples
"""
# Filter folders with enough images
files_dict = {
k: v for k, v in self._get_image_files(root_dir).items()
if len(v) >= img_num
}
# Split into groups
grouped_files = {}
for folder, files in files_dict.items():
random.shuffle(files)
grouped_files[folder] = [
files[i:i+group_num]
for i in range(0, len(files), group_num)
]
# Generate pairs
same_pairs = self._generate_same_pairs(
grouped_files, num_pairs, group_size=group_num
)
cross_pairs = self._generate_cross_pairs(
grouped_files, len(same_pairs)
)
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} group pairs")
# Save to JSON
with open("cross_same.json", 'w') as f:
json.dump(pairs, f)
return pairs
if __name__ == "__main__":
generator = PairGenerator()
# Example usage:
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs
# groups = generator.get_group_pairs('val') # Group pairs