From 52427ac8a9d95143a22980838b808c54b7758d83 Mon Sep 17 00:00:00 2001 From: Brainway Date: Wed, 12 Oct 2022 02:04:28 +0000 Subject: [PATCH] update --- onx.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 onx.py diff --git a/onx.py b/onx.py new file mode 100644 index 0000000..badaffe --- /dev/null +++ b/onx.py @@ -0,0 +1,39 @@ +import numpy as np +import json +import time +import cv2, base64 +import argparse +import sys, os +import torch +from gevent.pywsgi import WSGIServer +from PIL import Image +from torchvision import transforms +from models.modeling import VisionTransformer, CONFIGS +from vit_pytorch import ViT + +model = torch.load("../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin",map_location="cpu") +model.eval() +model.to("cpu") + +test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) +img = Image.open("img.jpg") +x = test_transform(img) +part_logits = model(x.unsqueeze(0)) +probs = torch.nn.Softmax(dim=-1)(part_logits) +topN = torch.argsort(probs, dim=-1, descending=True).tolist() +clas_ids = topN[0][0] +clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1 +result={} +result["success"] = "true" +result["rst_cls"] = str(clas_ids) + +print(result) + + +input = torch.randn(1, 3, 600, 600) # BCHW 其中Batch必须为1,因为测试时一般为1,尺寸HW必须和训练时的尺寸一致 +torch.onnx.export(model, input, "../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.onx", verbose=False) + + +