diff --git a/lightrise.py b/lightrise.py index 16d1554..ce0f6fd 100644 --- a/lightrise.py +++ b/lightrise.py @@ -15,11 +15,6 @@ args = parser.parse_args() #args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.device = torch.device("cpu") -# 准备模型 -model = torch.load(args.pretrained_model,map_location=torch.device('cpu')) #自己预训练模型 -model.to(args.device) -model.eval() - num_classes = 2 cls_dict = {0: "noemp", 1: "yesemp"} @@ -30,6 +25,10 @@ test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEA #复杂场景小模型测试单张图片 def riseempty(imgdata): + # 准备模型 + model = torch.load(args.pretrained_model,map_location=torch.device('cpu')) #自己预训练模型 + model.to(args.device) + model.eval() x = test_transform(imgdata) part_logits = model(x.unsqueeze(0).to(args.device)) probs = torch.nn.Softmax(dim=-1)(part_logits)