Files
detecttracking/tracking/trackers/reid/reid_interface.py
2024-06-03 15:25:39 +08:00

145 lines
4.5 KiB
Python

# -*- coding: utf-8 -*-
"""
Created on Thu Jan 18 17:21:01 2024
@author: ym
"""
import numpy as np
import torch
import cv2
import torch.nn as nn
import torchvision.transforms as T
from .model import mobilevit_s, resnet18, resnet34, resnet50, mobilenet_v2, MobileNetV3_Small
from .config import config as conf
class ReIDInterface:
def __init__(self, config):
self.device = conf.device
if conf.backbone == 'resnet18':
# model = ResIRSE(img_size, embedding_size, conf.drop_ratio).to(device)
model = resnet18().to(self.device)
elif conf.backbone == 'resnet34':
model = resnet34().to(self.device)
elif conf.backbone == 'resnet50':
model = resnet50().to(self.device)
elif conf.backbone == 'mobilevit_s':
model = mobilevit_s().to(self.device)
elif conf.backbone == 'mobilenetv3':
model = MobileNetV3_Small().to(self.device)
else:
model = mobilenet_v2().to(self.device)
self.batch_size = conf.batch_size
self.embedding_size = conf.embedding_size
self.img_size = conf.img_size
self.model_path = conf.model_path
# 原输入为PIL
self.transform = T.Compose([
T.ToTensor(),
T.Resize((self.img_size, self.img_size)),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[0.5], std=[0.5]),
])
self.model = nn.DataParallel(model).to(self.device)
self.model = model
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
self.model.eval()
def inference(self, images, detections):
if isinstance(images, np.ndarray):
features = self.inference_image(images, detections)
return features
batch_patches = []
patches = []
for i, img in enumerate(images):
img = img.copy()
patch = self.transform(img)
if str(self.device) != "cpu":
patch = patch.to(device=self.device).half()
else:
patch = patch.to(device=self.device)
patches.append(patch)
if (i + 1) % self.batch_size == 0:
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
patches = []
if len(patches):
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
features = np.zeros((0, self.embedding_size))
for patches in batch_patches:
pred=self.model(patches)
pred[torch.isinf(pred)] = 1.0
feat = pred.cpu().data.numpy()
features = np.vstack((features, feat))
return features
def inference_image(self, image, detections):
H, W, _ = np.shape(image)
batch_patches = []
patches = []
for d in range(np.size(detections, 0)):
tlbr = detections[d, :4].astype(np.int_)
tlbr[0] = max(0, tlbr[0])
tlbr[1] = max(0, tlbr[1])
tlbr[2] = min(W - 1, tlbr[2])
tlbr[3] = min(H - 1, tlbr[3])
img = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :]
img = img[:, :, ::-1].copy() # the model expects RGB inputs
patch = self.transform(img)
# patch = patch.to(device=self.device).half()
if str(self.device) != "cpu":
patch = patch.to(device=self.device).half()
else:
patch = patch.to(device=self.device)
patches.append(patch)
if (d + 1) % self.batch_size == 0:
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
patches = []
if len(patches):
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
features = np.zeros((0, self.embedding_size))
for patches in batch_patches:
pred = self.model(patches)
pred[torch.isinf(pred)] = 1.0
feat = pred.cpu().data.numpy()
features = np.vstack((features, feat))
return features