first commit
This commit is contained in:
59
utils/model_init.py
Normal file
59
utils/model_init.py
Normal file
@ -0,0 +1,59 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import torch.nn as nn
|
||||
from utils.config import config as conf
|
||||
from collections import OrderedDict
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
from detecttracking.contrast.feat_extract.model import resnet18
|
||||
from detecttracking.utils.torch_utils import select_device
|
||||
from detecttracking.models.common import DetectMultiBackend
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
class InitModel:
|
||||
def __init__(self):
|
||||
self.data = ROOT / 'data/coco128.yaml'
|
||||
self.device = conf.device
|
||||
self.curpath = Path(__file__).resolve().parents[0]
|
||||
self.yolo_model = self.init_yolo_model()
|
||||
self.resnet_model = self.init_resnet_model()
|
||||
self.qwen_model, self.processor = self.init_qwen_mdoel()
|
||||
|
||||
def init_yolo_model(self):
|
||||
# device = select_device('')
|
||||
yolo_model = DetectMultiBackend(conf.yolo_model, device=self.device, dnn=False, data=self.data, fp16=False)
|
||||
return yolo_model
|
||||
|
||||
def init_resnet_model(self):
|
||||
# self.device = conf.device
|
||||
resnet_model = resnet18().to(self.device)
|
||||
# resnet_mod_path = os.path.join(self.curpath, conf.resnet_model)
|
||||
try:
|
||||
resnet_model.load_state_dict(torch.load(conf.resnet_model, map_location=self.device))
|
||||
except Exception as e:
|
||||
resnet_model = resnet_model.to(self.device)
|
||||
# resnet_model = resnet_model.to(torch.device('cpu'))
|
||||
checkpoint = torch.load(conf.resnet_model)
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint.items():
|
||||
name = k[7:] # remove "module."
|
||||
new_state_dict[name] = v
|
||||
resnet_model.load_state_dict(new_state_dict)
|
||||
resnet_model.eval()
|
||||
return resnet_model
|
||||
def init_qwen_mdoel(self):
|
||||
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-7B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
|
||||
return qwen_model, processor
|
||||
|
||||
initModel = InitModel()
|
Reference in New Issue
Block a user