diff --git a/ieemoo-ai-isempty.py b/ieemoo-ai-isempty.py index 06c1248..197ea00 100755 --- a/ieemoo-ai-isempty.py +++ b/ieemoo-ai-isempty.py @@ -11,36 +11,40 @@ from gevent.pywsgi import WSGIServer from PIL import Image from torchvision import transforms from models.modeling import VisionTransformer, CONFIGS +# import logging.config as log_config sys.path.insert(0, ".") -import logging.config -from skywalking import agent, config -SW_SERVER = os.environ.get('SW_AGENT_COLLECTOR_BACKEND_SERVICES') -SW_SERVICE_NAME = os.environ.get('SW_AGENT_NAME') -if SW_SERVER and SW_SERVICE_NAME: - config.init() #采集服务的地址,给自己的服务起个名称 - #config.init(collector="123.60.56.51:11800", service='ieemoo-ai-search') #采集服务的地址,给自己的服务起个名称 - agent.start() -def setup_logging(path): - if os.path.exists(path): - with open(path, 'r') as f: - config = json.load(f) - logging.config.dictConfig(config) - logger = logging.getLogger("root") - return logger -logger = setup_logging('utils/logging.json') +#Flask对外服务接口 + + +# from skywalking import agent, config +# SW_SERVER = os.environ.get('SW_AGENT_COLLECTOR_BACKEND_SERVICES') +# SW_SERVICE_NAME = os.environ.get('SW_AGENT_NAME') +# if SW_SERVER and SW_SERVICE_NAME: +# config.init() #采集服务的地址,给自己的服务起个名称 +# #config.init(collector="123.60.56.51:11800", service='ieemoo-ai-search') #采集服务的地址,给自己的服务起个名称 +# agent.start() + +# def setup_logging(path): +# if os.path.exists(path): +# with open(path, 'r') as f: +# config = json.load(f) +# log_config.dictConfig(config) +# print = logging.getprint("root") +# return print + +# print = setup_logging('utils/logging.json') app = Flask(__name__) app.use_reloader=False -def parse_args(model_file="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin"): -#def parse_args(model_file="output/emptyjudge5_checkpoint.bin"): +def parse_args(model_file="./output/ieemooempty_checkpoint_good.pth"): parser = argparse.ArgumentParser() - parser.add_argument("--img_size", default=448, type=int, help="Resolution size") + parser.add_argument("--img_size", default=600, type=int, help="Resolution size") parser.add_argument('--split', type=str, default='overlap', help="Split method") - parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split") + parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value") parser.add_argument("--pretrained_model", type=str, default=model_file, help="load pretrained model") opt, unknown = parser.parse_known_args() @@ -57,7 +61,7 @@ class Predictor(object): self.num_classes = 0 self.model = None self.prepare_model() - self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR), + self.test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) @@ -65,27 +69,22 @@ class Predictor(object): config = CONFIGS["ViT-B_16"] config.split = self.args.split config.slide_step = self.args.slide_step - model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "") - #print("use model_name: ", model_name) self.num_classes = 5 self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"} self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value) if self.args.pretrained_model is not None: if not torch.cuda.is_available(): - pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model'] - self.model.load_state_dict(pretrained_model) + self.model = torch.load(self.args.pretrained_model) else: - pretrained_model = torch.load(self.args.pretrained_model)['model'] - self.model.load_state_dict(pretrained_model) + self.model = torch.load(self.args.pretrained_model,map_location='cpu') self.model.eval() self.model.to(self.args.device) - #self.model.eval() - + def normal_predict(self, img_data, result): # img = Image.open(img_path) if img_data is None: #print('error, img data is None') - logger.warning('error, img data is None') + print('error, img data is None') return result else: with torch.no_grad(): @@ -103,8 +102,7 @@ class Predictor(object): return result -model_file ="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin" -#model_file ="output/emptyjudge5_checkpoint.bin" +model_file ="./output/ieemooempty_checkpoint_good.pth" args = parse_args(model_file) predictor = Predictor(args) @@ -116,7 +114,7 @@ def get_isempty(): data = request.get_data() ip = request.remote_addr #print('------ ip = %s ------' % ip) - logger.info(ip) + print(ip) json_data = json.loads(data.decode("utf-8")) getdateend = time.time() @@ -133,10 +131,10 @@ def get_isempty(): img_data = Image.fromarray(np.uint8(img_src)) result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty except Exception as e: - logger.warning(e) + print(e) return repr(result) - logger.info(repr(result)) + print(repr(result)) return repr(result) if __name__ == "__main__": - app.run(host='192.168.1.142', port=8000) + app.run(host='0.0.0.0', port=14465)