update ieemoo-ai-isempty.py.

This commit is contained in:
Brainway
2022-11-01 02:24:27 +00:00
committed by Gitee
parent 5e4279f4e1
commit 8d6aa18fca

View File

@ -45,12 +45,13 @@ print(torch.__version__)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--img_size", default=320, type=int, help="Resolution size") parser.add_argument("--img_size", default=600, type=int, help="Resolution size")
parser.add_argument('--split', type=str, default='overlap', help="Split method") parser.add_argument('--split', type=str, default='overlap', help="Split method")
parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split") parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value")
parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin", help="load pretrained model") #parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin", help="load pretrained model")
#parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #使用自定义VIT #parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #使用自定义VIT
parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/new/ieemooempty_vit_checkpoint.pth", help="load pretrained model")
opt, unknown = parser.parse_known_args() opt, unknown = parser.parse_known_args()
return opt return opt
@ -65,7 +66,7 @@ class Predictor(object):
self.num_classes = 0 self.num_classes = 0
self.model = None self.model = None
self.prepare_model() self.prepare_model()
self.test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR), self.test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
@ -95,8 +96,8 @@ class Predictor(object):
else: else:
with torch.no_grad(): with torch.no_grad():
x = self.test_transform(img_data) x = self.test_transform(img_data)
# if torch.cuda.is_available(): if torch.cuda.is_available():
# x = x.cuda() x = x.cuda()
part_logits = self.model(x.unsqueeze(0)) part_logits = self.model(x.unsqueeze(0))
probs = torch.nn.Softmax(dim=-1)(part_logits) probs = torch.nn.Softmax(dim=-1)(part_logits)
topN = torch.argsort(probs, dim=-1, descending=True).tolist() topN = torch.argsort(probs, dim=-1, descending=True).tolist()