Files
ieemoo-ai-review/utils/model_init.py
2025-01-22 11:47:02 +08:00

60 lines
2.4 KiB
Python

import os
import sys
import torch
from pathlib import Path
import torch.nn as nn
from utils.config import config as conf
from collections import OrderedDict
from transformers import Qwen2VLForConditionalGeneration
from detecttracking.contrast.feat_extract.model import resnet18
from detecttracking.utils.torch_utils import select_device
from detecttracking.models.common import DetectMultiBackend
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
class InitModel:
def __init__(self):
self.data = ROOT / 'data/coco128.yaml'
self.device = conf.device
self.curpath = Path(__file__).resolve().parents[0]
self.yolo_model = self.init_yolo_model()
self.resnet_model = self.init_resnet_model()
self.qwen_model, self.processor = self.init_qwen_mdoel()
def init_yolo_model(self):
# device = select_device('')
yolo_model = DetectMultiBackend(conf.yolo_model, device=self.device, dnn=False, data=self.data, fp16=False)
return yolo_model
def init_resnet_model(self):
# self.device = conf.device
resnet_model = resnet18().to(self.device)
# resnet_mod_path = os.path.join(self.curpath, conf.resnet_model)
try:
resnet_model.load_state_dict(torch.load(conf.resnet_model, map_location=self.device))
except Exception as e:
resnet_model = resnet_model.to(self.device)
# resnet_model = resnet_model.to(torch.device('cpu'))
checkpoint = torch.load(conf.resnet_model)
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = k[7:] # remove "module."
new_state_dict[name] = v
resnet_model.load_state_dict(new_state_dict)
resnet_model.eval()
return resnet_model
def init_qwen_mdoel(self):
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype="auto",
device_map="auto"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
return qwen_model, processor
initModel = InitModel()