initial project version!
This commit is contained in:
143
tracking/trackers/reid/reid_interface.py
Normal file
143
tracking/trackers/reid/reid_interface.py
Normal file
@ -0,0 +1,143 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Thu Jan 18 17:21:01 2024
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from .model import mobilevit_s, resnet18, resnet34, resnet50, mobilenet_v2, MobileNetV3_Small
|
||||
from .config import config as conf
|
||||
|
||||
|
||||
class ReIDInterface:
|
||||
def __init__(self, config):
|
||||
self.device = conf.device
|
||||
if conf.backbone == 'resnet18':
|
||||
# model = ResIRSE(img_size, embedding_size, conf.drop_ratio).to(device)
|
||||
model = resnet18().to(self.device)
|
||||
elif conf.backbone == 'resnet34':
|
||||
model = resnet34().to(self.device)
|
||||
elif conf.backbone == 'resnet50':
|
||||
model = resnet50().to(self.device)
|
||||
elif conf.backbone == 'mobilevit_s':
|
||||
model = mobilevit_s().to(self.device)
|
||||
elif conf.backbone == 'mobilenetv3':
|
||||
model = MobileNetV3_Small().to(self.device)
|
||||
else:
|
||||
model = mobilenet_v2().to(self.device)
|
||||
|
||||
self.batch_size = conf.batch_size
|
||||
self.embedding_size = conf.embedding_size
|
||||
self.img_size = conf.img_size
|
||||
|
||||
self.model_path = conf.model_path
|
||||
|
||||
# 原输入为PIL
|
||||
self.transform = T.Compose([
|
||||
T.ToTensor(),
|
||||
T.Resize((self.img_size, self.img_size)),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[0.5], std=[0.5]),
|
||||
])
|
||||
|
||||
|
||||
self.model = nn.DataParallel(model).to(self.device)
|
||||
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
self.model.eval()
|
||||
|
||||
def inference(self, images, detections):
|
||||
if isinstance(images, np.ndarray):
|
||||
features = self.inference_image(images, detections)
|
||||
return features
|
||||
|
||||
batch_patches = []
|
||||
patches = []
|
||||
for i, img in enumerate(images):
|
||||
img = img.copy()
|
||||
patch = self.transform(img)
|
||||
if str(self.device) != "cpu":
|
||||
patch = patch.to(device=self.device).half()
|
||||
else:
|
||||
patch = patch.to(device=self.device)
|
||||
|
||||
patches.append(patch)
|
||||
if (i + 1) % self.batch_size == 0:
|
||||
patches = torch.stack(patches, dim=0)
|
||||
batch_patches.append(patches)
|
||||
patches = []
|
||||
|
||||
if len(patches):
|
||||
patches = torch.stack(patches, dim=0)
|
||||
batch_patches.append(patches)
|
||||
|
||||
features = np.zeros((0, self.embedding_size))
|
||||
for patches in batch_patches:
|
||||
pred=self.model(patches)
|
||||
pred[torch.isinf(pred)] = 1.0
|
||||
feat = pred.cpu().data.numpy()
|
||||
features = np.vstack((features, feat))
|
||||
return features
|
||||
|
||||
def inference_image(self, image, detections):
|
||||
H, W, _ = np.shape(image)
|
||||
|
||||
batch_patches = []
|
||||
patches = []
|
||||
for d in range(np.size(detections, 0)):
|
||||
tlbr = detections[d, :4].astype(np.int_)
|
||||
tlbr[0] = max(0, tlbr[0])
|
||||
tlbr[1] = max(0, tlbr[1])
|
||||
tlbr[2] = min(W - 1, tlbr[2])
|
||||
tlbr[3] = min(H - 1, tlbr[3])
|
||||
img = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :]
|
||||
|
||||
img = img[:, :, ::-1].copy() # the model expects RGB inputs
|
||||
patch = self.transform(img)
|
||||
|
||||
# patch = patch.to(device=self.device).half()
|
||||
if str(self.device) != "cpu":
|
||||
patch = patch.to(device=self.device).half()
|
||||
else:
|
||||
patch = patch.to(device=self.device)
|
||||
|
||||
patches.append(patch)
|
||||
if (d + 1) % self.batch_size == 0:
|
||||
patches = torch.stack(patches, dim=0)
|
||||
batch_patches.append(patches)
|
||||
patches = []
|
||||
|
||||
if len(patches):
|
||||
patches = torch.stack(patches, dim=0)
|
||||
batch_patches.append(patches)
|
||||
|
||||
features = np.zeros((0, self.embedding_size))
|
||||
for patches in batch_patches:
|
||||
pred = self.model(patches)
|
||||
pred[torch.isinf(pred)] = 1.0
|
||||
feat = pred.cpu().data.numpy()
|
||||
features = np.vstack((features, feat))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user