Files
detecttracking/contrast/config.py
2024-09-02 18:39:12 +08:00

84 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torchvision.transforms as T
class Config:
# network settings
backbone = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
metric = 'softmax' # [cosface, arcface, softmax]
cbam = True
embedding_size = 256 # 256
drop_ratio = 0.5
img_size = 224
teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
student = 'resnet'
# data preprocess
# input_shape = [1, 128, 128]
"""transforms.RandomCrop(size),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(),
RandomRotate(15, 0.3),
# RandomGaussianBlur()"""
train_transform = T.Compose([
T.ToTensor(),
T.Resize((img_size, img_size)),
# T.RandomCrop(img_size*4//5),
# T.RandomHorizontalFlip(p=0.5),
T.RandomRotation(180),
T.ColorJitter(brightness=0.5),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[0.5], std=[0.5]),
])
test_transform = T.Compose([
T.ToTensor(),
T.Resize((img_size, img_size)),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[0.5], std=[0.5]),
])
# dataset
train_root = './data/2250_train/train' # 初始筛选过一次的数据集
# train_root = './data/0625_train/train'
test_root = "./data/2250_train/val/"
# test_root = "./data/0625_train/val"
test_list = "./data/2250_train/val_pair.txt"
test_group_json = "./data/2250_train/cross_same.json"
# test_group_json = "./data/0625_train/cross_same.json"
# test_list = "./data/test_data_100/val_pair.txt"
# training settings
checkpoints = "checkpoints/vit_b_16_0815/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
restore = True
# restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
restore_model = "checkpoints/vit_b_16_0730/best.pth" # best_resnet18_1491_0306.pth
# test_model = "./checkpoints/renet18_1887_0311/best_resnet18_1887_0311.pth"
testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
# test_val = "./data/2250_train"
test_val = "./data/0625_train"
test_model = "checkpoints/resnet18_0721/best.pth"
train_batch_size = 128 # 256
test_batch_size = 256 # 256
epoch = 300
optimizer = 'adamw' # ['sgd', 'adam' 'adamw']
lr = 1e-3 # 1e-2
lr_step = 10 # 10
lr_decay = 0.95 # 0.98
weight_decay = 5e-4
loss = 'focal_loss' # ['focal_loss', 'cross_entropy']
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
pin_memory = True # if memory is large, set it True to speed up a bit
num_workers = 4 # dataloader
group_test = True
# group_test = False
config = Config()