增加学习率调度方式

This commit is contained in:
lee
2025-06-13 10:45:53 +08:00
parent 37ecef40f7
commit 1803f319a5
13 changed files with 319 additions and 294 deletions

View File

@ -3,6 +3,18 @@
<component name="CopilotChatHistory">
<option name="conversations">
<list>
<Conversation>
<option name="createTime" value="1749718122230" />
<option name="id" value="01976353bef6703884544447c919013c" />
<option name="title" value="新对话 2025年6月12日 16:48:42" />
<option name="updateTime" value="1749718122230" />
</Conversation>
<Conversation>
<option name="createTime" value="1749648208122" />
<option name="id" value="01975f28f0fa7128afe7feddcdedb740" />
<option name="title" value="新对话 2025年6月11日 21:23:28" />
<option name="updateTime" value="1749648208122" />
</Conversation>
<Conversation>
<option name="createTime" value="1749522765718" />
<option name="id" value="019757aed78e777c96c4b7007ff2fecc" />
@ -57,16 +69,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -91,16 +94,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
</list>
@ -135,16 +129,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -169,16 +154,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -203,16 +179,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -237,16 +204,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -271,16 +229,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -305,16 +254,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -339,16 +279,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -373,16 +304,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -407,16 +329,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -441,16 +354,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -475,16 +379,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -509,16 +404,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -543,16 +429,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -577,16 +454,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -611,16 +479,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -645,16 +504,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -679,16 +529,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -713,16 +554,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -773,16 +605,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -807,16 +630,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -841,16 +655,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
</list>

View File

@ -15,8 +15,8 @@ base:
# 模型配置
models:
backbone: 'resnet18'
channel_ratio: 0.75
backbone: 'resnet34'
channel_ratio: 1.0
# 训练参数
training:
@ -29,11 +29,14 @@ training:
lr_step: 10 # 学习率调整间隔epoch
lr_decay: 0.98 # 学习率衰减率
weight_decay: 0.0005 # 权重衰减
scheduler: "cosine_annealing" # 学习率调度器可选cosine_annealing/step/none
scheduler: "cosine" # 学习率调度器可选cosine/cosine_warm/step/None
num_workers: 32 # 数据加载线程数
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
checkpoints: "./checkpoints/resnet34_20250612_scale=1.0/" # 模型保存目录
restore: false
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
cosine_t_0: 10 # 初始周期长度
cosine_t_mult: 1 # 周期长度倍率
cosine_eta_min: 0.00001 # 最小学习率
# 验证参数
validation:

View File

@ -8,13 +8,13 @@ base:
log_level: "info" # 日志级别debug/info/warning/error
embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory
distributed: true # 是否启用分布式训练
distributed: false # 是否启用分布式训练
# 模型配置
models:
backbone: 'resnet18'
channel_ratio: 1.0
model_path: "./checkpoints/resnet18_scatter_6.2/best.pth"
channel_ratio: 0.75
model_path: "./checkpoints/resnet18_0515/best.pth"
half: false # 是否启用半精度测试fp16
# 数据配置
@ -22,9 +22,9 @@ data:
group_test: False # 数据集名称(示例用,可替换为实际数据集)
test_batch_size: 128 # 训练批次大小
num_workers: 32 # 数据加载线程数
test_dir: "../data_center/scatter/" # 验证数据集根目录
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
test_list: "../data_center/scatter/val_pair.txt"
test_list: "../data_center/contrast_learning/contrast_test_data/test_pair.txt"
transform:
img_size: 224 # 图像尺寸

27
configs/transform.yml Normal file
View File

@ -0,0 +1,27 @@
# configs/transform.yml
# pth转换onnx配置文件
# 基础配置
base:
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
seed: 42 # 随机种子(保证可复现性)
device: "cuda" # 训练设备cuda/cpu
log_level: "info" # 日志级别debug/info/warning/error
embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory
distributed: true # 是否启用分布式训练
# 模型配置
models:
backbone: 'resnet50'
channel_ratio: 1.0
model_path: "../checkpoints/resnet50_0519/best.pth"
onnx_model: "../checkpoints/resnet50_0519/best.onnx"
rknn_model: "../checkpoints/resnet50_0519/best.rknn"
# 日志与监控
logging:
logging_dir: "./logs" # 日志保存目录
tensorboard: true # 是否启用TensorBoard
checkpoint_interval: 30 # 检查点保存间隔epoch

View File

@ -1,4 +1,4 @@
from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
from model import (resnet18, resnet34, resnet50, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
from timm.models import vit_base_patch16_224 as vit_base_16
from model.metric import ArcFace, CosFace
@ -14,6 +14,8 @@ class trainer_tools:
def get_backbone(self):
backbone_mapping = {
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio']),
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio']),
'mobilevit_s': lambda: mobilevit_s(),
'mobilenetv3_small': lambda: MobileNetV3_Small(),
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
@ -54,3 +56,24 @@ class trainer_tools:
)
}
return optimizer_mapping
def get_scheduler(self, optimizer):
scheduler_mapping = {
'step': lambda: optim.lr_scheduler.StepLR(
optimizer,
step_size=self.conf['training']['lr_step'],
gamma=self.conf['training']['lr_decay']
),
'cosine': lambda: optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.conf['training']['epochs'],
eta_min=self.conf['training']['cosine_eta_min']
),
'cosine_warm': lambda: optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=self.conf['training'].get('cosine_t_0', 10),
T_mult=self.conf['training'].get('cosine_t_mult', 1),
eta_min=self.conf['training'].get('cosine_eta_min', 0)
)
}
return scheduler_mapping

171
getpairs.py Normal file
View File

@ -0,0 +1,171 @@
import os
import random
import json
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import logging
class PairGenerator:
"""Generate positive and negative image pairs for contrastive learning."""
def __init__(self):
self._setup_logging()
def _setup_logging(self):
"""Configure logging settings."""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
"""Scan directory and return dict of {folder: [image_paths]}."""
root = Path(root_dir)
if not root.is_dir():
raise ValueError(f"Invalid directory: {root_dir}")
return {
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
for folder in root.iterdir() if folder.is_dir()
}
def _generate_same_pairs(
self,
files_dict: Dict[str, List[str]],
num_pairs: int,
group_size: Optional[int] = None
) -> List[Tuple[str, str, int]]:
"""Generate positive pairs from same folder."""
pairs = []
for folder, files in files_dict.items():
if len(files) < 2:
continue
if group_size:
# Group mode: generate all possible pairs within group
for i in range(0, len(files), group_size):
group = files[i:i+group_size]
pairs.extend([
(group[i], group[j], 1)
for i in range(len(group))
for j in range(i+1, len(group))
])
else:
# Individual mode: random pairs
try:
pairs.extend(self._random_pairs(files, min(3, len(files)//2)))
except ValueError as e:
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
random.shuffle(pairs)
return pairs[:num_pairs]
def _generate_cross_pairs(
self,
files_dict: Dict[str, List[str]],
num_pairs: int
) -> List[Tuple[str, str, int]]:
"""Generate negative pairs from different folders."""
folders = list(files_dict.keys())
pairs = []
while len(pairs) < num_pairs and len(folders) >= 2:
folder1, folder2 = random.sample(folders, 2)
file1 = random.choice(files_dict[folder1])
file2 = random.choice(files_dict[folder2])
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1)
for f1, f2, _ in pairs):
pairs.append((file1, file2, 0))
return pairs
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
"""Generate random pairs from file list."""
if len(files) < 2 * num_pairs:
raise ValueError("Not enough files for requested pairs")
indices = random.sample(range(len(files)), 2 * num_pairs)
indices.sort()
return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)]
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]:
"""
Generate individual image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
num_pairs: Number of pairs to generate
Returns:
List of (path1, path2, label) tuples
"""
files_dict = self._get_image_files(root_dir)
same_pairs = self._generate_same_pairs(files_dict, num_pairs)
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
return pairs
def get_group_pairs(
self,
root_dir: str,
img_num: int = 20,
group_num: int = 10,
num_pairs: int = 5000
) -> List[Tuple[str, str, int]]:
"""
Generate grouped image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
img_num: Minimum images required per folder
group_num: Number of images per group
num_pairs: Number of pairs to generate
Returns:
List of (path1, path2, label) tuples
"""
# Filter folders with enough images
files_dict = {
k: v for k, v in self._get_image_files(root_dir).items()
if len(v) >= img_num
}
# Split into groups
grouped_files = {}
for folder, files in files_dict.items():
random.shuffle(files)
grouped_files[folder] = [
files[i:i+group_num]
for i in range(0, len(files), group_num)
]
# Generate pairs
same_pairs = self._generate_same_pairs(
grouped_files, num_pairs, group_size=group_num
)
cross_pairs = self._generate_cross_pairs(
grouped_files, len(same_pairs)
)
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} group pairs")
# Save to JSON
with open("cross_same.json", 'w') as f:
json.dump(pairs, f)
return pairs
if __name__ == "__main__":
generator = PairGenerator()
# Example usage:
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs
# groups = generator.get_group_pairs('val') # Group pairs

View File

@ -297,8 +297,8 @@ def init_model():
first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
else:
model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device']))
if conf.model_half:
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
if conf['models']['half']:
model.half()
first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))

View File

@ -37,11 +37,11 @@ def load_data(training=True, cfg=None):
if training:
dataroot = cfg['data']['data_train_dir']
transform = train_transform
# transform = conf.train_transform
# transform.yml = conf.train_transform
batch_size = cfg['data']['train_batch_size']
else:
dataroot = cfg['data']['data_val_dir']
# transform = conf.test_transform
# transform.yml = conf.test_transform
transform = test_transform
batch_size = cfg['data']['val_batch_size']
@ -56,13 +56,13 @@ def load_data(training=True, cfg=None):
return loader, class_num
# def load_gift_data(action):
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
# val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
# test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# return train_dataset, val_dataset, test_dataset

View File

@ -1,10 +1,10 @@
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
../quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
../quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
../quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
../quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
../quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
../quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
../quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
../quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
../quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
../quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg

View File

@ -2,17 +2,29 @@ import pdb
import torch
import torch.nn as nn
from model import resnet18
from config import config as conf
# from config import config as conf
from collections import OrderedDict
from configs import trainer_tools
import cv2
import yaml
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
# 定义模型
if model_name == 'resnet18':
model = resnet18(scale=0.75)
def tranform_onnx_model():
# # 定义模型
# if model_name == 'resnet18':
# model = resnet18(scale=0.75)
print('model_name >>> {}'.format(model_name))
if conf.multiple_cards:
with open('../configs/transform.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
pretrained_weights = conf['models']['model_path']
print('model_name >>> {}'.format(conf['models']['backbone']))
if conf['base']['distributed']:
model = model.to(torch.device('cpu'))
checkpoint = torch.load(pretrained_weights)
new_state_dict = OrderedDict()
@ -22,23 +34,9 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# try:
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# except Exception as e:
# print(e)
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
# model = nn.DataParallel(model).to(conf.device)
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
# 转换为ONNX
if model_name == 'gift_type2':
input_shape = [1, 64, 13, 13]
elif model_name == 'gift_type3':
input_shape = [1, 3, 224, 224]
else:
# 假设输入数据的大小是通道数*高度*宽度例如3*224*224
input_shape = [1, 3, 224, 224]
input_shape = [1, 3, 224, 224]
img = cv2.imread('./dog_224x224.jpg')
@ -59,5 +57,4 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
if __name__ == '__main__':
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断gift3_type3指resnet原图计算推理
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
tranform_onnx_model()

View File

@ -6,15 +6,14 @@ import time
import sys
import numpy as np
import cv2
from config import config as conf
from rknn.api import RKNN
import config
import yaml
with open('../configs/transform.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# ONNX_MODEL = 'resnet50v2.onnx'
# RKNN_MODEL = 'resnet50v2.rknn'
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
ONNX_MODEL = conf['models']['onnx_model']
RKNN_MODEL = conf['models']['rknn_model']
# ONNX_MODEL = 'v3_small_0424.onnx'

View File

@ -50,7 +50,7 @@ class FeatureExtractor:
raise FileNotFoundError(f"Model weights file not found: {model_path}")
# Initialize model
model = resnet18().to(self.conf['base']['device'])
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
# Handle multi-GPU case
if conf['base']['distributed']:

View File

@ -12,7 +12,7 @@ import matplotlib.pyplot as plt
from configs import trainer_tools
import yaml
with open('configs/scatter.yml', 'r') as f:
with open('configs/compare.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# Data Setup
@ -47,11 +47,11 @@ else:
optimizer_mapping = tr_tools.get_optimizer(model, metric)
if conf['training']['optimizer'] in optimizer_mapping:
optimizer = optimizer_mapping[conf['training']['optimizer']]()
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=conf['training']['lr_step'],
gamma=conf['training']['lr_decay']
)
scheduler_mapping = tr_tools.get_scheduler(optimizer)
scheduler = scheduler_mapping[conf['training']['scheduler']]()
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
conf['training']['scheduler']))
else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))