This commit is contained in:
lee
2025-06-11 15:23:50 +08:00
commit 37ecef40f7
79 changed files with 26981 additions and 0 deletions

0
tools/__init__.py Normal file
View File

68
tools/dataset.py Normal file
View File

@ -0,0 +1,68 @@
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms.functional as F
import torchvision.transforms as T
# from config import config as conf
import torch
def pad_to_square(img):
w, h = img.size
max_wh = max(w, h)
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
return F.pad(img, padding, fill=0, padding_mode='constant')
def get_transform(cfg):
train_transform = T.Compose([
T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
# T.RandomCrop(img_size * 4 // 5),
T.RandomHorizontalFlip(p=cfg['transform']['RandomHorizontalFlip']),
T.RandomRotation(cfg['transform']['RandomRotation']),
T.ColorJitter(brightness=cfg['transform']['ColorJitter']),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
])
test_transform = T.Compose([
# T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
])
return train_transform, test_transform
def load_data(training=True, cfg=None):
train_transform, test_transform = get_transform(cfg)
if training:
dataroot = cfg['data']['data_train_dir']
transform = train_transform
# transform = conf.train_transform
batch_size = cfg['data']['train_batch_size']
else:
dataroot = cfg['data']['data_val_dir']
# transform = conf.test_transform
transform = test_transform
batch_size = cfg['data']['val_batch_size']
data = ImageFolder(dataroot, transform=transform)
class_num = len(data.classes)
loader = DataLoader(data,
batch_size=batch_size,
shuffle=True,
pin_memory=cfg['base']['pin_memory'],
num_workers=cfg['data']['num_workers'],
drop_last=True)
return loader, class_num
# def load_gift_data(action):
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# return train_dataset, val_dataset, test_dataset

10
tools/dataset.txt Normal file
View File

@ -0,0 +1,10 @@
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg

112
tools/fp32comparefp16.py Normal file
View File

@ -0,0 +1,112 @@
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from test_ori import group_image, init_model, featurize
from config import config as conf
import json
import os.path as osp
def compare_fp16_fp32(values_pf16, values_pf32, dataTest):
if dataTest:
norm_values_pf16 = torch.norm(values_pf16, p=2)
norm_values_pf32 = torch.norm(values_pf32, p=2)
euclidean_distance = torch.norm(norm_values_pf16 - norm_values_pf32, p=2)
print(f"欧几里得距离: {euclidean_distance}")
cosine_sim = torch.dot(values_pf16.float(), values_pf32) / (norm_values_pf16 * norm_values_pf32)
print(f"余弦相似度: {cosine_sim}")
else:
pass
def cosin_metric(x1, x2, fp32=True):
if fp32:
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
else:
x1_fp16 = x1.astype(np.float16)
x2_fp16 = x2.astype(np.float16)
# print(type(x1))
# pdb.set_trace()
return np.dot(x1_fp16, x2_fp16) / (np.linalg.norm(x1_fp16) * np.linalg.norm(x2_fp16))
def deal_group_pair(pairList1, pairList2):
one_similarity_fp16, one_similarity_fp32, allsimilarity_fp32, allsimilarity_fp16 = [], [], [], []
for pair1 in pairList1:
for pair2 in pairList2:
# similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
one_similarity_fp32.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), True))
one_similarity_fp16.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), False))
allsimilarity_fp32.append(one_similarity_fp32)
allsimilarity_fp16.append(one_similarity_fp16)
one_similarity_fp16, one_similarity_fp32 = [], []
return np.array(allsimilarity_fp32), np.array(allsimilarity_fp16)
def compute_group_accuracy(content_list_read, model):
allSimilarity, allLabel = [], []
Same, Cross = [], []
flag_same = True
flag_diff = True
for data_loaded in content_list_read:
one_group_list = []
try:
if (flag_same and str(data_loaded[-1]) == '1') or (flag_diff and str(data_loaded[-1]) == '0'):
for i in range(2):
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
group = group_image(images, conf.test_batch_size)
d = featurize(group[0], conf.test_transform, model, conf.device)
one_group_list.append(d.values())
if str(data_loaded[-1]) == '1':
flag_same = False
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
print('fp32 same-- >', allsimilarity_fp32)
print('fp16 same-- >', allsimilarity_fp16)
else:
flag_diff = False
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
print('fp32 diff-- >', allsimilarity_fp32)
print('fp16 diff-- >', allsimilarity_fp16)
except Exception as e:
continue
# print(allSimilarity)
# print(allLabel)
return allSimilarity, allLabel
def get_feature_list(imgPth):
imgs = get_files(imgPth)
group = group_image(imgs, conf.test_batch_size)
model = init_model()
model.eval()
fe = featurize(group[0], conf.test_transform, model, conf.device)
return fe
def get_files(imgPth):
imgsList = []
for img in os.walk(imgPth):
for img_name in img[2]:
img_path = os.sep.join([img[0], img_name])
imgsList.append(img_path)
return imgsList
import pdb
def compare(imgPth, group=False):
model = init_model()
model.eval()
if not group:
values_pf16, values_pf32 = [], []
fe = get_feature_list(imgPth)
# pdb.set_trace()
values_pf32 += [value.cpu() for value in fe.values()]
values_pf16 += [value.cpu().half() for value in fe.values()]
for value_pf16, value_pf32 in zip(values_pf16, values_pf32):
compare_fp16_fp32(value_pf16, value_pf32, dataTest=True)
else:
filename = conf.test_group_json
with open(filename, 'r', encoding='utf-8') as file:
content_list_read = json.load(file)
compute_group_accuracy(content_list_read, model)
pass
if __name__ == '__main__':
imgPth = './data/test/inner/3701375401900'
compare(imgPth)

369
tools/gift_assessment.py Normal file
View File

@ -0,0 +1,369 @@
import os
import pdb
import shutil
import sys
sys.path.append('../model')
import matplotlib.pyplot as plt
import numpy as np
from model.mlp import Net2, Net3, Net4
from model import resnet18
import torch
from gift_data_pretreatment import getFeatureList
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def init_model(pkl_flag):
res_pth = r"../checkpoints/resnet18_1009/best.pth"
if pkl_flag:
gift_pth = r'../checkpoints/gift_model/action2/gift_v11.pth'
gift_model = Net3(pretrained=True, num_classes=1)
gift_model.load_state_dict(torch.load(gift_pth))
else:
gift_pth = r'../checkpoints/gift_model/action3/best.pth'
gift_model = Net4('resnet18', True, True) # 预训练模型
try:
print('>>multiple_cards load pre model <<')
gift_model.load_state_dict({k.replace('module.', ''): v for k, v in
torch.load(gift_pth,
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')).items()})
except Exception as e:
print('>> load pre model <<')
gift_model.load_state_dict(torch.load(gift_pth,
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
res_model = resnet18()
res_model.load_state_dict({k.replace('module.', ''): v for k, v in
torch.load(res_pth, map_location=torch.device(device)).items()})
return res_model, gift_model
def showHist(nongifts, gifts):
# Same = filtered_data[:, 1].astype(np.float32)
# Cross = filtered_data[:, 2].astype(np.float32)
fig, axs = plt.subplots(2, 1)
axs[0].hist(nongifts, bins=50, edgecolor='blue')
axs[0].set_xlim([-0.1, 1])
axs[0].set_title('nongifts')
axs[1].hist(gifts, bins=50, edgecolor='green')
axs[1].set_xlim([-0.1, 1])
axs[1].set_title('gifts')
# plt.savefig('plot.png')
plt.show()
def calculate_precision_recall(nongift, gift, points):
precision, recall = [], []
for point in points:
TP = np.sum(gift > point)
FN = np.sum(gift < point)
FP = np.sum(nongift > point)
TN = np.sum(nongift < point)
if TP == 0:
precision.append(0)
recall.append(0)
else:
precision.append(TP / (TP + FP))
recall.append(TP / (TP + FN))
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
if point == 0.5:
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
return precision, recall
def showgrid(all_prec, all_recall, points):
plt.figure(figsize=(10, 6))
plt.plot(points[:-1], all_prec[:-1], color='blue', label='precision')
plt.plot(points[:-1], all_recall[:-1], color='red', label='recall')
plt.legend()
plt.xlabel('threshold')
# plt.ylabel('Similarity')
plt.grid(True, linestyle='--', alpha=0.5)
# plt.savefig('grid.png')
plt.show()
plt.close()
pass
def discriminate_action(roots): # 判断加购还是退购
pth = os.sep.join([roots, 'process.data'])
with open(pth, 'r') as f:
lines = f.readlines()
for line in lines:
content = line.strip()
if 'weightValue' in content:
# print(content.split(":")[-1].split(',')[0])
if int(content.split(":")[-1].split(',')[0]) > 0:
return 'add'
else:
return 'return'
def median(lst):
sorted_lst = sorted(lst)
n = len(sorted_lst)
if n % 2 == 1:
# 如果列表长度是奇数,中位数是中间的那个元素
return sorted_lst[n // 2]
else:
# 如果列表长度是偶数,中位数是中间两个元素的平均值
mid1 = sorted_lst[(n // 2) - 1]
mid2 = sorted_lst[n // 2]
return (mid1 + mid2) / 2
def get_special_data(data, p):
# print(data)
length = len(data)
if length > 5:
if p == 'max':
return max(data[:round(length * 0.5)])
elif p == 'average':
return sum(data[:round(length * 0.5)]) / len(data[:round(length * 0.5)])
elif p == 'median':
return median(data[:round(length * 0.5)])
else:
return sum(data) / len(data)
def read_data_file(pth):
result = []
with open(pth, 'r') as data_file:
lines = data_file.readlines()
for line in lines:
if line.split(':')[0] == 'free_gift__result':
if '0_tracking_output.data' in pth:
result = line.split(':')[1].split(',')[:-1]
else:
result = line.split(':')[1].split(',')[:-2]
result = [float(i) for i in result]
return result
def get_tracking_data(pth):
result = []
with open(pth, 'r') as data_file:
lines = data_file.readlines()
for line in lines:
if len(line.split(',')) == 65:
result.append([float(item) for item in line.split(',')[:-1]])
return result
def clean_reurn_data(pth):
for roots, dirs, files in os.walk(pth):
# print(roots, dirs, files)
if len(dirs) == 0:
flag = discriminate_action(roots)
if flag == 'return':
shutil.rmtree(roots)
def get_gift_files(pth): # 测试后直接分析测试结果文件
add_special_output_0, return_special_output_0, return_special_output_1, add_special_output_1 = [], [], [], []
add_tracking_output_0, return_tracking_output_0, add_tracking_output_1, return_tracking_output_1 = [], [], [], []
for roots, dirs, files in os.walk(pth):
# print(roots, dirs, files)
if len(dirs) == 0:
flag = discriminate_action(roots)
for file in files:
if file == '0_tracking_output.data':
result = read_data_file(os.path.join(roots, file))
if not len(result) == 0:
if flag == 'add':
add_special_output_0.append(get_special_data(result, 'average')) # 加购后摄
else:
return_special_output_0.append(get_special_data(result, 'average')) # 退购后摄
if flag == 'add':
add_tracking_output_0 += read_data_file(os.path.join(roots, file))
else:
return_tracking_output_0 += read_data_file(os.path.join(roots, file))
elif file == '1_tracking_output.data':
result = read_data_file(os.path.join(roots, file))
if not len(result) == 0:
if flag == 'add':
add_special_output_1.append(get_special_data(result, 'average')) # 加购前摄
else:
return_special_output_1.append(get_special_data(result, 'average')) # 退购前摄
if flag == 'add':
add_tracking_output_1 += read_data_file(os.path.join(roots, file))
else:
return_tracking_output_1 += read_data_file(os.path.join(roots, file))
comprehensive_dicts = {"add_special_output_0": add_special_output_0,
"return_special_output_0": return_special_output_0,
"add_tracking_output_0": add_tracking_output_0,
"return_tracking_output_0": return_tracking_output_0,
"add_special_output_1": add_special_output_1,
"return_special_output_1": return_special_output_1,
"add_tracking_output_1": add_tracking_output_1,
"return_tracking_output_1": return_tracking_output_1,
}
# print(tracking_output_0, tracking_output_1)
showHist(np.array(comprehensive_dicts['add_tracking_output_0']),
np.array(comprehensive_dicts['add_tracking_output_1']))
# showHist(np.array(comprehensive_dicts['add_special_output_0']),
# np.array(comprehensive_dicts['add_special_output_1']))
return comprehensive_dicts
def get_feature_array(img_pth_lists, res_model, gift_model, pkl_flag=True):
features_np = []
if pkl_flag:
for img_lists in img_pth_lists:
# print(img_lists)
fe_nps = getFeatureList(None, img_lists, res_model)
# fe_nps.squeeze()
try:
fe_nps = fe_nps[0][:, 256:]
except Exception as e:
print(e)
continue
fe_nps = torch.from_numpy(fe_nps)
fe_nps = fe_nps.view(fe_nps.shape[0], 64, 13, 13)
if len(fe_nps):
fe_np = gift_model(fe_nps)
fe_np = np.squeeze(fe_np.detach().numpy())
features_np.append(fe_np)
else:
for img_lists in img_pth_lists:
fe_nps = getFeatureList(None, img_lists, gift_model)
if len(fe_nps) > 0:
fe_nps = np.concatenate(fe_nps)
features_np.append(fe_nps)
return features_np
import pickle
def create_gift_subimg_np(data_pth, pkl_flag):
gift_array_pth = os.path.join(data_pth, 'gift.pkl')
nongift_array_pth = os.path.join(data_pth, 'nongift.pkl')
res_model, gift_model = init_model(pkl_flag)
res_model = res_model.eval()
gift_model = gift_model.eval()
gift_img_pth_list, gift_lists, nongift_img_pth_list, nongift_lists = [], [], [], []
for root, dirs, files in os.walk(data_pth):
if ('commodity' in root and 'subimg' in root):
print("commodity >> {}".format(root))
for file in files:
nongift_img_pth_list.append(os.sep.join([root, file]))
nongift_lists.append(nongift_img_pth_list)
nongift_img_pth_list = []
elif ('Havegift' in root and 'subimg' in root):
print("Havegift >> {}".format(root))
for file in files:
gift_img_pth_list.append(os.sep.join([root, file]))
gift_lists.append(gift_img_pth_list)
gift_img_pth_list = []
nongift = get_feature_array(nongift_lists, res_model, gift_model, pkl_flag)
gift = get_feature_array(gift_lists, res_model, gift_model, pkl_flag)
with open(nongift_array_pth, 'wb') as file:
pickle.dump(nongift, file)
with open(gift_array_pth, 'wb') as file:
pickle.dump(gift, file)
def top_25_percent_mean(arr):
# 1. 对数组进行从高到低排序
sorted_arr = np.sort(arr)[::-1]
# 2. 计算数组长度的25%
top_25_percent_length = int(len(sorted_arr) * 0.25)
# 3. 取排序后数组的前25%元素
top_25_percent = sorted_arr[:top_25_percent_length]
# 4. 计算这些元素的平均值
mean_value = np.mean(top_25_percent)
return top_25_percent
def assess_gift_subimg(data_pth, pkl_flag=False): # 分析分割后子图,
points = (np.linspace(1, 100, 100)) / 100
gift_pkl_pth = os.path.join(data_pth, 'gift.pkl')
nongift_pkl_pth = os.path.join(data_pth, 'nongift.pkl')
if not os.path.exists(gift_pkl_pth):
create_gift_subimg_np(data_pth, pkl_flag)
with open(nongift_pkl_pth, 'rb') as f:
nongift = pickle.load(f)
with open(gift_pkl_pth, 'rb') as f:
gift = pickle.load(f)
# showHist(nongift.flatten(), gift.flatten())
'''
一分位均值
'''
nongift_mean = [np.mean(top_25_percent_mean(items)) for items in nongift]
gift_mean = [np.mean(top_25_percent_mean(items)) for items in gift]
'''
中位数
'''
# nongift_mean = [np.median(items) for items in nongift]
# gift_mean = [np.median(items) for items in gift] # 平均值
'''
全部结果
'''
# nongifts = [items for items in nongift]
# gifts = [items for items in gift]
# showHist(nongifts, gifts)
'''
平均值
'''
# nongift_mean = [np.mean(items) for items in nongift]
# gift_mean = [np.mean(items) for items in gift]
showHist(np.array(nongift_mean), np.array(gift_mean)) # 最大值
precision, recall = calculate_precision_recall(np.array(nongift_mean),
np.array(gift_mean),
points)
showgrid(precision, recall, points)
def get_comprehensive_dicts(data_pth):
gift_pth = r'../checkpoints/gift_model/action2/best.pth'
g_model = Net3(pretrained=True, num_classes=1)
g_model.load_state_dict(torch.load(gift_pth))
g_model.eval()
result = []
file_name = ['0_tracking_output.data',
'1_tracking_output.data']
for root, dirs, files in os.walk(data_pth):
if not len(dirs):
for file in files:
if file in file_name:
print(os.path.join(root, file))
result += get_tracking_data(os.path.join(root, file))
result = torch.from_numpy(np.array(result))
input = result.view(result.shape[0], 64, 1, 1)
input = input.to('cpu')
input = input.to(torch.float32)
ji = g_model(input)
print(ji)
if __name__ == '__main__':
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\赠品\\images'
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\没有赠品的商品\\images'
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\同样的商品没有捆绑赠品\\images'
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241213赠品测试数据\\赠品'
# pth = r'C:\Users\HP\Desktop\zengpin\1227'
# get_gift_files(pth)
# 根据子图分析结果
pth = r'D:\Project\contrast_nettest\data\gift_test'
assess_gift_subimg(pth)
# 根据完整数据集分析结果
# pth = r'C:\Users\HP\Desktop\zengpin\1231'
# get_comprehensive_dicts(pth)
# 删除退购视频
# pth = r'C:\Users\HP\Desktop\gift_test\20241213\非赠品'
# clean_reurn_data(pth)

View File

@ -0,0 +1,92 @@
import torch
from config import config as conf
from PIL import Image
import numpy as np
def convert_rgba_to_rgb(image_path, output_path=None):
"""
将给定路径的4通道PNG图像转换为3通道并保存到指定输出路径。
:param image_path: 输入图像的路径
:param output_path: 转换后的图像保存路径
"""
# 打开图像
img = Image.open(image_path)
# 转换图像模式从RGBA到RGB
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
if img.mode == 'RGBA':
# 转换为RGB模式
img_rgb = img.convert('RGB')
# 保存转换后的图像
img_rgb.save(image_path)
# print(f"Image converted from RGBA to RGB and saved to {image_path}")
# else:
# # 如果已经是RGB或其他模式直接保存
# img.save(image_path)
# print(f"Image already in {img.mode} mode, saved to {image_path}")
def test_preprocess(images: list, actionModel=False) -> torch.Tensor:
res = []
for img in images:
try:
# print(img)
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
res.append(im)
except:
continue
data = torch.stack(res)
return data
def inference(images, model, actionModel=False):
data = test_preprocess(images, actionModel)
if torch.cuda.is_available():
data = data.to(conf.device)
features = model(data)
return features
def group_image(images, batch=64) -> list:
"""Group image paths by batch size"""
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i:end])
return res
def normalize(queFeatList):
for num1 in range(len(queFeatList)):
for num2 in range(len(queFeatList[num1])):
queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2])
return queFeatList
def getFeatureList(barList, imgList, model):
# featList = [[] for i in range(len(barList))]
# for index, feat in enumerate(imgList):
fe_nps = []
groups = group_image(imgList)
for group in groups:
feat_tensor = inference(group, model)
# for fe in feat_tensor:
if feat_tensor.device == 'cpu':
fe_np = feat_tensor.squeeze().detach().numpy()
# fe_np = fe_np[:, 256:]
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
else:
fe_np = feat_tensor.squeeze().detach().cpu().numpy()
# fe_np = fe_np[:, 256:]
# fe_np = fe_np[256:]
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
# fe_np = fe_np.reshape(1, fe_np.shape[0], 1, 1)
# print(fe_np)
fe_nps.append(fe_np)
# if fe_nps:
# merged_fe_np = np.concatenate(fe_nps, axis=0)
# else:
# merged_fe_np = np.array([]) #
# fe_list = normalize(fe_nps)
return fe_nps

118
tools/json_contrast.py Normal file
View File

@ -0,0 +1,118 @@
import json
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import random
def showHist(same, cross):
Same = np.array(same)
Cross = np.array(cross)
fig, axs = plt.subplots(2, 1)
axs[0].hist(Same, bins=50, edgecolor='black')
axs[0].set_xlim([-0.1, 1])
axs[0].set_title('Same Barcode')
axs[1].hist(Cross, bins=50, edgecolor='black')
axs[1].set_xlim([-0.1, 1])
axs[1].set_title('Cross Barcode')
# plt.savefig('plot.png')
plt.show()
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
plt.figure(figsize=(10, 6))
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
plt.plot(x, Correct, color='m', label='Correct(TN+TP)/(TPFN+TNFP)')
plt.legend()
plt.xlabel('threshold')
# plt.ylabel('Similarity')
plt.grid(True, linestyle='--', alpha=0.5)
plt.savefig('grid.png')
plt.show()
plt.close()
def compute_accuracy_recall(score, labels):
th = 0.1
squence = np.linspace(-1, 1, num=50)
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
Same = score[:len(score) // 2]
Cross = score[len(score) // 2:]
for th in squence:
t_score = (score > th)
t_labels = (labels == 1)
TP = np.sum(np.logical_and(t_score, t_labels))
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
f_score = (score < th)
f_labels = (labels == 0)
TN = np.sum(np.logical_and(f_score, f_labels))
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
recall.append(0 if TP == 0 else TP / (TP + FN))
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
showHist(Same, Cross)
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
def get_similarity(features1, features2, n, m):
features1 = np.array(features1)
features2 = np.array(features2)
all_similarity = []
for feature1 in features1:
for feature2 in features2:
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
all_similarity.append(similarity)
test_similarity = np.array(all_similarity)
np_all_array = np.array(all_similarity).reshape(len(features1), len(features2))
if n == 5 and m == 5:
print(all_similarity)
return np.mean(np_all_array), all_similarity
# return sum(all_similarity)/len(all_similarity), all_similarity
# return max(all_similarity), all_similarity
def deal_similarity(dicts):
all_similarity = []
similarity = []
same_barcode, diff_barcode = [], []
for n, (key1, value1) in enumerate(dicts.items()):
print('key1 >> {}'.format(key1))
for m, (key2, value2) in enumerate(dicts.items()):
print('key1 >> {} key2 >> {} peidui {}{}'.format(key1, key2, n, m))
max_similarity, some_similarity = get_similarity(value1, value2, n, m)
similarity.append(max_similarity)
if key1 == key2:
same_barcode += some_similarity
else:
diff_barcode += some_similarity
all_similarity.append(similarity)
similarity = []
all_similarity = np.array(all_similarity)
random.shuffle(diff_barcode)
same_list = [1] * len(same_barcode)
diff_list = [0] * len(same_barcode)
all_list = same_list + diff_list
all_score = same_barcode + diff_barcode[:len(same_barcode)]
compute_accuracy_recall(np.array(all_score), np.array(all_list))
print(all_similarity.shape)
with open('../search_library/data_zhanting.json', 'r') as file:
data = json.load(file)
dicts = {}
for dict in data['total']:
key = dict['key']
value = dict['value']
dicts[key] = value
deal_similarity(dicts)

View File

@ -0,0 +1,63 @@
import pdb
import torch
import torch.nn as nn
from model import resnet18
from config import config as conf
from collections import OrderedDict
import cv2
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
# 定义模型
if model_name == 'resnet18':
model = resnet18(scale=0.75)
print('model_name >>> {}'.format(model_name))
if conf.multiple_cards:
model = model.to(torch.device('cpu'))
checkpoint = torch.load(pretrained_weights)
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = k[7:] # remove "module."
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# try:
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# except Exception as e:
# print(e)
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
# model = nn.DataParallel(model).to(conf.device)
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
# 转换为ONNX
if model_name == 'gift_type2':
input_shape = [1, 64, 13, 13]
elif model_name == 'gift_type3':
input_shape = [1, 3, 224, 224]
else:
# 假设输入数据的大小是通道数*高度*宽度例如3*224*224
input_shape = [1, 3, 224, 224]
img = cv2.imread('./dog_224x224.jpg')
output_file = pretrained_weights.replace('pth', 'onnx')
# 导出模型
torch.onnx.export(model,
torch.randn(input_shape),
output_file,
verbose=True,
input_names=['input'],
output_names=['output']) ##, optset_version=12
model.eval()
trace_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
trace_model.save(output_file.replace('.onnx', '.pt'))
print(f"Model exported to {output_file}")
if __name__ == '__main__':
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断gift3_type3指resnet原图计算推理
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')

View File

@ -0,0 +1,186 @@
import os
import pdb
import urllib
import traceback
import time
import sys
import numpy as np
import cv2
from config import config as conf
from rknn.api import RKNN
import config
# ONNX_MODEL = 'resnet50v2.onnx'
# RKNN_MODEL = 'resnet50v2.rknn'
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
# ONNX_MODEL = 'v3_small_0424.onnx'
# RKNN_MODEL = 'v3_small_0424.rknn'
def show_outputs(outputs):
# print('***************outputs', outputs)
output = outputs[0][0]
# print('len(outputs)',len(output), output)
output_sorted = sorted(output, reverse=True)
top5_str = 'resnet50v2\n-----TOP 5-----\n'
for i in range(5):
value = output_sorted[i]
index = np.where(output == value)
for j in range(len(index)):
if (i + j) >= 5:
break
if value > 0:
topi = '{}: {}\n'.format(index[j], value)
else:
topi = '-1: 0.0\n'
top5_str += topi
# pdb.set_trace()
print(top5_str)
def readable_speed(speed):
speed_bytes = float(speed)
speed_kbytes = speed_bytes / 1024
if speed_kbytes > 1024:
speed_mbytes = speed_kbytes / 1024
if speed_mbytes > 1024:
speed_gbytes = speed_mbytes / 1024
return "{:.2f} GB/s".format(speed_gbytes)
else:
return "{:.2f} MB/s".format(speed_mbytes)
else:
return "{:.2f} KB/s".format(speed_kbytes)
def show_progress(blocknum, blocksize, totalsize):
speed = (blocknum * blocksize) / (time.time() - start_time)
speed_str = " Speed: {}".format(readable_speed(speed))
recv_size = blocknum * blocksize
f = sys.stdout
progress = (recv_size / totalsize)
progress_str = "{:.2f}%".format(progress * 100)
n = round(progress * 50)
s = ('#' * n).ljust(50, '-')
f.write(progress_str.ljust(8, ' ') + '[' + s + ']' + speed_str)
f.flush()
f.write('\r\n')
if __name__ == '__main__':
# Create RKNN object
rknn = RKNN(verbose=True)
# If resnet50v2 does not exist, download it.
# Download address:
# https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx
if not os.path.exists(ONNX_MODEL):
print('--> Download {}'.format(ONNX_MODEL))
url = 'https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx'
download_file = ONNX_MODEL
try:
start_time = time.time()
urllib.request.urlretrieve(url, download_file, show_progress)
except:
print('Download {} failed.'.format(download_file))
print(traceback.format_exc())
exit(-1)
print('done')
# pre-process config
print('--> config model')
# rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.82, 58.82, 58.82])
rknn.config(
mean_values=[[127.5, 127.5, 127.5]],
std_values=[[127.5, 127.5, 127.5]],
target_platform='rk3588',
model_pruning=False,
compress_weight=False,
single_core_mode=True)
# rknn.config(
# mean_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
# std_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
# target_platform='rk3588', # 设置目标平台
# # quantize_dtype='int8',
# # quantize_algo='normal',
# # output_optimize=False,
# # output_format='rknnb'
# )
print('done')
# Load model
print('--> Loading model')
ret = rknn.load_onnx(model=ONNX_MODEL)
if ret != 0:
print('Load model failed!')
exit(ret)
print('done')
# Build model
print('--> Building model')
ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
if ret != 0:
print('Build model failed!')
exit(ret)
print('done')
# Export rknn model
print('--> Export rknn model')
ret = rknn.export_rknn(RKNN_MODEL)
if ret != 0:
print('Export rknn model failed!')
exit(ret)
print('done')
# Set inputs
img = cv2.imread('./dog_224x224.jpg')
# img = cv2.imread('./data/gift_test/Havegift/20241213-161415-cb8e0762-f376-45d1-8f36-7dc070990fa5/subimg/cam1_9_tid2_fid(18, 33250169482).png')
# print('img', img)
# with open('pixel_values.txt', 'w') as file:
# for y in range(img.shape[0]):
# for x in range(img.shape[1]):
# b, g, r = img[y, x]
# file.write(f'{r},{g},{b}\n')
# img = cv2.imread('./810115161912_810115161912_20240131-145622_0da14e4d-a3da-499f-b512-2d4168ab1c87_front_addGood_70f75407b7ae_29_01.jpg')
img = cv2.resize(img, (224, 224))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = conf.test_transform(img)
# img = img.numpy()
# img = img.transpose(1, 2, 0)
# Init runtime environment
print('--> Init runtime environment')
ret = rknn.init_runtime()
# ret = rknn.init_runtime('rk3588')
if ret != 0:
print('Init runtime environment failed!')
exit(ret)
print('done')
# Inference
print('--> Running model')
T1 = time.time()
outputs = rknn.inference(inputs=[img])
# outputs = rknn.inference(inputs=img)
T2 = time.time()
print('消耗时间 >>> {}'.format(T2 - T1))
with open('result_0415_128.txt', 'a') as f:
f.write(str(outputs))
# pdb.set_trace()
print('***outputs', outputs)
np.save('./onnx_resnet50v2_0.npy', outputs[0])
x = outputs[0]
output = np.exp(x) / np.sum(np.exp(x))
outputs = [output]
show_outputs(outputs)
print('done')
rknn.release()

233
tools/operate_usearch.py Normal file
View File

@ -0,0 +1,233 @@
import os
import numpy as np
from usearch.index import Index
import json
import struct
def create_index():
index = Index(
ndim=256,
metric='cos',
# dtype='f32',
dtype='f16',
connectivity=32,
expansion_add=40, # 128,
expansion_search=10, # 64,
multi=True
)
return index
def compare_feature(features1, features2, model='1'):
"""
:param model 比对策略
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
'1':带对比的所有相似度的均值
'2':比对1:1的最大值
:param feature1:
:param feature2:
:return:
"""
similarity_group, similarity_groups = [], []
if model == '0':
for feature1 in features1:
for feature2 in features2[0]:
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
similarity_group.append(similarity)
similarity_groups.append(max(similarity_group))
similarity_group = []
return sum(similarity_groups) / len(similarity_groups)
elif model == '1':
feature2 = features2[0]
for feature1 in features1:
for num in range(len(feature2)):
similarity = np.dot(feature1, feature2[num]) / (
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
similarity_group.append(similarity)
similarity_groups.append(sum(similarity_group) / len(similarity_group))
similarity_group = []
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
if len(similarity_groups) == 0:
return -1
return sum(similarity_groups) / len(similarity_groups)
elif model == '2':
feature2 = features2[0]
for feature1 in features1:
for num in range(len(feature2)):
similarity = np.dot(feature1, feature2[num]) / (
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
similarity_group.append(similarity)
return max(similarity_group)
def get_barcode_feature(data):
barcode = data['key']
features = data['value']
return [barcode] * len(features), features
def analysis_file(file_path):
"""
:param file_path:
:return:
"""
barcodes, features = [], []
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for dic in data['total']:
barcode, feature = get_barcode_feature(dic)
barcodes.append(barcode)
features.append(feature)
return barcodes, features
def create_base_index(index_file_pth=None,
barcodes=None,
features=None,
save_index_name=None):
index = create_index()
if index_file_pth is not None:
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
save_index_name = index_file_pth.split('json')[0] + 'data'
barcodes, features = analysis_file(index_file_pth)
else:
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
for barcode, feature in zip(barcodes, features):
try:
index.add(np.array(barcode), np.array(feature))
except Exception as e:
print(e)
continue
index.save(save_index_name)
def get_feature_index(index_file_pth=None,
barcodes=None):
assert index_file_pth is not None, 'index_file_pth must be not None'
index = Index.restore(index_file_pth, view=True)
feature_lists = index.get(np.array(barcodes))
print("memory {} size {}".format(index.memory_usage, index.size))
print("feature_lists {}".format(feature_lists))
return feature_lists
def search_in_index(query=None,
barcode=None, # barcode -> int or np.ndarray
index_name=None,
temp_index=False, # 是否为临时库
model='0',
):
if temp_index:
assert index_name is not None, 'index_name must be not None'
index = Index.restore(index_name, view=True)
if barcode is not None: # 1:1对比测试
feature_lists = index.get(np.array(barcode))
results = compare_feature(query, feature_lists)
else:
results = index.search(query, count=5)
return results
else: # 标准库
assert index_name is not None, 'index_name must be not None'
index = Index.restore(index_name, view=True)
if barcode is not None: # 1:1对比测试
feature_lists = index.get(np.array(barcode))
results = compare_feature(query, feature_lists, model)
else:
results = index.search(query, count=10)
return results
def delete_index(index_name=None, key=None, index=None):
assert key is not None, 'key must be not None'
if index is None:
assert index_name is not None, 'index_name must be not None'
index = Index.restore(index_name, view=True)
index.remove(index_name)
else:
index.remove(key)
from scipy.spatial.distance import cdist
def compute_similarity_matrix(featurelists1, featurelists2):
"""计算图片之间的余弦相似度矩阵"""
# 计算所有向量对之间的余弦相似度
cosine_similarities = 1 - cdist(featurelists1, featurelists2, metric='cosine')
cosine_similarities = np.around(cosine_similarities, decimals=3)
return cosine_similarities
def check_usearch_json_diff(index_file_pth, json_file_pth):
json_features = None
feature_lists = get_feature_index(index_file_pth, ['6923644272159'])
with open(json_file_pth, 'r') as json_file:
json_data = json.load(json_file)
for data in json_data['total']:
if data['key'] == '6923644272159':
json_features = data['value']
json_features = np.array(json_features)
feature_lists = np.array(feature_lists[0])
compute_similarity_matrix(json_features, feature_lists)
def write_binary_file(filename, datas):
with open(filename, 'wb') as f:
# 先写入数据中的key数量为C++读取提供便利)
key_count = len(datas)
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型4字节
for data in datas:
key = data['key']
feats = data['value']
key_bytes = key.encode('utf-8')
key_len = len(key)
length_byte = struct.pack('<B', key_len)
f.write(length_byte)
# f.write(struct.pack('Q', len(key_bytes)))
f.write(key_bytes)
value_count = len(feats)
f.write(struct.pack('I', (value_count * 256)))
# 遍历字典写入每个key及其对应的浮点数值列表
for values in feats:
# 写入每个浮点数值(保留小数点后六位)
for value in values:
# 使用'f'格式单精度浮点4字节并四舍五入保留六位小数
value_half = np.float16(value)
# print(value_half.tobytes())
f.write(value_half.tobytes())
def create_binary_file(json_path, flag=True):
# 1. 打开JSON文件
with open(json_path, 'r', encoding='utf-8') as file:
# 2. 读取并解析JSON文件内容
data = json.load(file)
if flag:
for flag, values in data.items():
# 逐个写入values中的每个值保留小数点后六位每个值占一行
write_binary_file(index_file_pth.replace('json', 'bin'), values)
else:
write_binary_file(json_path.replace('.json', '.bin'), [data])
def create_binary_files(index_file_pth):
if os.path.isfile(index_file_pth):
create_binary_file(index_file_pth)
else:
for name in os.listdir(index_file_pth):
jsonpth = os.sep.join([index_file_pth, name])
create_binary_file(jsonpth, False)
if __name__ == '__main__':
# index_file_pth = '../data/feature_json' # 生成二进制文件 多文件
index_file_pth = '../search_library/yunhedian_30-04.json'
# create_base_index(index_file_pth) # 生成usearch文件
create_binary_files(index_file_pth) # 生成二进制文件 多文件
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
# # index_file_pth = '../search_library/data_10_normal_0718.index'
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
# # check index data file
# index_file_pth = '../search_library/data_zhanting.data'
# # # get_feature_index(index_file_pth, ['6901070602818'])
# get_feature_index(index_file_pth, ['6923644272159'])
# index_file_pth = '../search_library/data_zhanting.data'
# json_file_pth = '../search_library/data_zhanting.json'
# check_usearch_json_diff(index_file_pth, json_file_pth)

View File

@ -0,0 +1,84 @@
'''
现场1:N测试确定阈值
'''
import os
import numpy as np
import matplotlib.pyplot as plt
def showHist(filtered_data):
Same = filtered_data[:, 1].astype(np.float32)
Cross = filtered_data[:, 2].astype(np.float32)
fig, axs = plt.subplots(2, 1)
axs[0].hist(Same, bins=50, edgecolor='black')
axs[0].set_xlim([-0.1, 1])
axs[0].set_title('first')
axs[1].hist(Cross, bins=50, edgecolor='black')
axs[1].set_xlim([-0.1, 1])
axs[1].set_title('second')
# plt.savefig('plot.png')
plt.show()
def get_tartget_list(nested_list):
filtered_list = np.array(list(filter(lambda x: len(x) >= 2, nested_list))) # 去除无轨迹的数据
filtered_correct = filtered_list[filtered_list[:, 0] != 'wrong'] # 获取比对正确的时项
filtered_wrong = filtered_list[filtered_list[:, 0] == 'wrong'] # 获取比对错误的时项
showHist(filtered_correct)
# showHist(filtered_wrong)
print(filtered_list)
def deal_process(file_pth):
flag = False
event = file_pth.split('\\')[-2]
target_barcode = file_pth.split('\\')[-2].split('_')[-1]
temp_list = []
with open(file_pth, 'r') as f:
for line in f:
if 'oneToOne' in line:
flag = True
continue
if flag:
line = line.replace('\n', '')
comparison_data = line.split(',')
forecast_barcode = comparison_data[0]
value = comparison_data[-1].split(':')[-1]
if value == '':
break
if len(temp_list) == 0:
if forecast_barcode == target_barcode:
temp_list.append('correct')
else:
temp_list.append('wrong')
temp_list.append(float(value))
temp_list.append(event)
return temp_list
def anaylze_scratch(scratch_pth):
purchase, back = [], []
for root, dirs, files in os.walk(scratch_pth):
if len(root) > 0:
if len(root.split('_')) == 4: # 加购
process = os.path.join(root, 'process.data')
if not os.path.exists(process):
continue
purchase.append(deal_process(process))
elif len(root.split('_')) == 3:
process = os.path.join(root, 'process.data')
if not os.path.exists(process):
continue
back.append(deal_process(process))
# get_tartget_list(purchase)
get_tartget_list(back)
print(purchase)
if __name__ == '__main__':
# scratch_pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\展厅测试\\1108_展厅模型v800测试\\'
scratch_pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\展厅测试\\1120_展厅模型v801测试\\扫A放A\\'
anaylze_scratch(scratch_pth)

411
tools/write_feature_json.py Normal file
View File

@ -0,0 +1,411 @@
import json
import os
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple
from tools.dataset import get_transform
from model import resnet18
import torch
from PIL import Image
import pandas as pd
from tqdm import tqdm
import yaml
import shutil
import struct
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class FeatureExtractor:
def __init__(self, conf):
self.conf = conf
self.model = self.initModel()
_, self.test_transform = get_transform(self.conf)
pass
def initModel(self, inference_model: Optional[str] = None) -> torch.nn.Module:
"""
Initialize and load the ResNet18 model for inference.
Args:
inference_model: Optional path to model weights. Uses conf.test_model if None.
Returns:
Loaded and configured PyTorch model in evaluation mode.
Raises:
FileNotFoundError: If model weights file is not found
RuntimeError: If model loading fails
"""
model_path = inference_model if inference_model else self.conf['models']['checkpoints']
try:
# Verify model file exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model weights file not found: {model_path}")
# Initialize model
model = resnet18().to(self.conf['base']['device'])
# Handle multi-GPU case
if conf['base']['distributed']:
model = torch.nn.DataParallel(model)
# Load weights
state_dict = torch.load(model_path, map_location=conf['base']['device'])
model.load_state_dict(state_dict)
model.eval()
logger.info(f"Successfully loaded model from {model_path}")
return model
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
raise
def convert_rgba_to_rgb(self, image_path):
# 打开图像
img = Image.open(image_path)
# 转换图像模式从RGBA到RGB
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
if img.mode == 'RGBA':
# 转换为RGB模式
img_rgb = img.convert('RGB')
# 保存转换后的图像
img_rgb.save(image_path)
print(f"Image converted from RGBA to RGB and saved to {image_path}")
def test_preprocess(self, images: list, actionModel=False) -> torch.Tensor:
res = []
for img in images:
try:
im = self.test_transform(img) if actionModel else self.test_transform(Image.open(img))
res.append(im)
except:
continue
data = torch.stack(res)
return data
def inference(self, images, model, actionModel=False):
data = self.test_preprocess(images, actionModel)
if torch.cuda.is_available():
data = data.to(conf['base']['device'])
features = model(data)
if conf['data']['half']:
features = features.half()
return features
def group_image(self, images, batch=64) -> list:
"""Group image paths by batch size"""
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i:end])
return res
def getFeatureList(self, barList, imgList):
featList = [[] for _ in range(len(barList))]
for index, image_paths in enumerate(imgList):
try:
# Process images in batches
for batch in self.group_image(image_paths):
# Get features for batch
features = self.inference(batch, self.model)
# Process each feature in batch
for feat in features:
# Move to CPU and convert to numpy
feat_np = feat.squeeze().detach().cpu().numpy()
# Normalize first 256 dimensions
normalized = self.normalize_256(feat_np[:256])
# Combine with remaining dimensions
combined = np.concatenate([normalized, feat_np[256:]], axis=0)
featList[index].append(combined)
except Exception as e:
logger.error(f"Error processing images for index {index}: {str(e)}")
continue
return featList
def get_files(
self,
folder: str,
filter: Optional[List[str]] = None,
create_single_json: bool = False
) -> Dict[str, List[str]]:
"""
Recursively collect image files from directory structure.
Args:
folder: Root directory to scan
filter: Optional list of barcodes to include
create_single_json: Whether to create individual JSON files per barcode
Returns:
Dictionary mapping barcode names to lists of image paths
Example:
{
"barcode1": ["path/to/img1.jpg", "path/to/img2.jpg"],
"barcode2": ["path/to/img3.jpg"]
}
"""
file_dicts = {}
total_files = 0
feature_counts = []
barcode_count = 0
subclass = [str(i) for i in range(100)]
# Validate input directory
if not os.path.isdir(folder):
raise ValueError(f"Invalid directory: {folder}")
# Process each barcode directory
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
if not dirs: # Leaf directory (contains images)
basename = os.path.basename(root)
if basename in subclass:
ori_barcode = root.split('/')[-2]
barcode = root.split('/')[-2] + '_' + basename
else:
ori_barcode = basename
barcode = basename
# Apply filter if provided
if filter and ori_barcode not in filter:
continue
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
logger.warning(f"Skipping invalid barcode {ori_barcode}")
with open(conf['save']['error_barcodes'], 'a') as f:
f.write(ori_barcode + '\n')
f.close()
continue
# Process image files
if files:
image_paths = self._process_image_files(root, files)
if not image_paths:
continue
# Update counters
barcode_count += 1
file_count = len(image_paths)
total_files += file_count
feature_counts.append(file_count)
# Handle output mode
if create_single_json:
self._process_single_barcode(barcode, image_paths)
else:
if barcode.split('_')[-1] == '0':
barcode = barcode.split('_')[0]
file_dicts[barcode] = image_paths
# # Log summary
# logger.info(f"Processed {barcode_count} barcodes with {total_files} total images")
# logger.debug(f"Image counts per barcode: {feature_counts}")
# Batch process if not creating individual JSONs
if not create_single_json and file_dicts:
self.createFeatureDict(
file_dicts,
create_single_json=False,
)
return file_dicts
def _process_image_files(self, root: str, files: List[str]) -> List[str]:
"""Process and validate image files in a directory."""
valid_paths = []
for filename in files:
file_path = os.path.join(root, filename)
try:
# Convert RGBA to RGB if needed
self.convert_rgba_to_rgb(file_path)
valid_paths.append(file_path)
except Exception as e:
logger.warning(f"Skipping invalid image {file_path}: {str(e)}")
return valid_paths
def _process_single_barcode(self, barcode: str, image_paths: List[str]):
"""Process a single barcode and create individual JSON file."""
temp_dict = {barcode: image_paths}
self.createFeatureDict(
temp_dict,
create_single_json=True,
)
def normalize_256(self, queFeatList):
queFeatList = queFeatList / np.linalg.norm(queFeatList)
return queFeatList
def img2feature(
self,
imgs_dict: Dict[str, List[str]]
) -> Tuple[List[str], List[List[np.ndarray]]]:
"""
Extract features for all images in the dictionary.
Args:
imgs_dict: Dictionary mapping barcodes to image paths
model: Pretrained feature extraction model
barcode_flag: Whether to include barcode info (unused)
Returns:
Tuple containing:
- List of barcode IDs
- List of feature lists (one per barcode)
Raises:
ValueError: If input dictionary is empty
RuntimeError: If feature extraction fails
"""
if not imgs_dict:
raise ValueError("No images provided for feature extraction")
try:
barcode_list = list(imgs_dict.keys())
image_list = list(imgs_dict.values())
feature_list = self.getFeatureList(barcode_list, image_list)
logger.info(f"Successfully extracted features for {len(barcode_list)} barcodes")
return barcode_list, feature_list
except Exception as e:
logger.error(f"Feature extraction failed: {str(e)}")
raise RuntimeError(f"Feature extraction failed: {str(e)}")
def createFeatureDict(self, imgs_dict,
create_single_json=False): # imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
dicts_all = {}
value_list = []
barcode_list, imgs_list = self.img2feature(imgs_dict)
for i in range(len(barcode_list)):
dicts = {}
imgs_list_ = []
for j in range(len(imgs_list[i])):
imgs_list_.append(imgs_list[i][j].tolist())
dicts['key'] = barcode_list[i]
truncated_imgs_list = [subarray[:256] for subarray in imgs_list_]
dicts['value'] = truncated_imgs_list
if create_single_json:
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
with open(json_path, 'w') as json_file:
json.dump(dicts, json_file)
else:
value_list.append(dicts)
if not create_single_json:
dicts_all['total'] = value_list
with open(self.conf['save']['json_bin'], 'w') as json_file:
json.dump(dicts_all, json_file)
self.create_binary_files(self.conf['save']['json_bin'])
def statisticsBarcodes(self, pth, filter=None):
feature_num = 0
feature_num_lists = []
nn = 0
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
for barcode in os.listdir(pth):
print("barcode length >> {}".format(len(barcode)))
if len(barcode) > 13 or len(barcode) < 8:
continue
if filter is not None:
f.writelines(barcode + '\n')
if barcode in filter:
print(barcode)
feature_num += len(os.listdir(os.path.join(pth, barcode)))
nn += 1
else:
print('barcode name >>{}'.format(barcode))
f.writelines(barcode + '\n')
feature_num += len(os.listdir(os.path.join(pth, barcode)))
feature_num_lists.append(feature_num)
print("特征总量: {}".format(feature_num))
print("barcode总量 {}".format(nn))
f.close()
def get_shop_barcodes(self, file_path):
if file_path:
df = pd.read_excel(file_path)
column_values = list(df.iloc[:, 6].values)
column_values = list(map(str, column_values))
return column_values
else:
return None
def del_base_dir(self, pth):
for root, dirs, files in os.walk(pth):
if len(dirs) == 1:
if dirs[0] == 'base':
shutil.rmtree(os.path.join(root, dirs[0]))
def write_binary_file(self, filename, datas):
with open(filename, 'wb') as f:
# 先写入数据中的key数量为C++读取提供便利)
key_count = len(datas)
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型4字节
for data in datas:
key = data['key']
feats = data['value']
key_bytes = key.encode('utf-8')
key_len = len(key)
length_byte = struct.pack('<B', key_len)
f.write(length_byte)
# f.write(struct.pack('Q', len(key_bytes)))
f.write(key_bytes)
value_count = len(feats)
f.write(struct.pack('I', (value_count * 256)))
# 遍历字典写入每个key及其对应的浮点数值列表
for values in feats:
# 写入每个浮点数值(保留小数点后六位)
for value in values:
# 使用'f'格式单精度浮点4字节并四舍五入保留六位小数
value_half = np.float16(value)
# print(value_half.tobytes())
f.write(value_half.tobytes())
def create_binary_file(self, json_path, flag=True):
# 1. 打开JSON文件
with open(json_path, 'r', encoding='utf-8') as file:
# 2. 读取并解析JSON文件内容
data = json.load(file)
if flag:
for flag, values in data.items():
# 逐个写入values中的每个值保留小数点后六位每个值占一行
self.write_binary_file(self.conf['save']['json_bin'].replace('json', 'bin'), values)
else:
self.write_binary_file(json_path.replace('.json', '.bin'), [data])
def create_binary_files(self, index_file_pth):
if os.path.isfile(index_file_pth):
self.create_binary_file(index_file_pth)
else:
for name in os.listdir(index_file_pth):
jsonpth = os.sep.join([index_file_pth, name])
self.create_binary_file(jsonpth, False)
if __name__ == "__main__":
with open('../configs/write_feature.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
###将图片名称和模型推理特征向量字典存为json文件
# xlsx_pth = './shop_xlsx/曹家桥门店在售商品表.xlsx'
# xlsx_pth = None
# del_base_dir(mg_path)
extractor = FeatureExtractor(conf)
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
filter=column_values,
create_single_json=False) # False
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)