增加学习率调度方式

This commit is contained in:
lee
2025-06-13 10:45:53 +08:00
parent 37ecef40f7
commit 1803f319a5
13 changed files with 319 additions and 294 deletions

View File

@ -2,17 +2,29 @@ import pdb
import torch
import torch.nn as nn
from model import resnet18
from config import config as conf
# from config import config as conf
from collections import OrderedDict
from configs import trainer_tools
import cv2
import yaml
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
# 定义模型
if model_name == 'resnet18':
model = resnet18(scale=0.75)
def tranform_onnx_model():
# # 定义模型
# if model_name == 'resnet18':
# model = resnet18(scale=0.75)
print('model_name >>> {}'.format(model_name))
if conf.multiple_cards:
with open('../configs/transform.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
pretrained_weights = conf['models']['model_path']
print('model_name >>> {}'.format(conf['models']['backbone']))
if conf['base']['distributed']:
model = model.to(torch.device('cpu'))
checkpoint = torch.load(pretrained_weights)
new_state_dict = OrderedDict()
@ -22,23 +34,9 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# try:
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# except Exception as e:
# print(e)
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
# model = nn.DataParallel(model).to(conf.device)
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
# 转换为ONNX
if model_name == 'gift_type2':
input_shape = [1, 64, 13, 13]
elif model_name == 'gift_type3':
input_shape = [1, 3, 224, 224]
else:
# 假设输入数据的大小是通道数*高度*宽度例如3*224*224
input_shape = [1, 3, 224, 224]
input_shape = [1, 3, 224, 224]
img = cv2.imread('./dog_224x224.jpg')
@ -59,5 +57,4 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
if __name__ == '__main__':
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断gift3_type3指resnet原图计算推理
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
tranform_onnx_model()