46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
import torch
|
|
from ytracking.models.common import DetectMultiBackend
|
|
from ytracking.utils.torch_utils import select_device
|
|
from tools.config import cfg
|
|
from contrast.model.resnet_pre import resnet18
|
|
from ytracking.tracking.utils import Boxes, IterableSimpleNamespace, yaml_load
|
|
from ytracking.tracking.trackers import BOTSORT, BYTETracker
|
|
# import mediapipe as mp
|
|
# from pymilvus import (
|
|
# connections,
|
|
# utility,
|
|
# FieldSchema, CollectionSchema, DataType,
|
|
# Collection,
|
|
# Milvus
|
|
# )
|
|
|
|
class Models:
|
|
def __init__(self):
|
|
self.yoloModel = None
|
|
self.reidModel = None
|
|
self.similarityModel = None
|
|
self.Milvus = None
|
|
self.device = 'cpu'
|
|
|
|
def initSimilarityModel(self):
|
|
# model = MobileNetV3_Large().to(cfg.device)
|
|
model = resnet18().to(cfg.device)
|
|
# model.load_state_dict(torch.load(cfg.test_model, map_location=cfg.device))
|
|
model.load_state_dict(torch.load(cfg.model_path, map_location=cfg.device))
|
|
model.eval()
|
|
return model
|
|
|
|
def initYoloModel(self):
|
|
device = select_device(self.device)
|
|
model = DetectMultiBackend(cfg.tracking_model, device=device, dnn=False, fp16=False)
|
|
return model
|
|
|
|
def initModel(self):
|
|
self.yoloModel = self.initYoloModel()
|
|
self.similarityModel = self.initSimilarityModel()
|
|
|
|
models = Models()
|
|
|
|
if __name__ == "__main__":
|
|
Models().initModel()
|