更新 detacttracking
This commit is contained in:
87
detecttracking/contrast/feat_extract/config.py
Normal file
87
detecttracking/contrast/feat_extract/config.py
Normal file
@ -0,0 +1,87 @@
|
||||
# import torch
|
||||
# import torchvision.transforms as T
|
||||
#
|
||||
#
|
||||
# class Config:
|
||||
# # network settings
|
||||
# backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
|
||||
# metric = 'arcface' # [cosface, arcface]
|
||||
# cbam = True
|
||||
# embedding_size = 256
|
||||
# drop_ratio = 0.5
|
||||
# img_size = 224
|
||||
#
|
||||
# batch_size = 8
|
||||
#
|
||||
# # data preprocess
|
||||
# # input_shape = [1, 128, 128]
|
||||
# """transforms.RandomCrop(size),
|
||||
# transforms.RandomVerticalFlip(p=0.5),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
# RandomRotate(15, 0.3),
|
||||
# # RandomGaussianBlur()"""
|
||||
#
|
||||
# train_transform = T.Compose([
|
||||
# T.ToTensor(),
|
||||
# T.Resize((img_size, img_size)),
|
||||
# # T.RandomCrop(img_size),
|
||||
# # T.RandomHorizontalFlip(p=0.5),
|
||||
# T.RandomRotation(180),
|
||||
# T.ColorJitter(brightness=0.5),
|
||||
# T.ConvertImageDtype(torch.float32),
|
||||
# T.Normalize(mean=[0.5], std=[0.5]),
|
||||
# ])
|
||||
# test_transform = T.Compose([
|
||||
# T.ToTensor(),
|
||||
# T.Resize((img_size, img_size)),
|
||||
# T.ConvertImageDtype(torch.float32),
|
||||
# T.Normalize(mean=[0.5], std=[0.5]),
|
||||
# ])
|
||||
#
|
||||
# # dataset
|
||||
# train_root = './data/2250_train/train' # 初始筛选过一次的数据集
|
||||
# # train_root = './data/0612_train/train'
|
||||
# test_root = "./data/2250_train/val/"
|
||||
# # test_root = "./data/0612_train/val"
|
||||
# test_list = "./data/2250_train/val_pair.txt"
|
||||
#
|
||||
# test_group_json = "./2250_train/cross_same_0508.json"
|
||||
#
|
||||
#
|
||||
# # test_list = "./data/test_data_100/val_pair.txt"
|
||||
#
|
||||
# # training settings
|
||||
# checkpoints = "checkpoints/resnet18_0613/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
|
||||
# restore = False
|
||||
# # restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
|
||||
# restore_model = "checkpoints/resnet18_0515/best.pth" # best_resnet18_1491_0306.pth
|
||||
#
|
||||
# # test_model = "checkpoints/renet18_2250_0314/best_resnet18_2250_0314.pth"
|
||||
# testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
|
||||
# test_val = "D:/比对/cl"
|
||||
# # test_val = "./data/test_data_100"
|
||||
#
|
||||
# # test_model = "checkpoints/zhanting_res_801.pth"
|
||||
# test_model = "checkpoints/resnet18_0515/v11.pth"
|
||||
#
|
||||
#
|
||||
#
|
||||
# train_batch_size = 512 # 256
|
||||
# test_batch_size = 256 # 256
|
||||
#
|
||||
# epoch = 300
|
||||
# optimizer = 'sgd' # ['sgd', 'adam']
|
||||
# lr = 1.5e-2 # 1e-2
|
||||
# lr_step = 5 # 10
|
||||
# lr_decay = 0.95 # 0.98
|
||||
# weight_decay = 5e-4
|
||||
# loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
|
||||
# # device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
|
||||
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
#
|
||||
# pin_memory = True # if memory is large, set it True to speed up a bit
|
||||
# num_workers = 4 # dataloader
|
||||
#
|
||||
# group_test = True
|
||||
#
|
||||
# config = Config()
|
Reference in New Issue
Block a user