contrast performance evaluatation have done!

This commit is contained in:
王庆刚
2024-09-05 19:01:49 +08:00
parent f978d4174f
commit 7309dec166
85 changed files with 3941 additions and 248 deletions

View File

@ -4,15 +4,13 @@ import torchvision.transforms as T
class Config:
# network settings
backbone = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
metric = 'softmax' # [cosface, arcface, softmax]
backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
metric = 'arcface' # [cosface, arcface]
cbam = True
embedding_size = 256 # 256
embedding_size = 256
drop_ratio = 0.5
img_size = 224
teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
student = 'resnet'
# data preprocess
# input_shape = [1, 128, 128]
"""transforms.RandomCrop(size),
@ -24,7 +22,7 @@ class Config:
train_transform = T.Compose([
T.ToTensor(),
T.Resize((img_size, img_size)),
# T.RandomCrop(img_size*4//5),
# T.RandomCrop(img_size),
# T.RandomHorizontalFlip(p=0.5),
T.RandomRotation(180),
T.ColorJitter(brightness=0.5),
@ -39,39 +37,39 @@ class Config:
])
# dataset
train_root = './data/2250_train/train' # 初始筛选过一次的数据集
# train_root = './data/0625_train/train'
train_root = './data/2250_train/train' # 初始筛选过一次的数据集
# train_root = './data/0612_train/train'
test_root = "./data/2250_train/val/"
# test_root = "./data/0625_train/val"
# test_root = "./data/0612_train/val"
test_list = "./data/2250_train/val_pair.txt"
test_group_json = "./data/2250_train/cross_same.json"
# test_group_json = "./data/0625_train/cross_same.json"
test_group_json = "./2250_train/cross_same_0508.json"
# test_list = "./data/test_data_100/val_pair.txt"
# training settings
checkpoints = "checkpoints/vit_b_16_0815/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
restore = True
checkpoints = "checkpoints/resnet18_0613/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
restore = False
# restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
restore_model = "checkpoints/vit_b_16_0730/best.pth" # best_resnet18_1491_0306.pth
restore_model = "checkpoints/resnet18_0515/best.pth" # best_resnet18_1491_0306.pth
# test_model = "./checkpoints/renet18_1887_0311/best_resnet18_1887_0311.pth"
# test_model = "checkpoints/renet18_2250_0314/best_resnet18_2250_0314.pth"
testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
# test_val = "./data/2250_train"
test_val = "./data/0625_train"
test_model = "checkpoints/resnet18_0721/best.pth"
test_val = "D:/比对/cl"
# test_val = "./data/test_data_100"
test_model = "checkpoints/resnet18_0515/best.pth"
train_batch_size = 128 # 256
train_batch_size = 512 # 256
test_batch_size = 256 # 256
epoch = 300
optimizer = 'adamw' # ['sgd', 'adam' 'adamw']
lr = 1e-3 # 1e-2
lr_step = 10 # 10
optimizer = 'sgd' # ['sgd', 'adam']
lr = 1.5e-2 # 1e-2
lr_step = 5 # 10
lr_decay = 0.95 # 0.98
weight_decay = 5e-4
loss = 'focal_loss' # ['focal_loss', 'cross_entropy']
loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
@ -79,6 +77,5 @@ class Config:
num_workers = 4 # dataloader
group_test = True
# group_test = False
config = Config()
config = Config()