更改
This commit is contained in:
@ -8,13 +8,13 @@ base:
|
|||||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
embedding_size: 256 # 特征维度
|
embedding_size: 256 # 特征维度
|
||||||
pin_memory: true # 是否启用pin_memory
|
pin_memory: true # 是否启用pin_memory
|
||||||
distributed: false # 是否启用分布式训练
|
distributed: true # 是否启用分布式训练
|
||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
models:
|
models:
|
||||||
backbone: 'resnet18'
|
backbone: 'resnet18'
|
||||||
channel_ratio: 0.75
|
channel_ratio: 0.75
|
||||||
model_path: "./checkpoints/resnet18_0515/best.pth"
|
model_path: "./checkpoints/resnet18_1009/best.pth"
|
||||||
half: false # 是否启用半精度测试(fp16)
|
half: false # 是否启用半精度测试(fp16)
|
||||||
|
|
||||||
# 数据配置
|
# 数据配置
|
||||||
@ -24,7 +24,7 @@ data:
|
|||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
|
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
|
||||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||||
test_list: "../data_center/contrast_learning/contrast_test_data/test_pair.txt"
|
test_list: "../data_center/contrast_learning/contrast_test_data/cross_same.txt"
|
||||||
|
|
||||||
transform:
|
transform:
|
||||||
img_size: 224 # 图像尺寸
|
img_size: 224 # 图像尺寸
|
||||||
|
@ -5,6 +5,7 @@ from pathlib import Path
|
|||||||
from typing import List, Tuple, Dict, Optional
|
from typing import List, Tuple, Dict, Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class PairGenerator:
|
class PairGenerator:
|
||||||
"""Generate positive and negative image pairs for contrastive learning."""
|
"""Generate positive and negative image pairs for contrastive learning."""
|
||||||
|
|
||||||
@ -31,10 +32,11 @@ class PairGenerator:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _generate_same_pairs(
|
def _generate_same_pairs(
|
||||||
self,
|
self,
|
||||||
files_dict: Dict[str, List[str]],
|
files_dict: Dict[str, List[str]],
|
||||||
num_pairs: int,
|
num_pairs: int,
|
||||||
group_size: Optional[int] = None
|
min_size: int, # min_size is the minimum number of images per folder
|
||||||
|
group_size: Optional[int] = None
|
||||||
) -> List[Tuple[str, str, int]]:
|
) -> List[Tuple[str, str, int]]:
|
||||||
"""Generate positive pairs from same folder."""
|
"""Generate positive pairs from same folder."""
|
||||||
pairs = []
|
pairs = []
|
||||||
@ -46,16 +48,16 @@ class PairGenerator:
|
|||||||
if group_size:
|
if group_size:
|
||||||
# Group mode: generate all possible pairs within group
|
# Group mode: generate all possible pairs within group
|
||||||
for i in range(0, len(files), group_size):
|
for i in range(0, len(files), group_size):
|
||||||
group = files[i:i+group_size]
|
group = files[i:i + group_size]
|
||||||
pairs.extend([
|
pairs.extend([
|
||||||
(group[i], group[j], 1)
|
(group[i], group[j], 1)
|
||||||
for i in range(len(group))
|
for i in range(len(group))
|
||||||
for j in range(i+1, len(group))
|
for j in range(i + 1, len(group))
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
# Individual mode: random pairs
|
# Individual mode: random pairs
|
||||||
try:
|
try:
|
||||||
pairs.extend(self._random_pairs(files, min(3, len(files)//2)))
|
pairs.extend(self._random_pairs(files, min(min_size, len(files) // 2)))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
|
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
|
||||||
|
|
||||||
@ -63,60 +65,82 @@ class PairGenerator:
|
|||||||
return pairs[:num_pairs]
|
return pairs[:num_pairs]
|
||||||
|
|
||||||
def _generate_cross_pairs(
|
def _generate_cross_pairs(
|
||||||
self,
|
self,
|
||||||
files_dict: Dict[str, List[str]],
|
files_dict: Dict[str, List[str]],
|
||||||
num_pairs: int
|
num_pairs: int
|
||||||
) -> List[Tuple[str, str, int]]:
|
) -> List[Tuple[str, str, int]]:
|
||||||
"""Generate negative pairs from different folders."""
|
"""Generate negative pairs from different folders."""
|
||||||
folders = list(files_dict.keys())
|
folders = list(files_dict.keys())
|
||||||
pairs = []
|
pairs = []
|
||||||
|
existing_pairs = set()
|
||||||
|
|
||||||
while len(pairs) < num_pairs and len(folders) >= 2:
|
while len(pairs) < num_pairs and len(folders) >= 2:
|
||||||
folder1, folder2 = random.sample(folders, 2)
|
folder1, folder2 = random.sample(folders, 2)
|
||||||
file1 = random.choice(files_dict[folder1])
|
file1 = random.choice(files_dict[folder1])
|
||||||
file2 = random.choice(files_dict[folder2])
|
file2 = random.choice(files_dict[folder2])
|
||||||
|
|
||||||
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1)
|
pair_key = (file1, file2)
|
||||||
for f1, f2, _ in pairs):
|
reverse_key = (file2, file1)
|
||||||
|
|
||||||
|
if pair_key not in existing_pairs and reverse_key not in existing_pairs:
|
||||||
pairs.append((file1, file2, 0))
|
pairs.append((file1, file2, 0))
|
||||||
|
existing_pairs.add(pair_key)
|
||||||
|
existing_pairs.add(reverse_key)
|
||||||
|
|
||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
|
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
|
||||||
"""Generate random pairs from file list."""
|
"""Generate random pairs from file list."""
|
||||||
if len(files) < 2 * num_pairs:
|
max_possible = len(files) // 2
|
||||||
raise ValueError("Not enough files for requested pairs")
|
if max_possible == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
indices = random.sample(range(len(files)), 2 * num_pairs)
|
actual_pairs = min(num_pairs, max_possible)
|
||||||
|
indices = random.sample(range(len(files)), 2 * actual_pairs)
|
||||||
indices.sort()
|
indices.sort()
|
||||||
return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)]
|
return [(files[i], files[i + 1], 1) for i in range(0, len(indices), 2)]
|
||||||
|
|
||||||
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]:
|
def get_pairs(self,
|
||||||
|
root_dir: str,
|
||||||
|
num_pairs: int = 6000,
|
||||||
|
output_txt: Optional[str] = None
|
||||||
|
) -> List[Tuple[str, str, int]]:
|
||||||
"""
|
"""
|
||||||
Generate individual image pairs with labels (1=same, 0=different).
|
Generate individual image pairs with labels (1=same, 0=different).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
root_dir: Directory containing subfolders of images
|
root_dir: Directory containing subfolders of images
|
||||||
num_pairs: Number of pairs to generate
|
num_pairs: Number of pairs to generate
|
||||||
|
output_txt: Optional path to save pairs as txt file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (path1, path2, label) tuples
|
List of (path1, path2, label) tuples
|
||||||
"""
|
"""
|
||||||
files_dict = self._get_image_files(root_dir)
|
files_dict = self._get_image_files(root_dir)
|
||||||
|
|
||||||
same_pairs = self._generate_same_pairs(files_dict, num_pairs)
|
same_pairs = self._generate_same_pairs(files_dict, num_pairs, min_size=30)
|
||||||
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
|
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
|
||||||
|
|
||||||
pairs = same_pairs + cross_pairs
|
pairs = same_pairs + cross_pairs
|
||||||
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
|
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
|
||||||
|
|
||||||
|
if output_txt:
|
||||||
|
try:
|
||||||
|
with open(output_txt, 'w') as f:
|
||||||
|
for file1, file2, label in pairs:
|
||||||
|
f.write(f"{file1} {file2} {label}\n")
|
||||||
|
self.logger.info(f"Saved pairs to {output_txt}")
|
||||||
|
except IOError as e:
|
||||||
|
self.logger.warning(f"Failed to write pairs to {output_txt}: {str(e)}")
|
||||||
|
|
||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
def get_group_pairs(
|
def get_group_pairs(
|
||||||
self,
|
self,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
img_num: int = 20,
|
img_num: int = 20,
|
||||||
group_num: int = 10,
|
group_num: int = 10,
|
||||||
num_pairs: int = 5000
|
num_pairs: int = 5000
|
||||||
) -> List[Tuple[str, str, int]]:
|
) -> List[Tuple[str, str, int]]:
|
||||||
"""
|
"""
|
||||||
Generate grouped image pairs with labels (1=same, 0=different).
|
Generate grouped image pairs with labels (1=same, 0=different).
|
||||||
@ -141,13 +165,13 @@ class PairGenerator:
|
|||||||
for folder, files in files_dict.items():
|
for folder, files in files_dict.items():
|
||||||
random.shuffle(files)
|
random.shuffle(files)
|
||||||
grouped_files[folder] = [
|
grouped_files[folder] = [
|
||||||
files[i:i+group_num]
|
files[i:i + group_num]
|
||||||
for i in range(0, len(files), group_num)
|
for i in range(0, len(files), group_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Generate pairs
|
# Generate pairs
|
||||||
same_pairs = self._generate_same_pairs(
|
same_pairs = self._generate_same_pairs(
|
||||||
grouped_files, num_pairs, group_size=group_num
|
grouped_files, num_pairs, min_size=30, group_size=group_num
|
||||||
)
|
)
|
||||||
cross_pairs = self._generate_cross_pairs(
|
cross_pairs = self._generate_cross_pairs(
|
||||||
grouped_files, len(same_pairs)
|
grouped_files, len(same_pairs)
|
||||||
@ -164,8 +188,11 @@ class PairGenerator:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
original_path = '/home/lc/data_center/contrast_learning/contrast_test_data/test'
|
||||||
|
parent_dir = str(Path(original_path).parent)
|
||||||
generator = PairGenerator()
|
generator = PairGenerator()
|
||||||
|
|
||||||
# Example usage:
|
# Example usage:
|
||||||
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs
|
pairs = generator.get_pairs(original_path,
|
||||||
|
output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
|
||||||
# groups = generator.get_group_pairs('val') # Group pairs
|
# groups = generator.get_group_pairs('val') # Group pairs
|
||||||
|
Reference in New Issue
Block a user