From 7a1c101307548b3c7bd368ac105ae6a591308a6a Mon Sep 17 00:00:00 2001 From: Brainway Date: Wed, 9 Nov 2022 03:30:40 +0000 Subject: [PATCH] update lightrise.py. --- lightrise.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/lightrise.py b/lightrise.py index ce0f6fd..3ad44e2 100644 --- a/lightrise.py +++ b/lightrise.py @@ -8,23 +8,23 @@ import argparse from models.modeling import VisionTransformer, CONFIGS import time -parser = argparse.ArgumentParser() -parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/new/ieemooempty_vitlight_checkpoint.pth", help="load pretrained model") #使用自定义VIT -args = parser.parse_args() - -#args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -args.device = torch.device("cpu") - -num_classes = 2 -cls_dict = {0: "noemp", 1: "yesemp"} - -test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) #复杂场景小模型测试单张图片 def riseempty(imgdata): + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/new/ieemooempty_vitlight_checkpoint.pth", help="load pretrained model") #使用自定义VIT + args = parser.parse_args() + + #args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args.device = torch.device("cpu") + + num_classes = 2 + cls_dict = {0: "noemp", 1: "yesemp"} + + test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # 准备模型 model = torch.load(args.pretrained_model,map_location=torch.device('cpu')) #自己预训练模型 model.to(args.device) @@ -41,8 +41,8 @@ def riseempty(imgdata): return riseresult -if __name__ == "__main__": - riseresult = riseempty("light.jpg") - print(riseresult) +# if __name__ == "__main__": +# riseresult = riseempty("light.jpg") +# print(riseresult)