64 lines
2.4 KiB
Python
64 lines
2.4 KiB
Python
import pdb
|
||
import torch
|
||
import torch.nn as nn
|
||
from model import resnet18
|
||
from config import config as conf
|
||
from collections import OrderedDict
|
||
import cv2
|
||
|
||
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
|
||
# 定义模型
|
||
if model_name == 'resnet18':
|
||
model = resnet18(scale=0.75)
|
||
|
||
print('model_name >>> {}'.format(model_name))
|
||
if conf.multiple_cards:
|
||
model = model.to(torch.device('cpu'))
|
||
checkpoint = torch.load(pretrained_weights)
|
||
new_state_dict = OrderedDict()
|
||
for k, v in checkpoint.items():
|
||
name = k[7:] # remove "module."
|
||
new_state_dict[name] = v
|
||
model.load_state_dict(new_state_dict)
|
||
else:
|
||
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||
# try:
|
||
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||
# except Exception as e:
|
||
# print(e)
|
||
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
|
||
# model = nn.DataParallel(model).to(conf.device)
|
||
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
|
||
|
||
|
||
# 转换为ONNX
|
||
if model_name == 'gift_type2':
|
||
input_shape = [1, 64, 13, 13]
|
||
elif model_name == 'gift_type3':
|
||
input_shape = [1, 3, 224, 224]
|
||
else:
|
||
# 假设输入数据的大小是通道数*高度*宽度,例如3*224*224
|
||
input_shape = [1, 3, 224, 224]
|
||
|
||
img = cv2.imread('./dog_224x224.jpg')
|
||
|
||
output_file = pretrained_weights.replace('pth', 'onnx')
|
||
|
||
# 导出模型
|
||
torch.onnx.export(model,
|
||
torch.randn(input_shape),
|
||
output_file,
|
||
verbose=True,
|
||
input_names=['input'],
|
||
output_names=['output']) ##, optset_version=12
|
||
|
||
model.eval()
|
||
trace_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
|
||
trace_model.save(output_file.replace('.onnx', '.pt'))
|
||
print(f"Model exported to {output_file}")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理
|
||
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
|