This commit is contained in:
lee
2025-06-13 13:22:41 +08:00
parent e27e6c3d5b
commit 180a41ae90
2 changed files with 56 additions and 29 deletions

View File

@ -8,13 +8,13 @@ base:
log_level: "info" # 日志级别debug/info/warning/error log_level: "info" # 日志级别debug/info/warning/error
embedding_size: 256 # 特征维度 embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory pin_memory: true # 是否启用pin_memory
distributed: false # 是否启用分布式训练 distributed: true # 是否启用分布式训练
# 模型配置 # 模型配置
models: models:
backbone: 'resnet18' backbone: 'resnet18'
channel_ratio: 0.75 channel_ratio: 0.75
model_path: "./checkpoints/resnet18_0515/best.pth" model_path: "./checkpoints/resnet18_1009/best.pth"
half: false # 是否启用半精度测试fp16 half: false # 是否启用半精度测试fp16
# 数据配置 # 数据配置
@ -24,7 +24,7 @@ data:
num_workers: 32 # 数据加载线程数 num_workers: 32 # 数据加载线程数
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录 test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
test_list: "../data_center/contrast_learning/contrast_test_data/test_pair.txt" test_list: "../data_center/contrast_learning/contrast_test_data/cross_same.txt"
transform: transform:
img_size: 224 # 图像尺寸 img_size: 224 # 图像尺寸

View File

@ -5,6 +5,7 @@ from pathlib import Path
from typing import List, Tuple, Dict, Optional from typing import List, Tuple, Dict, Optional
import logging import logging
class PairGenerator: class PairGenerator:
"""Generate positive and negative image pairs for contrastive learning.""" """Generate positive and negative image pairs for contrastive learning."""
@ -31,10 +32,11 @@ class PairGenerator:
} }
def _generate_same_pairs( def _generate_same_pairs(
self, self,
files_dict: Dict[str, List[str]], files_dict: Dict[str, List[str]],
num_pairs: int, num_pairs: int,
group_size: Optional[int] = None min_size: int, # min_size is the minimum number of images per folder
group_size: Optional[int] = None
) -> List[Tuple[str, str, int]]: ) -> List[Tuple[str, str, int]]:
"""Generate positive pairs from same folder.""" """Generate positive pairs from same folder."""
pairs = [] pairs = []
@ -46,16 +48,16 @@ class PairGenerator:
if group_size: if group_size:
# Group mode: generate all possible pairs within group # Group mode: generate all possible pairs within group
for i in range(0, len(files), group_size): for i in range(0, len(files), group_size):
group = files[i:i+group_size] group = files[i:i + group_size]
pairs.extend([ pairs.extend([
(group[i], group[j], 1) (group[i], group[j], 1)
for i in range(len(group)) for i in range(len(group))
for j in range(i+1, len(group)) for j in range(i + 1, len(group))
]) ])
else: else:
# Individual mode: random pairs # Individual mode: random pairs
try: try:
pairs.extend(self._random_pairs(files, min(3, len(files)//2))) pairs.extend(self._random_pairs(files, min(min_size, len(files) // 2)))
except ValueError as e: except ValueError as e:
self.logger.warning(f"Skipping folder {folder}: {str(e)}") self.logger.warning(f"Skipping folder {folder}: {str(e)}")
@ -63,60 +65,82 @@ class PairGenerator:
return pairs[:num_pairs] return pairs[:num_pairs]
def _generate_cross_pairs( def _generate_cross_pairs(
self, self,
files_dict: Dict[str, List[str]], files_dict: Dict[str, List[str]],
num_pairs: int num_pairs: int
) -> List[Tuple[str, str, int]]: ) -> List[Tuple[str, str, int]]:
"""Generate negative pairs from different folders.""" """Generate negative pairs from different folders."""
folders = list(files_dict.keys()) folders = list(files_dict.keys())
pairs = [] pairs = []
existing_pairs = set()
while len(pairs) < num_pairs and len(folders) >= 2: while len(pairs) < num_pairs and len(folders) >= 2:
folder1, folder2 = random.sample(folders, 2) folder1, folder2 = random.sample(folders, 2)
file1 = random.choice(files_dict[folder1]) file1 = random.choice(files_dict[folder1])
file2 = random.choice(files_dict[folder2]) file2 = random.choice(files_dict[folder2])
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1) pair_key = (file1, file2)
for f1, f2, _ in pairs): reverse_key = (file2, file1)
if pair_key not in existing_pairs and reverse_key not in existing_pairs:
pairs.append((file1, file2, 0)) pairs.append((file1, file2, 0))
existing_pairs.add(pair_key)
existing_pairs.add(reverse_key)
return pairs return pairs
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]: def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
"""Generate random pairs from file list.""" """Generate random pairs from file list."""
if len(files) < 2 * num_pairs: max_possible = len(files) // 2
raise ValueError("Not enough files for requested pairs") if max_possible == 0:
return []
indices = random.sample(range(len(files)), 2 * num_pairs) actual_pairs = min(num_pairs, max_possible)
indices = random.sample(range(len(files)), 2 * actual_pairs)
indices.sort() indices.sort()
return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)] return [(files[i], files[i + 1], 1) for i in range(0, len(indices), 2)]
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]: def get_pairs(self,
root_dir: str,
num_pairs: int = 6000,
output_txt: Optional[str] = None
) -> List[Tuple[str, str, int]]:
""" """
Generate individual image pairs with labels (1=same, 0=different). Generate individual image pairs with labels (1=same, 0=different).
Args: Args:
root_dir: Directory containing subfolders of images root_dir: Directory containing subfolders of images
num_pairs: Number of pairs to generate num_pairs: Number of pairs to generate
output_txt: Optional path to save pairs as txt file
Returns: Returns:
List of (path1, path2, label) tuples List of (path1, path2, label) tuples
""" """
files_dict = self._get_image_files(root_dir) files_dict = self._get_image_files(root_dir)
same_pairs = self._generate_same_pairs(files_dict, num_pairs) same_pairs = self._generate_same_pairs(files_dict, num_pairs, min_size=30)
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs)) cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
pairs = same_pairs + cross_pairs pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)") self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
if output_txt:
try:
with open(output_txt, 'w') as f:
for file1, file2, label in pairs:
f.write(f"{file1} {file2} {label}\n")
self.logger.info(f"Saved pairs to {output_txt}")
except IOError as e:
self.logger.warning(f"Failed to write pairs to {output_txt}: {str(e)}")
return pairs return pairs
def get_group_pairs( def get_group_pairs(
self, self,
root_dir: str, root_dir: str,
img_num: int = 20, img_num: int = 20,
group_num: int = 10, group_num: int = 10,
num_pairs: int = 5000 num_pairs: int = 5000
) -> List[Tuple[str, str, int]]: ) -> List[Tuple[str, str, int]]:
""" """
Generate grouped image pairs with labels (1=same, 0=different). Generate grouped image pairs with labels (1=same, 0=different).
@ -141,13 +165,13 @@ class PairGenerator:
for folder, files in files_dict.items(): for folder, files in files_dict.items():
random.shuffle(files) random.shuffle(files)
grouped_files[folder] = [ grouped_files[folder] = [
files[i:i+group_num] files[i:i + group_num]
for i in range(0, len(files), group_num) for i in range(0, len(files), group_num)
] ]
# Generate pairs # Generate pairs
same_pairs = self._generate_same_pairs( same_pairs = self._generate_same_pairs(
grouped_files, num_pairs, group_size=group_num grouped_files, num_pairs, min_size=30, group_size=group_num
) )
cross_pairs = self._generate_cross_pairs( cross_pairs = self._generate_cross_pairs(
grouped_files, len(same_pairs) grouped_files, len(same_pairs)
@ -164,8 +188,11 @@ class PairGenerator:
if __name__ == "__main__": if __name__ == "__main__":
original_path = '/home/lc/data_center/contrast_learning/contrast_test_data/test'
parent_dir = str(Path(original_path).parent)
generator = PairGenerator() generator = PairGenerator()
# Example usage: # Example usage:
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs pairs = generator.get_pairs(original_path,
output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
# groups = generator.get_group_pairs('val') # Group pairs # groups = generator.get_group_pairs('val') # Group pairs