update predict.py.

This commit is contained in:
Brainway
2022-10-26 15:43:43 +00:00
committed by Gitee
parent 5c21167991
commit a0dda64ad5

View File

@ -17,7 +17,9 @@ def parse_args():
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split") parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
parser.add_argument("--pretrained_model", type=str, default="./output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_checkpoint_good.pth", help="load pretrained model")
parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin", help="load pretrained model")
#parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vitgood_checkpoint.pth", help="load pretrained model")
return parser.parse_args() return parser.parse_args()
@ -41,7 +43,7 @@ class Predictor(object):
config.split = self.args.split config.split = self.args.split
config.slide_step = self.args.slide_step config.slide_step = self.args.slide_step
self.num_classes = 5 self.num_classes = 5
self.cls_dict = {0: "noemp", 1: "yesemp"} self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value) self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
@ -79,15 +81,10 @@ if __name__ == "__main__":
y_true = [] y_true = []
y_pred = [] y_pred = []
#test_dir = "./emptyJudge5/images/" # test_dir = "./emptyJudge5/images/"
# dir_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"} # dir_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"}
test_dir = "../emptyJudge2/images"
# test_dir = "../emptyJudge2/images" dir_dict = {"noempty":"0", "empty":"1"}
# dir_dict = {"noempty":"0", "empty":"1"}
test_dir = "../emptyJudge5/images"
dir_dict = {"noemp":"0", "yesemp":"1"}
total = 0 total = 0
num = 0 num = 0
t0 = time.time() t0 = time.time()