28 lines
852 B
Python
28 lines
852 B
Python
import torch
|
|
import torchvision.transforms.functional as F
|
|
import torchvision.transforms as T
|
|
|
|
|
|
def pad_to_square(img):
|
|
w, h = img.size
|
|
max_wh = max(w, h)
|
|
padding = [0, 0, max_wh - w, max_wh - h] # (left, top, right, bottom)
|
|
return F.pad(img, padding, fill=0, padding_mode='constant')
|
|
|
|
class Config:
|
|
# network settings
|
|
resnet_model = './detecttracking/contrast/feat_extract/checkpoints/resnet18_0515/v11.pth'
|
|
yolo_model = './detecttracking/tracking/ckpts/best_cls10_0906.pt'
|
|
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
|
|
embedding_size = 256
|
|
batch_size = 8
|
|
img_size = 224
|
|
test_transform = T.Compose([
|
|
T.ToTensor(),
|
|
T.Resize((img_size, img_size)),
|
|
T.ConvertImageDtype(torch.float32),
|
|
T.Normalize(mean=[0.5], std=[0.5]),
|
|
])
|
|
|
|
config = Config()
|