From 489d1a208689de0c3ce894d0fe35ff44dc4e15e6 Mon Sep 17 00:00:00 2001 From: Brainway Date: Mon, 27 Feb 2023 05:08:34 +0000 Subject: [PATCH] update ieemoo-ai-isempty.py. --- ieemoo-ai-isempty.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/ieemoo-ai-isempty.py b/ieemoo-ai-isempty.py index a787196..7bba370 100755 --- a/ieemoo-ai-isempty.py +++ b/ieemoo-ai-isempty.py @@ -7,14 +7,36 @@ import cv2, base64 import argparse import sys, os import torch +from gevent.pywsgi import WSGIServer from PIL import Image from torchvision import transforms +from models.modeling import VisionTransformer, CONFIGS +from vit_pytorch import ViT # import logging.config as log_config sys.path.insert(0, ".") #Flask对外服务接口 +# from skywalking import agent, config +# SW_SERVER = os.environ.get('SW_AGENT_COLLECTOR_BACKEND_SERVICES') +# SW_SERVICE_NAME = os.environ.get('SW_AGENT_NAME') +# if SW_SERVER and SW_SERVICE_NAME: +# config.init() #采集服务的地址,给自己的服务起个名称 +# #config.init(collector="123.60.56.51:11800", service='ieemoo-ai-search') #采集服务的地址,给自己的服务起个名称 +# agent.start() + +# def setup_logging(path): +# if os.path.exists(path): +# with open(path, 'r') as f: +# config = json.load(f) +# log_config.dictConfig(config) +# print = logging.getprint("root") +# return print + +# print = setup_logging('utils/logging.json') + + app = Flask(__name__) #app.use_reloader=False @@ -23,12 +45,12 @@ print(torch.__version__) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--img_size", default=600, type=int, help="Resolution size") + parser.add_argument("--img_size", default=320, type=int, help="Resolution size") parser.add_argument('--split', type=str, default='overlap', help="Split method") parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value") - #使用自定义VIT - parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/ieemooempty_vit_checkpoint.pth", help="load pretrained model") + parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin", help="load pretrained model") + #parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #使用自定义VIT opt, unknown = parser.parse_known_args() return opt @@ -43,7 +65,7 @@ class Predictor(object): self.num_classes = 0 self.model = None self.prepare_model() - self.test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + self.test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) @@ -62,8 +84,7 @@ class Predictor(object): # self.model = torch.load(self.args.pretrained_model,map_location='cpu') self.model = torch.load(self.args.pretrained_model,map_location=torch.device('cpu')) self.model.eval() - if torch.cuda.is_available(): - self.model.to("cuda") + self.model.to("cuda") def normal_predict(self, img_data, result): # img = Image.open(img_path) @@ -74,8 +95,8 @@ class Predictor(object): else: with torch.no_grad(): x = self.test_transform(img_data) - if torch.cuda.is_available(): - x = x.cuda() + # if torch.cuda.is_available(): + # x = x.cuda() part_logits = self.model(x.unsqueeze(0)) probs = torch.nn.Softmax(dim=-1)(part_logits) topN = torch.argsort(probs, dim=-1, descending=True).tolist() @@ -93,7 +114,7 @@ predictor = Predictor(args) @app.route("/isempty", methods=['POST']) def get_isempty(): - #print("begin") + print("begin") data = request.get_data()