This commit is contained in:
lee
2025-06-13 13:22:41 +08:00
parent e27e6c3d5b
commit 180a41ae90
2 changed files with 56 additions and 29 deletions

View File

@ -8,13 +8,13 @@ base:
log_level: "info" # 日志级别debug/info/warning/error
embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory
distributed: false # 是否启用分布式训练
distributed: true # 是否启用分布式训练
# 模型配置
models:
backbone: 'resnet18'
channel_ratio: 0.75
model_path: "./checkpoints/resnet18_0515/best.pth"
model_path: "./checkpoints/resnet18_1009/best.pth"
half: false # 是否启用半精度测试fp16
# 数据配置
@ -24,7 +24,7 @@ data:
num_workers: 32 # 数据加载线程数
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
test_list: "../data_center/contrast_learning/contrast_test_data/test_pair.txt"
test_list: "../data_center/contrast_learning/contrast_test_data/cross_same.txt"
transform:
img_size: 224 # 图像尺寸

View File

@ -5,6 +5,7 @@ from pathlib import Path
from typing import List, Tuple, Dict, Optional
import logging
class PairGenerator:
"""Generate positive and negative image pairs for contrastive learning."""
@ -34,6 +35,7 @@ class PairGenerator:
self,
files_dict: Dict[str, List[str]],
num_pairs: int,
min_size: int, # min_size is the minimum number of images per folder
group_size: Optional[int] = None
) -> List[Tuple[str, str, int]]:
"""Generate positive pairs from same folder."""
@ -55,7 +57,7 @@ class PairGenerator:
else:
# Individual mode: random pairs
try:
pairs.extend(self._random_pairs(files, min(3, len(files)//2)))
pairs.extend(self._random_pairs(files, min(min_size, len(files) // 2)))
except ValueError as e:
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
@ -70,45 +72,67 @@ class PairGenerator:
"""Generate negative pairs from different folders."""
folders = list(files_dict.keys())
pairs = []
existing_pairs = set()
while len(pairs) < num_pairs and len(folders) >= 2:
folder1, folder2 = random.sample(folders, 2)
file1 = random.choice(files_dict[folder1])
file2 = random.choice(files_dict[folder2])
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1)
for f1, f2, _ in pairs):
pair_key = (file1, file2)
reverse_key = (file2, file1)
if pair_key not in existing_pairs and reverse_key not in existing_pairs:
pairs.append((file1, file2, 0))
existing_pairs.add(pair_key)
existing_pairs.add(reverse_key)
return pairs
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
"""Generate random pairs from file list."""
if len(files) < 2 * num_pairs:
raise ValueError("Not enough files for requested pairs")
max_possible = len(files) // 2
if max_possible == 0:
return []
indices = random.sample(range(len(files)), 2 * num_pairs)
actual_pairs = min(num_pairs, max_possible)
indices = random.sample(range(len(files)), 2 * actual_pairs)
indices.sort()
return [(files[i], files[i + 1], 1) for i in range(0, len(indices), 2)]
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]:
def get_pairs(self,
root_dir: str,
num_pairs: int = 6000,
output_txt: Optional[str] = None
) -> List[Tuple[str, str, int]]:
"""
Generate individual image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
num_pairs: Number of pairs to generate
output_txt: Optional path to save pairs as txt file
Returns:
List of (path1, path2, label) tuples
"""
files_dict = self._get_image_files(root_dir)
same_pairs = self._generate_same_pairs(files_dict, num_pairs)
same_pairs = self._generate_same_pairs(files_dict, num_pairs, min_size=30)
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
if output_txt:
try:
with open(output_txt, 'w') as f:
for file1, file2, label in pairs:
f.write(f"{file1} {file2} {label}\n")
self.logger.info(f"Saved pairs to {output_txt}")
except IOError as e:
self.logger.warning(f"Failed to write pairs to {output_txt}: {str(e)}")
return pairs
def get_group_pairs(
@ -147,7 +171,7 @@ class PairGenerator:
# Generate pairs
same_pairs = self._generate_same_pairs(
grouped_files, num_pairs, group_size=group_num
grouped_files, num_pairs, min_size=30, group_size=group_num
)
cross_pairs = self._generate_cross_pairs(
grouped_files, len(same_pairs)
@ -164,8 +188,11 @@ class PairGenerator:
if __name__ == "__main__":
original_path = '/home/lc/data_center/contrast_learning/contrast_test_data/test'
parent_dir = str(Path(original_path).parent)
generator = PairGenerator()
# Example usage:
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs
pairs = generator.get_pairs(original_path,
output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
# groups = generator.get_group_pairs('val') # Group pairs