# -*- coding: utf-8 -*- from flask import request, Flask import numpy as np import json import time import cv2, base64 import argparse import sys, os import torch from PIL import Image from torchvision import transforms from models.modeling import VisionTransformer, CONFIGS sys.path.insert(0, ".") app = Flask(__name__) app.use_reloader=False def parse_args(model_file="ckpts/emptyjudge5_checkpoint.bin"): parser = argparse.ArgumentParser() parser.add_argument("--img_size", default=448, type=int, help="Resolution size") parser.add_argument('--split', type=str, default='overlap', help="Split method") parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value") parser.add_argument("--pretrained_model", type=str, default=model_file, help="load pretrained model") opt, unknown = parser.parse_known_args() return opt class Predictor(object): def __init__(self, args): self.args = args self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(self.args.device) self.args.nprocs = torch.cuda.device_count() self.cls_dict = {} self.num_classes = 0 self.model = None self.prepare_model() self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) def prepare_model(self): config = CONFIGS["ViT-B_16"] config.split = self.args.split config.slide_step = self.args.slide_step model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "") print("use model_name: ", model_name) self.num_classes = 5 self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"} self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value) if self.args.pretrained_model is not None: if not torch.cuda.is_available(): pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model'] self.model.load_state_dict(pretrained_model) else: pretrained_model = torch.load(self.args.pretrained_model)['model'] self.model.load_state_dict(pretrained_model) self.model.eval() self.model.to(self.args.device) #self.model.eval() def normal_predict(self, img_data, result): # img = Image.open(img_path) if img_data is None: print('error, img data is None') return result else: with torch.no_grad(): x = self.test_transform(img_data) if torch.cuda.is_available(): x = x.cuda() part_logits = self.model(x.unsqueeze(0)) probs = torch.nn.Softmax(dim=-1)(part_logits) topN = torch.argsort(probs, dim=-1, descending=True).tolist() clas_ids = topN[0][0] clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1 print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item())) result["success"] = "true" result["rst_cls"] = str(clas_ids) return result model_file ="/data/ieemoo/emptypredict_pfc_FG/ckpts/emptyjudge5_checkpoint.bin" args = parse_args(model_file) predictor = Predictor(args) @app.route("/isempty", methods=['POST']) def get_isempty(): start = time.time() print('--------------------EmptyPredict-----------------') data = request.get_data() ip = request.remote_addr print('------ ip = %s ------' % ip) json_data = json.loads(data.decode("utf-8")) getdateend = time.time() print('get date use time: {0:.2f}s'.format(getdateend - start)) pic = json_data.get("pic") result = {"success": "false", "rst_cls": '-1', } try: imgdata = base64.b64decode(pic) imgdata_np = np.frombuffer(imgdata, dtype='uint8') img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR) img_data = Image.fromarray(np.uint8(img_src)) result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty except: return repr(result) return repr(result) if __name__ == "__main__": app.run() # app.run("0.0.0.0", port=8083)