update ieemoo-ai-isempty.py.
This commit is contained in:
@ -11,6 +11,7 @@ from gevent.pywsgi import WSGIServer
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from models.modeling import VisionTransformer, CONFIGS
|
from models.modeling import VisionTransformer, CONFIGS
|
||||||
|
from vit_pytorch import ViT
|
||||||
# import logging.config as log_config
|
# import logging.config as log_config
|
||||||
sys.path.insert(0, ".")
|
sys.path.insert(0, ".")
|
||||||
|
|
||||||
@ -40,7 +41,7 @@ app = Flask(__name__)
|
|||||||
app.use_reloader=False
|
app.use_reloader=False
|
||||||
|
|
||||||
|
|
||||||
def parse_args(model_file="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin"):
|
def parse_args(model_file="./output/ieemooempty_vit_checkpoint.pth"):
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--img_size", default=600, type=int, help="Resolution size")
|
parser.add_argument("--img_size", default=600, type=int, help="Resolution size")
|
||||||
parser.add_argument('--split', type=str, default='overlap', help="Split method")
|
parser.add_argument('--split', type=str, default='overlap', help="Split method")
|
||||||
@ -54,9 +55,10 @@ def parse_args(model_file="../module/ieemoo-ai-isempty/model/now/emptyjudge5_che
|
|||||||
class Predictor(object):
|
class Predictor(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.args.device = torch.device("cpu")
|
#self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
#self.args.device = torch.device("cpu")
|
||||||
#print(self.args.device)
|
#print(self.args.device)
|
||||||
self.args.nprocs = torch.cuda.device_count()
|
#self.args.nprocs = torch.cuda.device_count()
|
||||||
self.cls_dict = {}
|
self.cls_dict = {}
|
||||||
self.num_classes = 0
|
self.num_classes = 0
|
||||||
self.model = None
|
self.model = None
|
||||||
@ -66,19 +68,21 @@ class Predictor(object):
|
|||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||||
|
|
||||||
def prepare_model(self):
|
def prepare_model(self):
|
||||||
config = CONFIGS["ViT-B_16"]
|
# config = CONFIGS["ViT-B_16"]
|
||||||
config.split = self.args.split
|
# config.split = self.args.split
|
||||||
config.slide_step = self.args.slide_step
|
# config.slide_step = self.args.slide_step
|
||||||
self.num_classes = 5
|
# self.num_classes = 5
|
||||||
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
|
# self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
|
||||||
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
|
# self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
|
||||||
if self.args.pretrained_model is not None:
|
|
||||||
if not torch.cuda.is_available():
|
# if self.args.pretrained_model is not None:
|
||||||
|
# if not torch.cuda.is_available():
|
||||||
|
# self.model = torch.load(self.args.pretrained_model)
|
||||||
|
# else:
|
||||||
|
# self.model = torch.load(self.args.pretrained_model,map_location='cpu')
|
||||||
self.model = torch.load(self.args.pretrained_model)
|
self.model = torch.load(self.args.pretrained_model)
|
||||||
else:
|
|
||||||
self.model = torch.load(self.args.pretrained_model,map_location='cpu')
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model.to(self.args.device)
|
self.model.to("cuda")
|
||||||
|
|
||||||
def normal_predict(self, img_data, result):
|
def normal_predict(self, img_data, result):
|
||||||
# img = Image.open(img_path)
|
# img = Image.open(img_path)
|
||||||
@ -102,7 +106,7 @@ class Predictor(object):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
model_file ="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin"
|
model_file ="./output/ieemooempty_checkpoint_good.pth"
|
||||||
args = parse_args(model_file)
|
args = parse_args(model_file)
|
||||||
predictor = Predictor(args)
|
predictor = Predictor(args)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user