import os import sys import torch from pathlib import Path import torch.nn as nn from utils.config import config as conf from collections import OrderedDict from transformers import Qwen2VLForConditionalGeneration from detecttracking.contrast.feat_extract.model import resnet18 from detecttracking.utils.torch_utils import select_device from detecttracking.models.common import DetectMultiBackend from transformers import Qwen2VLForConditionalGeneration, AutoProcessor FILE = Path(__file__).resolve() ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative class InitModel: def __init__(self): self.data = ROOT / 'data/coco128.yaml' self.device = conf.device self.curpath = Path(__file__).resolve().parents[0] self.yolo_model = self.init_yolo_model() self.resnet_model = self.init_resnet_model() self.qwen_model, self.processor = self.init_qwen_mdoel() def init_yolo_model(self): # device = select_device('') yolo_model = DetectMultiBackend(conf.yolo_model, device=self.device, dnn=False, data=self.data, fp16=False) return yolo_model def init_resnet_model(self): # self.device = conf.device resnet_model = resnet18().to(self.device) # resnet_mod_path = os.path.join(self.curpath, conf.resnet_model) try: resnet_model.load_state_dict(torch.load(conf.resnet_model, map_location=self.device)) except Exception as e: resnet_model = resnet_model.to(self.device) # resnet_model = resnet_model.to(torch.device('cpu')) checkpoint = torch.load(conf.resnet_model) new_state_dict = OrderedDict() for k, v in checkpoint.items(): name = k[7:] # remove "module." new_state_dict[name] = v resnet_model.load_state_dict(new_state_dict) resnet_model.eval() return resnet_model def init_qwen_mdoel(self): qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" ) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2") return qwen_model, processor initModel = InitModel()