This commit is contained in:
lee
2025-06-11 15:23:50 +08:00
commit 37ecef40f7
79 changed files with 26981 additions and 0 deletions

122
config.py Normal file
View File

@ -0,0 +1,122 @@
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
def pad_to_square(img):
w, h = img.size
max_wh = max(w, h)
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
return F.pad(img, padding, fill=0, padding_mode='constant')
class Config:
# network settings
backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large,
# mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5, vit_base]
metric = 'arcface' # [cosface, arcface, softmax]
cbam = False
embedding_size = 256 # 256 # gift:2 contrast:256
drop_ratio = 0.5
img_size = 224
multiple_cards = True # 多卡加载
model_half = False # 模型半精度测试
data_half = True # 数据半精度测试
channel_ratio = 0.75 # 通道剪枝比例
quantization_test = False # int8量化模型测试
# custom base_data settings
custom_backbone = False # 迁移学习载入除最后一层的所有层
custom_num_classes = 128 # 迁移学习的类别数量
# if quantization_test:
# device = torch.device('cpu')
# else:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1,
# PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
student = 'resnet'
# data preprocess
"""transforms.RandomCrop(size),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(),
RandomRotate(15, 0.3),
# RandomGaussianBlur()"""
train_transform = T.Compose([
T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((img_size, img_size), antialias=True),
# T.RandomCrop(img_size * 4 // 5),
T.RandomHorizontalFlip(p=0.5),
T.RandomRotation(180),
T.ColorJitter(brightness=0.5),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[0.5], std=[0.5]),
])
test_transform = T.Compose([
# T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((img_size, img_size), antialias=True),
T.ConvertImageDtype(torch.float32),
# T.Normalize(mean=[0,0,0], std=[255,255,255]),
T.Normalize(mean=[0.5], std=[0.5]),
])
# dataset
train_root = '../data_center/scatter/train' # ['./data/2250_train/base_data', # './data/2000_train/base_data', './data/zhanting/base_data', './data/base_train/one_stage/train']
test_root = '../data_center/scatter/val' # ["./data/2250_train/val", "./data/2000_train/val/", './data/zhanting/val', './data/base_train/one_stage/val']
# training settings
checkpoints = "checkpoints/resnet18_scatter_6.2/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
restore = True
# restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
restore_model = "checkpoints/resnet18_scatter_6.2/best.pth" # best_resnet18_1491_0306.pth
# test settings
testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
# test_val = "./data/2250_train"
# test_list = "./data/2250_train/val_pair.txt"
# test_group_json = "./data/2250_train/cross_same.json"
test_val = "../data_center/scatter/" # [../data_center/contrast_learning/model_test_data/val_2250]
test_list = "../data_center/scatter/val_pair.txt" # [./data/test/public_single_pairs.txt]
test_group_json = "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" # [./data/2250_train/cross_same.json]
# test_group_json = "./data/test/inner_group_pairs.json"
# test_model = "checkpoints/resnet18_scatter_6.2/best.pth"
test_model = "checkpoints/resnet18_1009/best.pth"
# test_model = "checkpoints/zhanting/inland/res_801.pth"
# test_model = "checkpoints/resnet18_20250504/best.pth"
# test_model = "checkpoints/resnet18_vit-base_20250430/best.pth"
group_test = False
# group_test = False
train_batch_size = 128 # 256
test_batch_size = 128 # 256
epoch = 5 # 512
optimizer = 'sgd' # ['sgd', 'adam' 'adamw']
lr = 5e-3 # 1e-2
lr_step = 10 # 10
lr_decay = 0.98 # 0.98
weight_decay = 5e-4
loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
log_path = './log'
lr_min = 1e-6 # min lr
pin_memory = False # if memory is large, set it True to speed up a bit
num_workers = 32 # 64
compare = False # compare the result of different models
'''
train_distill settings
'''
warmup_epochs = 3 # warmup_epoch
distributed = True # distributed training
teacher_path = "./checkpoints/resnet50_0519/best.pth"
distill_weight = 0.8 # 蒸馏权重
config = Config()