update ieemoo-ai-isempty.py.

This commit is contained in:
Brainway
2023-02-27 05:08:34 +00:00
committed by Gitee
parent f08ffe99f7
commit 489d1a2086

View File

@ -7,14 +7,36 @@ import cv2, base64
import argparse import argparse
import sys, os import sys, os
import torch import torch
from gevent.pywsgi import WSGIServer
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from models.modeling import VisionTransformer, CONFIGS
from vit_pytorch import ViT
# import logging.config as log_config # import logging.config as log_config
sys.path.insert(0, ".") sys.path.insert(0, ".")
#Flask对外服务接口 #Flask对外服务接口
# from skywalking import agent, config
# SW_SERVER = os.environ.get('SW_AGENT_COLLECTOR_BACKEND_SERVICES')
# SW_SERVICE_NAME = os.environ.get('SW_AGENT_NAME')
# if SW_SERVER and SW_SERVICE_NAME:
# config.init() #采集服务的地址,给自己的服务起个名称
# #config.init(collector="123.60.56.51:11800", service='ieemoo-ai-search') #采集服务的地址,给自己的服务起个名称
# agent.start()
# def setup_logging(path):
# if os.path.exists(path):
# with open(path, 'r') as f:
# config = json.load(f)
# log_config.dictConfig(config)
# print = logging.getprint("root")
# return print
# print = setup_logging('utils/logging.json')
app = Flask(__name__) app = Flask(__name__)
#app.use_reloader=False #app.use_reloader=False
@ -23,12 +45,12 @@ print(torch.__version__)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--img_size", default=600, type=int, help="Resolution size") parser.add_argument("--img_size", default=320, type=int, help="Resolution size")
parser.add_argument('--split', type=str, default='overlap', help="Split method") parser.add_argument('--split', type=str, default='overlap', help="Split method")
parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split") parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value")
#使用自定义VIT parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin", help="load pretrained model")
parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #使用自定义VIT
opt, unknown = parser.parse_known_args() opt, unknown = parser.parse_known_args()
return opt return opt
@ -43,7 +65,7 @@ class Predictor(object):
self.num_classes = 0 self.num_classes = 0
self.model = None self.model = None
self.prepare_model() self.prepare_model()
self.test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), self.test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
@ -62,8 +84,7 @@ class Predictor(object):
# self.model = torch.load(self.args.pretrained_model,map_location='cpu') # self.model = torch.load(self.args.pretrained_model,map_location='cpu')
self.model = torch.load(self.args.pretrained_model,map_location=torch.device('cpu')) self.model = torch.load(self.args.pretrained_model,map_location=torch.device('cpu'))
self.model.eval() self.model.eval()
if torch.cuda.is_available(): self.model.to("cuda")
self.model.to("cuda")
def normal_predict(self, img_data, result): def normal_predict(self, img_data, result):
# img = Image.open(img_path) # img = Image.open(img_path)
@ -74,8 +95,8 @@ class Predictor(object):
else: else:
with torch.no_grad(): with torch.no_grad():
x = self.test_transform(img_data) x = self.test_transform(img_data)
if torch.cuda.is_available(): # if torch.cuda.is_available():
x = x.cuda() # x = x.cuda()
part_logits = self.model(x.unsqueeze(0)) part_logits = self.model(x.unsqueeze(0))
probs = torch.nn.Softmax(dim=-1)(part_logits) probs = torch.nn.Softmax(dim=-1)(part_logits)
topN = torch.argsort(probs, dim=-1, descending=True).tolist() topN = torch.argsort(probs, dim=-1, descending=True).tolist()
@ -93,7 +114,7 @@ predictor = Predictor(args)
@app.route("/isempty", methods=['POST']) @app.route("/isempty", methods=['POST'])
def get_isempty(): def get_isempty():
#print("begin") print("begin")
data = request.get_data() data = request.get_data()