更改
This commit is contained in:
@ -8,13 +8,13 @@ base:
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: false # 是否启用分布式训练
|
||||
distributed: true # 是否启用分布式训练
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
model_path: "./checkpoints/resnet18_0515/best.pth"
|
||||
model_path: "./checkpoints/resnet18_1009/best.pth"
|
||||
half: false # 是否启用半精度测试(fp16)
|
||||
|
||||
# 数据配置
|
||||
@ -24,7 +24,7 @@ data:
|
||||
num_workers: 32 # 数据加载线程数
|
||||
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
|
||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||
test_list: "../data_center/contrast_learning/contrast_test_data/test_pair.txt"
|
||||
test_list: "../data_center/contrast_learning/contrast_test_data/cross_same.txt"
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
|
@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
import logging
|
||||
|
||||
|
||||
class PairGenerator:
|
||||
"""Generate positive and negative image pairs for contrastive learning."""
|
||||
|
||||
@ -34,6 +35,7 @@ class PairGenerator:
|
||||
self,
|
||||
files_dict: Dict[str, List[str]],
|
||||
num_pairs: int,
|
||||
min_size: int, # min_size is the minimum number of images per folder
|
||||
group_size: Optional[int] = None
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""Generate positive pairs from same folder."""
|
||||
@ -46,16 +48,16 @@ class PairGenerator:
|
||||
if group_size:
|
||||
# Group mode: generate all possible pairs within group
|
||||
for i in range(0, len(files), group_size):
|
||||
group = files[i:i+group_size]
|
||||
group = files[i:i + group_size]
|
||||
pairs.extend([
|
||||
(group[i], group[j], 1)
|
||||
for i in range(len(group))
|
||||
for j in range(i+1, len(group))
|
||||
for j in range(i + 1, len(group))
|
||||
])
|
||||
else:
|
||||
# Individual mode: random pairs
|
||||
try:
|
||||
pairs.extend(self._random_pairs(files, min(3, len(files)//2)))
|
||||
pairs.extend(self._random_pairs(files, min(min_size, len(files) // 2)))
|
||||
except ValueError as e:
|
||||
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
|
||||
|
||||
@ -70,45 +72,67 @@ class PairGenerator:
|
||||
"""Generate negative pairs from different folders."""
|
||||
folders = list(files_dict.keys())
|
||||
pairs = []
|
||||
existing_pairs = set()
|
||||
|
||||
while len(pairs) < num_pairs and len(folders) >= 2:
|
||||
folder1, folder2 = random.sample(folders, 2)
|
||||
file1 = random.choice(files_dict[folder1])
|
||||
file2 = random.choice(files_dict[folder2])
|
||||
|
||||
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1)
|
||||
for f1, f2, _ in pairs):
|
||||
pair_key = (file1, file2)
|
||||
reverse_key = (file2, file1)
|
||||
|
||||
if pair_key not in existing_pairs and reverse_key not in existing_pairs:
|
||||
pairs.append((file1, file2, 0))
|
||||
existing_pairs.add(pair_key)
|
||||
existing_pairs.add(reverse_key)
|
||||
|
||||
return pairs
|
||||
|
||||
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
|
||||
"""Generate random pairs from file list."""
|
||||
if len(files) < 2 * num_pairs:
|
||||
raise ValueError("Not enough files for requested pairs")
|
||||
max_possible = len(files) // 2
|
||||
if max_possible == 0:
|
||||
return []
|
||||
|
||||
indices = random.sample(range(len(files)), 2 * num_pairs)
|
||||
actual_pairs = min(num_pairs, max_possible)
|
||||
indices = random.sample(range(len(files)), 2 * actual_pairs)
|
||||
indices.sort()
|
||||
return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)]
|
||||
return [(files[i], files[i + 1], 1) for i in range(0, len(indices), 2)]
|
||||
|
||||
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]:
|
||||
def get_pairs(self,
|
||||
root_dir: str,
|
||||
num_pairs: int = 6000,
|
||||
output_txt: Optional[str] = None
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""
|
||||
Generate individual image pairs with labels (1=same, 0=different).
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing subfolders of images
|
||||
num_pairs: Number of pairs to generate
|
||||
output_txt: Optional path to save pairs as txt file
|
||||
|
||||
Returns:
|
||||
List of (path1, path2, label) tuples
|
||||
"""
|
||||
files_dict = self._get_image_files(root_dir)
|
||||
|
||||
same_pairs = self._generate_same_pairs(files_dict, num_pairs)
|
||||
same_pairs = self._generate_same_pairs(files_dict, num_pairs, min_size=30)
|
||||
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
|
||||
|
||||
pairs = same_pairs + cross_pairs
|
||||
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
|
||||
|
||||
if output_txt:
|
||||
try:
|
||||
with open(output_txt, 'w') as f:
|
||||
for file1, file2, label in pairs:
|
||||
f.write(f"{file1} {file2} {label}\n")
|
||||
self.logger.info(f"Saved pairs to {output_txt}")
|
||||
except IOError as e:
|
||||
self.logger.warning(f"Failed to write pairs to {output_txt}: {str(e)}")
|
||||
|
||||
return pairs
|
||||
|
||||
def get_group_pairs(
|
||||
@ -141,13 +165,13 @@ class PairGenerator:
|
||||
for folder, files in files_dict.items():
|
||||
random.shuffle(files)
|
||||
grouped_files[folder] = [
|
||||
files[i:i+group_num]
|
||||
files[i:i + group_num]
|
||||
for i in range(0, len(files), group_num)
|
||||
]
|
||||
|
||||
# Generate pairs
|
||||
same_pairs = self._generate_same_pairs(
|
||||
grouped_files, num_pairs, group_size=group_num
|
||||
grouped_files, num_pairs, min_size=30, group_size=group_num
|
||||
)
|
||||
cross_pairs = self._generate_cross_pairs(
|
||||
grouped_files, len(same_pairs)
|
||||
@ -164,8 +188,11 @@ class PairGenerator:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
original_path = '/home/lc/data_center/contrast_learning/contrast_test_data/test'
|
||||
parent_dir = str(Path(original_path).parent)
|
||||
generator = PairGenerator()
|
||||
|
||||
# Example usage:
|
||||
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs
|
||||
pairs = generator.get_pairs(original_path,
|
||||
output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
|
||||
# groups = generator.get_group_pairs('val') # Group pairs
|
||||
|
Reference in New Issue
Block a user