diff --git a/testsingle.py b/testsingle.py index e3ecfd9..a6f574f 100755 --- a/testsingle.py +++ b/testsingle.py @@ -36,7 +36,7 @@ model = None #model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value) if args.pretrained_model is not None: - model = torch.load(args.pretrained_model) #自己预训练模型 + model = torch.load(args.pretrained_model,map_location=torch.device('cpu')) #自己预训练模型 model.to(args.device) model.eval()