# -*- coding: utf-8 -*- """ Created on Thu Jan 18 17:21:01 2024 @author: ym """ import numpy as np import torch import cv2 import torch.nn as nn import torchvision.transforms as T from .model import mobilevit_s, resnet18, resnet34, resnet50, mobilenet_v2, MobileNetV3_Small from .config import config as conf class ReIDInterface: def __init__(self, config): self.device = conf.device if conf.backbone == 'resnet18': # model = ResIRSE(img_size, embedding_size, conf.drop_ratio).to(device) model = resnet18().to(self.device) elif conf.backbone == 'resnet34': model = resnet34().to(self.device) elif conf.backbone == 'resnet50': model = resnet50().to(self.device) elif conf.backbone == 'mobilevit_s': model = mobilevit_s().to(self.device) elif conf.backbone == 'mobilenetv3': model = MobileNetV3_Small().to(self.device) else: model = mobilenet_v2().to(self.device) self.batch_size = conf.batch_size self.embedding_size = conf.embedding_size self.img_size = conf.img_size self.model_path = conf.model_path # 原输入为PIL self.transform = T.Compose([ T.ToTensor(), T.Resize((self.img_size, self.img_size)), T.ConvertImageDtype(torch.float32), T.Normalize(mean=[0.5], std=[0.5]), ]) # self.model = nn.DataParallel(model).to(self.device) self.model = model self.model.load_state_dict(torch.load(self.model_path, map_location=self.device)) self.model.eval() def inference(self, images, detections): if isinstance(images, np.ndarray): features = self.inference_image(images, detections) return features batch_patches = [] patches = [] for i, img in enumerate(images): img = img.copy() patch = self.transform(img) if str(self.device) != "cpu": patch = patch.to(device=self.device).half() else: patch = patch.to(device=self.device) patches.append(patch) if (i + 1) % self.batch_size == 0: patches = torch.stack(patches, dim=0) batch_patches.append(patches) patches = [] if len(patches): patches = torch.stack(patches, dim=0) batch_patches.append(patches) features = np.zeros((0, self.embedding_size)) for patches in batch_patches: pred=self.model(patches) pred[torch.isinf(pred)] = 1.0 feat = pred.cpu().data.numpy() features = np.vstack((features, feat)) return features def inference_image(self, image, detections): H, W, _ = np.shape(image) batch_patches = [] patches = [] for d in range(np.size(detections, 0)): tlbr = detections[d, :4].astype(np.int_) tlbr[0] = max(0, tlbr[0]) tlbr[1] = max(0, tlbr[1]) tlbr[2] = min(W - 1, tlbr[2]) tlbr[3] = min(H - 1, tlbr[3]) img = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :] img = img[:, :, ::-1].copy() # the model expects RGB inputs patch = self.transform(img) # patch = patch.to(device=self.device).half() if str(self.device) != "cpu": patch = patch.to(device=self.device).half() else: patch = patch.to(device=self.device) patches.append(patch) if (d + 1) % self.batch_size == 0: patches = torch.stack(patches, dim=0) batch_patches.append(patches) patches = [] if len(patches): patches = torch.stack(patches, dim=0) batch_patches.append(patches) features = np.zeros((0, self.embedding_size)) for patches in batch_patches: pred = self.model(patches) pred[torch.isinf(pred)] = 1.0 feat = pred.cpu().data.numpy() features = np.vstack((features, feat)) return features