183 lines
5.8 KiB
Python
183 lines
5.8 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
|
||
@author: LiChen
|
||
"""
|
||
import json
|
||
import os
|
||
import pickle
|
||
import numpy as np
|
||
|
||
import sys
|
||
sys.path.append(r"D:\DetectTracking\contrast")
|
||
|
||
from config import config as conf
|
||
# from img_data import library_imgs, temp_imgs, main_library_imgs, main_imgs_2
|
||
# from test_logic import initModel,getFeatureList
|
||
from model import resnet18
|
||
import torch
|
||
from PIL import Image
|
||
|
||
device = conf.device
|
||
|
||
def initModel():
|
||
model = resnet18().to(device)
|
||
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
|
||
model.eval()
|
||
return model
|
||
|
||
|
||
from PIL import Image
|
||
|
||
|
||
def convert_rgba_to_rgb(image_path, output_path=None):
|
||
"""
|
||
将给定路径的4通道PNG图像转换为3通道,并保存到指定输出路径。
|
||
|
||
:param image_path: 输入图像的路径
|
||
:param output_path: 转换后的图像保存路径
|
||
"""
|
||
# 打开图像
|
||
img = Image.open(image_path)
|
||
# 转换图像模式从RGBA到RGB
|
||
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||
if img.mode == 'RGBA':
|
||
# 转换为RGB模式
|
||
img_rgb = img.convert('RGB')
|
||
# 保存转换后的图像
|
||
img_rgb.save(image_path)
|
||
print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||
# else:
|
||
# # 如果已经是RGB或其他模式,直接保存
|
||
# img.save(image_path)
|
||
# print(f"Image already in {img.mode} mode, saved to {image_path}")
|
||
|
||
|
||
def test_preprocess(images: list, actionModel=False) -> torch.Tensor:
|
||
res = []
|
||
for img in images:
|
||
try:
|
||
print(img)
|
||
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
|
||
res.append(im)
|
||
except:
|
||
continue
|
||
data = torch.stack(res)
|
||
return data
|
||
|
||
|
||
def inference(images, model, actionModel=False):
|
||
data = test_preprocess(images, actionModel)
|
||
if torch.cuda.is_available():
|
||
data = data.to(conf.device)
|
||
features = model(data)
|
||
return features
|
||
|
||
|
||
def group_image(images, batch=64) -> list:
|
||
"""Group image paths by batch size"""
|
||
size = len(images)
|
||
res = []
|
||
for i in range(0, size, batch):
|
||
end = min(batch + i, size)
|
||
res.append(images[i:end])
|
||
return res
|
||
|
||
def getFeatureList(barList, imgList, model):
|
||
featList = [[] for i in range(len(barList))]
|
||
for index, feat in enumerate(imgList):
|
||
groups = group_image(feat)
|
||
for group in groups:
|
||
feat_tensor = inference(group, model)
|
||
for fe in feat_tensor:
|
||
if fe.device == 'cpu':
|
||
fe_np = fe.squeeze().detach().numpy()
|
||
else:
|
||
fe_np = fe.squeeze().detach().cpu().numpy()
|
||
featList[index].append(fe_np)
|
||
return featList
|
||
|
||
def get_files(folder):
|
||
file_dict = {}
|
||
cnt = 0
|
||
# barcode_list = ['6944649700065', '6924743915848', '6920459905012', '6901285991219', '6924882406269']
|
||
for root, dirs, files in os.walk(folder):
|
||
|
||
folder_name = os.path.basename(root) # 获取当前文件夹名称
|
||
print(folder_name)
|
||
# with open('main_barcode.txt','a') as f:
|
||
# f.write(folder_name + '\n')
|
||
|
||
# if len(dirs) == 0 and len(files) > 0 and folder_name in barcode_list: # 如果该文件夹没有子文件夹且有文件
|
||
if len(dirs) == 0 and len(files) > 0: # 如果该文件夹没有子文件夹且有文件
|
||
|
||
file_names = [os.path.join(root, file) for file in files] # 获取所有文件名
|
||
for file_name in file_names:
|
||
try:
|
||
convert_rgba_to_rgb(file_name)
|
||
except:
|
||
file_names.remove(file_name)
|
||
cnt += len(file_names)
|
||
file_dict[folder_name] = file_names
|
||
print(cnt)
|
||
|
||
return file_dict
|
||
|
||
def normalize(queFeatList):
|
||
for num1 in range(len(queFeatList)):
|
||
for num2 in range(len(queFeatList[num1])):
|
||
queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2])
|
||
return queFeatList
|
||
def img2feature(imgs_dict, model, barcode_flag):
|
||
if not len(imgs_dict) > 0:
|
||
raise ValueError("No imgs files provided")
|
||
queBarIdList = list(imgs_dict.keys())
|
||
queImgsList = list(imgs_dict.values())
|
||
queFeatList = getFeatureList(queBarIdList, queImgsList, model)
|
||
queFeatList = normalize(queFeatList)
|
||
return queBarIdList, queFeatList
|
||
|
||
|
||
def createFeatureDict(imgs_dict, model,
|
||
barcode_flag=False): ##imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||
dicts_all = {}
|
||
value_list = []
|
||
barcode_list, imgs_list = img2feature(imgs_dict, model, barcode_flag=False)
|
||
for i in range(len(barcode_list)):
|
||
dicts = {}
|
||
|
||
imgs_list_ = []
|
||
for j in range(len(imgs_list[i])):
|
||
imgs_list_.append(imgs_list[i][j].tolist())
|
||
# with open('feature.txt','a') as f:
|
||
# f.write(str(imgs_list[i][j].tolist())+'\n')
|
||
|
||
dicts['key'] = barcode_list[i]
|
||
dicts['value'] = imgs_list_
|
||
value_list.append(dicts)
|
||
dicts_all['total'] = value_list
|
||
print('dicts_all', dicts_all)
|
||
with open('data_0909.json', 'a') as json_file:
|
||
json.dump(dicts_all, json_file)
|
||
|
||
|
||
def read_pkl_file(file_path):
|
||
with open(file_path, 'rb') as file:
|
||
data = pickle.load(file)
|
||
return data
|
||
|
||
|
||
if __name__ == "__main__":
|
||
###将图片名称和模型推理特征向量字典存为json文件
|
||
img_path = 'data/2000_train/base'
|
||
imgs_dict = get_files(img_path)
|
||
# print('imgs_dict', imgs_dict)
|
||
model = initModel()
|
||
createFeatureDict(imgs_dict, model, barcode_flag=False)
|
||
###=======================================================
|
||
# ## =========pkl转json================
|
||
# contents = read_pkl_file('dicts_list_1887.pkl')
|
||
# print(contents)
|
||
# with open('data_1887.json', 'w') as json_file:
|
||
# json.dump(contents, json_file)
|