增加学习率调度方式

This commit is contained in:
lee
2025-06-13 10:45:53 +08:00
parent 37ecef40f7
commit 1803f319a5
13 changed files with 319 additions and 294 deletions

View File

@ -3,6 +3,18 @@
<component name="CopilotChatHistory"> <component name="CopilotChatHistory">
<option name="conversations"> <option name="conversations">
<list> <list>
<Conversation>
<option name="createTime" value="1749718122230" />
<option name="id" value="01976353bef6703884544447c919013c" />
<option name="title" value="新对话 2025年6月12日 16:48:42" />
<option name="updateTime" value="1749718122230" />
</Conversation>
<Conversation>
<option name="createTime" value="1749648208122" />
<option name="id" value="01975f28f0fa7128afe7feddcdedb740" />
<option name="title" value="新对话 2025年6月11日 21:23:28" />
<option name="updateTime" value="1749648208122" />
</Conversation>
<Conversation> <Conversation>
<option name="createTime" value="1749522765718" /> <option name="createTime" value="1749522765718" />
<option name="id" value="019757aed78e777c96c4b7007ff2fecc" /> <option name="id" value="019757aed78e777c96c4b7007ff2fecc" />
@ -57,16 +69,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -91,16 +94,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
</list> </list>
@ -135,16 +129,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -169,16 +154,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -203,16 +179,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -237,16 +204,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -271,16 +229,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -305,16 +254,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -339,16 +279,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -373,16 +304,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -407,16 +329,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -441,16 +354,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -475,16 +379,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -509,16 +404,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -543,16 +429,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -577,16 +454,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -611,16 +479,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -645,16 +504,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -679,16 +529,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -713,16 +554,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -773,16 +605,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -807,16 +630,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
<Turn> <Turn>
@ -841,16 +655,7 @@
</option> </option>
<option name="status" value="SUCCESS" /> <option name="status" value="SUCCESS" />
<option name="variables"> <option name="variables">
<list> <list />
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
</option> </option>
</Turn> </Turn>
</list> </list>

View File

@ -15,8 +15,8 @@ base:
# 模型配置 # 模型配置
models: models:
backbone: 'resnet18' backbone: 'resnet34'
channel_ratio: 0.75 channel_ratio: 1.0
# 训练参数 # 训练参数
training: training:
@ -29,11 +29,14 @@ training:
lr_step: 10 # 学习率调整间隔epoch lr_step: 10 # 学习率调整间隔epoch
lr_decay: 0.98 # 学习率衰减率 lr_decay: 0.98 # 学习率衰减率
weight_decay: 0.0005 # 权重衰减 weight_decay: 0.0005 # 权重衰减
scheduler: "cosine_annealing" # 学习率调度器可选cosine_annealing/step/none scheduler: "cosine" # 学习率调度器可选cosine/cosine_warm/step/None
num_workers: 32 # 数据加载线程数 num_workers: 32 # 数据加载线程数
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录 checkpoints: "./checkpoints/resnet34_20250612_scale=1.0/" # 模型保存目录
restore: false restore: false
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径 restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
cosine_t_0: 10 # 初始周期长度
cosine_t_mult: 1 # 周期长度倍率
cosine_eta_min: 0.00001 # 最小学习率
# 验证参数 # 验证参数
validation: validation:

View File

@ -8,13 +8,13 @@ base:
log_level: "info" # 日志级别debug/info/warning/error log_level: "info" # 日志级别debug/info/warning/error
embedding_size: 256 # 特征维度 embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory pin_memory: true # 是否启用pin_memory
distributed: true # 是否启用分布式训练 distributed: false # 是否启用分布式训练
# 模型配置 # 模型配置
models: models:
backbone: 'resnet18' backbone: 'resnet18'
channel_ratio: 1.0 channel_ratio: 0.75
model_path: "./checkpoints/resnet18_scatter_6.2/best.pth" model_path: "./checkpoints/resnet18_0515/best.pth"
half: false # 是否启用半精度测试fp16 half: false # 是否启用半精度测试fp16
# 数据配置 # 数据配置
@ -22,9 +22,9 @@ data:
group_test: False # 数据集名称(示例用,可替换为实际数据集) group_test: False # 数据集名称(示例用,可替换为实际数据集)
test_batch_size: 128 # 训练批次大小 test_batch_size: 128 # 训练批次大小
num_workers: 32 # 数据加载线程数 num_workers: 32 # 数据加载线程数
test_dir: "../data_center/scatter/" # 验证数据集根目录 test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
test_list: "../data_center/scatter/val_pair.txt" test_list: "../data_center/contrast_learning/contrast_test_data/test_pair.txt"
transform: transform:
img_size: 224 # 图像尺寸 img_size: 224 # 图像尺寸

27
configs/transform.yml Normal file
View File

@ -0,0 +1,27 @@
# configs/transform.yml
# pth转换onnx配置文件
# 基础配置
base:
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
seed: 42 # 随机种子(保证可复现性)
device: "cuda" # 训练设备cuda/cpu
log_level: "info" # 日志级别debug/info/warning/error
embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory
distributed: true # 是否启用分布式训练
# 模型配置
models:
backbone: 'resnet50'
channel_ratio: 1.0
model_path: "../checkpoints/resnet50_0519/best.pth"
onnx_model: "../checkpoints/resnet50_0519/best.onnx"
rknn_model: "../checkpoints/resnet50_0519/best.rknn"
# 日志与监控
logging:
logging_dir: "./logs" # 日志保存目录
tensorboard: true # 是否启用TensorBoard
checkpoint_interval: 30 # 检查点保存间隔epoch

View File

@ -1,4 +1,4 @@
from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1, from model import (resnet18, resnet34, resnet50, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5) PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
from timm.models import vit_base_patch16_224 as vit_base_16 from timm.models import vit_base_patch16_224 as vit_base_16
from model.metric import ArcFace, CosFace from model.metric import ArcFace, CosFace
@ -14,6 +14,8 @@ class trainer_tools:
def get_backbone(self): def get_backbone(self):
backbone_mapping = { backbone_mapping = {
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']), 'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio']),
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio']),
'mobilevit_s': lambda: mobilevit_s(), 'mobilevit_s': lambda: mobilevit_s(),
'mobilenetv3_small': lambda: MobileNetV3_Small(), 'mobilenetv3_small': lambda: MobileNetV3_Small(),
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(), 'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
@ -54,3 +56,24 @@ class trainer_tools:
) )
} }
return optimizer_mapping return optimizer_mapping
def get_scheduler(self, optimizer):
scheduler_mapping = {
'step': lambda: optim.lr_scheduler.StepLR(
optimizer,
step_size=self.conf['training']['lr_step'],
gamma=self.conf['training']['lr_decay']
),
'cosine': lambda: optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.conf['training']['epochs'],
eta_min=self.conf['training']['cosine_eta_min']
),
'cosine_warm': lambda: optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=self.conf['training'].get('cosine_t_0', 10),
T_mult=self.conf['training'].get('cosine_t_mult', 1),
eta_min=self.conf['training'].get('cosine_eta_min', 0)
)
}
return scheduler_mapping

171
getpairs.py Normal file
View File

@ -0,0 +1,171 @@
import os
import random
import json
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import logging
class PairGenerator:
"""Generate positive and negative image pairs for contrastive learning."""
def __init__(self):
self._setup_logging()
def _setup_logging(self):
"""Configure logging settings."""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
"""Scan directory and return dict of {folder: [image_paths]}."""
root = Path(root_dir)
if not root.is_dir():
raise ValueError(f"Invalid directory: {root_dir}")
return {
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
for folder in root.iterdir() if folder.is_dir()
}
def _generate_same_pairs(
self,
files_dict: Dict[str, List[str]],
num_pairs: int,
group_size: Optional[int] = None
) -> List[Tuple[str, str, int]]:
"""Generate positive pairs from same folder."""
pairs = []
for folder, files in files_dict.items():
if len(files) < 2:
continue
if group_size:
# Group mode: generate all possible pairs within group
for i in range(0, len(files), group_size):
group = files[i:i+group_size]
pairs.extend([
(group[i], group[j], 1)
for i in range(len(group))
for j in range(i+1, len(group))
])
else:
# Individual mode: random pairs
try:
pairs.extend(self._random_pairs(files, min(3, len(files)//2)))
except ValueError as e:
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
random.shuffle(pairs)
return pairs[:num_pairs]
def _generate_cross_pairs(
self,
files_dict: Dict[str, List[str]],
num_pairs: int
) -> List[Tuple[str, str, int]]:
"""Generate negative pairs from different folders."""
folders = list(files_dict.keys())
pairs = []
while len(pairs) < num_pairs and len(folders) >= 2:
folder1, folder2 = random.sample(folders, 2)
file1 = random.choice(files_dict[folder1])
file2 = random.choice(files_dict[folder2])
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1)
for f1, f2, _ in pairs):
pairs.append((file1, file2, 0))
return pairs
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
"""Generate random pairs from file list."""
if len(files) < 2 * num_pairs:
raise ValueError("Not enough files for requested pairs")
indices = random.sample(range(len(files)), 2 * num_pairs)
indices.sort()
return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)]
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]:
"""
Generate individual image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
num_pairs: Number of pairs to generate
Returns:
List of (path1, path2, label) tuples
"""
files_dict = self._get_image_files(root_dir)
same_pairs = self._generate_same_pairs(files_dict, num_pairs)
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
return pairs
def get_group_pairs(
self,
root_dir: str,
img_num: int = 20,
group_num: int = 10,
num_pairs: int = 5000
) -> List[Tuple[str, str, int]]:
"""
Generate grouped image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
img_num: Minimum images required per folder
group_num: Number of images per group
num_pairs: Number of pairs to generate
Returns:
List of (path1, path2, label) tuples
"""
# Filter folders with enough images
files_dict = {
k: v for k, v in self._get_image_files(root_dir).items()
if len(v) >= img_num
}
# Split into groups
grouped_files = {}
for folder, files in files_dict.items():
random.shuffle(files)
grouped_files[folder] = [
files[i:i+group_num]
for i in range(0, len(files), group_num)
]
# Generate pairs
same_pairs = self._generate_same_pairs(
grouped_files, num_pairs, group_size=group_num
)
cross_pairs = self._generate_cross_pairs(
grouped_files, len(same_pairs)
)
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} group pairs")
# Save to JSON
with open("cross_same.json", 'w') as f:
json.dump(pairs, f)
return pairs
if __name__ == "__main__":
generator = PairGenerator()
# Example usage:
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs
# groups = generator.get_group_pairs('val') # Group pairs

View File

@ -297,8 +297,8 @@ def init_model():
first_param_dtype = next(model.parameters()).dtype first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype)) print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
else: else:
model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device'])) model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
if conf.model_half: if conf['models']['half']:
model.half() model.half()
first_param_dtype = next(model.parameters()).dtype first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype)) print("模型的第一个参数的数据类型: {}".format(first_param_dtype))

View File

@ -37,11 +37,11 @@ def load_data(training=True, cfg=None):
if training: if training:
dataroot = cfg['data']['data_train_dir'] dataroot = cfg['data']['data_train_dir']
transform = train_transform transform = train_transform
# transform = conf.train_transform # transform.yml = conf.train_transform
batch_size = cfg['data']['train_batch_size'] batch_size = cfg['data']['train_batch_size']
else: else:
dataroot = cfg['data']['data_val_dir'] dataroot = cfg['data']['data_val_dir']
# transform = conf.test_transform # transform.yml = conf.test_transform
transform = test_transform transform = test_transform
batch_size = cfg['data']['val_batch_size'] batch_size = cfg['data']['val_batch_size']
@ -56,13 +56,13 @@ def load_data(training=True, cfg=None):
return loader, class_num return loader, class_num
# def load_gift_data(action): # def load_gift_data(action):
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform) # train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True, # train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True) # pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform) # val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True, # val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True) # pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform) # test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True, # test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True) # pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# return train_dataset, val_dataset, test_dataset # return train_dataset, val_dataset, test_dataset

View File

@ -1,10 +1,10 @@
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg ../quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg ../quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg ../quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg ../quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg ../quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg ../quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg ../quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg ../quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg ../quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg ../quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg

View File

@ -2,17 +2,29 @@ import pdb
import torch import torch
import torch.nn as nn import torch.nn as nn
from model import resnet18 from model import resnet18
from config import config as conf # from config import config as conf
from collections import OrderedDict from collections import OrderedDict
from configs import trainer_tools
import cv2 import cv2
import yaml
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'): def tranform_onnx_model():
# 定义模型 # # 定义模型
if model_name == 'resnet18': # if model_name == 'resnet18':
model = resnet18(scale=0.75) # model = resnet18(scale=0.75)
print('model_name >>> {}'.format(model_name)) with open('../configs/transform.yml', 'r') as f:
if conf.multiple_cards: conf = yaml.load(f, Loader=yaml.FullLoader)
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
pretrained_weights = conf['models']['model_path']
print('model_name >>> {}'.format(conf['models']['backbone']))
if conf['base']['distributed']:
model = model.to(torch.device('cpu')) model = model.to(torch.device('cpu'))
checkpoint = torch.load(pretrained_weights) checkpoint = torch.load(pretrained_weights)
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
@ -22,23 +34,9 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
else: else:
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu'))) model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# try:
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# except Exception as e:
# print(e)
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
# model = nn.DataParallel(model).to(conf.device)
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
# 转换为ONNX # 转换为ONNX
if model_name == 'gift_type2': input_shape = [1, 3, 224, 224]
input_shape = [1, 64, 13, 13]
elif model_name == 'gift_type3':
input_shape = [1, 3, 224, 224]
else:
# 假设输入数据的大小是通道数*高度*宽度例如3*224*224
input_shape = [1, 3, 224, 224]
img = cv2.imread('./dog_224x224.jpg') img = cv2.imread('./dog_224x224.jpg')
@ -59,5 +57,4 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
if __name__ == '__main__': if __name__ == '__main__':
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断gift3_type3指resnet原图计算推理 tranform_onnx_model()
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')

View File

@ -6,15 +6,14 @@ import time
import sys import sys
import numpy as np import numpy as np
import cv2 import cv2
from config import config as conf
from rknn.api import RKNN from rknn.api import RKNN
import yaml
import config with open('../configs/transform.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# ONNX_MODEL = 'resnet50v2.onnx' # ONNX_MODEL = 'resnet50v2.onnx'
# RKNN_MODEL = 'resnet50v2.rknn' # RKNN_MODEL = 'resnet50v2.rknn'
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx' ONNX_MODEL = conf['models']['onnx_model']
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn' RKNN_MODEL = conf['models']['rknn_model']
# ONNX_MODEL = 'v3_small_0424.onnx' # ONNX_MODEL = 'v3_small_0424.onnx'

View File

@ -50,7 +50,7 @@ class FeatureExtractor:
raise FileNotFoundError(f"Model weights file not found: {model_path}") raise FileNotFoundError(f"Model weights file not found: {model_path}")
# Initialize model # Initialize model
model = resnet18().to(self.conf['base']['device']) model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
# Handle multi-GPU case # Handle multi-GPU case
if conf['base']['distributed']: if conf['base']['distributed']:

View File

@ -12,7 +12,7 @@ import matplotlib.pyplot as plt
from configs import trainer_tools from configs import trainer_tools
import yaml import yaml
with open('configs/scatter.yml', 'r') as f: with open('configs/compare.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader) conf = yaml.load(f, Loader=yaml.FullLoader)
# Data Setup # Data Setup
@ -47,11 +47,11 @@ else:
optimizer_mapping = tr_tools.get_optimizer(model, metric) optimizer_mapping = tr_tools.get_optimizer(model, metric)
if conf['training']['optimizer'] in optimizer_mapping: if conf['training']['optimizer'] in optimizer_mapping:
optimizer = optimizer_mapping[conf['training']['optimizer']]() optimizer = optimizer_mapping[conf['training']['optimizer']]()
scheduler = optim.lr_scheduler.StepLR( scheduler_mapping = tr_tools.get_scheduler(optimizer)
optimizer, scheduler = scheduler_mapping[conf['training']['scheduler']]()
step_size=conf['training']['lr_step'], print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
gamma=conf['training']['lr_decay'] conf['training']['scheduler']))
)
else: else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer'])) raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))