llm_agent

This commit is contained in:
2025-04-15 09:26:24 +08:00
parent ad850221c5
commit 9400ae904a
25 changed files with 52650 additions and 39 deletions

View File

@ -1,15 +1,18 @@
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from accelerate import init_empty_weights, load_checkpoint_in_model
from stream_pipeline import stream_pipeline
from PIL import Image
from io import BytesIO
import torch
import ast
import requests
import random
# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
@ -30,23 +33,31 @@ def qwen_prompt(img_list, messages):
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=256)
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, add_special_tokens=False, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
del inputs
del generated_ids
del generated_ids_trimmed
torch.cuda.empty_cache()
return output_text[0]
def get_best_image(track_imgs):
if len(track_imgs) >= 5:
track_imgs = random.sample(track_imgs, 5)
img_frames = []
for i in range(len(track_imgs)):
content = {}
content['type'] = 'image'
content['min_pixels'] = 224 * 224
content['max_pixels'] = 1280 * 28 * 28
content['max_pixels'] = 800 * 800
img_frames.append(content)
messages = [
@ -66,6 +77,8 @@ def get_best_image(track_imgs):
output_text = qwen_prompt(track_imgs, messages)
output_dict = ast.literal_eval(output_text.strip('```python\n'))
if output_dict['index'] > len(track_imgs):
output_dict['index'] = len(track_imgs)
best_img = track_imgs[output_dict['index'] - 1]
return best_img
@ -74,42 +87,48 @@ def get_product_description(std_img, track_imgs):
messages = [
{
"role": "system",
"content": "你是一个在超市工作的chatbot你现在需要提取商品的信息信息需要按照以下python dict的格式输出: \
{\
'Text': 商品中提取出的文字信息, \
'Color': 商品的颜色, \
'Shape': 商品的形状, \
'Material': 商品的材质, \
'Category': 商品的类别, \
'is_Same': 如果比对的两件商品的['Text', 'Color', 'Shape', 'Material', 'Category']属性中至少有3个相同则输出True\
否则输出False, \
"content": "你是一个在超市工作的chatbot你现在需要提取图像中商品的信息信息需要按照以下python dict的格式输出,如果 \
信息模糊不清则输出'未知': \
{ \
'item1': {\
'Text': 第一张图像中商品中提取出的文字信息, \
'Color': 第一张图像中商品的颜色, \
'Shape': 第一张图像中商品的形状, \
'Material': 第一张图像中商品的材质, \
'Category': 第一张图像中商品的类别, \
} \
'item2': {\
'Text': 第二张图像中商品中提取出的文字信息, \
'Color': 第二张图像中商品的颜色, \
'Shape': 第二张图像中商品的形状, \
'Material': 第二张图像中商品的材质, \
'Category': 第二张图像中商品的类别, \
} \
'is_Same': 首先判断'Color'是否一致如果不一致则返回False如果一致则判断是否以上两个dict的['Text', 'Shape', 'Material', 'Category']key中至少有3个相同则输出True\
否则输出False。 \
} \
"
},
{
"role": "system",
"content": [
{
"type": "image",
"min_pixels": 224 * 224,
"max_pixels": 1280 * 28 * 28,
},
],
},
{
"role": "system",
"content": [
{
"type": "image",
"min_pixels": 224 * 224,
"max_pixels": 1280 * 28 * 28,
},
],
},
{
"role": "user",
"content": "以python dict的形式输出第二张图像的比对信息。"
}
"content": [
{
"type": "image",
"min_pixels": 224 * 224,
"max_pixels": 800 * 800,
},
{
"type": "image",
"min_pixels": 224 * 224,
"max_pixels": 800 * 800,
},
],
},
# {
# "role": "user",
# "content": "以python dict的形式输出第二张图像的比对信息。"
# "content": "输出一个listlist的内容包含两张图像提取出的dict信息。"
# }
]
best_img = get_best_image(track_imgs)
if std_img is not None:
@ -124,10 +143,13 @@ def get_product_description(std_img, track_imgs):
def item_analysis(stream_dict):
track_imgs = stream_pipeline(stream_dict)
if len(track_imgs) == 0:
return {}
std_img = None
if stream_dict['goodsPic'] is not None:
response = requests.get(stream_dict['goodsPic'])
std_img = Image.open(BytesIO(response.content))
# response = requests.get(stream_dict['goodsPic'])
# std_img = Image.open(BytesIO(response.content))
std_img = Image.open(stream_dict['goodsPic']).convert("RGB")
description_dict = get_product_description(std_img, track_imgs)
return description_dict