From 0ac741c608ba9d894b97cfbd46fd417028fba0ac Mon Sep 17 00:00:00 2001 From: Brainway Date: Wed, 9 Nov 2022 02:13:14 +0000 Subject: [PATCH] update ieemoo-ai-isempty.py. --- ieemoo-ai-isempty.py | 63 ++++++++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/ieemoo-ai-isempty.py b/ieemoo-ai-isempty.py index 3e3e879..ea10e73 100755 --- a/ieemoo-ai-isempty.py +++ b/ieemoo-ai-isempty.py @@ -12,6 +12,7 @@ from PIL import Image from torchvision import transforms from models.modeling import VisionTransformer, CONFIGS from vit_pytorch import ViT +import lightrise # import logging.config as log_config sys.path.insert(0, ".") @@ -49,9 +50,8 @@ def parse_args(): parser.add_argument('--split', type=str, default='overlap', help="Split method") parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value") - #parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin", help="load pretrained model") - #parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #使用自定义VIT parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/new/ieemooempty_vit_checkpoint.pth", help="load pretrained model") + #parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #使用自定义VIT opt, unknown = parser.parse_known_args() return opt @@ -83,7 +83,7 @@ class Predictor(object): # self.model = torch.load(self.args.pretrained_model) # else: # self.model = torch.load(self.args.pretrained_model,map_location='cpu') - self.model = torch.load(self.args.pretrained_model,map_location=torch.device('cpu')) + self.model = torch.load(self.args.pretrained_model) self.model.eval() self.model.to("cuda") @@ -103,9 +103,11 @@ class Predictor(object): topN = torch.argsort(probs, dim=-1, descending=True).tolist() clas_ids = topN[0][0] clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1 - print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item())) + #print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item())) + result={} result["success"] = "true" result["rst_cls"] = str(clas_ids) + return result @@ -115,36 +117,51 @@ predictor = Predictor(args) @app.route("/isempty", methods=['POST']) def get_isempty(): - print("begin") - + start = time.time() + #print('--------------------EmptyPredict-----------------') data = request.get_data() + ip = request.remote_addr + #print('------ ip = %s ------' % ip) + print(ip) json_data = json.loads(data.decode("utf-8")) + getdateend = time.time() + #print('get date use time: {0:.2f}s'.format(getdateend - start)) pic = json_data.get("pic") + result = {} + imgdata = base64.b64decode(pic) - - result ={} imgdata_np = np.frombuffer(imgdata, dtype='uint8') img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR) - img_data = Image.fromarray(np.uint8(img_src)) + cv2.imwrite('huanyuan.jpg',img_src) + #img_data = Image.fromarray(np.uint8(img_src)) #这个转换不能要,会导致判空错误增加 + img_data = Image.open('huanyuan.jpg') result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty + + riseresult = lightrise.riseempty(img_data) + #print(riseresult["rst_cls"]) + + if(result["rst_cls"]==1): + if(riseresult["rst_cls"]==1): + result = {} + result["success"] = "true" + result["rst_cls"] = 1 + else: + result = {} + result["success"] = "true" + result["rst_cls"] = 0 + else: + if(riseresult["rst_cls"]==0): + result = {} + result["success"] = "true" + result["rst_cls"] = 0 + else: + result = {} + result["success"] = "true" + result["rst_cls"] = 1 return repr(result) -def getByte(path): - with open(path, 'rb') as f: - img_byte = base64.b64encode(f.read()) - img_str = img_byte.decode('utf-8') - return img_str - if __name__ == "__main__": app.run(host='0.0.0.0', port=8888) - - # result ={} - # imgdata = base64.b64decode(getByte("img.jpg")) - # imgdata_np = np.frombuffer(imgdata, dtype='uint8') - # img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR) - # img_data = Image.fromarray(np.uint8(img_src)) - # result = predictor.normal_predict(img_data, result) - # print(result)