60 lines
2.4 KiB
Python
60 lines
2.4 KiB
Python
import os
|
|
import sys
|
|
import torch
|
|
from pathlib import Path
|
|
import torch.nn as nn
|
|
from utils.config import config as conf
|
|
from collections import OrderedDict
|
|
from transformers import Qwen2VLForConditionalGeneration
|
|
from detecttracking.contrast.feat_extract.model import resnet18
|
|
from detecttracking.utils.torch_utils import select_device
|
|
from detecttracking.models.common import DetectMultiBackend
|
|
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
|
FILE = Path(__file__).resolve()
|
|
ROOT = FILE.parents[0] # YOLOv5 root directory
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.append(str(ROOT)) # add ROOT to PATH
|
|
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|
|
|
class InitModel:
|
|
def __init__(self):
|
|
self.data = ROOT / 'data/coco128.yaml'
|
|
self.device = conf.device
|
|
self.curpath = Path(__file__).resolve().parents[0]
|
|
self.yolo_model = self.init_yolo_model()
|
|
self.resnet_model = self.init_resnet_model()
|
|
self.qwen_model, self.processor = self.init_qwen_mdoel()
|
|
|
|
def init_yolo_model(self):
|
|
# device = select_device('')
|
|
yolo_model = DetectMultiBackend(conf.yolo_model, device=self.device, dnn=False, data=self.data, fp16=False)
|
|
return yolo_model
|
|
|
|
def init_resnet_model(self):
|
|
# self.device = conf.device
|
|
resnet_model = resnet18().to(self.device)
|
|
# resnet_mod_path = os.path.join(self.curpath, conf.resnet_model)
|
|
try:
|
|
resnet_model.load_state_dict(torch.load(conf.resnet_model, map_location=self.device))
|
|
except Exception as e:
|
|
resnet_model = resnet_model.to(self.device)
|
|
# resnet_model = resnet_model.to(torch.device('cpu'))
|
|
checkpoint = torch.load(conf.resnet_model)
|
|
new_state_dict = OrderedDict()
|
|
for k, v in checkpoint.items():
|
|
name = k[7:] # remove "module."
|
|
new_state_dict[name] = v
|
|
resnet_model.load_state_dict(new_state_dict)
|
|
resnet_model.eval()
|
|
return resnet_model
|
|
def init_qwen_mdoel(self):
|
|
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
"Qwen/Qwen2-VL-7B-Instruct",
|
|
torch_dtype="auto",
|
|
device_map="auto"
|
|
)
|
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
|
|
return qwen_model, processor
|
|
|
|
initModel = InitModel()
|