172 lines
5.7 KiB
Python
172 lines
5.7 KiB
Python
import os
|
|
import random
|
|
import json
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Dict, Optional
|
|
import logging
|
|
|
|
class PairGenerator:
|
|
"""Generate positive and negative image pairs for contrastive learning."""
|
|
|
|
def __init__(self):
|
|
self._setup_logging()
|
|
|
|
def _setup_logging(self):
|
|
"""Configure logging settings."""
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
|
|
"""Scan directory and return dict of {folder: [image_paths]}."""
|
|
root = Path(root_dir)
|
|
if not root.is_dir():
|
|
raise ValueError(f"Invalid directory: {root_dir}")
|
|
|
|
return {
|
|
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
|
|
for folder in root.iterdir() if folder.is_dir()
|
|
}
|
|
|
|
def _generate_same_pairs(
|
|
self,
|
|
files_dict: Dict[str, List[str]],
|
|
num_pairs: int,
|
|
group_size: Optional[int] = None
|
|
) -> List[Tuple[str, str, int]]:
|
|
"""Generate positive pairs from same folder."""
|
|
pairs = []
|
|
|
|
for folder, files in files_dict.items():
|
|
if len(files) < 2:
|
|
continue
|
|
|
|
if group_size:
|
|
# Group mode: generate all possible pairs within group
|
|
for i in range(0, len(files), group_size):
|
|
group = files[i:i+group_size]
|
|
pairs.extend([
|
|
(group[i], group[j], 1)
|
|
for i in range(len(group))
|
|
for j in range(i+1, len(group))
|
|
])
|
|
else:
|
|
# Individual mode: random pairs
|
|
try:
|
|
pairs.extend(self._random_pairs(files, min(3, len(files)//2)))
|
|
except ValueError as e:
|
|
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
|
|
|
|
random.shuffle(pairs)
|
|
return pairs[:num_pairs]
|
|
|
|
def _generate_cross_pairs(
|
|
self,
|
|
files_dict: Dict[str, List[str]],
|
|
num_pairs: int
|
|
) -> List[Tuple[str, str, int]]:
|
|
"""Generate negative pairs from different folders."""
|
|
folders = list(files_dict.keys())
|
|
pairs = []
|
|
|
|
while len(pairs) < num_pairs and len(folders) >= 2:
|
|
folder1, folder2 = random.sample(folders, 2)
|
|
file1 = random.choice(files_dict[folder1])
|
|
file2 = random.choice(files_dict[folder2])
|
|
|
|
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1)
|
|
for f1, f2, _ in pairs):
|
|
pairs.append((file1, file2, 0))
|
|
|
|
return pairs
|
|
|
|
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
|
|
"""Generate random pairs from file list."""
|
|
if len(files) < 2 * num_pairs:
|
|
raise ValueError("Not enough files for requested pairs")
|
|
|
|
indices = random.sample(range(len(files)), 2 * num_pairs)
|
|
indices.sort()
|
|
return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)]
|
|
|
|
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]:
|
|
"""
|
|
Generate individual image pairs with labels (1=same, 0=different).
|
|
|
|
Args:
|
|
root_dir: Directory containing subfolders of images
|
|
num_pairs: Number of pairs to generate
|
|
|
|
Returns:
|
|
List of (path1, path2, label) tuples
|
|
"""
|
|
files_dict = self._get_image_files(root_dir)
|
|
|
|
same_pairs = self._generate_same_pairs(files_dict, num_pairs)
|
|
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
|
|
|
|
pairs = same_pairs + cross_pairs
|
|
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
|
|
return pairs
|
|
|
|
def get_group_pairs(
|
|
self,
|
|
root_dir: str,
|
|
img_num: int = 20,
|
|
group_num: int = 10,
|
|
num_pairs: int = 5000
|
|
) -> List[Tuple[str, str, int]]:
|
|
"""
|
|
Generate grouped image pairs with labels (1=same, 0=different).
|
|
|
|
Args:
|
|
root_dir: Directory containing subfolders of images
|
|
img_num: Minimum images required per folder
|
|
group_num: Number of images per group
|
|
num_pairs: Number of pairs to generate
|
|
|
|
Returns:
|
|
List of (path1, path2, label) tuples
|
|
"""
|
|
# Filter folders with enough images
|
|
files_dict = {
|
|
k: v for k, v in self._get_image_files(root_dir).items()
|
|
if len(v) >= img_num
|
|
}
|
|
|
|
# Split into groups
|
|
grouped_files = {}
|
|
for folder, files in files_dict.items():
|
|
random.shuffle(files)
|
|
grouped_files[folder] = [
|
|
files[i:i+group_num]
|
|
for i in range(0, len(files), group_num)
|
|
]
|
|
|
|
# Generate pairs
|
|
same_pairs = self._generate_same_pairs(
|
|
grouped_files, num_pairs, group_size=group_num
|
|
)
|
|
cross_pairs = self._generate_cross_pairs(
|
|
grouped_files, len(same_pairs)
|
|
)
|
|
|
|
pairs = same_pairs + cross_pairs
|
|
self.logger.info(f"Generated {len(pairs)} group pairs")
|
|
|
|
# Save to JSON
|
|
with open("cross_same.json", 'w') as f:
|
|
json.dump(pairs, f)
|
|
|
|
return pairs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
generator = PairGenerator()
|
|
|
|
# Example usage:
|
|
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs
|
|
# groups = generator.get_group_pairs('val') # Group pairs
|