update ieemoo-ai-isempty.py.

This commit is contained in:
Brainway
2022-10-11 08:22:17 +00:00
committed by Gitee
parent 062e2245de
commit 8bdd089fe6

View File

@ -11,6 +11,7 @@ from gevent.pywsgi import WSGIServer
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from models.modeling import VisionTransformer, CONFIGS from models.modeling import VisionTransformer, CONFIGS
from vit_pytorch import ViT
# import logging.config as log_config # import logging.config as log_config
sys.path.insert(0, ".") sys.path.insert(0, ".")
@ -40,7 +41,7 @@ app = Flask(__name__)
app.use_reloader=False app.use_reloader=False
def parse_args(model_file="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin"): def parse_args(model_file="./output/ieemooempty_vit_checkpoint.pth"):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--img_size", default=600, type=int, help="Resolution size") parser.add_argument("--img_size", default=600, type=int, help="Resolution size")
parser.add_argument('--split', type=str, default='overlap', help="Split method") parser.add_argument('--split', type=str, default='overlap', help="Split method")
@ -54,9 +55,10 @@ def parse_args(model_file="../module/ieemoo-ai-isempty/model/now/emptyjudge5_che
class Predictor(object): class Predictor(object):
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self.args.device = torch.device("cpu") #self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#self.args.device = torch.device("cpu")
#print(self.args.device) #print(self.args.device)
self.args.nprocs = torch.cuda.device_count() #self.args.nprocs = torch.cuda.device_count()
self.cls_dict = {} self.cls_dict = {}
self.num_classes = 0 self.num_classes = 0
self.model = None self.model = None
@ -66,19 +68,21 @@ class Predictor(object):
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def prepare_model(self): def prepare_model(self):
config = CONFIGS["ViT-B_16"] # config = CONFIGS["ViT-B_16"]
config.split = self.args.split # config.split = self.args.split
config.slide_step = self.args.slide_step # config.slide_step = self.args.slide_step
self.num_classes = 5 # self.num_classes = 5
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"} # self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value) # self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
if self.args.pretrained_model is not None:
if not torch.cuda.is_available(): # if self.args.pretrained_model is not None:
# if not torch.cuda.is_available():
# self.model = torch.load(self.args.pretrained_model)
# else:
# self.model = torch.load(self.args.pretrained_model,map_location='cpu')
self.model = torch.load(self.args.pretrained_model) self.model = torch.load(self.args.pretrained_model)
else:
self.model = torch.load(self.args.pretrained_model,map_location='cpu')
self.model.eval() self.model.eval()
self.model.to(self.args.device) self.model.to("cuda")
def normal_predict(self, img_data, result): def normal_predict(self, img_data, result):
# img = Image.open(img_path) # img = Image.open(img_path)
@ -102,7 +106,7 @@ class Predictor(object):
return result return result
model_file ="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin" model_file ="./output/ieemooempty_checkpoint_good.pth"
args = parse_args(model_file) args = parse_args(model_file)
predictor = Predictor(args) predictor = Predictor(args)