From 5d5788fe6fe25f4adc9c3aecd55424d851a152a1 Mon Sep 17 00:00:00 2001 From: Brainway Date: Tue, 22 Nov 2022 08:19:05 +0000 Subject: [PATCH] update ieemoo-ai-isempty.py. --- ieemoo-ai-isempty.py | 82 ++++++++++++++------------------------------ 1 file changed, 26 insertions(+), 56 deletions(-) diff --git a/ieemoo-ai-isempty.py b/ieemoo-ai-isempty.py index 1d75946..7bba370 100755 --- a/ieemoo-ai-isempty.py +++ b/ieemoo-ai-isempty.py @@ -12,7 +12,6 @@ from PIL import Image from torchvision import transforms from models.modeling import VisionTransformer, CONFIGS from vit_pytorch import ViT -#import lightrise # import logging.config as log_config sys.path.insert(0, ".") @@ -46,11 +45,11 @@ print(torch.__version__) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--img_size", default=600, type=int, help="Resolution size") + parser.add_argument("--img_size", default=320, type=int, help="Resolution size") parser.add_argument('--split', type=str, default='overlap', help="Split method") parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value") - parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/new/ieemooempty_vit_checkpoint.pth", help="load pretrained model") + parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin", help="load pretrained model") #parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #使用自定义VIT opt, unknown = parser.parse_known_args() return opt @@ -66,7 +65,7 @@ class Predictor(object): self.num_classes = 0 self.model = None self.prepare_model() - self.test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + self.test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) @@ -83,7 +82,7 @@ class Predictor(object): # self.model = torch.load(self.args.pretrained_model) # else: # self.model = torch.load(self.args.pretrained_model,map_location='cpu') - self.model = torch.load(self.args.pretrained_model) + self.model = torch.load(self.args.pretrained_model,map_location=torch.device('cpu')) self.model.eval() self.model.to("cuda") @@ -96,18 +95,16 @@ class Predictor(object): else: with torch.no_grad(): x = self.test_transform(img_data) - if torch.cuda.is_available(): - x = x.cuda() + # if torch.cuda.is_available(): + # x = x.cuda() part_logits = self.model(x.unsqueeze(0)) probs = torch.nn.Softmax(dim=-1)(part_logits) topN = torch.argsort(probs, dim=-1, descending=True).tolist() clas_ids = topN[0][0] clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1 - #print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item())) - result={} + print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item())) result["success"] = "true" result["rst_cls"] = str(clas_ids) - return result @@ -115,65 +112,38 @@ class Predictor(object): args = parse_args() predictor = Predictor(args) -def riseempty(imgdata): - risemodel = torch.load("../module/ieemoo-ai-isempty/model/new/ieemooempty_vitlight_checkpoint.pth",map_location=torch.device('cpu')) #自己预训练模型 - risemodel.to("cpu") - risemodel.eval() - test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) - x = test_transform(imgdata) - part_logits = risemodel(x.unsqueeze(0).to('cpu')) - probs = torch.nn.Softmax(dim=-1)(part_logits) - top2 = torch.argsort(probs, dim=-1, descending=True) - riseclas_ids = top2[0][0] - #print("cur_img result: class id: %d, score: %0.3f" % (riseclas_ids, probs[0, riseclas_ids].item())) - riseresult={} - riseresult["success"] = "true" - riseresult["rst_cls"] = int(riseclas_ids) - return riseresult - - - @app.route("/isempty", methods=['POST']) def get_isempty(): - start = time.time() - #print('--------------------EmptyPredict-----------------') + print("begin") + data = request.get_data() - ip = request.remote_addr - #print('------ ip = %s ------' % ip) - print(ip) json_data = json.loads(data.decode("utf-8")) - getdateend = time.time() - #print('get date use time: {0:.2f}s'.format(getdateend - start)) pic = json_data.get("pic") - result = {} - imgdata = base64.b64decode(pic) + + result ={} imgdata_np = np.frombuffer(imgdata, dtype='uint8') img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR) - cv2.imwrite('huanyuan.jpg',img_src) - #img_data = Image.fromarray(np.uint8(img_src)) #这个转换不能要,会导致判空错误增加 - img_data = Image.open('huanyuan.jpg') + img_data = Image.fromarray(np.uint8(img_src)) result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty - - riseresult = riseempty(img_data) - #print(riseresult["rst_cls"]) - - if(int(result["rst_cls"])==1): - if(int(riseresult["rst_cls"])==1): - result = {} - result["success"] = "true" - result["rst_cls"] = 1 - else: - result = {} - result["success"] = "true" - result["rst_cls"] = 0 - return repr(result) +def getByte(path): + with open(path, 'rb') as f: + img_byte = base64.b64encode(f.read()) + img_str = img_byte.decode('utf-8') + return img_str + if __name__ == "__main__": app.run(host='0.0.0.0', port=8888) + + # result ={} + # imgdata = base64.b64decode(getByte("img.jpg")) + # imgdata_np = np.frombuffer(imgdata, dtype='uint8') + # img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR) + # img_data = Image.fromarray(np.uint8(img_src)) + # result = predictor.normal_predict(img_data, result) + # print(result)