42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on Fri Jan 19 14:01:46 2024
|
|
|
|
@author: ym
|
|
"""
|
|
|
|
import torch
|
|
import os
|
|
# import torchvision.transforms as T
|
|
class Config:
|
|
# network settings
|
|
backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
|
|
batch_size = 8
|
|
embedding_size = 256
|
|
img_size = 224
|
|
|
|
current_path = os.path.dirname(os.path.abspath(__file__))
|
|
model_path = os.path.join(current_path, r"ckpts\resnet18_1220\best.pth")
|
|
|
|
# model_path = "./trackers/reid/ckpts/resnet18_1220/best.pth"
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# =============================================================================
|
|
# metric = 'arcface' # [cosface, arcface]
|
|
# drop_ratio = 0.5
|
|
#
|
|
# # training settings
|
|
# checkpoints = "checkpoints/Mobilev3Large_1225" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
|
|
# restore = False
|
|
#
|
|
# test_model = "./checkpoints/resnet18_1220/best.pth"
|
|
#
|
|
#
|
|
#
|
|
#
|
|
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
# pin_memory = True # if memory is large, set it True to speed up a bit
|
|
# num_workers = 4 # dataloader
|
|
# =============================================================================
|
|
|
|
config = Config() |