update
This commit is contained in:
45
tools/initModel.py
Normal file
45
tools/initModel.py
Normal file
@ -0,0 +1,45 @@
|
||||
import torch
|
||||
from ytracking.models.common import DetectMultiBackend
|
||||
from ytracking.utils.torch_utils import select_device
|
||||
from tools.config import cfg
|
||||
from contrast.model.resnet_pre import resnet18
|
||||
from ytracking.tracking.utils import Boxes, IterableSimpleNamespace, yaml_load
|
||||
from ytracking.tracking.trackers import BOTSORT, BYTETracker
|
||||
# import mediapipe as mp
|
||||
# from pymilvus import (
|
||||
# connections,
|
||||
# utility,
|
||||
# FieldSchema, CollectionSchema, DataType,
|
||||
# Collection,
|
||||
# Milvus
|
||||
# )
|
||||
|
||||
class Models:
|
||||
def __init__(self):
|
||||
self.yoloModel = None
|
||||
self.reidModel = None
|
||||
self.similarityModel = None
|
||||
self.Milvus = None
|
||||
self.device = 'cpu'
|
||||
|
||||
def initSimilarityModel(self):
|
||||
# model = MobileNetV3_Large().to(cfg.device)
|
||||
model = resnet18().to(cfg.device)
|
||||
# model.load_state_dict(torch.load(cfg.test_model, map_location=cfg.device))
|
||||
model.load_state_dict(torch.load(cfg.model_path, map_location=cfg.device))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def initYoloModel(self):
|
||||
device = select_device(self.device)
|
||||
model = DetectMultiBackend(cfg.tracking_model, device=device, dnn=False, fp16=False)
|
||||
return model
|
||||
|
||||
def initModel(self):
|
||||
self.yoloModel = self.initYoloModel()
|
||||
self.similarityModel = self.initSimilarityModel()
|
||||
|
||||
models = Models()
|
||||
|
||||
if __name__ == "__main__":
|
||||
Models().initModel()
|
Reference in New Issue
Block a user