import torch from ytracking.models.common import DetectMultiBackend from ytracking.utils.torch_utils import select_device from tools.config import cfg from contrast.model.resnet_pre import resnet18 from ytracking.tracking.utils import Boxes, IterableSimpleNamespace, yaml_load from ytracking.tracking.trackers import BOTSORT, BYTETracker # import mediapipe as mp # from pymilvus import ( # connections, # utility, # FieldSchema, CollectionSchema, DataType, # Collection, # Milvus # ) class Models: def __init__(self): self.yoloModel = None self.reidModel = None self.similarityModel = None self.Milvus = None self.device = 'cpu' def initSimilarityModel(self): # model = MobileNetV3_Large().to(cfg.device) model = resnet18().to(cfg.device) # model.load_state_dict(torch.load(cfg.test_model, map_location=cfg.device)) model.load_state_dict(torch.load(cfg.model_path, map_location=cfg.device)) model.eval() return model def initYoloModel(self): device = select_device(self.device) model = DetectMultiBackend(cfg.tracking_model, device=device, dnn=False, fp16=False) return model def initModel(self): self.yoloModel = self.initYoloModel() self.similarityModel = self.initSimilarityModel() models = Models() if __name__ == "__main__": Models().initModel()