From 2320468c40d4bccf5501c2b9726a20f581cac501 Mon Sep 17 00:00:00 2001
From: lee <770918727@qq.com>
Date: Wed, 22 Jan 2025 11:47:02 +0800
Subject: [PATCH] first commit
---
.idea/deployment.xml | 43 +++++
.idea/ieemoo-ai-review.iml | 19 +++
.idea/inspectionProfiles/Project_Default.xml | 12 ++
.../inspectionProfiles/profiles_settings.xml | 6 +
.idea/misc.xml | 10 ++
.idea/modules.xml | 8 +
.idea/sshConfigs.xml | 8 +
.idea/vcs.xml | 6 +
.idea/webServers.xml | 14 ++
README.md | 0
app.py | 92 ++++++++++
client.py | 23 +++
detecttracking | 1 +
llm/qwe_agent.py | 161 ++++++++++++++++++
llm/qwe_agent_old.py | 160 +++++++++++++++++
utils/config.py | 27 +++
utils/load_model.py | 5 +
utils/model_init.py | 59 +++++++
18 files changed, 654 insertions(+)
create mode 100644 .idea/deployment.xml
create mode 100644 .idea/ieemoo-ai-review.iml
create mode 100644 .idea/inspectionProfiles/Project_Default.xml
create mode 100644 .idea/inspectionProfiles/profiles_settings.xml
create mode 100644 .idea/misc.xml
create mode 100644 .idea/modules.xml
create mode 100644 .idea/sshConfigs.xml
create mode 100644 .idea/vcs.xml
create mode 100644 .idea/webServers.xml
create mode 100644 README.md
create mode 100644 app.py
create mode 100644 client.py
create mode 160000 detecttracking
create mode 100644 llm/qwe_agent.py
create mode 100644 llm/qwe_agent_old.py
create mode 100644 utils/config.py
create mode 100644 utils/load_model.py
create mode 100644 utils/model_init.py
diff --git a/.idea/deployment.xml b/.idea/deployment.xml
new file mode 100644
index 0000000..c3dc9f3
--- /dev/null
+++ b/.idea/deployment.xml
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/ieemoo-ai-review.iml b/.idea/ieemoo-ai-review.iml
new file mode 100644
index 0000000..11e051d
--- /dev/null
+++ b/.idea/ieemoo-ai-review.iml
@@ -0,0 +1,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..920d523
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..6a7dcec
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..cef2b4d
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/sshConfigs.xml b/.idea/sshConfigs.xml
new file mode 100644
index 0000000..e2bf899
--- /dev/null
+++ b/.idea/sshConfigs.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..3d9e9e0
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/webServers.xml b/.idea/webServers.xml
new file mode 100644
index 0000000..906a263
--- /dev/null
+++ b/.idea/webServers.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..e69de29
diff --git a/app.py b/app.py
new file mode 100644
index 0000000..fd0b968
--- /dev/null
+++ b/app.py
@@ -0,0 +1,92 @@
+from flask import Flask, request
+import requests
+import json
+import time
+import logging
+from PIL import Image
+from io import BytesIO
+from utils.model_init import initModel
+from detecttracking.stream_pipeline import stream_pipeline
+from logging.handlers import TimedRotatingFileHandler
+from llm.qwe_agent import get_product_description
+import pdb
+
+app = Flask(__name__)
+# 配置日志处理器
+log_handler = TimedRotatingFileHandler('./log/aiReview.log', when='midnight', interval=90, backupCount=1)
+log_handler.suffix = "%Y-%m-%d"
+log_handler.setLevel(logging.INFO)
+log_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+log_handler.setFormatter(log_formatter)
+
+# 获取根日志记录器并添加处理器
+root_logger = logging.getLogger()
+root_logger.setLevel(logging.INFO)
+root_logger.addHandler(log_handler)
+
+# data = {
+# "goodsName": "优诺优丝黄桃果粒风味发酵乳",
+# "measureProperty": 0,
+# "qty": 1,
+# "price": 25.9,
+# "weight": 560, # 单位克
+# "barcode": "6931806801024",
+# # "video": "https://resources.cos.yimaogo.com/bl/3203600/54:78:c9:a4:8c:5e/video/411173317367614619680.mp4",
+# "video": "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
+# "goodsPic": "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
+# "measureUnit": "组",
+# "goodsSpec": "405g"
+# }
+
+
+def item_analysis(stream_dict):
+ assert initModel.resnet_model is not None, "resnetModel is None"
+ assert initModel.yolo_model is not None, "yoloModel is None"
+ track_imgs = stream_pipeline(stream_dict, initModel.resnet_model, initModel.yolo_model)
+ std_img = None
+ if stream_dict['goodsPic'] is not None:
+ response = requests.get(stream_dict['goodsPic'])
+ std_img = Image.open(BytesIO(response.content))
+ description_dict = get_product_description(std_img, track_imgs, initModel.qwen_model, initModel.processor)
+ print(description_dict)
+ return description_dict
+
+
+@app.route('/ai_review', methods=['POST'])
+def aiReview(): # put application's code here
+ start = time.time()
+ data = request.get_data()
+ data = json.loads(data)
+ video_url = data['video']
+ goods_pic_url = data['goodsPic']
+ v_reponse = requests.get(video_url)
+ p_reponse = requests.get(goods_pic_url)
+ if v_reponse.status_code == 200:
+ logging.info(f'video:{video_url} download success')
+ else:
+ video_url = None
+ logging.error(f'video:{video_url} download fail')
+ if p_reponse.status_code == 200:
+ logging.info(f'video:{goods_pic_url} download success')
+ else:
+ goods_pic_url = None
+ logging.error(f'video:{goods_pic_url} download fail')
+
+ for key, value in data.items():
+ if not value:
+ data[key] = None
+ logging.error(f'{key} is null')
+ # stream_pipeline(data)
+ item_analysis(data)
+ end = time.time()
+ logging.info(f'aiReview cost {end - start}s')
+ return 0
+
+
+# def main():
+# item_analysis(data)
+
+
+if __name__ == '__main__':
+ # main()
+ app.run('0.0.0.0', port=8060)
diff --git a/client.py b/client.py
new file mode 100644
index 0000000..8dbb4f4
--- /dev/null
+++ b/client.py
@@ -0,0 +1,23 @@
+import requests
+import json
+def aiReviewClient():
+ data = {
+ "goodsName": "优诺优丝黄桃果粒风味发酵乳",
+ "measureProperty": 0,
+ "qty": 1,
+ "price": 25.9,
+ "weight": 560, # 单位克
+ "barcode": "6931806801024",
+ # "video": "https://resources.cos.yimaogo.com/bl/3203600/54:78:c9:a4:8c:5e/video/411173317367614619680.mp4",
+ "video": "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
+ "goodsPic": "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
+ "measureUnit": "组",
+ "goodsSpec": "405g"
+ }
+ url = "http://192.168.1.28:8060/ai_review"
+ r = requests.post(url=url, data=json.dumps(data))
+ print(r.text)
+
+
+if __name__ == '__main__':
+ aiReviewClient()
diff --git a/detecttracking b/detecttracking
new file mode 160000
index 0000000..2feedd6
--- /dev/null
+++ b/detecttracking
@@ -0,0 +1 @@
+Subproject commit 2feedd622d419657b43fdd71c5e27bfa65baae3c
diff --git a/llm/qwe_agent.py b/llm/qwe_agent.py
new file mode 100644
index 0000000..06ed62e
--- /dev/null
+++ b/llm/qwe_agent.py
@@ -0,0 +1,161 @@
+# from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
+from detecttracking.stream_pipeline import stream_pipeline
+from PIL import Image
+from io import BytesIO
+import torch
+import ast
+import requests
+
+# # default: Load the model on the available device(s)
+# model = Qwen2VLForConditionalGeneration.from_pretrained(
+# "Qwen/Qwen2-VL-7B-Instruct",
+# torch_dtype="auto",
+# device_map="auto"
+# )
+#
+# # default processer
+# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
+
+
+def qwen_prompt(img_list, messages, model=None, processor=None):
+ assert model is not None and processor is not None, "model and processor must be provided"
+ # Preparation for inference
+ text = processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ inputs = processor(
+ text=[text],
+ images=img_list,
+ padding=True,
+ return_tensors="pt",
+ )
+ inputs = inputs.to("cuda")
+
+ # Inference: Generation of the output
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
+ generated_ids_trimmed = [
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ output_text = processor.batch_decode(
+ generated_ids_trimmed, add_special_tokens=False, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+
+ return output_text[0]
+
+
+def get_best_image(track_imgs, model, processor):
+ img_frames = []
+ for i in range(len(track_imgs)):
+ content = {}
+ content['type'] = 'image'
+ content['min_pixels'] = 224 * 224
+ content['max_pixels'] = 1280 * 28 * 28
+ img_frames.append(content)
+
+ messages = [
+ {
+ "role": "system",
+ "content": "你是一个在超市工作的chatbot,你现在需要帮助顾客找到一张质量最好的商品图像。一个好的商品图像需要满足以下条件: \
+ 1. 文字清晰且连贯。\
+ 2. 商品图案清晰可识别。\
+ 3. 商品可提取的描述信息丰富。\
+ 基于以上条件,从多张图像中筛选出最好的图像,然后以dict输出该图像的索引信息,key为'index'。"
+ },
+ {
+ "role": "system",
+ "content": img_frames,
+ },
+ ]
+
+ output_text = qwen_prompt(track_imgs, messages, model, processor)
+ output_dict = ast.literal_eval(output_text.strip('```python\n'))
+ best_img = track_imgs[output_dict['index'] - 1]
+
+ return best_img
+
+
+def get_product_description(std_img, track_imgs, model, processor):
+ messages = [
+ {
+ "role": "system",
+ "content": "你是一个在超市工作的chatbot,你现在需要提取商品的信息,信息需要按照以下python dict的格式输出: \
+ {\
+ 'Text': 商品中提取出的文字信息, \
+ 'Color': 商品的颜色, \
+ 'Shape': 商品的形状, \
+ 'Material': 商品的材质, \
+ 'Category': 商品的类别, \
+ 'is_Same': 如果比对的两件商品的['Text', 'Color', 'Shape', 'Material', 'Category']属性中至少有3个相同则输出True,\
+ 否则输出False, \
+ } \
+ "
+ },
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "image",
+ "min_pixels": 224 * 224,
+ "max_pixels": 1280 * 28 * 28,
+ },
+ ],
+ },
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "image",
+ "min_pixels": 224 * 224,
+ "max_pixels": 1280 * 28 * 28,
+ },
+ ],
+ },
+ {
+ "role": "user",
+ "content": "以python dict的形式输出第二张图像的比对信息。"
+ }
+ ]
+ best_img = get_best_image(track_imgs, model, processor)
+ if std_img is not None:
+ img_list = [std_img, best_img]
+ else:
+ img_list = [best_img, best_img]
+
+ output_text = qwen_prompt(img_list, messages, model, processor)
+ contrast_pair = ast.literal_eval(output_text.strip('```python\n'))
+
+ return contrast_pair
+
+
+def item_analysis(stream_dict, model, processor):
+ track_imgs = stream_pipeline(stream_dict)
+ std_img = None
+ if stream_dict['goodsPic'] is not None:
+ response = requests.get(stream_dict['goodsPic'])
+ std_img = Image.open(BytesIO(response.content))
+ description_dict = get_product_description(std_img, track_imgs, model, processor)
+
+ return description_dict
+
+
+def main():
+ # sample input dict
+ stream_dict = {
+ "goodsName": "优诺优丝黄桃果粒风味发酵乳",
+ "measureProperty": 0,
+ "qty": 1,
+ "price": 25.9,
+ "weight": 560, # 单位克
+ "barcode": "6931806801024",
+ "video": "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
+ "goodsPic": "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
+ "measureUnit": "组",
+ "goodsSpec": "405g"
+ }
+
+ result = item_analysis(stream_dict)
+ print(result)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/llm/qwe_agent_old.py b/llm/qwe_agent_old.py
new file mode 100644
index 0000000..b8e38fd
--- /dev/null
+++ b/llm/qwe_agent_old.py
@@ -0,0 +1,160 @@
+from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
+from detecttracking.stream_pipeline import stream_pipeline
+from PIL import Image
+from io import BytesIO
+import torch
+import ast
+import requests
+
+# default: Load the model on the available device(s)
+model = Qwen2VLForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen2-VL-7B-Instruct",
+ torch_dtype="auto",
+ device_map="auto"
+)
+
+# default processer
+processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
+
+
+def qwen_prompt(img_list, messages):
+ # Preparation for inference
+ text = processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ inputs = processor(
+ text=[text],
+ images=img_list,
+ padding=True,
+ return_tensors="pt",
+ )
+ inputs = inputs.to("cuda")
+
+ # Inference: Generation of the output
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
+ generated_ids_trimmed = [
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ output_text = processor.batch_decode(
+ generated_ids_trimmed, add_special_tokens=False, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+
+ return output_text[0]
+
+
+def get_best_image(track_imgs):
+ img_frames = []
+ for i in range(len(track_imgs)):
+ content = {}
+ content['type'] = 'image'
+ content['min_pixels'] = 224 * 224
+ content['max_pixels'] = 1280 * 28 * 28
+ img_frames.append(content)
+
+ messages = [
+ {
+ "role": "system",
+ "content": "你是一个在超市工作的chatbot,你现在需要帮助顾客找到一张质量最好的商品图像。一个好的商品图像需要满足以下条件: \
+ 1. 文字清晰且连贯。\
+ 2. 商品图案清晰可识别。\
+ 3. 商品可提取的描述信息丰富。\
+ 基于以上条件,从多张图像中筛选出最好的图像,然后以dict输出该图像的索引信息,key为'index'。"
+ },
+ {
+ "role": "system",
+ "content": img_frames,
+ },
+ ]
+
+ output_text = qwen_prompt(track_imgs, messages)
+ output_dict = ast.literal_eval(output_text.strip('```python\n'))
+ best_img = track_imgs[output_dict['index'] - 1]
+
+ return best_img
+
+
+def get_product_description(std_img, track_imgs):
+ messages = [
+ {
+ "role": "system",
+ "content": "你是一个在超市工作的chatbot,你现在需要提取商品的信息,信息需要按照以下python dict的格式输出: \
+ {\
+ 'Text': 商品中提取出的文字信息, \
+ 'Color': 商品的颜色, \
+ 'Shape': 商品的形状, \
+ 'Material': 商品的材质, \
+ 'Category': 商品的类别, \
+ 'is_Same': 如果比对的两件商品的['Text', 'Color', 'Shape', 'Material', 'Category']属性中至少有3个相同则输出True,\
+ 否则输出False, \
+ } \
+ "
+ },
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "image",
+ "min_pixels": 224 * 224,
+ "max_pixels": 1280 * 28 * 28,
+ },
+ ],
+ },
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "image",
+ "min_pixels": 224 * 224,
+ "max_pixels": 1280 * 28 * 28,
+ },
+ ],
+ },
+ {
+ "role": "user",
+ "content": "以python dict的形式输出第二张图像的比对信息。"
+ }
+ ]
+ best_img = get_best_image(track_imgs)
+ if std_img is not None:
+ img_list = [std_img, best_img]
+ else:
+ img_list = [best_img, best_img]
+
+ output_text = qwen_prompt(img_list, messages)
+ contrast_pair = ast.literal_eval(output_text.strip('```python\n'))
+
+ return contrast_pair
+
+
+def item_analysis(stream_dict):
+ track_imgs = stream_pipeline(stream_dict)
+ std_img = None
+ if stream_dict['goodsPic'] is not None:
+ response = requests.get(stream_dict['goodsPic'])
+ std_img = Image.open(BytesIO(response.content))
+ description_dict = get_product_description(std_img, track_imgs)
+
+ return description_dict
+
+
+def main():
+ # sample input dict
+ stream_dict = {
+ "goodsName": "优诺优丝黄桃果粒风味发酵乳",
+ "measureProperty": 0,
+ "qty": 1,
+ "price": 25.9,
+ "weight": 560, # 单位克
+ "barcode": "6931806801024",
+ "video": "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
+ "goodsPic": "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
+ "measureUnit": "组",
+ "goodsSpec": "405g"
+ }
+
+ result = item_analysis(stream_dict)
+ print(result)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/utils/config.py b/utils/config.py
new file mode 100644
index 0000000..434b53a
--- /dev/null
+++ b/utils/config.py
@@ -0,0 +1,27 @@
+import torch
+import torchvision.transforms.functional as F
+import torchvision.transforms as T
+
+
+def pad_to_square(img):
+ w, h = img.size
+ max_wh = max(w, h)
+ padding = [0, 0, max_wh - w, max_wh - h] # (left, top, right, bottom)
+ return F.pad(img, padding, fill=0, padding_mode='constant')
+
+class Config:
+ # network settings
+ resnet_model = './detecttracking/contrast/feat_extract/checkpoints/resnet18_0515/v11.pth'
+ yolo_model = './detecttracking/tracking/ckpts/best_cls10_0906.pt'
+ device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
+ embedding_size = 256
+ batch_size = 8
+ img_size = 224
+ test_transform = T.Compose([
+ T.ToTensor(),
+ T.Resize((img_size, img_size)),
+ T.ConvertImageDtype(torch.float32),
+ T.Normalize(mean=[0.5], std=[0.5]),
+ ])
+
+config = Config()
diff --git a/utils/load_model.py b/utils/load_model.py
new file mode 100644
index 0000000..1225629
--- /dev/null
+++ b/utils/load_model.py
@@ -0,0 +1,5 @@
+# Load model directly
+from transformers import AutoProcessor, AutoModelForImageTextToText
+
+processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
+model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
\ No newline at end of file
diff --git a/utils/model_init.py b/utils/model_init.py
new file mode 100644
index 0000000..6847b0e
--- /dev/null
+++ b/utils/model_init.py
@@ -0,0 +1,59 @@
+import os
+import sys
+import torch
+from pathlib import Path
+import torch.nn as nn
+from utils.config import config as conf
+from collections import OrderedDict
+from transformers import Qwen2VLForConditionalGeneration
+from detecttracking.contrast.feat_extract.model import resnet18
+from detecttracking.utils.torch_utils import select_device
+from detecttracking.models.common import DetectMultiBackend
+from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
+FILE = Path(__file__).resolve()
+ROOT = FILE.parents[0] # YOLOv5 root directory
+if str(ROOT) not in sys.path:
+ sys.path.append(str(ROOT)) # add ROOT to PATH
+ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
+
+class InitModel:
+ def __init__(self):
+ self.data = ROOT / 'data/coco128.yaml'
+ self.device = conf.device
+ self.curpath = Path(__file__).resolve().parents[0]
+ self.yolo_model = self.init_yolo_model()
+ self.resnet_model = self.init_resnet_model()
+ self.qwen_model, self.processor = self.init_qwen_mdoel()
+
+ def init_yolo_model(self):
+ # device = select_device('')
+ yolo_model = DetectMultiBackend(conf.yolo_model, device=self.device, dnn=False, data=self.data, fp16=False)
+ return yolo_model
+
+ def init_resnet_model(self):
+ # self.device = conf.device
+ resnet_model = resnet18().to(self.device)
+ # resnet_mod_path = os.path.join(self.curpath, conf.resnet_model)
+ try:
+ resnet_model.load_state_dict(torch.load(conf.resnet_model, map_location=self.device))
+ except Exception as e:
+ resnet_model = resnet_model.to(self.device)
+ # resnet_model = resnet_model.to(torch.device('cpu'))
+ checkpoint = torch.load(conf.resnet_model)
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint.items():
+ name = k[7:] # remove "module."
+ new_state_dict[name] = v
+ resnet_model.load_state_dict(new_state_dict)
+ resnet_model.eval()
+ return resnet_model
+ def init_qwen_mdoel(self):
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen2-VL-7B-Instruct",
+ torch_dtype="auto",
+ device_map="auto"
+ )
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
+ return qwen_model, processor
+
+initModel = InitModel()