from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor from accelerate import init_empty_weights, load_checkpoint_in_model from stream_pipeline import stream_pipeline from PIL import Image from io import BytesIO import torch import ast import requests import random # default: Load the model on the available device(s) model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto" ) # default processer processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2") def qwen_prompt(img_list, messages): # Preparation for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = processor( text=[text], images=img_list, padding=True, return_tensors="pt", ) inputs = inputs.to("cuda") # Inference: Generation of the output with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=256) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, add_special_tokens=False, skip_special_tokens=True, clean_up_tokenization_spaces=False ) del inputs del generated_ids del generated_ids_trimmed torch.cuda.empty_cache() return output_text[0] def get_best_image(track_imgs): if len(track_imgs) >= 5: track_imgs = random.sample(track_imgs, 5) img_frames = [] for i in range(len(track_imgs)): content = {} content['type'] = 'image' content['min_pixels'] = 224 * 224 content['max_pixels'] = 800 * 800 img_frames.append(content) messages = [ { "role": "system", "content": "你是一个在超市工作的chatbot,你现在需要帮助顾客找到一张质量最好的商品图像。一个好的商品图像需要满足以下条件: \ 1. 文字清晰且连贯。\ 2. 商品图案清晰可识别。\ 3. 商品可提取的描述信息丰富。\ 基于以上条件,从多张图像中筛选出最好的图像,然后以dict输出该图像的索引信息,key为'index'。" }, { "role": "system", "content": img_frames, }, ] output_text = qwen_prompt(track_imgs, messages) output_dict = ast.literal_eval(output_text.strip('```python\n')) if output_dict['index'] > len(track_imgs): output_dict['index'] = len(track_imgs) best_img = track_imgs[output_dict['index'] - 1] return best_img def get_product_description(std_img, track_imgs): messages = [ { "role": "system", "content": "你是一个在超市工作的chatbot,你现在需要提取图像中商品的信息,信息需要按照以下python dict的格式输出,如果 \ 信息模糊不清则输出'未知': \ { \ 'item1': {\ 'Text': 第一张图像中商品中提取出的文字信息, \ 'Color': 第一张图像中商品的颜色, \ 'Shape': 第一张图像中商品的形状, \ 'Material': 第一张图像中商品的材质, \ 'Category': 第一张图像中商品的类别, \ } \ 'item2': {\ 'Text': 第二张图像中商品中提取出的文字信息, \ 'Color': 第二张图像中商品的颜色, \ 'Shape': 第二张图像中商品的形状, \ 'Material': 第二张图像中商品的材质, \ 'Category': 第二张图像中商品的类别, \ } \ 'is_Same': 首先判断'Color'是否一致,如果不一致则返回False,如果一致则判断是否以上两个dict的['Text', 'Shape', 'Material', 'Category']key中至少有3个相同则输出True,\ 否则输出False。 \ } \ " }, { "role": "user", "content": [ { "type": "image", "min_pixels": 224 * 224, "max_pixels": 800 * 800, }, { "type": "image", "min_pixels": 224 * 224, "max_pixels": 800 * 800, }, ], }, # { # "role": "user", # "content": "以python dict的形式输出第二张图像的比对信息。" # "content": "输出一个list,list的内容包含两张图像提取出的dict信息。" # } ] best_img = get_best_image(track_imgs) if std_img is not None: img_list = [std_img, best_img] else: img_list = [best_img, best_img] output_text = qwen_prompt(img_list, messages) contrast_pair = ast.literal_eval(output_text.strip('```python\n')) return contrast_pair def item_analysis(stream_dict): track_imgs = stream_pipeline(stream_dict) if len(track_imgs) == 0: return {} std_img = None if stream_dict['goodsPic'] is not None: # response = requests.get(stream_dict['goodsPic']) # std_img = Image.open(BytesIO(response.content)) std_img = Image.open(stream_dict['goodsPic']).convert("RGB") description_dict = get_product_description(std_img, track_imgs) return description_dict def main(): # sample input dict stream_dict = { "goodsName" : "优诺优丝黄桃果粒风味发酵乳", "measureProperty" : 0, "qty" : 1, "price" : 25.9, "weight": 560, # 单位克 "barcode": "6931806801024", "video" : "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4", "goodsPic" : "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg", "measureUnit" : "组", "goodsSpec" : "405g" } result = item_analysis(stream_dict) print(result) if __name__ == "__main__": main()