import pdb import torch import torch.nn as nn from model import resnet18 from config import config as conf from collections import OrderedDict import cv2 def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'): # 定义模型 if model_name == 'resnet18': model = resnet18(scale=0.75) print('model_name >>> {}'.format(model_name)) if conf.multiple_cards: model = model.to(torch.device('cpu')) checkpoint = torch.load(pretrained_weights) new_state_dict = OrderedDict() for k, v in checkpoint.items(): name = k[7:] # remove "module." new_state_dict[name] = v model.load_state_dict(new_state_dict) else: model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu'))) # try: # model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu'))) # except Exception as e: # print(e) # # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()}) # model = nn.DataParallel(model).to(conf.device) # model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu'))) # 转换为ONNX if model_name == 'gift_type2': input_shape = [1, 64, 13, 13] elif model_name == 'gift_type3': input_shape = [1, 3, 224, 224] else: # 假设输入数据的大小是通道数*高度*宽度,例如3*224*224 input_shape = [1, 3, 224, 224] img = cv2.imread('./dog_224x224.jpg') output_file = pretrained_weights.replace('pth', 'onnx') # 导出模型 torch.onnx.export(model, torch.randn(input_shape), output_file, verbose=True, input_names=['input'], output_names=['output']) ##, optset_version=12 model.eval() trace_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224)) trace_model.save(output_file.replace('.onnx', '.pt')) print(f"Model exported to {output_file}") if __name__ == '__main__': tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理 pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')