145 lines
4.5 KiB
Python
145 lines
4.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on Thu Jan 18 17:21:01 2024
|
|
|
|
@author: ym
|
|
"""
|
|
import numpy as np
|
|
import torch
|
|
import cv2
|
|
import torch.nn as nn
|
|
import torchvision.transforms as T
|
|
from .model import mobilevit_s, resnet18, resnet34, resnet50, mobilenet_v2, MobileNetV3_Small
|
|
from .config import config as conf
|
|
|
|
|
|
class ReIDInterface:
|
|
def __init__(self, config):
|
|
self.device = conf.device
|
|
if conf.backbone == 'resnet18':
|
|
# model = ResIRSE(img_size, embedding_size, conf.drop_ratio).to(device)
|
|
model = resnet18().to(self.device)
|
|
elif conf.backbone == 'resnet34':
|
|
model = resnet34().to(self.device)
|
|
elif conf.backbone == 'resnet50':
|
|
model = resnet50().to(self.device)
|
|
elif conf.backbone == 'mobilevit_s':
|
|
model = mobilevit_s().to(self.device)
|
|
elif conf.backbone == 'mobilenetv3':
|
|
model = MobileNetV3_Small().to(self.device)
|
|
else:
|
|
model = mobilenet_v2().to(self.device)
|
|
|
|
self.batch_size = conf.batch_size
|
|
self.embedding_size = conf.embedding_size
|
|
self.img_size = conf.img_size
|
|
|
|
self.model_path = conf.model_path
|
|
|
|
# 原输入为PIL
|
|
self.transform = T.Compose([
|
|
T.ToTensor(),
|
|
T.Resize((self.img_size, self.img_size)),
|
|
T.ConvertImageDtype(torch.float32),
|
|
T.Normalize(mean=[0.5], std=[0.5]),
|
|
])
|
|
|
|
|
|
self.model = nn.DataParallel(model).to(self.device)
|
|
|
|
self.model = model
|
|
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
|
self.model.eval()
|
|
|
|
def inference(self, images, detections):
|
|
if isinstance(images, np.ndarray):
|
|
features = self.inference_image(images, detections)
|
|
return features
|
|
|
|
batch_patches = []
|
|
patches = []
|
|
for i, img in enumerate(images):
|
|
img = img.copy()
|
|
patch = self.transform(img)
|
|
if str(self.device) != "cpu":
|
|
patch = patch.to(device=self.device).half()
|
|
else:
|
|
patch = patch.to(device=self.device)
|
|
|
|
patches.append(patch)
|
|
if (i + 1) % self.batch_size == 0:
|
|
patches = torch.stack(patches, dim=0)
|
|
batch_patches.append(patches)
|
|
patches = []
|
|
|
|
if len(patches):
|
|
patches = torch.stack(patches, dim=0)
|
|
batch_patches.append(patches)
|
|
|
|
features = np.zeros((0, self.embedding_size))
|
|
for patches in batch_patches:
|
|
pred=self.model(patches)
|
|
pred[torch.isinf(pred)] = 1.0
|
|
feat = pred.cpu().data.numpy()
|
|
features = np.vstack((features, feat))
|
|
return features
|
|
|
|
def inference_image(self, image, detections):
|
|
H, W, _ = np.shape(image)
|
|
|
|
batch_patches = []
|
|
patches = []
|
|
for d in range(np.size(detections, 0)):
|
|
tlbr = detections[d, :4].astype(np.int_)
|
|
tlbr[0] = max(0, tlbr[0])
|
|
tlbr[1] = max(0, tlbr[1])
|
|
tlbr[2] = min(W - 1, tlbr[2])
|
|
tlbr[3] = min(H - 1, tlbr[3])
|
|
img = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :]
|
|
|
|
img = img[:, :, ::-1].copy() # the model expects RGB inputs
|
|
patch = self.transform(img)
|
|
|
|
# patch = patch.to(device=self.device).half()
|
|
if str(self.device) != "cpu":
|
|
patch = patch.to(device=self.device).half()
|
|
else:
|
|
patch = patch.to(device=self.device)
|
|
|
|
patches.append(patch)
|
|
if (d + 1) % self.batch_size == 0:
|
|
patches = torch.stack(patches, dim=0)
|
|
batch_patches.append(patches)
|
|
patches = []
|
|
|
|
if len(patches):
|
|
patches = torch.stack(patches, dim=0)
|
|
batch_patches.append(patches)
|
|
|
|
features = np.zeros((0, self.embedding_size))
|
|
for patches in batch_patches:
|
|
pred = self.model(patches)
|
|
pred[torch.isinf(pred)] = 1.0
|
|
feat = pred.cpu().data.numpy()
|
|
features = np.vstack((features, feat))
|
|
|
|
return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|