增加学习率调度方式
This commit is contained in:
265
.idea/CopilotChatHistory.xml
generated
265
.idea/CopilotChatHistory.xml
generated
@ -3,6 +3,18 @@
|
||||
<component name="CopilotChatHistory">
|
||||
<option name="conversations">
|
||||
<list>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1749718122230" />
|
||||
<option name="id" value="01976353bef6703884544447c919013c" />
|
||||
<option name="title" value="新对话 2025年6月12日 16:48:42" />
|
||||
<option name="updateTime" value="1749718122230" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1749648208122" />
|
||||
<option name="id" value="01975f28f0fa7128afe7feddcdedb740" />
|
||||
<option name="title" value="新对话 2025年6月11日 21:23:28" />
|
||||
<option name="updateTime" value="1749648208122" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1749522765718" />
|
||||
<option name="id" value="019757aed78e777c96c4b7007ff2fecc" />
|
||||
@ -57,16 +69,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -91,16 +94,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
</list>
|
||||
@ -135,16 +129,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -169,16 +154,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -203,16 +179,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -237,16 +204,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -271,16 +229,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -305,16 +254,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -339,16 +279,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -373,16 +304,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -407,16 +329,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -441,16 +354,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -475,16 +379,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -509,16 +404,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -543,16 +429,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -577,16 +454,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -611,16 +479,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -645,16 +504,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -679,16 +529,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -713,16 +554,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -773,16 +605,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -807,16 +630,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -841,16 +655,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
</list>
|
||||
|
@ -15,8 +15,8 @@ base:
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
backbone: 'resnet34'
|
||||
channel_ratio: 1.0
|
||||
|
||||
# 训练参数
|
||||
training:
|
||||
@ -29,11 +29,14 @@ training:
|
||||
lr_step: 10 # 学习率调整间隔(epoch)
|
||||
lr_decay: 0.98 # 学习率衰减率
|
||||
weight_decay: 0.0005 # 权重衰减
|
||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||
scheduler: "cosine" # 学习率调度器(可选:cosine/cosine_warm/step/None)
|
||||
num_workers: 32 # 数据加载线程数
|
||||
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
|
||||
checkpoints: "./checkpoints/resnet34_20250612_scale=1.0/" # 模型保存目录
|
||||
restore: false
|
||||
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
|
||||
cosine_t_0: 10 # 初始周期长度
|
||||
cosine_t_mult: 1 # 周期长度倍率
|
||||
cosine_eta_min: 0.00001 # 最小学习率
|
||||
|
||||
# 验证参数
|
||||
validation:
|
||||
|
@ -8,13 +8,13 @@ base:
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
distributed: false # 是否启用分布式训练
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 1.0
|
||||
model_path: "./checkpoints/resnet18_scatter_6.2/best.pth"
|
||||
channel_ratio: 0.75
|
||||
model_path: "./checkpoints/resnet18_0515/best.pth"
|
||||
half: false # 是否启用半精度测试(fp16)
|
||||
|
||||
# 数据配置
|
||||
@ -22,9 +22,9 @@ data:
|
||||
group_test: False # 数据集名称(示例用,可替换为实际数据集)
|
||||
test_batch_size: 128 # 训练批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
test_dir: "../data_center/scatter/" # 验证数据集根目录
|
||||
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
|
||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||
test_list: "../data_center/scatter/val_pair.txt"
|
||||
test_list: "../data_center/contrast_learning/contrast_test_data/test_pair.txt"
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
|
27
configs/transform.yml
Normal file
27
configs/transform.yml
Normal file
@ -0,0 +1,27 @@
|
||||
# configs/transform.yml
|
||||
# pth转换onnx配置文件
|
||||
|
||||
# 基础配置
|
||||
base:
|
||||
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||
seed: 42 # 随机种子(保证可复现性)
|
||||
device: "cuda" # 训练设备(cuda/cpu)
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet50'
|
||||
channel_ratio: 1.0
|
||||
model_path: "../checkpoints/resnet50_0519/best.pth"
|
||||
onnx_model: "../checkpoints/resnet50_0519/best.onnx"
|
||||
rknn_model: "../checkpoints/resnet50_0519/best.rknn"
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./logs" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
@ -1,4 +1,4 @@
|
||||
from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||||
from model import (resnet18, resnet34, resnet50, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||||
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||||
from timm.models import vit_base_patch16_224 as vit_base_16
|
||||
from model.metric import ArcFace, CosFace
|
||||
@ -14,6 +14,8 @@ class trainer_tools:
|
||||
def get_backbone(self):
|
||||
backbone_mapping = {
|
||||
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
|
||||
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio']),
|
||||
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio']),
|
||||
'mobilevit_s': lambda: mobilevit_s(),
|
||||
'mobilenetv3_small': lambda: MobileNetV3_Small(),
|
||||
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
|
||||
@ -54,3 +56,24 @@ class trainer_tools:
|
||||
)
|
||||
}
|
||||
return optimizer_mapping
|
||||
|
||||
def get_scheduler(self, optimizer):
|
||||
scheduler_mapping = {
|
||||
'step': lambda: optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=self.conf['training']['lr_step'],
|
||||
gamma=self.conf['training']['lr_decay']
|
||||
),
|
||||
'cosine': lambda: optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer,
|
||||
T_max=self.conf['training']['epochs'],
|
||||
eta_min=self.conf['training']['cosine_eta_min']
|
||||
),
|
||||
'cosine_warm': lambda: optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer,
|
||||
T_0=self.conf['training'].get('cosine_t_0', 10),
|
||||
T_mult=self.conf['training'].get('cosine_t_mult', 1),
|
||||
eta_min=self.conf['training'].get('cosine_eta_min', 0)
|
||||
)
|
||||
}
|
||||
return scheduler_mapping
|
||||
|
171
getpairs.py
Normal file
171
getpairs.py
Normal file
@ -0,0 +1,171 @@
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
import logging
|
||||
|
||||
class PairGenerator:
|
||||
"""Generate positive and negative image pairs for contrastive learning."""
|
||||
|
||||
def __init__(self):
|
||||
self._setup_logging()
|
||||
|
||||
def _setup_logging(self):
|
||||
"""Configure logging settings."""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
|
||||
"""Scan directory and return dict of {folder: [image_paths]}."""
|
||||
root = Path(root_dir)
|
||||
if not root.is_dir():
|
||||
raise ValueError(f"Invalid directory: {root_dir}")
|
||||
|
||||
return {
|
||||
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
|
||||
for folder in root.iterdir() if folder.is_dir()
|
||||
}
|
||||
|
||||
def _generate_same_pairs(
|
||||
self,
|
||||
files_dict: Dict[str, List[str]],
|
||||
num_pairs: int,
|
||||
group_size: Optional[int] = None
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""Generate positive pairs from same folder."""
|
||||
pairs = []
|
||||
|
||||
for folder, files in files_dict.items():
|
||||
if len(files) < 2:
|
||||
continue
|
||||
|
||||
if group_size:
|
||||
# Group mode: generate all possible pairs within group
|
||||
for i in range(0, len(files), group_size):
|
||||
group = files[i:i+group_size]
|
||||
pairs.extend([
|
||||
(group[i], group[j], 1)
|
||||
for i in range(len(group))
|
||||
for j in range(i+1, len(group))
|
||||
])
|
||||
else:
|
||||
# Individual mode: random pairs
|
||||
try:
|
||||
pairs.extend(self._random_pairs(files, min(3, len(files)//2)))
|
||||
except ValueError as e:
|
||||
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
|
||||
|
||||
random.shuffle(pairs)
|
||||
return pairs[:num_pairs]
|
||||
|
||||
def _generate_cross_pairs(
|
||||
self,
|
||||
files_dict: Dict[str, List[str]],
|
||||
num_pairs: int
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""Generate negative pairs from different folders."""
|
||||
folders = list(files_dict.keys())
|
||||
pairs = []
|
||||
|
||||
while len(pairs) < num_pairs and len(folders) >= 2:
|
||||
folder1, folder2 = random.sample(folders, 2)
|
||||
file1 = random.choice(files_dict[folder1])
|
||||
file2 = random.choice(files_dict[folder2])
|
||||
|
||||
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1)
|
||||
for f1, f2, _ in pairs):
|
||||
pairs.append((file1, file2, 0))
|
||||
|
||||
return pairs
|
||||
|
||||
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
|
||||
"""Generate random pairs from file list."""
|
||||
if len(files) < 2 * num_pairs:
|
||||
raise ValueError("Not enough files for requested pairs")
|
||||
|
||||
indices = random.sample(range(len(files)), 2 * num_pairs)
|
||||
indices.sort()
|
||||
return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)]
|
||||
|
||||
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]:
|
||||
"""
|
||||
Generate individual image pairs with labels (1=same, 0=different).
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing subfolders of images
|
||||
num_pairs: Number of pairs to generate
|
||||
|
||||
Returns:
|
||||
List of (path1, path2, label) tuples
|
||||
"""
|
||||
files_dict = self._get_image_files(root_dir)
|
||||
|
||||
same_pairs = self._generate_same_pairs(files_dict, num_pairs)
|
||||
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
|
||||
|
||||
pairs = same_pairs + cross_pairs
|
||||
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
|
||||
return pairs
|
||||
|
||||
def get_group_pairs(
|
||||
self,
|
||||
root_dir: str,
|
||||
img_num: int = 20,
|
||||
group_num: int = 10,
|
||||
num_pairs: int = 5000
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""
|
||||
Generate grouped image pairs with labels (1=same, 0=different).
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing subfolders of images
|
||||
img_num: Minimum images required per folder
|
||||
group_num: Number of images per group
|
||||
num_pairs: Number of pairs to generate
|
||||
|
||||
Returns:
|
||||
List of (path1, path2, label) tuples
|
||||
"""
|
||||
# Filter folders with enough images
|
||||
files_dict = {
|
||||
k: v for k, v in self._get_image_files(root_dir).items()
|
||||
if len(v) >= img_num
|
||||
}
|
||||
|
||||
# Split into groups
|
||||
grouped_files = {}
|
||||
for folder, files in files_dict.items():
|
||||
random.shuffle(files)
|
||||
grouped_files[folder] = [
|
||||
files[i:i+group_num]
|
||||
for i in range(0, len(files), group_num)
|
||||
]
|
||||
|
||||
# Generate pairs
|
||||
same_pairs = self._generate_same_pairs(
|
||||
grouped_files, num_pairs, group_size=group_num
|
||||
)
|
||||
cross_pairs = self._generate_cross_pairs(
|
||||
grouped_files, len(same_pairs)
|
||||
)
|
||||
|
||||
pairs = same_pairs + cross_pairs
|
||||
self.logger.info(f"Generated {len(pairs)} group pairs")
|
||||
|
||||
# Save to JSON
|
||||
with open("cross_same.json", 'w') as f:
|
||||
json.dump(pairs, f)
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generator = PairGenerator()
|
||||
|
||||
# Example usage:
|
||||
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs
|
||||
# groups = generator.get_group_pairs('val') # Group pairs
|
@ -297,8 +297,8 @@ def init_model():
|
||||
first_param_dtype = next(model.parameters()).dtype
|
||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||
else:
|
||||
model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device']))
|
||||
if conf.model_half:
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
||||
if conf['models']['half']:
|
||||
model.half()
|
||||
first_param_dtype = next(model.parameters()).dtype
|
||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||
|
@ -37,11 +37,11 @@ def load_data(training=True, cfg=None):
|
||||
if training:
|
||||
dataroot = cfg['data']['data_train_dir']
|
||||
transform = train_transform
|
||||
# transform = conf.train_transform
|
||||
# transform.yml = conf.train_transform
|
||||
batch_size = cfg['data']['train_batch_size']
|
||||
else:
|
||||
dataroot = cfg['data']['data_val_dir']
|
||||
# transform = conf.test_transform
|
||||
# transform.yml = conf.test_transform
|
||||
transform = test_transform
|
||||
batch_size = cfg['data']['val_batch_size']
|
||||
|
||||
@ -56,13 +56,13 @@ def load_data(training=True, cfg=None):
|
||||
return loader, class_num
|
||||
|
||||
# def load_gift_data(action):
|
||||
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
|
||||
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
|
||||
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||
# val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||||
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||
# test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||||
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# return train_dataset, val_dataset, test_dataset
|
||||
|
@ -1,10 +1,10 @@
|
||||
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
||||
../quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||
../quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||
../quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||
../quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||
../quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||
../quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||
../quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||
../quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||
../quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||
../quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
||||
|
@ -2,17 +2,29 @@ import pdb
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model import resnet18
|
||||
from config import config as conf
|
||||
# from config import config as conf
|
||||
from collections import OrderedDict
|
||||
from configs import trainer_tools
|
||||
import cv2
|
||||
import yaml
|
||||
|
||||
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
|
||||
# 定义模型
|
||||
if model_name == 'resnet18':
|
||||
model = resnet18(scale=0.75)
|
||||
def tranform_onnx_model():
|
||||
# # 定义模型
|
||||
# if model_name == 'resnet18':
|
||||
# model = resnet18(scale=0.75)
|
||||
|
||||
print('model_name >>> {}'.format(model_name))
|
||||
if conf.multiple_cards:
|
||||
with open('../configs/transform.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
tr_tools = trainer_tools(conf)
|
||||
backbone_mapping = tr_tools.get_backbone()
|
||||
if conf['models']['backbone'] in backbone_mapping:
|
||||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||
else:
|
||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||
pretrained_weights = conf['models']['model_path']
|
||||
print('model_name >>> {}'.format(conf['models']['backbone']))
|
||||
if conf['base']['distributed']:
|
||||
model = model.to(torch.device('cpu'))
|
||||
checkpoint = torch.load(pretrained_weights)
|
||||
new_state_dict = OrderedDict()
|
||||
@ -22,23 +34,9 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||
# try:
|
||||
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
|
||||
# model = nn.DataParallel(model).to(conf.device)
|
||||
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
|
||||
|
||||
|
||||
# 转换为ONNX
|
||||
if model_name == 'gift_type2':
|
||||
input_shape = [1, 64, 13, 13]
|
||||
elif model_name == 'gift_type3':
|
||||
input_shape = [1, 3, 224, 224]
|
||||
else:
|
||||
# 假设输入数据的大小是通道数*高度*宽度,例如3*224*224
|
||||
input_shape = [1, 3, 224, 224]
|
||||
input_shape = [1, 3, 224, 224]
|
||||
|
||||
img = cv2.imread('./dog_224x224.jpg')
|
||||
|
||||
@ -59,5 +57,4 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理
|
||||
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
|
||||
tranform_onnx_model()
|
||||
|
@ -6,15 +6,14 @@ import time
|
||||
import sys
|
||||
import numpy as np
|
||||
import cv2
|
||||
from config import config as conf
|
||||
from rknn.api import RKNN
|
||||
|
||||
import config
|
||||
|
||||
import yaml
|
||||
with open('../configs/transform.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
# ONNX_MODEL = 'resnet50v2.onnx'
|
||||
# RKNN_MODEL = 'resnet50v2.rknn'
|
||||
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
|
||||
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
|
||||
ONNX_MODEL = conf['models']['onnx_model']
|
||||
RKNN_MODEL = conf['models']['rknn_model']
|
||||
|
||||
|
||||
# ONNX_MODEL = 'v3_small_0424.onnx'
|
||||
|
@ -50,7 +50,7 @@ class FeatureExtractor:
|
||||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||
|
||||
# Initialize model
|
||||
model = resnet18().to(self.conf['base']['device'])
|
||||
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
|
||||
# Handle multi-GPU case
|
||||
if conf['base']['distributed']:
|
||||
|
@ -12,7 +12,7 @@ import matplotlib.pyplot as plt
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
|
||||
with open('configs/scatter.yml', 'r') as f:
|
||||
with open('configs/compare.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# Data Setup
|
||||
@ -47,11 +47,11 @@ else:
|
||||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||
if conf['training']['optimizer'] in optimizer_mapping:
|
||||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||
scheduler = optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=conf['training']['lr_step'],
|
||||
gamma=conf['training']['lr_decay']
|
||||
)
|
||||
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||||
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||||
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||||
conf['training']['scheduler']))
|
||||
|
||||
else:
|
||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||
|
||||
|
Reference in New Issue
Block a user