161 lines
5.3 KiB
Python
161 lines
5.3 KiB
Python
# from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||
from detecttracking.stream_pipeline import stream_pipeline
|
||
from PIL import Image
|
||
from io import BytesIO
|
||
import torch
|
||
import ast
|
||
import requests
|
||
|
||
# # default: Load the model on the available device(s)
|
||
# model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||
# "Qwen/Qwen2-VL-7B-Instruct",
|
||
# torch_dtype="auto",
|
||
# device_map="auto"
|
||
# )
|
||
#
|
||
# # default processer
|
||
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2")
|
||
|
||
|
||
def qwen_prompt(img_list, messages, model=None, processor=None):
|
||
assert model is not None and processor is not None, "model and processor must be provided"
|
||
# Preparation for inference
|
||
text = processor.apply_chat_template(
|
||
messages, tokenize=False, add_generation_prompt=True
|
||
)
|
||
inputs = processor(
|
||
text=[text],
|
||
images=img_list,
|
||
padding=True,
|
||
return_tensors="pt",
|
||
)
|
||
inputs = inputs.to("cuda")
|
||
|
||
# Inference: Generation of the output
|
||
generated_ids = model.generate(**inputs, max_new_tokens=256)
|
||
generated_ids_trimmed = [
|
||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||
]
|
||
output_text = processor.batch_decode(
|
||
generated_ids_trimmed, add_special_tokens=False, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||
)
|
||
|
||
return output_text[0]
|
||
|
||
|
||
def get_best_image(track_imgs, model, processor):
|
||
img_frames = []
|
||
for i in range(len(track_imgs)):
|
||
content = {}
|
||
content['type'] = 'image'
|
||
content['min_pixels'] = 224 * 224
|
||
content['max_pixels'] = 1280 * 28 * 28
|
||
img_frames.append(content)
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个在超市工作的chatbot,你现在需要帮助顾客找到一张质量最好的商品图像。一个好的商品图像需要满足以下条件: \
|
||
1. 文字清晰且连贯。\
|
||
2. 商品图案清晰可识别。\
|
||
3. 商品可提取的描述信息丰富。\
|
||
基于以上条件,从多张图像中筛选出最好的图像,然后以dict输出该图像的索引信息,key为'index'。"
|
||
},
|
||
{
|
||
"role": "system",
|
||
"content": img_frames,
|
||
},
|
||
]
|
||
|
||
output_text = qwen_prompt(track_imgs, messages, model, processor)
|
||
output_dict = ast.literal_eval(output_text.strip('```python\n'))
|
||
best_img = track_imgs[output_dict['index'] - 1]
|
||
|
||
return best_img
|
||
|
||
|
||
def get_product_description(std_img, track_imgs, model, processor):
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个在超市工作的chatbot,你现在需要提取商品的信息,信息需要按照以下python dict的格式输出: \
|
||
{\
|
||
'Text': 商品中提取出的文字信息, \
|
||
'Color': 商品的颜色, \
|
||
'Shape': 商品的形状, \
|
||
'Material': 商品的材质, \
|
||
'Category': 商品的类别, \
|
||
'is_Same': 如果比对的两件商品的['Text', 'Color', 'Shape', 'Material', 'Category']属性中至少有3个相同则输出True,\
|
||
否则输出False, \
|
||
} \
|
||
"
|
||
},
|
||
{
|
||
"role": "system",
|
||
"content": [
|
||
{
|
||
"type": "image",
|
||
"min_pixels": 224 * 224,
|
||
"max_pixels": 1280 * 28 * 28,
|
||
},
|
||
],
|
||
},
|
||
{
|
||
"role": "system",
|
||
"content": [
|
||
{
|
||
"type": "image",
|
||
"min_pixels": 224 * 224,
|
||
"max_pixels": 1280 * 28 * 28,
|
||
},
|
||
],
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": "以python dict的形式输出第二张图像的比对信息。"
|
||
}
|
||
]
|
||
best_img = get_best_image(track_imgs, model, processor)
|
||
if std_img is not None:
|
||
img_list = [std_img, best_img]
|
||
else:
|
||
img_list = [best_img, best_img]
|
||
|
||
output_text = qwen_prompt(img_list, messages, model, processor)
|
||
contrast_pair = ast.literal_eval(output_text.strip('```python\n'))
|
||
|
||
return contrast_pair
|
||
|
||
|
||
def item_analysis(stream_dict, model, processor):
|
||
track_imgs = stream_pipeline(stream_dict)
|
||
std_img = None
|
||
if stream_dict['goodsPic'] is not None:
|
||
response = requests.get(stream_dict['goodsPic'])
|
||
std_img = Image.open(BytesIO(response.content))
|
||
description_dict = get_product_description(std_img, track_imgs, model, processor)
|
||
|
||
return description_dict
|
||
|
||
|
||
def main():
|
||
# sample input dict
|
||
stream_dict = {
|
||
"goodsName": "优诺优丝黄桃果粒风味发酵乳",
|
||
"measureProperty": 0,
|
||
"qty": 1,
|
||
"price": 25.9,
|
||
"weight": 560, # 单位克
|
||
"barcode": "6931806801024",
|
||
"video": "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
|
||
"goodsPic": "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
|
||
"measureUnit": "组",
|
||
"goodsSpec": "405g"
|
||
}
|
||
|
||
result = item_analysis(stream_dict)
|
||
print(result)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |