rebuild
This commit is contained in:
122
config.py
Normal file
122
config.py
Normal file
@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
def pad_to_square(img):
|
||||
w, h = img.size
|
||||
max_wh = max(w, h)
|
||||
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||
|
||||
|
||||
class Config:
|
||||
# network settings
|
||||
backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large,
|
||||
# mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5, vit_base]
|
||||
metric = 'arcface' # [cosface, arcface, softmax]
|
||||
cbam = False
|
||||
embedding_size = 256 # 256 # gift:2 contrast:256
|
||||
drop_ratio = 0.5
|
||||
img_size = 224
|
||||
multiple_cards = True # 多卡加载
|
||||
model_half = False # 模型半精度测试
|
||||
data_half = True # 数据半精度测试
|
||||
channel_ratio = 0.75 # 通道剪枝比例
|
||||
quantization_test = False # int8量化模型测试
|
||||
|
||||
# custom base_data settings
|
||||
custom_backbone = False # 迁移学习载入除最后一层的所有层
|
||||
custom_num_classes = 128 # 迁移学习的类别数量
|
||||
|
||||
# if quantization_test:
|
||||
# device = torch.device('cpu')
|
||||
# else:
|
||||
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1,
|
||||
# PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
|
||||
|
||||
student = 'resnet'
|
||||
# data preprocess
|
||||
"""transforms.RandomCrop(size),
|
||||
transforms.RandomVerticalFlip(p=0.5),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
RandomRotate(15, 0.3),
|
||||
# RandomGaussianBlur()"""
|
||||
train_transform = T.Compose([
|
||||
T.Lambda(pad_to_square), # 补边
|
||||
T.ToTensor(),
|
||||
T.Resize((img_size, img_size), antialias=True),
|
||||
# T.RandomCrop(img_size * 4 // 5),
|
||||
T.RandomHorizontalFlip(p=0.5),
|
||||
T.RandomRotation(180),
|
||||
T.ColorJitter(brightness=0.5),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[0.5], std=[0.5]),
|
||||
])
|
||||
test_transform = T.Compose([
|
||||
# T.Lambda(pad_to_square), # 补边
|
||||
T.ToTensor(),
|
||||
T.Resize((img_size, img_size), antialias=True),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
# T.Normalize(mean=[0,0,0], std=[255,255,255]),
|
||||
T.Normalize(mean=[0.5], std=[0.5]),
|
||||
])
|
||||
|
||||
# dataset
|
||||
train_root = '../data_center/scatter/train' # ['./data/2250_train/base_data', # './data/2000_train/base_data', './data/zhanting/base_data', './data/base_train/one_stage/train']
|
||||
test_root = '../data_center/scatter/val' # ["./data/2250_train/val", "./data/2000_train/val/", './data/zhanting/val', './data/base_train/one_stage/val']
|
||||
|
||||
# training settings
|
||||
checkpoints = "checkpoints/resnet18_scatter_6.2/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
|
||||
restore = True
|
||||
# restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
|
||||
restore_model = "checkpoints/resnet18_scatter_6.2/best.pth" # best_resnet18_1491_0306.pth
|
||||
|
||||
# test settings
|
||||
testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
|
||||
|
||||
# test_val = "./data/2250_train"
|
||||
# test_list = "./data/2250_train/val_pair.txt"
|
||||
# test_group_json = "./data/2250_train/cross_same.json"
|
||||
|
||||
test_val = "../data_center/scatter/" # [../data_center/contrast_learning/model_test_data/val_2250]
|
||||
test_list = "../data_center/scatter/val_pair.txt" # [./data/test/public_single_pairs.txt]
|
||||
test_group_json = "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" # [./data/2250_train/cross_same.json]
|
||||
# test_group_json = "./data/test/inner_group_pairs.json"
|
||||
|
||||
# test_model = "checkpoints/resnet18_scatter_6.2/best.pth"
|
||||
test_model = "checkpoints/resnet18_1009/best.pth"
|
||||
# test_model = "checkpoints/zhanting/inland/res_801.pth"
|
||||
# test_model = "checkpoints/resnet18_20250504/best.pth"
|
||||
# test_model = "checkpoints/resnet18_vit-base_20250430/best.pth"
|
||||
group_test = False
|
||||
# group_test = False
|
||||
|
||||
train_batch_size = 128 # 256
|
||||
test_batch_size = 128 # 256
|
||||
|
||||
epoch = 5 # 512
|
||||
optimizer = 'sgd' # ['sgd', 'adam', 'adamw']
|
||||
lr = 5e-3 # 1e-2
|
||||
lr_step = 10 # 10
|
||||
lr_decay = 0.98 # 0.98
|
||||
weight_decay = 5e-4
|
||||
loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
|
||||
log_path = './log'
|
||||
lr_min = 1e-6 # min lr
|
||||
|
||||
pin_memory = False # if memory is large, set it True to speed up a bit
|
||||
num_workers = 32 # 64
|
||||
compare = False # compare the result of different models
|
||||
|
||||
'''
|
||||
train_distill settings
|
||||
'''
|
||||
warmup_epochs = 3 # warmup_epoch
|
||||
distributed = True # distributed training
|
||||
teacher_path = "./checkpoints/resnet50_0519/best.pth"
|
||||
distill_weight = 0.8 # 蒸馏权重
|
||||
|
||||
config = Config()
|
Reference in New Issue
Block a user