diff --git a/configs/test.yml b/configs/test.yml index 75649a1..0d94973 100644 --- a/configs/test.yml +++ b/configs/test.yml @@ -8,13 +8,13 @@ base: log_level: "info" # 日志级别(debug/info/warning/error) embedding_size: 256 # 特征维度 pin_memory: true # 是否启用pin_memory - distributed: false # 是否启用分布式训练 + distributed: true # 是否启用分布式训练 # 模型配置 models: backbone: 'resnet18' channel_ratio: 0.75 - model_path: "./checkpoints/resnet18_0515/best.pth" + model_path: "./checkpoints/resnet18_1009/best.pth" half: false # 是否启用半精度测试(fp16) # 数据配置 @@ -24,7 +24,7 @@ data: num_workers: 32 # 数据加载线程数 test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录 test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" - test_list: "../data_center/contrast_learning/contrast_test_data/test_pair.txt" + test_list: "../data_center/contrast_learning/contrast_test_data/cross_same.txt" transform: img_size: 224 # 图像尺寸 diff --git a/tools/getpairs.py b/tools/getpairs.py index 2e16f47..b35eb88 100644 --- a/tools/getpairs.py +++ b/tools/getpairs.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import List, Tuple, Dict, Optional import logging + class PairGenerator: """Generate positive and negative image pairs for contrastive learning.""" @@ -31,10 +32,11 @@ class PairGenerator: } def _generate_same_pairs( - self, - files_dict: Dict[str, List[str]], - num_pairs: int, - group_size: Optional[int] = None + self, + files_dict: Dict[str, List[str]], + num_pairs: int, + min_size: int, # min_size is the minimum number of images per folder + group_size: Optional[int] = None ) -> List[Tuple[str, str, int]]: """Generate positive pairs from same folder.""" pairs = [] @@ -46,16 +48,16 @@ class PairGenerator: if group_size: # Group mode: generate all possible pairs within group for i in range(0, len(files), group_size): - group = files[i:i+group_size] + group = files[i:i + group_size] pairs.extend([ (group[i], group[j], 1) for i in range(len(group)) - for j in range(i+1, len(group)) + for j in range(i + 1, len(group)) ]) else: # Individual mode: random pairs try: - pairs.extend(self._random_pairs(files, min(3, len(files)//2))) + pairs.extend(self._random_pairs(files, min(min_size, len(files) // 2))) except ValueError as e: self.logger.warning(f"Skipping folder {folder}: {str(e)}") @@ -63,60 +65,82 @@ class PairGenerator: return pairs[:num_pairs] def _generate_cross_pairs( - self, - files_dict: Dict[str, List[str]], - num_pairs: int + self, + files_dict: Dict[str, List[str]], + num_pairs: int ) -> List[Tuple[str, str, int]]: """Generate negative pairs from different folders.""" folders = list(files_dict.keys()) pairs = [] + existing_pairs = set() while len(pairs) < num_pairs and len(folders) >= 2: folder1, folder2 = random.sample(folders, 2) file1 = random.choice(files_dict[folder1]) file2 = random.choice(files_dict[folder2]) - if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1) - for f1, f2, _ in pairs): + pair_key = (file1, file2) + reverse_key = (file2, file1) + + if pair_key not in existing_pairs and reverse_key not in existing_pairs: pairs.append((file1, file2, 0)) + existing_pairs.add(pair_key) + existing_pairs.add(reverse_key) return pairs def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]: """Generate random pairs from file list.""" - if len(files) < 2 * num_pairs: - raise ValueError("Not enough files for requested pairs") + max_possible = len(files) // 2 + if max_possible == 0: + return [] - indices = random.sample(range(len(files)), 2 * num_pairs) + actual_pairs = min(num_pairs, max_possible) + indices = random.sample(range(len(files)), 2 * actual_pairs) indices.sort() - return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)] + return [(files[i], files[i + 1], 1) for i in range(0, len(indices), 2)] - def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]: + def get_pairs(self, + root_dir: str, + num_pairs: int = 6000, + output_txt: Optional[str] = None + ) -> List[Tuple[str, str, int]]: """ Generate individual image pairs with labels (1=same, 0=different). Args: root_dir: Directory containing subfolders of images num_pairs: Number of pairs to generate + output_txt: Optional path to save pairs as txt file Returns: List of (path1, path2, label) tuples """ files_dict = self._get_image_files(root_dir) - same_pairs = self._generate_same_pairs(files_dict, num_pairs) + same_pairs = self._generate_same_pairs(files_dict, num_pairs, min_size=30) cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs)) pairs = same_pairs + cross_pairs self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)") + + if output_txt: + try: + with open(output_txt, 'w') as f: + for file1, file2, label in pairs: + f.write(f"{file1} {file2} {label}\n") + self.logger.info(f"Saved pairs to {output_txt}") + except IOError as e: + self.logger.warning(f"Failed to write pairs to {output_txt}: {str(e)}") + return pairs def get_group_pairs( - self, - root_dir: str, - img_num: int = 20, - group_num: int = 10, - num_pairs: int = 5000 + self, + root_dir: str, + img_num: int = 20, + group_num: int = 10, + num_pairs: int = 5000 ) -> List[Tuple[str, str, int]]: """ Generate grouped image pairs with labels (1=same, 0=different). @@ -141,13 +165,13 @@ class PairGenerator: for folder, files in files_dict.items(): random.shuffle(files) grouped_files[folder] = [ - files[i:i+group_num] + files[i:i + group_num] for i in range(0, len(files), group_num) ] # Generate pairs same_pairs = self._generate_same_pairs( - grouped_files, num_pairs, group_size=group_num + grouped_files, num_pairs, min_size=30, group_size=group_num ) cross_pairs = self._generate_cross_pairs( grouped_files, len(same_pairs) @@ -164,8 +188,11 @@ class PairGenerator: if __name__ == "__main__": + original_path = '/home/lc/data_center/contrast_learning/contrast_test_data/test' + parent_dir = str(Path(original_path).parent) generator = PairGenerator() # Example usage: - pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs + pairs = generator.get_pairs(original_path, + output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs # groups = generator.get_group_pairs('val') # Group pairs