diff --git a/lightrise.py b/lightrise.py new file mode 100644 index 0000000..16d1554 --- /dev/null +++ b/lightrise.py @@ -0,0 +1,49 @@ +# coding=utf-8 +import os +import torch +import numpy as np +from PIL import Image +from torchvision import transforms +import argparse +from models.modeling import VisionTransformer, CONFIGS +import time + +parser = argparse.ArgumentParser() +parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/new/ieemooempty_vitlight_checkpoint.pth", help="load pretrained model") #使用自定义VIT +args = parser.parse_args() + +#args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +args.device = torch.device("cpu") + +# 准备模型 +model = torch.load(args.pretrained_model,map_location=torch.device('cpu')) #自己预训练模型 +model.to(args.device) +model.eval() + +num_classes = 2 +cls_dict = {0: "noemp", 1: "yesemp"} + +test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + + +#复杂场景小模型测试单张图片 +def riseempty(imgdata): + x = test_transform(imgdata) + part_logits = model(x.unsqueeze(0).to(args.device)) + probs = torch.nn.Softmax(dim=-1)(part_logits) + top2 = torch.argsort(probs, dim=-1, descending=True) + riseclas_ids = top2[0][0] + #print("cur_img result: class id: %d, score: %0.3f" % (riseclas_ids, probs[0, riseclas_ids].item())) + riseresult={} + riseresult["success"] = "true" + riseresult["rst_cls"] = int(riseclas_ids) + + return riseresult + +if __name__ == "__main__": + riseresult = riseempty("light.jpg") + print(riseresult) + +