# -*- coding: utf-8 -*- """ @author: LiChen """ import json import os import pickle import numpy as np import sys sys.path.append(r"D:\DetectTracking\contrast") from config import config as conf # from img_data import library_imgs, temp_imgs, main_library_imgs, main_imgs_2 # from test_logic import initModel,getFeatureList from model import resnet18 import torch from PIL import Image device = conf.device def initModel(): model = resnet18().to(device) model.load_state_dict(torch.load(conf.test_model, map_location=conf.device)) model.eval() return model from PIL import Image def convert_rgba_to_rgb(image_path, output_path=None): """ 将给定路径的4通道PNG图像转换为3通道,并保存到指定输出路径。 :param image_path: 输入图像的路径 :param output_path: 转换后的图像保存路径 """ # 打开图像 img = Image.open(image_path) # 转换图像模式从RGBA到RGB # .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像 if img.mode == 'RGBA': # 转换为RGB模式 img_rgb = img.convert('RGB') # 保存转换后的图像 img_rgb.save(image_path) print(f"Image converted from RGBA to RGB and saved to {image_path}") # else: # # 如果已经是RGB或其他模式,直接保存 # img.save(image_path) # print(f"Image already in {img.mode} mode, saved to {image_path}") def test_preprocess(images: list, actionModel=False) -> torch.Tensor: res = [] for img in images: try: print(img) im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img)) res.append(im) except: continue data = torch.stack(res) return data def inference(images, model, actionModel=False): data = test_preprocess(images, actionModel) if torch.cuda.is_available(): data = data.to(conf.device) features = model(data) return features def group_image(images, batch=64) -> list: """Group image paths by batch size""" size = len(images) res = [] for i in range(0, size, batch): end = min(batch + i, size) res.append(images[i:end]) return res def getFeatureList(barList, imgList, model): featList = [[] for i in range(len(barList))] for index, feat in enumerate(imgList): groups = group_image(feat) for group in groups: feat_tensor = inference(group, model) for fe in feat_tensor: if fe.device == 'cpu': fe_np = fe.squeeze().detach().numpy() else: fe_np = fe.squeeze().detach().cpu().numpy() featList[index].append(fe_np) return featList def get_files(folder): file_dict = {} cnt = 0 # barcode_list = ['6944649700065', '6924743915848', '6920459905012', '6901285991219', '6924882406269'] for root, dirs, files in os.walk(folder): folder_name = os.path.basename(root) # 获取当前文件夹名称 print(folder_name) # with open('main_barcode.txt','a') as f: # f.write(folder_name + '\n') # if len(dirs) == 0 and len(files) > 0 and folder_name in barcode_list: # 如果该文件夹没有子文件夹且有文件 if len(dirs) == 0 and len(files) > 0: # 如果该文件夹没有子文件夹且有文件 file_names = [os.path.join(root, file) for file in files] # 获取所有文件名 for file_name in file_names: try: convert_rgba_to_rgb(file_name) except: file_names.remove(file_name) cnt += len(file_names) file_dict[folder_name] = file_names print(cnt) return file_dict def normalize(queFeatList): for num1 in range(len(queFeatList)): for num2 in range(len(queFeatList[num1])): queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2]) return queFeatList def img2feature(imgs_dict, model, barcode_flag): if not len(imgs_dict) > 0: raise ValueError("No imgs files provided") queBarIdList = list(imgs_dict.keys()) queImgsList = list(imgs_dict.values()) queFeatList = getFeatureList(queBarIdList, queImgsList, model) queFeatList = normalize(queFeatList) return queBarIdList, queFeatList def createFeatureDict(imgs_dict, model, barcode_flag=False): ##imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]} dicts_all = {} value_list = [] barcode_list, imgs_list = img2feature(imgs_dict, model, barcode_flag=False) for i in range(len(barcode_list)): dicts = {} imgs_list_ = [] for j in range(len(imgs_list[i])): imgs_list_.append(imgs_list[i][j].tolist()) # with open('feature.txt','a') as f: # f.write(str(imgs_list[i][j].tolist())+'\n') dicts['key'] = barcode_list[i] dicts['value'] = imgs_list_ value_list.append(dicts) dicts_all['total'] = value_list print('dicts_all', dicts_all) with open('data_0909.json', 'a') as json_file: json.dump(dicts_all, json_file) def read_pkl_file(file_path): with open(file_path, 'rb') as file: data = pickle.load(file) return data if __name__ == "__main__": ###将图片名称和模型推理特征向量字典存为json文件 img_path = 'data/2000_train/base' imgs_dict = get_files(img_path) # print('imgs_dict', imgs_dict) model = initModel() createFeatureDict(imgs_dict, model, barcode_flag=False) ###======================================================= # ## =========pkl转json================ # contents = read_pkl_file('dicts_list_1887.pkl') # print(contents) # with open('data_1887.json', 'w') as json_file: # json.dump(contents, json_file)