diff --git a/Qwen_agent.py b/Qwen_agent.py new file mode 100644 index 0000000..6ebcc2e --- /dev/null +++ b/Qwen_agent.py @@ -0,0 +1,147 @@ +from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor +from stream_pipeline import stream_pipeline +from PIL import Image +import torch +import ast +import requests +from io import BytesIO + +# default: Load the model on the available device(s) +model = Qwen2VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-VL-7B-Instruct", + # torch_dtype=torch.float16, + torch_dtype="auto", + device_map="auto" +) + +# default processer +processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2") + +def qwen_prompt(img_list, messages): + # Preparation for inference + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + inputs = processor( + text=[text], + images=img_list, + # videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to("cuda") + + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=256) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = processor.batch_decode( + generated_ids_trimmed, add_special_tokens=False, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + return output_text[0] + +def get_best_image(track_imgs): + img_frames = [] + for i in range(len(track_imgs)): + content = {} + content['type'] = 'image' + content['min_pixels'] = 224 * 224 + content['max_pixels'] = 1280 * 28 * 28 + img_frames.append(content) + + messages = [ + { + "role": "system", + "content": "你是一个在超市工作的chatbot,你现在需要帮助顾客找到一张质量最好的商品图像。一个好的商品图像需要满足以下条件: \ + 1. 文字清晰且连贯。\ + 2. 商品图案清晰可识别。\ + 3. 商品可提取的描述信息丰富。\ + 基于以上条件,从多张图像中筛选出最好的图像,然后以dict输出该图像的索引信息,key为'index'。" + }, + { + "role": "system", + "content": img_frames, + }, + ] + + output_text = qwen_prompt(track_imgs, messages) + output_dict = ast.literal_eval(output_text.strip('```python\n')) + best_img = track_imgs[output_dict['index'] - 1] + + return best_img + +def get_product_description(std_img, track_imgs): + messages = [ + { + "role": "system", + "content": "你是一个在超市工作的chatbot,你现在需要提取商品的信息,信息需要按照以下python dict的格式输出: \ + {\ + 'Text': 商品中提取出的文字信息, \ + 'Color': 商品的颜色, \ + 'Shape': 商品的形状, \ + 'Material': 商品的材质, \ + 'Category': 商品的类别, \ + 'is_Same': 如果比对的两件商品的['Text', 'Color', 'Shape', 'Material', 'Category']属性中至少有3个相同则输出True,\ + 否则输出False, \ + } \ + " + }, + { + "role": "system", + "content": [ + { + "type": "image", + "min_pixels": 224 * 224, + "max_pixels": 1280 * 28 * 28, + }, + ], + }, + { + "role": "system", + "content": [ + { + "type": "image", + "min_pixels": 224 * 224, + "max_pixels": 1280 * 28 * 28, + }, + ], + }, + { + "role": "user", + "content": "以python dict的形式输出第二张图像的比对信息。" + } + ] + best_img = get_best_image(track_imgs) + img_list = [std_img, best_img] + + output_text = qwen_prompt(img_list, messages) + contrast_pair = ast.literal_eval(output_text.strip('```python\n')) + + return contrast_pair + +def main(): + # sample input dict + stream_dict = { + "goodsName" : "优诺优丝黄桃果粒风味发酵乳", + "measureProperty" : 0, + "qty" : 1, + "price" : 25.9, + "weight": 560, # 单位克 + "barcode": "6931806801024", + "video" : "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4", + "goodsPic" : "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg", + "measureUnit" : "组", + "goodsSpec" : "405g" + } + + track_imgs = stream_pipeline(stream_dict) + # std_img = Image.open(stream_dict['goodsPic']) + response = requests.get(stream_dict['goodsPic']) + std_img = Image.open(BytesIO(response.content)) + description_dict = get_product_description(std_img, track_imgs) + print(description_dict) + +if __name__ == "__main__": + main() \ No newline at end of file