From f1734594a9703258a7dd55ed989cbc8277c08db7 Mon Sep 17 00:00:00 2001 From: Brainway Date: Wed, 9 Nov 2022 03:55:51 +0000 Subject: [PATCH] update ieemoo-ai-isempty.py. --- ieemoo-ai-isempty.py | 62 +++++++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/ieemoo-ai-isempty.py b/ieemoo-ai-isempty.py index 6c22206..6a477b7 100755 --- a/ieemoo-ai-isempty.py +++ b/ieemoo-ai-isempty.py @@ -12,7 +12,7 @@ from PIL import Image from torchvision import transforms from models.modeling import VisionTransformer, CONFIGS from vit_pytorch import ViT -import lightrise +#import lightrise # import logging.config as log_config sys.path.insert(0, ".") @@ -115,6 +115,26 @@ class Predictor(object): args = parse_args() predictor = Predictor(args) +def riseempty(imgdata): + risemodel = torch.load("../module/ieemoo-ai-isempty/model/new/ieemooempty_vitlight_checkpoint.pth",map_location=torch.device('cpu')) #自己预训练模型 + risemodel.to("cpu") + risemodel.eval() + test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + x = test_transform(imgdata) + part_logits = risemodel(x.unsqueeze(0).to('cpu')) + probs = torch.nn.Softmax(dim=-1)(part_logits) + top2 = torch.argsort(probs, dim=-1, descending=True) + riseclas_ids = top2[0][0] + #print("cur_img result: class id: %d, score: %0.3f" % (riseclas_ids, probs[0, riseclas_ids].item())) + riseresult={} + riseresult["success"] = "true" + riseresult["rst_cls"] = int(riseclas_ids) + return riseresult + + + @app.route("/isempty", methods=['POST']) def get_isempty(): start = time.time() @@ -139,27 +159,27 @@ def get_isempty(): img_data = Image.open('huanyuan.jpg') result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty - #riseresult = lightrise.riseempty(img_data) - #print(riseresult["rst_cls"]) + riseresult = riseempty(img_data) + print(riseresult["rst_cls"]) - # if(result["rst_cls"]==1): - # if(riseresult["rst_cls"]==1): - # result = {} - # result["success"] = "true" - # result["rst_cls"] = 1 - # else: - # result = {} - # result["success"] = "true" - # result["rst_cls"] = 0 - # else: - # if(riseresult["rst_cls"]==0): - # result = {} - # result["success"] = "true" - # result["rst_cls"] = 0 - # else: - # result = {} - # result["success"] = "true" - # result["rst_cls"] = 1 + if(result["rst_cls"]==1): + if(riseresult["rst_cls"]==1): + result = {} + result["success"] = "true" + result["rst_cls"] = 1 + else: + result = {} + result["success"] = "true" + result["rst_cls"] = 0 + else: + if(riseresult["rst_cls"]==0): + result = {} + result["success"] = "true" + result["rst_cls"] = 0 + else: + result = {} + result["success"] = "true" + result["rst_cls"] = 1 return repr(result)