first commit
This commit is contained in:
27
utils/config.py
Normal file
27
utils/config.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
def pad_to_square(img):
|
||||
w, h = img.size
|
||||
max_wh = max(w, h)
|
||||
padding = [0, 0, max_wh - w, max_wh - h] # (left, top, right, bottom)
|
||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||
|
||||
class Config:
|
||||
# network settings
|
||||
resnet_model = './detecttracking/contrast/feat_extract/checkpoints/resnet18_0515/v11.pth'
|
||||
yolo_model = './detecttracking/tracking/ckpts/best_cls10_0906.pt'
|
||||
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
|
||||
embedding_size = 256
|
||||
batch_size = 8
|
||||
img_size = 224
|
||||
test_transform = T.Compose([
|
||||
T.ToTensor(),
|
||||
T.Resize((img_size, img_size)),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[0.5], std=[0.5]),
|
||||
])
|
||||
|
||||
config = Config()
|
5
utils/load_model.py
Normal file
5
utils/load_model.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Load model directly
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
59
utils/model_init.py
Normal file
59
utils/model_init.py
Normal file
@ -0,0 +1,59 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import torch.nn as nn
|
||||
from utils.config import config as conf
|
||||
from collections import OrderedDict
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
from detecttracking.contrast.feat_extract.model import resnet18
|
||||
from detecttracking.utils.torch_utils import select_device
|
||||
from detecttracking.models.common import DetectMultiBackend
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
class InitModel:
|
||||
def __init__(self):
|
||||
self.data = ROOT / 'data/coco128.yaml'
|
||||
self.device = conf.device
|
||||
self.curpath = Path(__file__).resolve().parents[0]
|
||||
self.yolo_model = self.init_yolo_model()
|
||||
self.resnet_model = self.init_resnet_model()
|
||||
self.qwen_model, self.processor = self.init_qwen_mdoel()
|
||||
|
||||
def init_yolo_model(self):
|
||||
# device = select_device('')
|
||||
yolo_model = DetectMultiBackend(conf.yolo_model, device=self.device, dnn=False, data=self.data, fp16=False)
|
||||
return yolo_model
|
||||
|
||||
def init_resnet_model(self):
|
||||
# self.device = conf.device
|
||||
resnet_model = resnet18().to(self.device)
|
||||
# resnet_mod_path = os.path.join(self.curpath, conf.resnet_model)
|
||||
try:
|
||||
resnet_model.load_state_dict(torch.load(conf.resnet_model, map_location=self.device))
|
||||
except Exception as e:
|
||||
resnet_model = resnet_model.to(self.device)
|
||||
# resnet_model = resnet_model.to(torch.device('cpu'))
|
||||
checkpoint = torch.load(conf.resnet_model)
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint.items():
|
||||
name = k[7:] # remove "module."
|
||||
new_state_dict[name] = v
|
||||
resnet_model.load_state_dict(new_state_dict)
|
||||
resnet_model.eval()
|
||||
return resnet_model
|
||||
def init_qwen_mdoel(self):
|
||||
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-7B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
|
||||
return qwen_model, processor
|
||||
|
||||
initModel = InitModel()
|
Reference in New Issue
Block a user