Files
ieemoo-ai-isempty/ieemoo-ai-isempty.py
2023-02-27 05:16:29 +00:00

129 lines
4.8 KiB
Python
Executable File

# -*- coding: utf-8 -*-
from flask import request, Flask
import numpy as np
import json
import time
import cv2, base64
import argparse
import sys, os
import torch
from PIL import Image
from torchvision import transforms
# import logging.config as log_config
sys.path.insert(0, ".")
#Flask对外服务接口
app = Flask(__name__)
#app.use_reloader=False
print("Autor:ieemoo_lc&ieemoo_lx")
print(torch.__version__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--img_size", default=600, type=int, help="Resolution size")
parser.add_argument('--split', type=str, default='overlap', help="Split method")
parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value")
#使用自定义VIT
parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/ieemooempty_vit_checkpoint.pth", help="load pretrained model")
opt, unknown = parser.parse_known_args()
return opt
class Predictor(object):
def __init__(self, args):
self.args = args
#self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#self.args.device = torch.device("cpu")
#print(self.args.device)
#self.args.nprocs = torch.cuda.device_count()
self.cls_dict = {}
self.num_classes = 0
self.model = None
self.prepare_model()
self.test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def prepare_model(self):
# config = CONFIGS["ViT-B_16"]
# config.split = self.args.split
# config.slide_step = self.args.slide_step
# self.num_classes = 5
# self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
# self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
# if self.args.pretrained_model is not None:
# if not torch.cuda.is_available():
# self.model = torch.load(self.args.pretrained_model)
# else:
# self.model = torch.load(self.args.pretrained_model,map_location='cpu')
self.model = torch.load(self.args.pretrained_model,map_location=torch.device('cpu'))
self.model.eval()
if torch.cuda.is_available():
self.model.to("cuda")
def normal_predict(self, img_data, result):
# img = Image.open(img_path)
if img_data is None:
#print('error, img data is None')
print('error, img data is None')
return result
else:
with torch.no_grad():
x = self.test_transform(img_data)
if torch.cuda.is_available():
x = x.cuda()
part_logits = self.model(x.unsqueeze(0))
probs = torch.nn.Softmax(dim=-1)(part_logits)
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
clas_ids = topN[0][0]
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item()))
result["success"] = "true"
result["rst_cls"] = str(clas_ids)
return result
args = parse_args()
predictor = Predictor(args)
@app.route("/isempty", methods=['POST'])
def get_isempty():
#print("begin")
data = request.get_data()
json_data = json.loads(data.decode("utf-8"))
pic = json_data.get("pic")
imgdata = base64.b64decode(pic)
result ={}
imgdata_np = np.frombuffer(imgdata, dtype='uint8')
img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR)
img_data = Image.fromarray(np.uint8(img_src))
result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty
return repr(result)
def getByte(path):
with open(path, 'rb') as f:
img_byte = base64.b64encode(f.read())
img_str = img_byte.decode('utf-8')
return img_str
if __name__ == "__main__":
app.run(host='0.0.0.0', port=8888)
# result ={}
# imgdata = base64.b64decode(getByte("img.jpg"))
# imgdata_np = np.frombuffer(imgdata, dtype='uint8')
# img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR)
# img_data = Image.fromarray(np.uint8(img_src))
# result = predictor.normal_predict(img_data, result)
# print(result)