84 lines
3.0 KiB
Python
84 lines
3.0 KiB
Python
import torch
|
||
import torchvision.transforms as T
|
||
|
||
|
||
class Config:
|
||
# network settings
|
||
backbone = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
|
||
metric = 'softmax' # [cosface, arcface, softmax]
|
||
cbam = True
|
||
embedding_size = 256 # 256
|
||
drop_ratio = 0.5
|
||
img_size = 224
|
||
|
||
teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
|
||
student = 'resnet'
|
||
# data preprocess
|
||
# input_shape = [1, 128, 128]
|
||
"""transforms.RandomCrop(size),
|
||
transforms.RandomVerticalFlip(p=0.5),
|
||
transforms.RandomHorizontalFlip(),
|
||
RandomRotate(15, 0.3),
|
||
# RandomGaussianBlur()"""
|
||
|
||
train_transform = T.Compose([
|
||
T.ToTensor(),
|
||
T.Resize((img_size, img_size)),
|
||
# T.RandomCrop(img_size*4//5),
|
||
# T.RandomHorizontalFlip(p=0.5),
|
||
T.RandomRotation(180),
|
||
T.ColorJitter(brightness=0.5),
|
||
T.ConvertImageDtype(torch.float32),
|
||
T.Normalize(mean=[0.5], std=[0.5]),
|
||
])
|
||
test_transform = T.Compose([
|
||
T.ToTensor(),
|
||
T.Resize((img_size, img_size)),
|
||
T.ConvertImageDtype(torch.float32),
|
||
T.Normalize(mean=[0.5], std=[0.5]),
|
||
])
|
||
|
||
# dataset
|
||
train_root = './data/2250_train/train' # 初始筛选过一次的数据集
|
||
# train_root = './data/0625_train/train'
|
||
test_root = "./data/2250_train/val/"
|
||
# test_root = "./data/0625_train/val"
|
||
|
||
test_list = "./data/2250_train/val_pair.txt"
|
||
test_group_json = "./data/2250_train/cross_same.json"
|
||
# test_group_json = "./data/0625_train/cross_same.json"
|
||
# test_list = "./data/test_data_100/val_pair.txt"
|
||
|
||
# training settings
|
||
checkpoints = "checkpoints/vit_b_16_0815/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
|
||
restore = True
|
||
# restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
|
||
restore_model = "checkpoints/vit_b_16_0730/best.pth" # best_resnet18_1491_0306.pth
|
||
|
||
# test_model = "./checkpoints/renet18_1887_0311/best_resnet18_1887_0311.pth"
|
||
testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
|
||
# test_val = "./data/2250_train"
|
||
test_val = "./data/0625_train"
|
||
test_model = "checkpoints/resnet18_0721/best.pth"
|
||
|
||
train_batch_size = 128 # 256
|
||
test_batch_size = 256 # 256
|
||
|
||
|
||
epoch = 300
|
||
optimizer = 'adamw' # ['sgd', 'adam', 'adamw']
|
||
lr = 1e-3 # 1e-2
|
||
lr_step = 10 # 10
|
||
lr_decay = 0.95 # 0.98
|
||
weight_decay = 5e-4
|
||
loss = 'focal_loss' # ['focal_loss', 'cross_entropy']
|
||
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
|
||
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||
|
||
pin_memory = True # if memory is large, set it True to speed up a bit
|
||
num_workers = 4 # dataloader
|
||
|
||
group_test = True
|
||
# group_test = False
|
||
|
||
config = Config() |