40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
import numpy as np
|
||
import json
|
||
import time
|
||
import cv2, base64
|
||
import argparse
|
||
import sys, os
|
||
import torch
|
||
from gevent.pywsgi import WSGIServer
|
||
from PIL import Image
|
||
from torchvision import transforms
|
||
from models.modeling import VisionTransformer, CONFIGS
|
||
from vit_pytorch import ViT
|
||
|
||
model = torch.load("../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin",map_location="cpu")
|
||
model.eval()
|
||
model.to("cpu")
|
||
|
||
test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||
img = Image.open("img.jpg")
|
||
x = test_transform(img)
|
||
part_logits = model(x.unsqueeze(0))
|
||
probs = torch.nn.Softmax(dim=-1)(part_logits)
|
||
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
|
||
clas_ids = topN[0][0]
|
||
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
|
||
result={}
|
||
result["success"] = "true"
|
||
result["rst_cls"] = str(clas_ids)
|
||
|
||
print(result)
|
||
|
||
|
||
input = torch.randn(1, 3, 600, 600) # BCHW 其中Batch必须为1,因为测试时一般为1,尺寸HW必须和训练时的尺寸一致
|
||
torch.onnx.export(model, input, "../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.onx", verbose=False)
|
||
|
||
|
||
|