61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
import pdb
|
|
import torch
|
|
import torch.nn as nn
|
|
from model import resnet18
|
|
# from config import config as conf
|
|
from collections import OrderedDict
|
|
from configs import trainer_tools
|
|
import cv2
|
|
import yaml
|
|
|
|
def tranform_onnx_model():
|
|
# # 定义模型
|
|
# if model_name == 'resnet18':
|
|
# model = resnet18(scale=0.75)
|
|
|
|
with open('../configs/transform.yml', 'r') as f:
|
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
tr_tools = trainer_tools(conf)
|
|
backbone_mapping = tr_tools.get_backbone()
|
|
if conf['models']['backbone'] in backbone_mapping:
|
|
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
|
else:
|
|
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
|
pretrained_weights = conf['models']['model_path']
|
|
print('model_name >>> {}'.format(conf['models']['backbone']))
|
|
if conf['base']['distributed']:
|
|
model = model.to(torch.device('cpu'))
|
|
checkpoint = torch.load(pretrained_weights)
|
|
new_state_dict = OrderedDict()
|
|
for k, v in checkpoint.items():
|
|
name = k[7:] # remove "module."
|
|
new_state_dict[name] = v
|
|
model.load_state_dict(new_state_dict)
|
|
else:
|
|
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
|
|
|
# 转换为ONNX
|
|
input_shape = [1, 3, 224, 224]
|
|
|
|
img = cv2.imread('./dog_224x224.jpg')
|
|
|
|
output_file = pretrained_weights.replace('pth', 'onnx')
|
|
|
|
# 导出模型
|
|
torch.onnx.export(model,
|
|
torch.randn(input_shape),
|
|
output_file,
|
|
verbose=True,
|
|
input_names=['input'],
|
|
output_names=['output']) ##, optset_version=12
|
|
|
|
model.eval()
|
|
trace_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
|
|
trace_model.save(output_file.replace('.onnx', '.pt'))
|
|
print(f"Model exported to {output_file}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tranform_onnx_model()
|