# coding=utf-8 import os import torch import numpy as np from PIL import Image from torchvision import transforms import argparse from models.modeling import VisionTransformer, CONFIGS import time #复杂场景小模型测试单张图片 def riseempty(imgdata): parser = argparse.ArgumentParser() parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/new/ieemooempty_vitlight_checkpoint.pth", help="load pretrained model") #使用自定义VIT args = parser.parse_args() #args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.device = torch.device("cpu") num_classes = 2 cls_dict = {0: "noemp", 1: "yesemp"} test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # 准备模型 model = torch.load(args.pretrained_model,map_location=torch.device('cpu')) #自己预训练模型 model.to(args.device) model.eval() x = test_transform(imgdata) part_logits = model(x.unsqueeze(0).to(args.device)) probs = torch.nn.Softmax(dim=-1)(part_logits) top2 = torch.argsort(probs, dim=-1, descending=True) riseclas_ids = top2[0][0] #print("cur_img result: class id: %d, score: %0.3f" % (riseclas_ids, probs[0, riseclas_ids].item())) riseresult={} riseresult["success"] = "true" riseresult["rst_cls"] = int(riseclas_ids) return riseresult # if __name__ == "__main__": # riseresult = riseempty("light.jpg") # print(riseresult)