import torch import torchvision.transforms as T import torchvision.transforms.functional as F def pad_to_square(img): w, h = img.size max_wh = max(w, h) padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom) return F.pad(img, padding, fill=0, padding_mode='constant') class Config: # network settings backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, # mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5, vit_base] metric = 'arcface' # [cosface, arcface, softmax] cbam = False embedding_size = 256 # 256 # gift:2 contrast:256 drop_ratio = 0.5 img_size = 224 multiple_cards = True # 多卡加载 model_half = False # 模型半精度测试 data_half = True # 数据半精度测试 channel_ratio = 0.75 # 通道剪枝比例 quantization_test = False # int8量化模型测试 # custom base_data settings custom_backbone = False # 迁移学习载入除最后一层的所有层 custom_num_classes = 128 # 迁移学习的类别数量 # if quantization_test: # device = torch.device('cpu') # else: # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, # PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5] student = 'resnet' # data preprocess """transforms.RandomCrop(size), transforms.RandomVerticalFlip(p=0.5), transforms.RandomHorizontalFlip(), RandomRotate(15, 0.3), # RandomGaussianBlur()""" train_transform = T.Compose([ T.Lambda(pad_to_square), # 补边 T.ToTensor(), T.Resize((img_size, img_size), antialias=True), # T.RandomCrop(img_size * 4 // 5), T.RandomHorizontalFlip(p=0.5), T.RandomRotation(180), T.ColorJitter(brightness=0.5), T.ConvertImageDtype(torch.float32), T.Normalize(mean=[0.5], std=[0.5]), ]) test_transform = T.Compose([ # T.Lambda(pad_to_square), # 补边 T.ToTensor(), T.Resize((img_size, img_size), antialias=True), T.ConvertImageDtype(torch.float32), # T.Normalize(mean=[0,0,0], std=[255,255,255]), T.Normalize(mean=[0.5], std=[0.5]), ]) # dataset train_root = '../data_center/scatter/train' # ['./data/2250_train/base_data', # './data/2000_train/base_data', './data/zhanting/base_data', './data/base_train/one_stage/train'] test_root = '../data_center/scatter/val' # ["./data/2250_train/val", "./data/2000_train/val/", './data/zhanting/val', './data/base_train/one_stage/val'] # training settings checkpoints = "checkpoints/resnet18_scatter_6.2/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3] restore = True # restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth restore_model = "checkpoints/resnet18_scatter_6.2/best.pth" # best_resnet18_1491_0306.pth # test settings testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5] # test_val = "./data/2250_train" # test_list = "./data/2250_train/val_pair.txt" # test_group_json = "./data/2250_train/cross_same.json" test_val = "../data_center/scatter/" # [../data_center/contrast_learning/model_test_data/val_2250] test_list = "../data_center/scatter/val_pair.txt" # [./data/test/public_single_pairs.txt] test_group_json = "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" # [./data/2250_train/cross_same.json] # test_group_json = "./data/test/inner_group_pairs.json" # test_model = "checkpoints/resnet18_scatter_6.2/best.pth" test_model = "checkpoints/resnet18_1009/best.pth" # test_model = "checkpoints/zhanting/inland/res_801.pth" # test_model = "checkpoints/resnet18_20250504/best.pth" # test_model = "checkpoints/resnet18_vit-base_20250430/best.pth" group_test = False # group_test = False train_batch_size = 128 # 256 test_batch_size = 128 # 256 epoch = 5 # 512 optimizer = 'sgd' # ['sgd', 'adam', 'adamw'] lr = 5e-3 # 1e-2 lr_step = 10 # 10 lr_decay = 0.98 # 0.98 weight_decay = 5e-4 loss = 'cross_entropy' # ['focal_loss', 'cross_entropy'] log_path = './log' lr_min = 1e-6 # min lr pin_memory = False # if memory is large, set it True to speed up a bit num_workers = 32 # 64 compare = False # compare the result of different models ''' train_distill settings ''' warmup_epochs = 3 # warmup_epoch distributed = True # distributed training teacher_path = "./checkpoints/resnet50_0519/best.pth" distill_weight = 0.8 # 蒸馏权重 config = Config()