first commit

This commit is contained in:
lee
2025-01-22 11:47:02 +08:00
commit 2320468c40
18 changed files with 654 additions and 0 deletions

27
utils/config.py Normal file
View File

@ -0,0 +1,27 @@
import torch
import torchvision.transforms.functional as F
import torchvision.transforms as T
def pad_to_square(img):
w, h = img.size
max_wh = max(w, h)
padding = [0, 0, max_wh - w, max_wh - h] # (left, top, right, bottom)
return F.pad(img, padding, fill=0, padding_mode='constant')
class Config:
# network settings
resnet_model = './detecttracking/contrast/feat_extract/checkpoints/resnet18_0515/v11.pth'
yolo_model = './detecttracking/tracking/ckpts/best_cls10_0906.pt'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
embedding_size = 256
batch_size = 8
img_size = 224
test_transform = T.Compose([
T.ToTensor(),
T.Resize((img_size, img_size)),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[0.5], std=[0.5]),
])
config = Config()

5
utils/load_model.py Normal file
View File

@ -0,0 +1,5 @@
# Load model directly
from transformers import AutoProcessor, AutoModelForImageTextToText
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

59
utils/model_init.py Normal file
View File

@ -0,0 +1,59 @@
import os
import sys
import torch
from pathlib import Path
import torch.nn as nn
from utils.config import config as conf
from collections import OrderedDict
from transformers import Qwen2VLForConditionalGeneration
from detecttracking.contrast.feat_extract.model import resnet18
from detecttracking.utils.torch_utils import select_device
from detecttracking.models.common import DetectMultiBackend
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
class InitModel:
def __init__(self):
self.data = ROOT / 'data/coco128.yaml'
self.device = conf.device
self.curpath = Path(__file__).resolve().parents[0]
self.yolo_model = self.init_yolo_model()
self.resnet_model = self.init_resnet_model()
self.qwen_model, self.processor = self.init_qwen_mdoel()
def init_yolo_model(self):
# device = select_device('')
yolo_model = DetectMultiBackend(conf.yolo_model, device=self.device, dnn=False, data=self.data, fp16=False)
return yolo_model
def init_resnet_model(self):
# self.device = conf.device
resnet_model = resnet18().to(self.device)
# resnet_mod_path = os.path.join(self.curpath, conf.resnet_model)
try:
resnet_model.load_state_dict(torch.load(conf.resnet_model, map_location=self.device))
except Exception as e:
resnet_model = resnet_model.to(self.device)
# resnet_model = resnet_model.to(torch.device('cpu'))
checkpoint = torch.load(conf.resnet_model)
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = k[7:] # remove "module."
new_state_dict[name] = v
resnet_model.load_state_dict(new_state_dict)
resnet_model.eval()
return resnet_model
def init_qwen_mdoel(self):
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype="auto",
device_map="auto"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
return qwen_model, processor
initModel = InitModel()