Files
ieemoo-ai-review/llm/qwe_agent_old.py
2025-01-22 11:47:02 +08:00

160 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from detecttracking.stream_pipeline import stream_pipeline
from PIL import Image
from io import BytesIO
import torch
import ast
import requests
# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype="auto",
device_map="auto"
)
# default processer
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
def qwen_prompt(img_list, messages):
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[text],
images=img_list,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, add_special_tokens=False, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
def get_best_image(track_imgs):
img_frames = []
for i in range(len(track_imgs)):
content = {}
content['type'] = 'image'
content['min_pixels'] = 224 * 224
content['max_pixels'] = 1280 * 28 * 28
img_frames.append(content)
messages = [
{
"role": "system",
"content": "你是一个在超市工作的chatbot你现在需要帮助顾客找到一张质量最好的商品图像。一个好的商品图像需要满足以下条件: \
1. 文字清晰且连贯。\
2. 商品图案清晰可识别。\
3. 商品可提取的描述信息丰富。\
基于以上条件从多张图像中筛选出最好的图像然后以dict输出该图像的索引信息key为'index'"
},
{
"role": "system",
"content": img_frames,
},
]
output_text = qwen_prompt(track_imgs, messages)
output_dict = ast.literal_eval(output_text.strip('```python\n'))
best_img = track_imgs[output_dict['index'] - 1]
return best_img
def get_product_description(std_img, track_imgs):
messages = [
{
"role": "system",
"content": "你是一个在超市工作的chatbot你现在需要提取商品的信息信息需要按照以下python dict的格式输出: \
{\
'Text': 商品中提取出的文字信息, \
'Color': 商品的颜色, \
'Shape': 商品的形状, \
'Material': 商品的材质, \
'Category': 商品的类别, \
'is_Same': 如果比对的两件商品的['Text', 'Color', 'Shape', 'Material', 'Category']属性中至少有3个相同则输出True\
否则输出False, \
} \
"
},
{
"role": "system",
"content": [
{
"type": "image",
"min_pixels": 224 * 224,
"max_pixels": 1280 * 28 * 28,
},
],
},
{
"role": "system",
"content": [
{
"type": "image",
"min_pixels": 224 * 224,
"max_pixels": 1280 * 28 * 28,
},
],
},
{
"role": "user",
"content": "以python dict的形式输出第二张图像的比对信息。"
}
]
best_img = get_best_image(track_imgs)
if std_img is not None:
img_list = [std_img, best_img]
else:
img_list = [best_img, best_img]
output_text = qwen_prompt(img_list, messages)
contrast_pair = ast.literal_eval(output_text.strip('```python\n'))
return contrast_pair
def item_analysis(stream_dict):
track_imgs = stream_pipeline(stream_dict)
std_img = None
if stream_dict['goodsPic'] is not None:
response = requests.get(stream_dict['goodsPic'])
std_img = Image.open(BytesIO(response.content))
description_dict = get_product_description(std_img, track_imgs)
return description_dict
def main():
# sample input dict
stream_dict = {
"goodsName": "优诺优丝黄桃果粒风味发酵乳",
"measureProperty": 0,
"qty": 1,
"price": 25.9,
"weight": 560, # 单位克
"barcode": "6931806801024",
"video": "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
"goodsPic": "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
"measureUnit": "",
"goodsSpec": "405g"
}
result = item_analysis(stream_dict)
print(result)
if __name__ == "__main__":
main()