Files
ieemoo-ai-contrast/tools/model_onnx_transform.py
2025-06-11 15:23:50 +08:00

64 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pdb
import torch
import torch.nn as nn
from model import resnet18
from config import config as conf
from collections import OrderedDict
import cv2
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
# 定义模型
if model_name == 'resnet18':
model = resnet18(scale=0.75)
print('model_name >>> {}'.format(model_name))
if conf.multiple_cards:
model = model.to(torch.device('cpu'))
checkpoint = torch.load(pretrained_weights)
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = k[7:] # remove "module."
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# try:
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# except Exception as e:
# print(e)
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
# model = nn.DataParallel(model).to(conf.device)
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
# 转换为ONNX
if model_name == 'gift_type2':
input_shape = [1, 64, 13, 13]
elif model_name == 'gift_type3':
input_shape = [1, 3, 224, 224]
else:
# 假设输入数据的大小是通道数*高度*宽度例如3*224*224
input_shape = [1, 3, 224, 224]
img = cv2.imread('./dog_224x224.jpg')
output_file = pretrained_weights.replace('pth', 'onnx')
# 导出模型
torch.onnx.export(model,
torch.randn(input_shape),
output_file,
verbose=True,
input_names=['input'],
output_names=['output']) ##, optset_version=12
model.eval()
trace_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
trace_model.save(output_file.replace('.onnx', '.pt'))
print(f"Model exported to {output_file}")
if __name__ == '__main__':
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断gift3_type3指resnet原图计算推理
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')