diff --git a/predict.py b/predict.py index ca2a087..f3dabaf 100755 --- a/predict.py +++ b/predict.py @@ -9,6 +9,7 @@ from sklearn.metrics import f1_score from PIL import Image from torchvision import transforms from models.modeling import VisionTransformer, CONFIGS +import lightrise #模型预测 def parse_args(): @@ -17,16 +18,15 @@ def parse_args(): parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n") - parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") - #parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin", help="load pretrained model") - #parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vitgood_checkpoint.pth", help="load pretrained model") + parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/new/ieemooempty_vit_checkpoint.pth", help="load pretrained model") + #parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #使用自定义VIT return parser.parse_args() class Predictor(object): def __init__(self, args): self.args = args - self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.args.device = torch.device("cuda") print("self.args.device =", self.args.device) self.args.nprocs = torch.cuda.device_count() @@ -43,16 +43,12 @@ class Predictor(object): config.split = self.args.split config.slide_step = self.args.slide_step self.num_classes = 5 - self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"} + self.cls_dict = {0: "noemp", 1: "yesemp"} self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value) if self.args.pretrained_model is not None: - if not torch.cuda.is_available(): - self.model = torch.load(self.args.pretrained_model) - else: - self.model = torch.load(self.args.pretrained_model,map_location='cpu') - + self.model = torch.load(self.args.pretrained_model,map_location='cpu') self.model.to(self.args.device) self.model.eval() @@ -65,9 +61,7 @@ class Predictor(object): "Image file failed to read: {}".format(img_path)) else: x = self.test_transform(img) - if torch.cuda.is_available(): - x = x.cuda() - part_logits = self.model(x.unsqueeze(0)) + part_logits = self.model(x.unsqueeze(0).to(args.device)) probs = torch.nn.Softmax(dim=-1)(part_logits) topN = torch.argsort(probs, dim=-1, descending=True).tolist() clas_ids = topN[0][0] @@ -81,10 +75,12 @@ if __name__ == "__main__": y_true = [] y_pred = [] - # test_dir = "./emptyJudge5/images/" - # dir_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"} - test_dir = "../emptyJudge2/images" - dir_dict = {"noempty":"0", "empty":"1"} + test_dir = "./emptyJudge5/images/" + dir_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"} + + # test_dir = "../emptyJudge2/images" + # dir_dict = {"noempty":"0", "empty":"1"} + total = 0 num = 0 t0 = time.time() @@ -100,6 +96,19 @@ if __name__ == "__main__": cur_pred, pred_score = predictor.normal_predict(cur_img_file) label = 0 if 2 == int(label) or 3 == int(label) or 4 == int(label) else int(label) + + riseresult = lightrise.riseempty(Image.open(cur_img_file)) + if(label==1): + if(int(riseresult["rst_cls"])==1): + label=1 + else: + label=0 + # else: + # if(riseresult["rst_cls"]==0): + # label=0 + # else: + # label=1 + cur_pred = 0 if 2 == int(cur_pred) or 3 == int(cur_pred) or 4 == int(cur_pred) else int(cur_pred) y_true.append(int(label)) y_pred.append(int(cur_pred))