first commit
This commit is contained in:
43
.idea/deployment.xml
generated
Normal file
43
.idea/deployment.xml
generated
Normal file
@ -0,0 +1,43 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="ieemoo0169@192.168.1.28:22 password (3)" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="ieemoo0169@192.168.1.28:22 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/tmp/pycharm_project_630" local="$PROJECT_DIR$" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ieemoo0169@192.168.1.28:22 password (2)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/home/ieemoo0169/ieemoo-ai-review" local="$PROJECT_DIR$" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ieemoo0169@192.168.1.28:22 password (3)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/home/ieemoo0169/ieemoo-ai-review" local="$PROJECT_DIR$" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.1.184:22 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/home/lc/ieemoo-ai-review" local="$PROJECT_DIR$" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="大模型">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="ieemoo-ai-review" local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
<option name="myAutoUpload" value="ALWAYS" />
|
||||
</component>
|
||||
</project>
|
19
.idea/ieemoo-ai-review.iml
generated
Normal file
19
.idea/ieemoo-ai-review.iml
generated
Normal file
@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="Flask">
|
||||
<option name="enabled" value="true" />
|
||||
</component>
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Remote Python 3.9.21 (sftp://ieemoo0169@192.168.1.28:22/home/ieemoo0169/.conda/envs/py/bin/python3)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="TemplatesService">
|
||||
<option name="TEMPLATE_CONFIGURATION" value="Jinja2" />
|
||||
<option name="TEMPLATE_FOLDERS">
|
||||
<list>
|
||||
<option value="$MODULE_DIR$/../ieemoo-ai-review\templates" />
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
</module>
|
12
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
12
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,12 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<option name="ignoredErrors">
|
||||
<list>
|
||||
<option value="N803" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
10
.idea/misc.xml
generated
Normal file
10
.idea/misc.xml
generated
Normal file
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.8 (my_env)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.9.21 (sftp://ieemoo0169@192.168.1.28:22/home/ieemoo0169/.conda/envs/py/bin/python3)" project-jdk-type="Python SDK" />
|
||||
<component name="PyPackaging">
|
||||
<option name="earlyReleasesAsUpgrades" value="true" />
|
||||
</component>
|
||||
</project>
|
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/ieemoo-ai-review.iml" filepath="$PROJECT_DIR$/.idea/ieemoo-ai-review.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
8
.idea/sshConfigs.xml
generated
Normal file
8
.idea/sshConfigs.xml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="SshConfigs">
|
||||
<configs>
|
||||
<sshConfig authType="PASSWORD" host="192.168.1.28" id="2da10617-4199-45cd-9d9d-d0ef6db9be05" port="22" nameFormat="DESCRIPTIVE" username="ieemoo0169" useOpenSSHConfig="true" />
|
||||
</configs>
|
||||
</component>
|
||||
</project>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$/detecttracking" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
14
.idea/webServers.xml
generated
Normal file
14
.idea/webServers.xml
generated
Normal file
@ -0,0 +1,14 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="WebServers">
|
||||
<option name="servers">
|
||||
<webServer id="138b53d9-6b24-464a-a577-b1bc7faa69e4" name="大模型">
|
||||
<fileTransfer rootFolder="/home/ieemoo0169" accessType="SFTP" host="192.168.1.28" port="22" sshConfigId="2da10617-4199-45cd-9d9d-d0ef6db9be05" sshConfig="ieemoo0169@192.168.1.28:22 password">
|
||||
<advancedOptions>
|
||||
<advancedOptions dataProtectionLevel="Private" keepAliveTimeout="0" passiveMode="true" shareSSLContext="true" />
|
||||
</advancedOptions>
|
||||
</fileTransfer>
|
||||
</webServer>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
92
app.py
Normal file
92
app.py
Normal file
@ -0,0 +1,92 @@
|
||||
from flask import Flask, request
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from utils.model_init import initModel
|
||||
from detecttracking.stream_pipeline import stream_pipeline
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from llm.qwe_agent import get_product_description
|
||||
import pdb
|
||||
|
||||
app = Flask(__name__)
|
||||
# 配置日志处理器
|
||||
log_handler = TimedRotatingFileHandler('./log/aiReview.log', when='midnight', interval=90, backupCount=1)
|
||||
log_handler.suffix = "%Y-%m-%d"
|
||||
log_handler.setLevel(logging.INFO)
|
||||
log_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
log_handler.setFormatter(log_formatter)
|
||||
|
||||
# 获取根日志记录器并添加处理器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.INFO)
|
||||
root_logger.addHandler(log_handler)
|
||||
|
||||
# data = {
|
||||
# "goodsName": "优诺优丝黄桃果粒风味发酵乳",
|
||||
# "measureProperty": 0,
|
||||
# "qty": 1,
|
||||
# "price": 25.9,
|
||||
# "weight": 560, # 单位克
|
||||
# "barcode": "6931806801024",
|
||||
# # "video": "https://resources.cos.yimaogo.com/bl/3203600/54:78:c9:a4:8c:5e/video/411173317367614619680.mp4",
|
||||
# "video": "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
|
||||
# "goodsPic": "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
|
||||
# "measureUnit": "组",
|
||||
# "goodsSpec": "405g"
|
||||
# }
|
||||
|
||||
|
||||
def item_analysis(stream_dict):
|
||||
assert initModel.resnet_model is not None, "resnetModel is None"
|
||||
assert initModel.yolo_model is not None, "yoloModel is None"
|
||||
track_imgs = stream_pipeline(stream_dict, initModel.resnet_model, initModel.yolo_model)
|
||||
std_img = None
|
||||
if stream_dict['goodsPic'] is not None:
|
||||
response = requests.get(stream_dict['goodsPic'])
|
||||
std_img = Image.open(BytesIO(response.content))
|
||||
description_dict = get_product_description(std_img, track_imgs, initModel.qwen_model, initModel.processor)
|
||||
print(description_dict)
|
||||
return description_dict
|
||||
|
||||
|
||||
@app.route('/ai_review', methods=['POST'])
|
||||
def aiReview(): # put application's code here
|
||||
start = time.time()
|
||||
data = request.get_data()
|
||||
data = json.loads(data)
|
||||
video_url = data['video']
|
||||
goods_pic_url = data['goodsPic']
|
||||
v_reponse = requests.get(video_url)
|
||||
p_reponse = requests.get(goods_pic_url)
|
||||
if v_reponse.status_code == 200:
|
||||
logging.info(f'video:{video_url} download success')
|
||||
else:
|
||||
video_url = None
|
||||
logging.error(f'video:{video_url} download fail')
|
||||
if p_reponse.status_code == 200:
|
||||
logging.info(f'video:{goods_pic_url} download success')
|
||||
else:
|
||||
goods_pic_url = None
|
||||
logging.error(f'video:{goods_pic_url} download fail')
|
||||
|
||||
for key, value in data.items():
|
||||
if not value:
|
||||
data[key] = None
|
||||
logging.error(f'{key} is null')
|
||||
# stream_pipeline(data)
|
||||
item_analysis(data)
|
||||
end = time.time()
|
||||
logging.info(f'aiReview cost {end - start}s')
|
||||
return 0
|
||||
|
||||
|
||||
# def main():
|
||||
# item_analysis(data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# main()
|
||||
app.run('0.0.0.0', port=8060)
|
23
client.py
Normal file
23
client.py
Normal file
@ -0,0 +1,23 @@
|
||||
import requests
|
||||
import json
|
||||
def aiReviewClient():
|
||||
data = {
|
||||
"goodsName": "优诺优丝黄桃果粒风味发酵乳",
|
||||
"measureProperty": 0,
|
||||
"qty": 1,
|
||||
"price": 25.9,
|
||||
"weight": 560, # 单位克
|
||||
"barcode": "6931806801024",
|
||||
# "video": "https://resources.cos.yimaogo.com/bl/3203600/54:78:c9:a4:8c:5e/video/411173317367614619680.mp4",
|
||||
"video": "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
|
||||
"goodsPic": "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
|
||||
"measureUnit": "组",
|
||||
"goodsSpec": "405g"
|
||||
}
|
||||
url = "http://192.168.1.28:8060/ai_review"
|
||||
r = requests.post(url=url, data=json.dumps(data))
|
||||
print(r.text)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
aiReviewClient()
|
1
detecttracking
Submodule
1
detecttracking
Submodule
Submodule detecttracking added at 2feedd622d
161
llm/qwe_agent.py
Normal file
161
llm/qwe_agent.py
Normal file
@ -0,0 +1,161 @@
|
||||
# from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from detecttracking.stream_pipeline import stream_pipeline
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import torch
|
||||
import ast
|
||||
import requests
|
||||
|
||||
# # default: Load the model on the available device(s)
|
||||
# model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
# "Qwen/Qwen2-VL-7B-Instruct",
|
||||
# torch_dtype="auto",
|
||||
# device_map="auto"
|
||||
# )
|
||||
#
|
||||
# # default processer
|
||||
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
|
||||
|
||||
|
||||
def qwen_prompt(img_list, messages, model=None, processor=None):
|
||||
assert model is not None and processor is not None, "model and processor must be provided"
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=img_list,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=256)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, add_special_tokens=False, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
return output_text[0]
|
||||
|
||||
|
||||
def get_best_image(track_imgs, model, processor):
|
||||
img_frames = []
|
||||
for i in range(len(track_imgs)):
|
||||
content = {}
|
||||
content['type'] = 'image'
|
||||
content['min_pixels'] = 224 * 224
|
||||
content['max_pixels'] = 1280 * 28 * 28
|
||||
img_frames.append(content)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个在超市工作的chatbot,你现在需要帮助顾客找到一张质量最好的商品图像。一个好的商品图像需要满足以下条件: \
|
||||
1. 文字清晰且连贯。\
|
||||
2. 商品图案清晰可识别。\
|
||||
3. 商品可提取的描述信息丰富。\
|
||||
基于以上条件,从多张图像中筛选出最好的图像,然后以dict输出该图像的索引信息,key为'index'。"
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": img_frames,
|
||||
},
|
||||
]
|
||||
|
||||
output_text = qwen_prompt(track_imgs, messages, model, processor)
|
||||
output_dict = ast.literal_eval(output_text.strip('```python\n'))
|
||||
best_img = track_imgs[output_dict['index'] - 1]
|
||||
|
||||
return best_img
|
||||
|
||||
|
||||
def get_product_description(std_img, track_imgs, model, processor):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个在超市工作的chatbot,你现在需要提取商品的信息,信息需要按照以下python dict的格式输出: \
|
||||
{\
|
||||
'Text': 商品中提取出的文字信息, \
|
||||
'Color': 商品的颜色, \
|
||||
'Shape': 商品的形状, \
|
||||
'Material': 商品的材质, \
|
||||
'Category': 商品的类别, \
|
||||
'is_Same': 如果比对的两件商品的['Text', 'Color', 'Shape', 'Material', 'Category']属性中至少有3个相同则输出True,\
|
||||
否则输出False, \
|
||||
} \
|
||||
"
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"min_pixels": 224 * 224,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"min_pixels": 224 * 224,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "以python dict的形式输出第二张图像的比对信息。"
|
||||
}
|
||||
]
|
||||
best_img = get_best_image(track_imgs, model, processor)
|
||||
if std_img is not None:
|
||||
img_list = [std_img, best_img]
|
||||
else:
|
||||
img_list = [best_img, best_img]
|
||||
|
||||
output_text = qwen_prompt(img_list, messages, model, processor)
|
||||
contrast_pair = ast.literal_eval(output_text.strip('```python\n'))
|
||||
|
||||
return contrast_pair
|
||||
|
||||
|
||||
def item_analysis(stream_dict, model, processor):
|
||||
track_imgs = stream_pipeline(stream_dict)
|
||||
std_img = None
|
||||
if stream_dict['goodsPic'] is not None:
|
||||
response = requests.get(stream_dict['goodsPic'])
|
||||
std_img = Image.open(BytesIO(response.content))
|
||||
description_dict = get_product_description(std_img, track_imgs, model, processor)
|
||||
|
||||
return description_dict
|
||||
|
||||
|
||||
def main():
|
||||
# sample input dict
|
||||
stream_dict = {
|
||||
"goodsName": "优诺优丝黄桃果粒风味发酵乳",
|
||||
"measureProperty": 0,
|
||||
"qty": 1,
|
||||
"price": 25.9,
|
||||
"weight": 560, # 单位克
|
||||
"barcode": "6931806801024",
|
||||
"video": "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
|
||||
"goodsPic": "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
|
||||
"measureUnit": "组",
|
||||
"goodsSpec": "405g"
|
||||
}
|
||||
|
||||
result = item_analysis(stream_dict)
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
160
llm/qwe_agent_old.py
Normal file
160
llm/qwe_agent_old.py
Normal file
@ -0,0 +1,160 @@
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from detecttracking.stream_pipeline import stream_pipeline
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import torch
|
||||
import ast
|
||||
import requests
|
||||
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-7B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# default processer
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
|
||||
|
||||
|
||||
def qwen_prompt(img_list, messages):
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=img_list,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=256)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, add_special_tokens=False, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
return output_text[0]
|
||||
|
||||
|
||||
def get_best_image(track_imgs):
|
||||
img_frames = []
|
||||
for i in range(len(track_imgs)):
|
||||
content = {}
|
||||
content['type'] = 'image'
|
||||
content['min_pixels'] = 224 * 224
|
||||
content['max_pixels'] = 1280 * 28 * 28
|
||||
img_frames.append(content)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个在超市工作的chatbot,你现在需要帮助顾客找到一张质量最好的商品图像。一个好的商品图像需要满足以下条件: \
|
||||
1. 文字清晰且连贯。\
|
||||
2. 商品图案清晰可识别。\
|
||||
3. 商品可提取的描述信息丰富。\
|
||||
基于以上条件,从多张图像中筛选出最好的图像,然后以dict输出该图像的索引信息,key为'index'。"
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": img_frames,
|
||||
},
|
||||
]
|
||||
|
||||
output_text = qwen_prompt(track_imgs, messages)
|
||||
output_dict = ast.literal_eval(output_text.strip('```python\n'))
|
||||
best_img = track_imgs[output_dict['index'] - 1]
|
||||
|
||||
return best_img
|
||||
|
||||
|
||||
def get_product_description(std_img, track_imgs):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个在超市工作的chatbot,你现在需要提取商品的信息,信息需要按照以下python dict的格式输出: \
|
||||
{\
|
||||
'Text': 商品中提取出的文字信息, \
|
||||
'Color': 商品的颜色, \
|
||||
'Shape': 商品的形状, \
|
||||
'Material': 商品的材质, \
|
||||
'Category': 商品的类别, \
|
||||
'is_Same': 如果比对的两件商品的['Text', 'Color', 'Shape', 'Material', 'Category']属性中至少有3个相同则输出True,\
|
||||
否则输出False, \
|
||||
} \
|
||||
"
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"min_pixels": 224 * 224,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"min_pixels": 224 * 224,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "以python dict的形式输出第二张图像的比对信息。"
|
||||
}
|
||||
]
|
||||
best_img = get_best_image(track_imgs)
|
||||
if std_img is not None:
|
||||
img_list = [std_img, best_img]
|
||||
else:
|
||||
img_list = [best_img, best_img]
|
||||
|
||||
output_text = qwen_prompt(img_list, messages)
|
||||
contrast_pair = ast.literal_eval(output_text.strip('```python\n'))
|
||||
|
||||
return contrast_pair
|
||||
|
||||
|
||||
def item_analysis(stream_dict):
|
||||
track_imgs = stream_pipeline(stream_dict)
|
||||
std_img = None
|
||||
if stream_dict['goodsPic'] is not None:
|
||||
response = requests.get(stream_dict['goodsPic'])
|
||||
std_img = Image.open(BytesIO(response.content))
|
||||
description_dict = get_product_description(std_img, track_imgs)
|
||||
|
||||
return description_dict
|
||||
|
||||
|
||||
def main():
|
||||
# sample input dict
|
||||
stream_dict = {
|
||||
"goodsName": "优诺优丝黄桃果粒风味发酵乳",
|
||||
"measureProperty": 0,
|
||||
"qty": 1,
|
||||
"price": 25.9,
|
||||
"weight": 560, # 单位克
|
||||
"barcode": "6931806801024",
|
||||
"video": "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
|
||||
"goodsPic": "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
|
||||
"measureUnit": "组",
|
||||
"goodsSpec": "405g"
|
||||
}
|
||||
|
||||
result = item_analysis(stream_dict)
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
27
utils/config.py
Normal file
27
utils/config.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
def pad_to_square(img):
|
||||
w, h = img.size
|
||||
max_wh = max(w, h)
|
||||
padding = [0, 0, max_wh - w, max_wh - h] # (left, top, right, bottom)
|
||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||
|
||||
class Config:
|
||||
# network settings
|
||||
resnet_model = './detecttracking/contrast/feat_extract/checkpoints/resnet18_0515/v11.pth'
|
||||
yolo_model = './detecttracking/tracking/ckpts/best_cls10_0906.pt'
|
||||
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
|
||||
embedding_size = 256
|
||||
batch_size = 8
|
||||
img_size = 224
|
||||
test_transform = T.Compose([
|
||||
T.ToTensor(),
|
||||
T.Resize((img_size, img_size)),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[0.5], std=[0.5]),
|
||||
])
|
||||
|
||||
config = Config()
|
5
utils/load_model.py
Normal file
5
utils/load_model.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Load model directly
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
59
utils/model_init.py
Normal file
59
utils/model_init.py
Normal file
@ -0,0 +1,59 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import torch.nn as nn
|
||||
from utils.config import config as conf
|
||||
from collections import OrderedDict
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
from detecttracking.contrast.feat_extract.model import resnet18
|
||||
from detecttracking.utils.torch_utils import select_device
|
||||
from detecttracking.models.common import DetectMultiBackend
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
class InitModel:
|
||||
def __init__(self):
|
||||
self.data = ROOT / 'data/coco128.yaml'
|
||||
self.device = conf.device
|
||||
self.curpath = Path(__file__).resolve().parents[0]
|
||||
self.yolo_model = self.init_yolo_model()
|
||||
self.resnet_model = self.init_resnet_model()
|
||||
self.qwen_model, self.processor = self.init_qwen_mdoel()
|
||||
|
||||
def init_yolo_model(self):
|
||||
# device = select_device('')
|
||||
yolo_model = DetectMultiBackend(conf.yolo_model, device=self.device, dnn=False, data=self.data, fp16=False)
|
||||
return yolo_model
|
||||
|
||||
def init_resnet_model(self):
|
||||
# self.device = conf.device
|
||||
resnet_model = resnet18().to(self.device)
|
||||
# resnet_mod_path = os.path.join(self.curpath, conf.resnet_model)
|
||||
try:
|
||||
resnet_model.load_state_dict(torch.load(conf.resnet_model, map_location=self.device))
|
||||
except Exception as e:
|
||||
resnet_model = resnet_model.to(self.device)
|
||||
# resnet_model = resnet_model.to(torch.device('cpu'))
|
||||
checkpoint = torch.load(conf.resnet_model)
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint.items():
|
||||
name = k[7:] # remove "module."
|
||||
new_state_dict[name] = v
|
||||
resnet_model.load_state_dict(new_state_dict)
|
||||
resnet_model.eval()
|
||||
return resnet_model
|
||||
def init_qwen_mdoel(self):
|
||||
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-7B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
|
||||
return qwen_model, processor
|
||||
|
||||
initModel = InitModel()
|
Reference in New Issue
Block a user