Files
ieemoo-ai-isempty/onx.py
Brainway 52427ac8a9 update
2022-10-12 02:04:28 +00:00

40 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import json
import time
import cv2, base64
import argparse
import sys, os
import torch
from gevent.pywsgi import WSGIServer
from PIL import Image
from torchvision import transforms
from models.modeling import VisionTransformer, CONFIGS
from vit_pytorch import ViT
model = torch.load("../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin",map_location="cpu")
model.eval()
model.to("cpu")
test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
img = Image.open("img.jpg")
x = test_transform(img)
part_logits = model(x.unsqueeze(0))
probs = torch.nn.Softmax(dim=-1)(part_logits)
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
clas_ids = topN[0][0]
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
result={}
result["success"] = "true"
result["rst_cls"] = str(clas_ids)
print(result)
input = torch.randn(1, 3, 600, 600) # BCHW 其中Batch必须为1因为测试时一般为1尺寸HW必须和训练时的尺寸一致
torch.onnx.export(model, input, "../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.onx", verbose=False)