增加学习率调度方式
This commit is contained in:
@ -37,11 +37,11 @@ def load_data(training=True, cfg=None):
|
||||
if training:
|
||||
dataroot = cfg['data']['data_train_dir']
|
||||
transform = train_transform
|
||||
# transform = conf.train_transform
|
||||
# transform.yml = conf.train_transform
|
||||
batch_size = cfg['data']['train_batch_size']
|
||||
else:
|
||||
dataroot = cfg['data']['data_val_dir']
|
||||
# transform = conf.test_transform
|
||||
# transform.yml = conf.test_transform
|
||||
transform = test_transform
|
||||
batch_size = cfg['data']['val_batch_size']
|
||||
|
||||
@ -56,13 +56,13 @@ def load_data(training=True, cfg=None):
|
||||
return loader, class_num
|
||||
|
||||
# def load_gift_data(action):
|
||||
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
|
||||
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
|
||||
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||
# val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||||
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||
# test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||||
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# return train_dataset, val_dataset, test_dataset
|
||||
|
@ -1,10 +1,10 @@
|
||||
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
||||
../quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||
../quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||
../quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||
../quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||
../quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||
../quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||
../quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||
../quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||
../quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||
../quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
||||
|
@ -2,17 +2,29 @@ import pdb
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model import resnet18
|
||||
from config import config as conf
|
||||
# from config import config as conf
|
||||
from collections import OrderedDict
|
||||
from configs import trainer_tools
|
||||
import cv2
|
||||
import yaml
|
||||
|
||||
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
|
||||
# 定义模型
|
||||
if model_name == 'resnet18':
|
||||
model = resnet18(scale=0.75)
|
||||
def tranform_onnx_model():
|
||||
# # 定义模型
|
||||
# if model_name == 'resnet18':
|
||||
# model = resnet18(scale=0.75)
|
||||
|
||||
print('model_name >>> {}'.format(model_name))
|
||||
if conf.multiple_cards:
|
||||
with open('../configs/transform.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
tr_tools = trainer_tools(conf)
|
||||
backbone_mapping = tr_tools.get_backbone()
|
||||
if conf['models']['backbone'] in backbone_mapping:
|
||||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||
else:
|
||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||
pretrained_weights = conf['models']['model_path']
|
||||
print('model_name >>> {}'.format(conf['models']['backbone']))
|
||||
if conf['base']['distributed']:
|
||||
model = model.to(torch.device('cpu'))
|
||||
checkpoint = torch.load(pretrained_weights)
|
||||
new_state_dict = OrderedDict()
|
||||
@ -22,23 +34,9 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||
# try:
|
||||
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
|
||||
# model = nn.DataParallel(model).to(conf.device)
|
||||
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
|
||||
|
||||
|
||||
# 转换为ONNX
|
||||
if model_name == 'gift_type2':
|
||||
input_shape = [1, 64, 13, 13]
|
||||
elif model_name == 'gift_type3':
|
||||
input_shape = [1, 3, 224, 224]
|
||||
else:
|
||||
# 假设输入数据的大小是通道数*高度*宽度,例如3*224*224
|
||||
input_shape = [1, 3, 224, 224]
|
||||
input_shape = [1, 3, 224, 224]
|
||||
|
||||
img = cv2.imread('./dog_224x224.jpg')
|
||||
|
||||
@ -59,5 +57,4 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理
|
||||
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
|
||||
tranform_onnx_model()
|
||||
|
@ -6,15 +6,14 @@ import time
|
||||
import sys
|
||||
import numpy as np
|
||||
import cv2
|
||||
from config import config as conf
|
||||
from rknn.api import RKNN
|
||||
|
||||
import config
|
||||
|
||||
import yaml
|
||||
with open('../configs/transform.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
# ONNX_MODEL = 'resnet50v2.onnx'
|
||||
# RKNN_MODEL = 'resnet50v2.rknn'
|
||||
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
|
||||
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
|
||||
ONNX_MODEL = conf['models']['onnx_model']
|
||||
RKNN_MODEL = conf['models']['rknn_model']
|
||||
|
||||
|
||||
# ONNX_MODEL = 'v3_small_0424.onnx'
|
||||
|
@ -50,7 +50,7 @@ class FeatureExtractor:
|
||||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||
|
||||
# Initialize model
|
||||
model = resnet18().to(self.conf['base']['device'])
|
||||
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
|
||||
# Handle multi-GPU case
|
||||
if conf['base']['distributed']:
|
||||
|
Reference in New Issue
Block a user