guoqing bakeup

This commit is contained in:
王庆刚
2024-10-04 12:12:44 +08:00
parent 09e92d63b3
commit 390c5d2d94
37 changed files with 1409 additions and 219 deletions

View File

@ -0,0 +1,182 @@
# -*- coding: utf-8 -*-
"""
@author: LiChen
"""
import json
import os
import pickle
import numpy as np
import sys
sys.path.append(r"D:\DetectTracking\contrast")
from config import config as conf
# from img_data import library_imgs, temp_imgs, main_library_imgs, main_imgs_2
# from test_logic import initModel,getFeatureList
from model import resnet18
import torch
from PIL import Image
device = conf.device
def initModel():
model = resnet18().to(device)
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
model.eval()
return model
from PIL import Image
def convert_rgba_to_rgb(image_path, output_path=None):
"""
将给定路径的4通道PNG图像转换为3通道并保存到指定输出路径。
:param image_path: 输入图像的路径
:param output_path: 转换后的图像保存路径
"""
# 打开图像
img = Image.open(image_path)
# 转换图像模式从RGBA到RGB
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
if img.mode == 'RGBA':
# 转换为RGB模式
img_rgb = img.convert('RGB')
# 保存转换后的图像
img_rgb.save(image_path)
print(f"Image converted from RGBA to RGB and saved to {image_path}")
# else:
# # 如果已经是RGB或其他模式直接保存
# img.save(image_path)
# print(f"Image already in {img.mode} mode, saved to {image_path}")
def test_preprocess(images: list, actionModel=False) -> torch.Tensor:
res = []
for img in images:
try:
print(img)
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
res.append(im)
except:
continue
data = torch.stack(res)
return data
def inference(images, model, actionModel=False):
data = test_preprocess(images, actionModel)
if torch.cuda.is_available():
data = data.to(conf.device)
features = model(data)
return features
def group_image(images, batch=64) -> list:
"""Group image paths by batch size"""
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i:end])
return res
def getFeatureList(barList, imgList, model):
featList = [[] for i in range(len(barList))]
for index, feat in enumerate(imgList):
groups = group_image(feat)
for group in groups:
feat_tensor = inference(group, model)
for fe in feat_tensor:
if fe.device == 'cpu':
fe_np = fe.squeeze().detach().numpy()
else:
fe_np = fe.squeeze().detach().cpu().numpy()
featList[index].append(fe_np)
return featList
def get_files(folder):
file_dict = {}
cnt = 0
# barcode_list = ['6944649700065', '6924743915848', '6920459905012', '6901285991219', '6924882406269']
for root, dirs, files in os.walk(folder):
folder_name = os.path.basename(root) # 获取当前文件夹名称
print(folder_name)
# with open('main_barcode.txt','a') as f:
# f.write(folder_name + '\n')
# if len(dirs) == 0 and len(files) > 0 and folder_name in barcode_list: # 如果该文件夹没有子文件夹且有文件
if len(dirs) == 0 and len(files) > 0: # 如果该文件夹没有子文件夹且有文件
file_names = [os.path.join(root, file) for file in files] # 获取所有文件名
for file_name in file_names:
try:
convert_rgba_to_rgb(file_name)
except:
file_names.remove(file_name)
cnt += len(file_names)
file_dict[folder_name] = file_names
print(cnt)
return file_dict
def normalize(queFeatList):
for num1 in range(len(queFeatList)):
for num2 in range(len(queFeatList[num1])):
queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2])
return queFeatList
def img2feature(imgs_dict, model, barcode_flag):
if not len(imgs_dict) > 0:
raise ValueError("No imgs files provided")
queBarIdList = list(imgs_dict.keys())
queImgsList = list(imgs_dict.values())
queFeatList = getFeatureList(queBarIdList, queImgsList, model)
queFeatList = normalize(queFeatList)
return queBarIdList, queFeatList
def createFeatureDict(imgs_dict, model,
barcode_flag=False): ##imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
dicts_all = {}
value_list = []
barcode_list, imgs_list = img2feature(imgs_dict, model, barcode_flag=False)
for i in range(len(barcode_list)):
dicts = {}
imgs_list_ = []
for j in range(len(imgs_list[i])):
imgs_list_.append(imgs_list[i][j].tolist())
# with open('feature.txt','a') as f:
# f.write(str(imgs_list[i][j].tolist())+'\n')
dicts['key'] = barcode_list[i]
dicts['value'] = imgs_list_
value_list.append(dicts)
dicts_all['total'] = value_list
print('dicts_all', dicts_all)
with open('data_0909.json', 'a') as json_file:
json.dump(dicts_all, json_file)
def read_pkl_file(file_path):
with open(file_path, 'rb') as file:
data = pickle.load(file)
return data
if __name__ == "__main__":
###将图片名称和模型推理特征向量字典存为json文件
img_path = 'data/2000_train/base'
imgs_dict = get_files(img_path)
# print('imgs_dict', imgs_dict)
model = initModel()
createFeatureDict(imgs_dict, model, barcode_flag=False)
###=======================================================
# ## =========pkl转json================
# contents = read_pkl_file('dicts_list_1887.pkl')
# print(contents)
# with open('data_1887.json', 'w') as json_file:
# json.dump(contents, json_file)