first commit

This commit is contained in:
lee
2025-01-22 11:47:02 +08:00
commit 2320468c40
18 changed files with 654 additions and 0 deletions

59
utils/model_init.py Normal file
View File

@ -0,0 +1,59 @@
import os
import sys
import torch
from pathlib import Path
import torch.nn as nn
from utils.config import config as conf
from collections import OrderedDict
from transformers import Qwen2VLForConditionalGeneration
from detecttracking.contrast.feat_extract.model import resnet18
from detecttracking.utils.torch_utils import select_device
from detecttracking.models.common import DetectMultiBackend
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
class InitModel:
def __init__(self):
self.data = ROOT / 'data/coco128.yaml'
self.device = conf.device
self.curpath = Path(__file__).resolve().parents[0]
self.yolo_model = self.init_yolo_model()
self.resnet_model = self.init_resnet_model()
self.qwen_model, self.processor = self.init_qwen_mdoel()
def init_yolo_model(self):
# device = select_device('')
yolo_model = DetectMultiBackend(conf.yolo_model, device=self.device, dnn=False, data=self.data, fp16=False)
return yolo_model
def init_resnet_model(self):
# self.device = conf.device
resnet_model = resnet18().to(self.device)
# resnet_mod_path = os.path.join(self.curpath, conf.resnet_model)
try:
resnet_model.load_state_dict(torch.load(conf.resnet_model, map_location=self.device))
except Exception as e:
resnet_model = resnet_model.to(self.device)
# resnet_model = resnet_model.to(torch.device('cpu'))
checkpoint = torch.load(conf.resnet_model)
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = k[7:] # remove "module."
new_state_dict[name] = v
resnet_model.load_state_dict(new_state_dict)
resnet_model.eval()
return resnet_model
def init_qwen_mdoel(self):
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype="auto",
device_map="auto"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
return qwen_model, processor
initModel = InitModel()