214 lines
7.5 KiB
Python
214 lines
7.5 KiB
Python
import os
|
|
import random
|
|
import json
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Dict, Optional
|
|
import logging
|
|
|
|
|
|
class PairGenerator:
|
|
"""Generate positive and negative image pairs for contrastive learning."""
|
|
|
|
def __init__(self, original_path):
|
|
self._setup_logging()
|
|
self.original_path = original_path
|
|
self._delete_space()
|
|
|
|
def _delete_space(self): # 删除图片文件名中的空格
|
|
print(self.original_path)
|
|
for root, dirs, files in os.walk(self.original_path):
|
|
for file_name in files:
|
|
if file_name.endswith('.jpg' or '.png'):
|
|
n_file_name = file_name.replace(' ', '')
|
|
os.rename(os.path.join(root, file_name), os.path.join(root, n_file_name))
|
|
if 'rotate' in file_name:
|
|
os.remove(os.path.join(root, file_name))
|
|
for dir_name in dirs:
|
|
n_dir_name = dir_name.replace(' ', '')
|
|
os.rename(os.path.join(root, dir_name), os.path.join(root, n_dir_name))
|
|
|
|
def _setup_logging(self):
|
|
"""Configure logging settings."""
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
|
|
"""Scan directory and return dict of {folder: [image_paths]}."""
|
|
root = Path(root_dir)
|
|
if not root.is_dir():
|
|
raise ValueError(f"Invalid directory: {root_dir}")
|
|
|
|
return {
|
|
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
|
|
for folder in root.iterdir() if folder.is_dir()
|
|
}
|
|
|
|
def _generate_same_pairs(
|
|
self,
|
|
files_dict: Dict[str, List[str]],
|
|
num_pairs: int,
|
|
min_size: int, # min_size is the minimum number of images per folder
|
|
group_size: Optional[int] = None
|
|
) -> List[Tuple[str, str, int]]:
|
|
"""Generate positive pairs from same folder."""
|
|
pairs = []
|
|
|
|
for folder, files in files_dict.items():
|
|
if len(files) < 2:
|
|
continue
|
|
|
|
if group_size:
|
|
# Group mode: generate all possible pairs within group
|
|
for i in range(0, len(files), group_size):
|
|
group = files[i:i + group_size]
|
|
pairs.extend([
|
|
(group[i], group[j], 1)
|
|
for i in range(len(group))
|
|
for j in range(i + 1, len(group))
|
|
])
|
|
else:
|
|
# Individual mode: random pairs
|
|
try:
|
|
pairs.extend(self._random_pairs(files, min(min_size, len(files) // 2)))
|
|
except ValueError as e:
|
|
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
|
|
|
|
random.shuffle(pairs)
|
|
return pairs[:num_pairs]
|
|
|
|
def _generate_cross_pairs(
|
|
self,
|
|
files_dict: Dict[str, List[str]],
|
|
num_pairs: int
|
|
) -> List[Tuple[str, str, int]]:
|
|
"""Generate negative pairs from different folders."""
|
|
folders = list(files_dict.keys())
|
|
pairs = []
|
|
existing_pairs = set()
|
|
|
|
while len(pairs) < num_pairs and len(folders) >= 2:
|
|
folder1, folder2 = random.sample(folders, 2)
|
|
file1 = random.choice(files_dict[folder1])
|
|
file2 = random.choice(files_dict[folder2])
|
|
|
|
pair_key = (file1, file2)
|
|
reverse_key = (file2, file1)
|
|
|
|
if pair_key not in existing_pairs and reverse_key not in existing_pairs:
|
|
pairs.append((file1, file2, 0))
|
|
existing_pairs.add(pair_key)
|
|
existing_pairs.add(reverse_key)
|
|
|
|
return pairs
|
|
|
|
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
|
|
"""Generate random pairs from file list."""
|
|
max_possible = len(files) // 2
|
|
if max_possible == 0:
|
|
return []
|
|
|
|
actual_pairs = min(num_pairs, max_possible)
|
|
indices = random.sample(range(len(files)), 2 * actual_pairs)
|
|
indices.sort()
|
|
return [(files[i], files[i + 1], 1) for i in range(0, len(indices), 2)]
|
|
|
|
def get_pairs(self,
|
|
root_dir: str,
|
|
num_pairs: int = 6000,
|
|
output_txt: Optional[str] = None
|
|
) -> List[Tuple[str, str, int]]:
|
|
"""
|
|
Generate individual image pairs with labels (1=same, 0=different).
|
|
|
|
Args:
|
|
root_dir: Directory containing subfolders of images
|
|
num_pairs: Number of pairs to generate
|
|
output_txt: Optional path to save pairs as txt file
|
|
|
|
Returns:
|
|
List of (path1, path2, label) tuples
|
|
"""
|
|
files_dict = self._get_image_files(root_dir)
|
|
|
|
same_pairs = self._generate_same_pairs(files_dict, num_pairs, min_size=30)
|
|
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
|
|
|
|
pairs = same_pairs + cross_pairs
|
|
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
|
|
|
|
if output_txt:
|
|
try:
|
|
with open(output_txt, 'w') as f:
|
|
for file1, file2, label in pairs:
|
|
f.write(f"{file1} {file2} {label}\n")
|
|
self.logger.info(f"Saved pairs to {output_txt}")
|
|
except IOError as e:
|
|
self.logger.warning(f"Failed to write pairs to {output_txt}: {str(e)}")
|
|
|
|
return pairs
|
|
|
|
def get_group_pairs(
|
|
self,
|
|
root_dir: str,
|
|
img_num: int = 20,
|
|
group_num: int = 10,
|
|
num_pairs: int = 5000
|
|
) -> List[Tuple[str, str, int]]:
|
|
"""
|
|
Generate grouped image pairs with labels (1=same, 0=different).
|
|
|
|
Args:
|
|
root_dir: Directory containing subfolders of images
|
|
img_num: Minimum images required per folder
|
|
group_num: Number of images per group
|
|
num_pairs: Number of pairs to generate
|
|
|
|
Returns:
|
|
List of (path1, path2, label) tuples
|
|
"""
|
|
# Filter folders with enough images
|
|
files_dict = {
|
|
k: v for k, v in self._get_image_files(root_dir).items()
|
|
if len(v) >= img_num
|
|
}
|
|
|
|
# Split into groups
|
|
grouped_files = {}
|
|
for folder, files in files_dict.items():
|
|
random.shuffle(files)
|
|
grouped_files[folder] = [
|
|
files[i:i + group_num]
|
|
for i in range(0, len(files), group_num)
|
|
]
|
|
|
|
# Generate pairs
|
|
same_pairs = self._generate_same_pairs(
|
|
grouped_files, num_pairs, min_size=30, group_size=group_num
|
|
)
|
|
cross_pairs = self._generate_cross_pairs(
|
|
grouped_files, len(same_pairs)
|
|
)
|
|
|
|
pairs = same_pairs + cross_pairs
|
|
self.logger.info(f"Generated {len(pairs)} group pairs")
|
|
|
|
# Save to JSON
|
|
with open("cross_same.json", 'w') as f:
|
|
json.dump(pairs, f)
|
|
|
|
return pairs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
original_path = '/home/lc/data_center/contrast_data/v1/extra'
|
|
parent_dir = str(Path(original_path).parent)
|
|
generator = PairGenerator(original_path)
|
|
|
|
# Example usage:
|
|
pairs = generator.get_pairs(original_path,
|
|
output_txt=os.sep.join([parent_dir, 'extra_cross_same.txt'])) # Individual pairs
|
|
# groups = generator.get_group_pairs('val') # Group pairs
|