Files
ieemoo-ai-imageassessment/tools/initModel.py
2024-11-27 15:37:10 +08:00

46 lines
1.4 KiB
Python

import torch
from ytracking.models.common import DetectMultiBackend
from ytracking.utils.torch_utils import select_device
from tools.config import cfg
from contrast.model.resnet_pre import resnet18
from ytracking.tracking.utils import Boxes, IterableSimpleNamespace, yaml_load
from ytracking.tracking.trackers import BOTSORT, BYTETracker
# import mediapipe as mp
# from pymilvus import (
# connections,
# utility,
# FieldSchema, CollectionSchema, DataType,
# Collection,
# Milvus
# )
class Models:
def __init__(self):
self.yoloModel = None
self.reidModel = None
self.similarityModel = None
self.Milvus = None
self.device = 'cpu'
def initSimilarityModel(self):
# model = MobileNetV3_Large().to(cfg.device)
model = resnet18().to(cfg.device)
# model.load_state_dict(torch.load(cfg.test_model, map_location=cfg.device))
model.load_state_dict(torch.load(cfg.model_path, map_location=cfg.device))
model.eval()
return model
def initYoloModel(self):
device = select_device(self.device)
model = DetectMultiBackend(cfg.tracking_model, device=device, dnn=False, fp16=False)
return model
def initModel(self):
self.yoloModel = self.initYoloModel()
self.similarityModel = self.initSimilarityModel()
models = Models()
if __name__ == "__main__":
Models().initModel()