update
This commit is contained in:
39
onx.py
Normal file
39
onx.py
Normal file
@ -0,0 +1,39 @@
|
||||
import numpy as np
|
||||
import json
|
||||
import time
|
||||
import cv2, base64
|
||||
import argparse
|
||||
import sys, os
|
||||
import torch
|
||||
from gevent.pywsgi import WSGIServer
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from models.modeling import VisionTransformer, CONFIGS
|
||||
from vit_pytorch import ViT
|
||||
|
||||
model = torch.load("../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin",map_location="cpu")
|
||||
model.eval()
|
||||
model.to("cpu")
|
||||
|
||||
test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
img = Image.open("img.jpg")
|
||||
x = test_transform(img)
|
||||
part_logits = model(x.unsqueeze(0))
|
||||
probs = torch.nn.Softmax(dim=-1)(part_logits)
|
||||
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
|
||||
clas_ids = topN[0][0]
|
||||
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
|
||||
result={}
|
||||
result["success"] = "true"
|
||||
result["rst_cls"] = str(clas_ids)
|
||||
|
||||
print(result)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 600, 600) # BCHW 其中Batch必须为1,因为测试时一般为1,尺寸HW必须和训练时的尺寸一致
|
||||
torch.onnx.export(model, input, "../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.onx", verbose=False)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user