guoqing bakeup
This commit is contained in:
182
contrast/utils/write_feature_json.py
Normal file
182
contrast/utils/write_feature_json.py
Normal file
@ -0,0 +1,182 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
|
||||
@author: LiChen
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
import sys
|
||||
sys.path.append(r"D:\DetectTracking\contrast")
|
||||
|
||||
from config import config as conf
|
||||
# from img_data import library_imgs, temp_imgs, main_library_imgs, main_imgs_2
|
||||
# from test_logic import initModel,getFeatureList
|
||||
from model import resnet18
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
device = conf.device
|
||||
|
||||
def initModel():
|
||||
model = resnet18().to(device)
|
||||
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def convert_rgba_to_rgb(image_path, output_path=None):
|
||||
"""
|
||||
将给定路径的4通道PNG图像转换为3通道,并保存到指定输出路径。
|
||||
|
||||
:param image_path: 输入图像的路径
|
||||
:param output_path: 转换后的图像保存路径
|
||||
"""
|
||||
# 打开图像
|
||||
img = Image.open(image_path)
|
||||
# 转换图像模式从RGBA到RGB
|
||||
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||||
if img.mode == 'RGBA':
|
||||
# 转换为RGB模式
|
||||
img_rgb = img.convert('RGB')
|
||||
# 保存转换后的图像
|
||||
img_rgb.save(image_path)
|
||||
print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||||
# else:
|
||||
# # 如果已经是RGB或其他模式,直接保存
|
||||
# img.save(image_path)
|
||||
# print(f"Image already in {img.mode} mode, saved to {image_path}")
|
||||
|
||||
|
||||
def test_preprocess(images: list, actionModel=False) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
try:
|
||||
print(img)
|
||||
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
|
||||
res.append(im)
|
||||
except:
|
||||
continue
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
|
||||
def inference(images, model, actionModel=False):
|
||||
data = test_preprocess(images, actionModel)
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(conf.device)
|
||||
features = model(data)
|
||||
return features
|
||||
|
||||
|
||||
def group_image(images, batch=64) -> list:
|
||||
"""Group image paths by batch size"""
|
||||
size = len(images)
|
||||
res = []
|
||||
for i in range(0, size, batch):
|
||||
end = min(batch + i, size)
|
||||
res.append(images[i:end])
|
||||
return res
|
||||
|
||||
def getFeatureList(barList, imgList, model):
|
||||
featList = [[] for i in range(len(barList))]
|
||||
for index, feat in enumerate(imgList):
|
||||
groups = group_image(feat)
|
||||
for group in groups:
|
||||
feat_tensor = inference(group, model)
|
||||
for fe in feat_tensor:
|
||||
if fe.device == 'cpu':
|
||||
fe_np = fe.squeeze().detach().numpy()
|
||||
else:
|
||||
fe_np = fe.squeeze().detach().cpu().numpy()
|
||||
featList[index].append(fe_np)
|
||||
return featList
|
||||
|
||||
def get_files(folder):
|
||||
file_dict = {}
|
||||
cnt = 0
|
||||
# barcode_list = ['6944649700065', '6924743915848', '6920459905012', '6901285991219', '6924882406269']
|
||||
for root, dirs, files in os.walk(folder):
|
||||
|
||||
folder_name = os.path.basename(root) # 获取当前文件夹名称
|
||||
print(folder_name)
|
||||
# with open('main_barcode.txt','a') as f:
|
||||
# f.write(folder_name + '\n')
|
||||
|
||||
# if len(dirs) == 0 and len(files) > 0 and folder_name in barcode_list: # 如果该文件夹没有子文件夹且有文件
|
||||
if len(dirs) == 0 and len(files) > 0: # 如果该文件夹没有子文件夹且有文件
|
||||
|
||||
file_names = [os.path.join(root, file) for file in files] # 获取所有文件名
|
||||
for file_name in file_names:
|
||||
try:
|
||||
convert_rgba_to_rgb(file_name)
|
||||
except:
|
||||
file_names.remove(file_name)
|
||||
cnt += len(file_names)
|
||||
file_dict[folder_name] = file_names
|
||||
print(cnt)
|
||||
|
||||
return file_dict
|
||||
|
||||
def normalize(queFeatList):
|
||||
for num1 in range(len(queFeatList)):
|
||||
for num2 in range(len(queFeatList[num1])):
|
||||
queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2])
|
||||
return queFeatList
|
||||
def img2feature(imgs_dict, model, barcode_flag):
|
||||
if not len(imgs_dict) > 0:
|
||||
raise ValueError("No imgs files provided")
|
||||
queBarIdList = list(imgs_dict.keys())
|
||||
queImgsList = list(imgs_dict.values())
|
||||
queFeatList = getFeatureList(queBarIdList, queImgsList, model)
|
||||
queFeatList = normalize(queFeatList)
|
||||
return queBarIdList, queFeatList
|
||||
|
||||
|
||||
def createFeatureDict(imgs_dict, model,
|
||||
barcode_flag=False): ##imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||||
dicts_all = {}
|
||||
value_list = []
|
||||
barcode_list, imgs_list = img2feature(imgs_dict, model, barcode_flag=False)
|
||||
for i in range(len(barcode_list)):
|
||||
dicts = {}
|
||||
|
||||
imgs_list_ = []
|
||||
for j in range(len(imgs_list[i])):
|
||||
imgs_list_.append(imgs_list[i][j].tolist())
|
||||
# with open('feature.txt','a') as f:
|
||||
# f.write(str(imgs_list[i][j].tolist())+'\n')
|
||||
|
||||
dicts['key'] = barcode_list[i]
|
||||
dicts['value'] = imgs_list_
|
||||
value_list.append(dicts)
|
||||
dicts_all['total'] = value_list
|
||||
print('dicts_all', dicts_all)
|
||||
with open('data_0909.json', 'a') as json_file:
|
||||
json.dump(dicts_all, json_file)
|
||||
|
||||
|
||||
def read_pkl_file(file_path):
|
||||
with open(file_path, 'rb') as file:
|
||||
data = pickle.load(file)
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
###将图片名称和模型推理特征向量字典存为json文件
|
||||
img_path = 'data/2000_train/base'
|
||||
imgs_dict = get_files(img_path)
|
||||
# print('imgs_dict', imgs_dict)
|
||||
model = initModel()
|
||||
createFeatureDict(imgs_dict, model, barcode_flag=False)
|
||||
###=======================================================
|
||||
# ## =========pkl转json================
|
||||
# contents = read_pkl_file('dicts_list_1887.pkl')
|
||||
# print(contents)
|
||||
# with open('data_1887.json', 'w') as json_file:
|
||||
# json.dump(contents, json_file)
|
Reference in New Issue
Block a user