From 8d6aa18fcada87e536d6e8a2b3863978a46bff80 Mon Sep 17 00:00:00 2001 From: Brainway Date: Tue, 1 Nov 2022 02:24:27 +0000 Subject: [PATCH] update ieemoo-ai-isempty.py. --- ieemoo-ai-isempty.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ieemoo-ai-isempty.py b/ieemoo-ai-isempty.py index 7bba370..3e3e879 100755 --- a/ieemoo-ai-isempty.py +++ b/ieemoo-ai-isempty.py @@ -45,12 +45,13 @@ print(torch.__version__) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--img_size", default=320, type=int, help="Resolution size") + parser.add_argument("--img_size", default=600, type=int, help="Resolution size") parser.add_argument('--split', type=str, default='overlap', help="Split method") parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value") - parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin", help="load pretrained model") + #parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin", help="load pretrained model") #parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #使用自定义VIT + parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/new/ieemooempty_vit_checkpoint.pth", help="load pretrained model") opt, unknown = parser.parse_known_args() return opt @@ -65,7 +66,7 @@ class Predictor(object): self.num_classes = 0 self.model = None self.prepare_model() - self.test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR), + self.test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) @@ -95,8 +96,8 @@ class Predictor(object): else: with torch.no_grad(): x = self.test_transform(img_data) - # if torch.cuda.is_available(): - # x = x.cuda() + if torch.cuda.is_available(): + x = x.cuda() part_logits = self.model(x.unsqueeze(0)) probs = torch.nn.Softmax(dim=-1)(part_logits) topN = torch.argsort(probs, dim=-1, descending=True).tolist()