# import torch # import torchvision.transforms as T # # # class Config: # # network settings # backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5] # metric = 'arcface' # [cosface, arcface] # cbam = True # embedding_size = 256 # drop_ratio = 0.5 # img_size = 224 # # batch_size = 8 # # # data preprocess # # input_shape = [1, 128, 128] # """transforms.RandomCrop(size), # transforms.RandomVerticalFlip(p=0.5), # transforms.RandomHorizontalFlip(), # RandomRotate(15, 0.3), # # RandomGaussianBlur()""" # # train_transform = T.Compose([ # T.ToTensor(), # T.Resize((img_size, img_size)), # # T.RandomCrop(img_size), # # T.RandomHorizontalFlip(p=0.5), # T.RandomRotation(180), # T.ColorJitter(brightness=0.5), # T.ConvertImageDtype(torch.float32), # T.Normalize(mean=[0.5], std=[0.5]), # ]) # test_transform = T.Compose([ # T.ToTensor(), # T.Resize((img_size, img_size)), # T.ConvertImageDtype(torch.float32), # T.Normalize(mean=[0.5], std=[0.5]), # ]) # # # dataset # train_root = './data/2250_train/train' # 初始筛选过一次的数据集 # # train_root = './data/0612_train/train' # test_root = "./data/2250_train/val/" # # test_root = "./data/0612_train/val" # test_list = "./data/2250_train/val_pair.txt" # # test_group_json = "./2250_train/cross_same_0508.json" # # # # test_list = "./data/test_data_100/val_pair.txt" # # # training settings # checkpoints = "checkpoints/resnet18_0613/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3] # restore = False # # restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth # restore_model = "checkpoints/resnet18_0515/best.pth" # best_resnet18_1491_0306.pth # # # test_model = "checkpoints/renet18_2250_0314/best_resnet18_2250_0314.pth" # testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5] # test_val = "D:/比对/cl" # # test_val = "./data/test_data_100" # # # test_model = "checkpoints/zhanting_res_801.pth" # test_model = "checkpoints/resnet18_0515/v11.pth" # # # # train_batch_size = 512 # 256 # test_batch_size = 256 # 256 # # epoch = 300 # optimizer = 'sgd' # ['sgd', 'adam'] # lr = 1.5e-2 # 1e-2 # lr_step = 5 # 10 # lr_decay = 0.95 # 0.98 # weight_decay = 5e-4 # loss = 'cross_entropy' # ['focal_loss', 'cross_entropy'] # # device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # # pin_memory = True # if memory is large, set it True to speed up a bit # num_workers = 4 # dataloader # # group_test = True # # config = Config()