更改
This commit is contained in:
171
tools/getpairs.py
Normal file
171
tools/getpairs.py
Normal file
@ -0,0 +1,171 @@
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
import logging
|
||||
|
||||
class PairGenerator:
|
||||
"""Generate positive and negative image pairs for contrastive learning."""
|
||||
|
||||
def __init__(self):
|
||||
self._setup_logging()
|
||||
|
||||
def _setup_logging(self):
|
||||
"""Configure logging settings."""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
|
||||
"""Scan directory and return dict of {folder: [image_paths]}."""
|
||||
root = Path(root_dir)
|
||||
if not root.is_dir():
|
||||
raise ValueError(f"Invalid directory: {root_dir}")
|
||||
|
||||
return {
|
||||
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
|
||||
for folder in root.iterdir() if folder.is_dir()
|
||||
}
|
||||
|
||||
def _generate_same_pairs(
|
||||
self,
|
||||
files_dict: Dict[str, List[str]],
|
||||
num_pairs: int,
|
||||
group_size: Optional[int] = None
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""Generate positive pairs from same folder."""
|
||||
pairs = []
|
||||
|
||||
for folder, files in files_dict.items():
|
||||
if len(files) < 2:
|
||||
continue
|
||||
|
||||
if group_size:
|
||||
# Group mode: generate all possible pairs within group
|
||||
for i in range(0, len(files), group_size):
|
||||
group = files[i:i+group_size]
|
||||
pairs.extend([
|
||||
(group[i], group[j], 1)
|
||||
for i in range(len(group))
|
||||
for j in range(i+1, len(group))
|
||||
])
|
||||
else:
|
||||
# Individual mode: random pairs
|
||||
try:
|
||||
pairs.extend(self._random_pairs(files, min(3, len(files)//2)))
|
||||
except ValueError as e:
|
||||
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
|
||||
|
||||
random.shuffle(pairs)
|
||||
return pairs[:num_pairs]
|
||||
|
||||
def _generate_cross_pairs(
|
||||
self,
|
||||
files_dict: Dict[str, List[str]],
|
||||
num_pairs: int
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""Generate negative pairs from different folders."""
|
||||
folders = list(files_dict.keys())
|
||||
pairs = []
|
||||
|
||||
while len(pairs) < num_pairs and len(folders) >= 2:
|
||||
folder1, folder2 = random.sample(folders, 2)
|
||||
file1 = random.choice(files_dict[folder1])
|
||||
file2 = random.choice(files_dict[folder2])
|
||||
|
||||
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1)
|
||||
for f1, f2, _ in pairs):
|
||||
pairs.append((file1, file2, 0))
|
||||
|
||||
return pairs
|
||||
|
||||
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
|
||||
"""Generate random pairs from file list."""
|
||||
if len(files) < 2 * num_pairs:
|
||||
raise ValueError("Not enough files for requested pairs")
|
||||
|
||||
indices = random.sample(range(len(files)), 2 * num_pairs)
|
||||
indices.sort()
|
||||
return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)]
|
||||
|
||||
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]:
|
||||
"""
|
||||
Generate individual image pairs with labels (1=same, 0=different).
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing subfolders of images
|
||||
num_pairs: Number of pairs to generate
|
||||
|
||||
Returns:
|
||||
List of (path1, path2, label) tuples
|
||||
"""
|
||||
files_dict = self._get_image_files(root_dir)
|
||||
|
||||
same_pairs = self._generate_same_pairs(files_dict, num_pairs)
|
||||
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
|
||||
|
||||
pairs = same_pairs + cross_pairs
|
||||
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
|
||||
return pairs
|
||||
|
||||
def get_group_pairs(
|
||||
self,
|
||||
root_dir: str,
|
||||
img_num: int = 20,
|
||||
group_num: int = 10,
|
||||
num_pairs: int = 5000
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""
|
||||
Generate grouped image pairs with labels (1=same, 0=different).
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing subfolders of images
|
||||
img_num: Minimum images required per folder
|
||||
group_num: Number of images per group
|
||||
num_pairs: Number of pairs to generate
|
||||
|
||||
Returns:
|
||||
List of (path1, path2, label) tuples
|
||||
"""
|
||||
# Filter folders with enough images
|
||||
files_dict = {
|
||||
k: v for k, v in self._get_image_files(root_dir).items()
|
||||
if len(v) >= img_num
|
||||
}
|
||||
|
||||
# Split into groups
|
||||
grouped_files = {}
|
||||
for folder, files in files_dict.items():
|
||||
random.shuffle(files)
|
||||
grouped_files[folder] = [
|
||||
files[i:i+group_num]
|
||||
for i in range(0, len(files), group_num)
|
||||
]
|
||||
|
||||
# Generate pairs
|
||||
same_pairs = self._generate_same_pairs(
|
||||
grouped_files, num_pairs, group_size=group_num
|
||||
)
|
||||
cross_pairs = self._generate_cross_pairs(
|
||||
grouped_files, len(same_pairs)
|
||||
)
|
||||
|
||||
pairs = same_pairs + cross_pairs
|
||||
self.logger.info(f"Generated {len(pairs)} group pairs")
|
||||
|
||||
# Save to JSON
|
||||
with open("cross_same.json", 'w') as f:
|
||||
json.dump(pairs, f)
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generator = PairGenerator()
|
||||
|
||||
# Example usage:
|
||||
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs
|
||||
# groups = generator.get_group_pairs('val') # Group pairs
|
Reference in New Issue
Block a user