Files
ieemoo-ai-isempty/ieemoo-ai-isempty.py
2022-05-05 11:54:40 +08:00

142 lines
5.0 KiB
Python

# -*- coding: utf-8 -*-
from flask import request, Flask
import numpy as np
import json
import time
import cv2, base64
import argparse
import sys, os
import torch
from gevent.pywsgi import WSGIServer
from PIL import Image
from torchvision import transforms
from models.modeling import VisionTransformer, CONFIGS
sys.path.insert(0, ".")
app = Flask(__name__)
app.use_reloader=False
def parse_args(model_file="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin"):
#def parse_args(model_file="output/emptyjudge5_checkpoint.bin"):
parser = argparse.ArgumentParser()
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
parser.add_argument('--split', type=str, default='overlap', help="Split method")
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value")
parser.add_argument("--pretrained_model", type=str, default=model_file, help="load pretrained model")
opt, unknown = parser.parse_known_args()
return opt
class Predictor(object):
def __init__(self, args):
self.args = args
self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(self.args.device)
self.args.nprocs = torch.cuda.device_count()
self.cls_dict = {}
self.num_classes = 0
self.model = None
self.prepare_model()
self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def prepare_model(self):
config = CONFIGS["ViT-B_16"]
config.split = self.args.split
config.slide_step = self.args.slide_step
model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "")
print("use model_name: ", model_name)
self.num_classes = 5
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
if self.args.pretrained_model is not None:
if not torch.cuda.is_available():
pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model']
self.model.load_state_dict(pretrained_model)
else:
pretrained_model = torch.load(self.args.pretrained_model)['model']
self.model.load_state_dict(pretrained_model)
self.model.eval()
self.model.to(self.args.device)
#self.model.eval()
def normal_predict(self, img_data, result):
# img = Image.open(img_path)
if img_data is None:
print('error, img data is None')
return result
else:
with torch.no_grad():
x = self.test_transform(img_data)
if torch.cuda.is_available():
x = x.cuda()
part_logits = self.model(x.unsqueeze(0))
probs = torch.nn.Softmax(dim=-1)(part_logits)
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
clas_ids = topN[0][0]
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item()))
result["success"] = "true"
result["rst_cls"] = str(clas_ids)
return result
model_file ="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin"
#model_file ="output/emptyjudge5_checkpoint.bin"
args = parse_args(model_file)
predictor = Predictor(args)
@app.route("/isempty", methods=['POST'])
def get_isempty():
start = time.time()
print('--------------------EmptyPredict-----------------')
data = request.get_data()
ip = request.remote_addr
print('------ ip = %s ------' % ip)
json_data = json.loads(data.decode("utf-8"))
getdateend = time.time()
print('get date use time: {0:.2f}s'.format(getdateend - start))
pic = json_data.get("pic")
result = {"success": "false",
"rst_cls": '-1',
}
try:
imgdata = base64.b64decode(pic)
imgdata_np = np.frombuffer(imgdata, dtype='uint8')
img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR)
img_data = Image.fromarray(np.uint8(img_src))
result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty
except:
return repr(result)
return repr(result)
if __name__ == "__main__":
app.run()
# http_server = WSGIServer(('0.0.0.0',8000), app)
# http_server.serve_forever()