import pdb import torch import torch.nn as nn from model import resnet18 # from config import config as conf from collections import OrderedDict from configs import trainer_tools import cv2 import yaml def tranform_onnx_model(): # # 定义模型 # if model_name == 'resnet18': # model = resnet18(scale=0.75) with open('../configs/transform.yml', 'r') as f: conf = yaml.load(f, Loader=yaml.FullLoader) tr_tools = trainer_tools(conf) backbone_mapping = tr_tools.get_backbone() if conf['models']['backbone'] in backbone_mapping: model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device']) else: raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']})) pretrained_weights = conf['models']['model_path'] print('model_name >>> {}'.format(conf['models']['backbone'])) if conf['base']['distributed']: model = model.to(torch.device('cpu')) checkpoint = torch.load(pretrained_weights) new_state_dict = OrderedDict() for k, v in checkpoint.items(): name = k[7:] # remove "module." new_state_dict[name] = v model.load_state_dict(new_state_dict) else: model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu'))) # 转换为ONNX input_shape = [1, 3, 224, 224] img = cv2.imread('./dog_224x224.jpg') output_file = pretrained_weights.replace('pth', 'onnx') # 导出模型 torch.onnx.export(model, torch.randn(input_shape), output_file, verbose=True, input_names=['input'], output_names=['output']) ##, optset_version=12 model.eval() trace_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224)) trace_model.save(output_file.replace('.onnx', '.pt')) print(f"Model exported to {output_file}") if __name__ == '__main__': tranform_onnx_model()