Files
ieemoo-ai-contrast/config.py
2025-06-11 15:23:50 +08:00

122 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
def pad_to_square(img):
w, h = img.size
max_wh = max(w, h)
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
return F.pad(img, padding, fill=0, padding_mode='constant')
class Config:
# network settings
backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large,
# mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5, vit_base]
metric = 'arcface' # [cosface, arcface, softmax]
cbam = False
embedding_size = 256 # 256 # gift:2 contrast:256
drop_ratio = 0.5
img_size = 224
multiple_cards = True # 多卡加载
model_half = False # 模型半精度测试
data_half = True # 数据半精度测试
channel_ratio = 0.75 # 通道剪枝比例
quantization_test = False # int8量化模型测试
# custom base_data settings
custom_backbone = False # 迁移学习载入除最后一层的所有层
custom_num_classes = 128 # 迁移学习的类别数量
# if quantization_test:
# device = torch.device('cpu')
# else:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1,
# PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
student = 'resnet'
# data preprocess
"""transforms.RandomCrop(size),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(),
RandomRotate(15, 0.3),
# RandomGaussianBlur()"""
train_transform = T.Compose([
T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((img_size, img_size), antialias=True),
# T.RandomCrop(img_size * 4 // 5),
T.RandomHorizontalFlip(p=0.5),
T.RandomRotation(180),
T.ColorJitter(brightness=0.5),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[0.5], std=[0.5]),
])
test_transform = T.Compose([
# T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((img_size, img_size), antialias=True),
T.ConvertImageDtype(torch.float32),
# T.Normalize(mean=[0,0,0], std=[255,255,255]),
T.Normalize(mean=[0.5], std=[0.5]),
])
# dataset
train_root = '../data_center/scatter/train' # ['./data/2250_train/base_data', # './data/2000_train/base_data', './data/zhanting/base_data', './data/base_train/one_stage/train']
test_root = '../data_center/scatter/val' # ["./data/2250_train/val", "./data/2000_train/val/", './data/zhanting/val', './data/base_train/one_stage/val']
# training settings
checkpoints = "checkpoints/resnet18_scatter_6.2/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
restore = True
# restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
restore_model = "checkpoints/resnet18_scatter_6.2/best.pth" # best_resnet18_1491_0306.pth
# test settings
testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
# test_val = "./data/2250_train"
# test_list = "./data/2250_train/val_pair.txt"
# test_group_json = "./data/2250_train/cross_same.json"
test_val = "../data_center/scatter/" # [../data_center/contrast_learning/model_test_data/val_2250]
test_list = "../data_center/scatter/val_pair.txt" # [./data/test/public_single_pairs.txt]
test_group_json = "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" # [./data/2250_train/cross_same.json]
# test_group_json = "./data/test/inner_group_pairs.json"
# test_model = "checkpoints/resnet18_scatter_6.2/best.pth"
test_model = "checkpoints/resnet18_1009/best.pth"
# test_model = "checkpoints/zhanting/inland/res_801.pth"
# test_model = "checkpoints/resnet18_20250504/best.pth"
# test_model = "checkpoints/resnet18_vit-base_20250430/best.pth"
group_test = False
# group_test = False
train_batch_size = 128 # 256
test_batch_size = 128 # 256
epoch = 5 # 512
optimizer = 'sgd' # ['sgd', 'adam' 'adamw']
lr = 5e-3 # 1e-2
lr_step = 10 # 10
lr_decay = 0.98 # 0.98
weight_decay = 5e-4
loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
log_path = './log'
lr_min = 1e-6 # min lr
pin_memory = False # if memory is large, set it True to speed up a bit
num_workers = 32 # 64
compare = False # compare the result of different models
'''
train_distill settings
'''
warmup_epochs = 3 # warmup_epoch
distributed = True # distributed training
teacher_path = "./checkpoints/resnet50_0519/best.pth"
distill_weight = 0.8 # 蒸馏权重
config = Config()