import torch import torchvision.transforms as T class Config: # network settings backbone = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5] metric = 'softmax' # [cosface, arcface, softmax] cbam = True embedding_size = 256 # 256 drop_ratio = 0.5 img_size = 224 teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5] student = 'resnet' # data preprocess # input_shape = [1, 128, 128] """transforms.RandomCrop(size), transforms.RandomVerticalFlip(p=0.5), transforms.RandomHorizontalFlip(), RandomRotate(15, 0.3), # RandomGaussianBlur()""" train_transform = T.Compose([ T.ToTensor(), T.Resize((img_size, img_size)), # T.RandomCrop(img_size*4//5), # T.RandomHorizontalFlip(p=0.5), T.RandomRotation(180), T.ColorJitter(brightness=0.5), T.ConvertImageDtype(torch.float32), T.Normalize(mean=[0.5], std=[0.5]), ]) test_transform = T.Compose([ T.ToTensor(), T.Resize((img_size, img_size)), T.ConvertImageDtype(torch.float32), T.Normalize(mean=[0.5], std=[0.5]), ]) # dataset train_root = './data/2250_train/train' # 初始筛选过一次的数据集 # train_root = './data/0625_train/train' test_root = "./data/2250_train/val/" # test_root = "./data/0625_train/val" test_list = "./data/2250_train/val_pair.txt" test_group_json = "./data/2250_train/cross_same.json" # test_group_json = "./data/0625_train/cross_same.json" # test_list = "./data/test_data_100/val_pair.txt" # training settings checkpoints = "checkpoints/vit_b_16_0815/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3] restore = True # restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth restore_model = "checkpoints/vit_b_16_0730/best.pth" # best_resnet18_1491_0306.pth # test_model = "./checkpoints/renet18_1887_0311/best_resnet18_1887_0311.pth" testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5] # test_val = "./data/2250_train" test_val = "./data/0625_train" test_model = "checkpoints/resnet18_0721/best.pth" train_batch_size = 128 # 256 test_batch_size = 256 # 256 epoch = 300 optimizer = 'adamw' # ['sgd', 'adam', 'adamw'] lr = 1e-3 # 1e-2 lr_step = 10 # 10 lr_decay = 0.95 # 0.98 weight_decay = 5e-4 loss = 'focal_loss' # ['focal_loss', 'cross_entropy'] device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') pin_memory = True # if memory is large, set it True to speed up a bit num_workers = 4 # dataloader group_test = True # group_test = False config = Config()