rebuild
This commit is contained in:
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
BIN
tools/__pycache__/gift_data_pretreatment.cpython-38.pyc
Normal file
BIN
tools/__pycache__/gift_data_pretreatment.cpython-38.pyc
Normal file
Binary file not shown.
68
tools/dataset.py
Normal file
68
tools/dataset.py
Normal file
@ -0,0 +1,68 @@
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms as T
|
||||
# from config import config as conf
|
||||
import torch
|
||||
|
||||
def pad_to_square(img):
|
||||
w, h = img.size
|
||||
max_wh = max(w, h)
|
||||
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||
|
||||
def get_transform(cfg):
|
||||
train_transform = T.Compose([
|
||||
T.Lambda(pad_to_square), # 补边
|
||||
T.ToTensor(),
|
||||
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
||||
# T.RandomCrop(img_size * 4 // 5),
|
||||
T.RandomHorizontalFlip(p=cfg['transform']['RandomHorizontalFlip']),
|
||||
T.RandomRotation(cfg['transform']['RandomRotation']),
|
||||
T.ColorJitter(brightness=cfg['transform']['ColorJitter']),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
||||
])
|
||||
test_transform = T.Compose([
|
||||
# T.Lambda(pad_to_square), # 补边
|
||||
T.ToTensor(),
|
||||
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
||||
])
|
||||
return train_transform, test_transform
|
||||
|
||||
def load_data(training=True, cfg=None):
|
||||
train_transform, test_transform = get_transform(cfg)
|
||||
if training:
|
||||
dataroot = cfg['data']['data_train_dir']
|
||||
transform = train_transform
|
||||
# transform = conf.train_transform
|
||||
batch_size = cfg['data']['train_batch_size']
|
||||
else:
|
||||
dataroot = cfg['data']['data_val_dir']
|
||||
# transform = conf.test_transform
|
||||
transform = test_transform
|
||||
batch_size = cfg['data']['val_batch_size']
|
||||
|
||||
data = ImageFolder(dataroot, transform=transform)
|
||||
class_num = len(data.classes)
|
||||
loader = DataLoader(data,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg['base']['pin_memory'],
|
||||
num_workers=cfg['data']['num_workers'],
|
||||
drop_last=True)
|
||||
return loader, class_num
|
||||
|
||||
# def load_gift_data(action):
|
||||
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
|
||||
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# return train_dataset, val_dataset, test_dataset
|
10
tools/dataset.txt
Normal file
10
tools/dataset.txt
Normal file
@ -0,0 +1,10 @@
|
||||
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
112
tools/fp32comparefp16.py
Normal file
112
tools/fp32comparefp16.py
Normal file
@ -0,0 +1,112 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from test_ori import group_image, init_model, featurize
|
||||
from config import config as conf
|
||||
import json
|
||||
import os.path as osp
|
||||
|
||||
def compare_fp16_fp32(values_pf16, values_pf32, dataTest):
|
||||
if dataTest:
|
||||
norm_values_pf16 = torch.norm(values_pf16, p=2)
|
||||
norm_values_pf32 = torch.norm(values_pf32, p=2)
|
||||
euclidean_distance = torch.norm(norm_values_pf16 - norm_values_pf32, p=2)
|
||||
print(f"欧几里得距离: {euclidean_distance}")
|
||||
cosine_sim = torch.dot(values_pf16.float(), values_pf32) / (norm_values_pf16 * norm_values_pf32)
|
||||
print(f"余弦相似度: {cosine_sim}")
|
||||
else:
|
||||
|
||||
pass
|
||||
def cosin_metric(x1, x2, fp32=True):
|
||||
if fp32:
|
||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
else:
|
||||
x1_fp16 = x1.astype(np.float16)
|
||||
x2_fp16 = x2.astype(np.float16)
|
||||
# print(type(x1))
|
||||
# pdb.set_trace()
|
||||
return np.dot(x1_fp16, x2_fp16) / (np.linalg.norm(x1_fp16) * np.linalg.norm(x2_fp16))
|
||||
def deal_group_pair(pairList1, pairList2):
|
||||
one_similarity_fp16, one_similarity_fp32, allsimilarity_fp32, allsimilarity_fp16 = [], [], [], []
|
||||
for pair1 in pairList1:
|
||||
for pair2 in pairList2:
|
||||
# similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
||||
one_similarity_fp32.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), True))
|
||||
one_similarity_fp16.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), False))
|
||||
allsimilarity_fp32.append(one_similarity_fp32)
|
||||
allsimilarity_fp16.append(one_similarity_fp16)
|
||||
one_similarity_fp16, one_similarity_fp32 = [], []
|
||||
return np.array(allsimilarity_fp32), np.array(allsimilarity_fp16)
|
||||
|
||||
def compute_group_accuracy(content_list_read, model):
|
||||
allSimilarity, allLabel = [], []
|
||||
Same, Cross = [], []
|
||||
flag_same = True
|
||||
flag_diff = True
|
||||
for data_loaded in content_list_read:
|
||||
one_group_list = []
|
||||
try:
|
||||
if (flag_same and str(data_loaded[-1]) == '1') or (flag_diff and str(data_loaded[-1]) == '0'):
|
||||
for i in range(2):
|
||||
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
||||
group = group_image(images, conf.test_batch_size)
|
||||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
one_group_list.append(d.values())
|
||||
if str(data_loaded[-1]) == '1':
|
||||
flag_same = False
|
||||
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
print('fp32 same-- >', allsimilarity_fp32)
|
||||
print('fp16 same-- >', allsimilarity_fp16)
|
||||
else:
|
||||
flag_diff = False
|
||||
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
print('fp32 diff-- >', allsimilarity_fp32)
|
||||
print('fp16 diff-- >', allsimilarity_fp16)
|
||||
except Exception as e:
|
||||
continue
|
||||
# print(allSimilarity)
|
||||
# print(allLabel)
|
||||
return allSimilarity, allLabel
|
||||
def get_feature_list(imgPth):
|
||||
imgs = get_files(imgPth)
|
||||
group = group_image(imgs, conf.test_batch_size)
|
||||
model = init_model()
|
||||
model.eval()
|
||||
fe = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
return fe
|
||||
|
||||
|
||||
def get_files(imgPth):
|
||||
imgsList = []
|
||||
for img in os.walk(imgPth):
|
||||
for img_name in img[2]:
|
||||
img_path = os.sep.join([img[0], img_name])
|
||||
imgsList.append(img_path)
|
||||
return imgsList
|
||||
import pdb
|
||||
|
||||
def compare(imgPth, group=False):
|
||||
model = init_model()
|
||||
model.eval()
|
||||
if not group:
|
||||
values_pf16, values_pf32 = [], []
|
||||
fe = get_feature_list(imgPth)
|
||||
# pdb.set_trace()
|
||||
values_pf32 += [value.cpu() for value in fe.values()]
|
||||
values_pf16 += [value.cpu().half() for value in fe.values()]
|
||||
for value_pf16, value_pf32 in zip(values_pf16, values_pf32):
|
||||
compare_fp16_fp32(value_pf16, value_pf32, dataTest=True)
|
||||
else:
|
||||
filename = conf.test_group_json
|
||||
with open(filename, 'r', encoding='utf-8') as file:
|
||||
content_list_read = json.load(file)
|
||||
compute_group_accuracy(content_list_read, model)
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
imgPth = './data/test/inner/3701375401900'
|
||||
compare(imgPth)
|
369
tools/gift_assessment.py
Normal file
369
tools/gift_assessment.py
Normal file
@ -0,0 +1,369 @@
|
||||
import os
|
||||
import pdb
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
sys.path.append('../model')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from model.mlp import Net2, Net3, Net4
|
||||
from model import resnet18
|
||||
import torch
|
||||
from gift_data_pretreatment import getFeatureList
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def init_model(pkl_flag):
|
||||
res_pth = r"../checkpoints/resnet18_1009/best.pth"
|
||||
if pkl_flag:
|
||||
gift_pth = r'../checkpoints/gift_model/action2/gift_v11.pth'
|
||||
gift_model = Net3(pretrained=True, num_classes=1)
|
||||
gift_model.load_state_dict(torch.load(gift_pth))
|
||||
else:
|
||||
gift_pth = r'../checkpoints/gift_model/action3/best.pth'
|
||||
gift_model = Net4('resnet18', True, True) # 预训练模型
|
||||
try:
|
||||
print('>>multiple_cards load pre model <<')
|
||||
gift_model.load_state_dict({k.replace('module.', ''): v for k, v in
|
||||
torch.load(gift_pth,
|
||||
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')).items()})
|
||||
except Exception as e:
|
||||
print('>> load pre model <<')
|
||||
gift_model.load_state_dict(torch.load(gift_pth,
|
||||
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
|
||||
res_model = resnet18()
|
||||
res_model.load_state_dict({k.replace('module.', ''): v for k, v in
|
||||
torch.load(res_pth, map_location=torch.device(device)).items()})
|
||||
return res_model, gift_model
|
||||
|
||||
|
||||
def showHist(nongifts, gifts):
|
||||
# Same = filtered_data[:, 1].astype(np.float32)
|
||||
# Cross = filtered_data[:, 2].astype(np.float32)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(nongifts, bins=50, edgecolor='blue')
|
||||
axs[0].set_xlim([-0.1, 1])
|
||||
axs[0].set_title('nongifts')
|
||||
|
||||
axs[1].hist(gifts, bins=50, edgecolor='green')
|
||||
axs[1].set_xlim([-0.1, 1])
|
||||
axs[1].set_title('gifts')
|
||||
# plt.savefig('plot.png')
|
||||
plt.show()
|
||||
|
||||
|
||||
def calculate_precision_recall(nongift, gift, points):
|
||||
precision, recall = [], []
|
||||
for point in points:
|
||||
TP = np.sum(gift > point)
|
||||
FN = np.sum(gift < point)
|
||||
FP = np.sum(nongift > point)
|
||||
TN = np.sum(nongift < point)
|
||||
if TP == 0:
|
||||
precision.append(0)
|
||||
recall.append(0)
|
||||
else:
|
||||
precision.append(TP / (TP + FP))
|
||||
recall.append(TP / (TP + FN))
|
||||
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
|
||||
if point == 0.5:
|
||||
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
|
||||
return precision, recall
|
||||
|
||||
|
||||
def showgrid(all_prec, all_recall, points):
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(points[:-1], all_prec[:-1], color='blue', label='precision')
|
||||
plt.plot(points[:-1], all_recall[:-1], color='red', label='recall')
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
# plt.savefig('grid.png')
|
||||
plt.show()
|
||||
plt.close()
|
||||
pass
|
||||
|
||||
|
||||
def discriminate_action(roots): # 判断加购还是退购
|
||||
pth = os.sep.join([roots, 'process.data'])
|
||||
with open(pth, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
content = line.strip()
|
||||
if 'weightValue' in content:
|
||||
# print(content.split(":")[-1].split(',')[0])
|
||||
if int(content.split(":")[-1].split(',')[0]) > 0:
|
||||
return 'add'
|
||||
else:
|
||||
return 'return'
|
||||
|
||||
|
||||
def median(lst):
|
||||
sorted_lst = sorted(lst)
|
||||
n = len(sorted_lst)
|
||||
if n % 2 == 1:
|
||||
# 如果列表长度是奇数,中位数是中间的那个元素
|
||||
return sorted_lst[n // 2]
|
||||
else:
|
||||
# 如果列表长度是偶数,中位数是中间两个元素的平均值
|
||||
mid1 = sorted_lst[(n // 2) - 1]
|
||||
mid2 = sorted_lst[n // 2]
|
||||
return (mid1 + mid2) / 2
|
||||
|
||||
|
||||
def get_special_data(data, p):
|
||||
# print(data)
|
||||
length = len(data)
|
||||
if length > 5:
|
||||
if p == 'max':
|
||||
return max(data[:round(length * 0.5)])
|
||||
elif p == 'average':
|
||||
return sum(data[:round(length * 0.5)]) / len(data[:round(length * 0.5)])
|
||||
elif p == 'median':
|
||||
return median(data[:round(length * 0.5)])
|
||||
else:
|
||||
return sum(data) / len(data)
|
||||
|
||||
|
||||
def read_data_file(pth):
|
||||
result = []
|
||||
with open(pth, 'r') as data_file:
|
||||
lines = data_file.readlines()
|
||||
for line in lines:
|
||||
if line.split(':')[0] == 'free_gift__result':
|
||||
if '0_tracking_output.data' in pth:
|
||||
result = line.split(':')[1].split(',')[:-1]
|
||||
else:
|
||||
result = line.split(':')[1].split(',')[:-2]
|
||||
result = [float(i) for i in result]
|
||||
return result
|
||||
|
||||
|
||||
def get_tracking_data(pth):
|
||||
result = []
|
||||
with open(pth, 'r') as data_file:
|
||||
lines = data_file.readlines()
|
||||
for line in lines:
|
||||
if len(line.split(',')) == 65:
|
||||
result.append([float(item) for item in line.split(',')[:-1]])
|
||||
return result
|
||||
|
||||
|
||||
def clean_reurn_data(pth):
|
||||
for roots, dirs, files in os.walk(pth):
|
||||
# print(roots, dirs, files)
|
||||
if len(dirs) == 0:
|
||||
flag = discriminate_action(roots)
|
||||
if flag == 'return':
|
||||
shutil.rmtree(roots)
|
||||
|
||||
|
||||
def get_gift_files(pth): # 测试后直接分析测试结果文件
|
||||
add_special_output_0, return_special_output_0, return_special_output_1, add_special_output_1 = [], [], [], []
|
||||
add_tracking_output_0, return_tracking_output_0, add_tracking_output_1, return_tracking_output_1 = [], [], [], []
|
||||
for roots, dirs, files in os.walk(pth):
|
||||
# print(roots, dirs, files)
|
||||
if len(dirs) == 0:
|
||||
flag = discriminate_action(roots)
|
||||
for file in files:
|
||||
if file == '0_tracking_output.data':
|
||||
result = read_data_file(os.path.join(roots, file))
|
||||
if not len(result) == 0:
|
||||
if flag == 'add':
|
||||
add_special_output_0.append(get_special_data(result, 'average')) # 加购后摄
|
||||
else:
|
||||
return_special_output_0.append(get_special_data(result, 'average')) # 退购后摄
|
||||
if flag == 'add':
|
||||
add_tracking_output_0 += read_data_file(os.path.join(roots, file))
|
||||
else:
|
||||
return_tracking_output_0 += read_data_file(os.path.join(roots, file))
|
||||
elif file == '1_tracking_output.data':
|
||||
result = read_data_file(os.path.join(roots, file))
|
||||
if not len(result) == 0:
|
||||
if flag == 'add':
|
||||
add_special_output_1.append(get_special_data(result, 'average')) # 加购前摄
|
||||
else:
|
||||
return_special_output_1.append(get_special_data(result, 'average')) # 退购前摄
|
||||
if flag == 'add':
|
||||
add_tracking_output_1 += read_data_file(os.path.join(roots, file))
|
||||
else:
|
||||
return_tracking_output_1 += read_data_file(os.path.join(roots, file))
|
||||
comprehensive_dicts = {"add_special_output_0": add_special_output_0,
|
||||
"return_special_output_0": return_special_output_0,
|
||||
"add_tracking_output_0": add_tracking_output_0,
|
||||
"return_tracking_output_0": return_tracking_output_0,
|
||||
"add_special_output_1": add_special_output_1,
|
||||
"return_special_output_1": return_special_output_1,
|
||||
"add_tracking_output_1": add_tracking_output_1,
|
||||
"return_tracking_output_1": return_tracking_output_1,
|
||||
}
|
||||
# print(tracking_output_0, tracking_output_1)
|
||||
showHist(np.array(comprehensive_dicts['add_tracking_output_0']),
|
||||
np.array(comprehensive_dicts['add_tracking_output_1']))
|
||||
# showHist(np.array(comprehensive_dicts['add_special_output_0']),
|
||||
# np.array(comprehensive_dicts['add_special_output_1']))
|
||||
return comprehensive_dicts
|
||||
|
||||
|
||||
def get_feature_array(img_pth_lists, res_model, gift_model, pkl_flag=True):
|
||||
features_np = []
|
||||
if pkl_flag:
|
||||
for img_lists in img_pth_lists:
|
||||
# print(img_lists)
|
||||
fe_nps = getFeatureList(None, img_lists, res_model)
|
||||
# fe_nps.squeeze()
|
||||
try:
|
||||
fe_nps = fe_nps[0][:, 256:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
fe_nps = torch.from_numpy(fe_nps)
|
||||
fe_nps = fe_nps.view(fe_nps.shape[0], 64, 13, 13)
|
||||
if len(fe_nps):
|
||||
fe_np = gift_model(fe_nps)
|
||||
fe_np = np.squeeze(fe_np.detach().numpy())
|
||||
features_np.append(fe_np)
|
||||
else:
|
||||
for img_lists in img_pth_lists:
|
||||
fe_nps = getFeatureList(None, img_lists, gift_model)
|
||||
if len(fe_nps) > 0:
|
||||
fe_nps = np.concatenate(fe_nps)
|
||||
features_np.append(fe_nps)
|
||||
return features_np
|
||||
|
||||
|
||||
import pickle
|
||||
|
||||
|
||||
def create_gift_subimg_np(data_pth, pkl_flag):
|
||||
gift_array_pth = os.path.join(data_pth, 'gift.pkl')
|
||||
nongift_array_pth = os.path.join(data_pth, 'nongift.pkl')
|
||||
res_model, gift_model = init_model(pkl_flag)
|
||||
res_model = res_model.eval()
|
||||
gift_model = gift_model.eval()
|
||||
gift_img_pth_list, gift_lists, nongift_img_pth_list, nongift_lists = [], [], [], []
|
||||
|
||||
for root, dirs, files in os.walk(data_pth):
|
||||
if ('commodity' in root and 'subimg' in root):
|
||||
print("commodity >> {}".format(root))
|
||||
for file in files:
|
||||
nongift_img_pth_list.append(os.sep.join([root, file]))
|
||||
nongift_lists.append(nongift_img_pth_list)
|
||||
nongift_img_pth_list = []
|
||||
elif ('Havegift' in root and 'subimg' in root):
|
||||
print("Havegift >> {}".format(root))
|
||||
for file in files:
|
||||
gift_img_pth_list.append(os.sep.join([root, file]))
|
||||
gift_lists.append(gift_img_pth_list)
|
||||
gift_img_pth_list = []
|
||||
nongift = get_feature_array(nongift_lists, res_model, gift_model, pkl_flag)
|
||||
gift = get_feature_array(gift_lists, res_model, gift_model, pkl_flag)
|
||||
with open(nongift_array_pth, 'wb') as file:
|
||||
pickle.dump(nongift, file)
|
||||
with open(gift_array_pth, 'wb') as file:
|
||||
pickle.dump(gift, file)
|
||||
|
||||
|
||||
def top_25_percent_mean(arr):
|
||||
# 1. 对数组进行从高到低排序
|
||||
sorted_arr = np.sort(arr)[::-1]
|
||||
|
||||
# 2. 计算数组长度的25%
|
||||
top_25_percent_length = int(len(sorted_arr) * 0.25)
|
||||
|
||||
# 3. 取排序后数组的前25%元素
|
||||
top_25_percent = sorted_arr[:top_25_percent_length]
|
||||
|
||||
# 4. 计算这些元素的平均值
|
||||
mean_value = np.mean(top_25_percent)
|
||||
|
||||
return top_25_percent
|
||||
|
||||
|
||||
def assess_gift_subimg(data_pth, pkl_flag=False): # 分析分割后子图,
|
||||
points = (np.linspace(1, 100, 100)) / 100
|
||||
gift_pkl_pth = os.path.join(data_pth, 'gift.pkl')
|
||||
nongift_pkl_pth = os.path.join(data_pth, 'nongift.pkl')
|
||||
if not os.path.exists(gift_pkl_pth):
|
||||
create_gift_subimg_np(data_pth, pkl_flag)
|
||||
with open(nongift_pkl_pth, 'rb') as f:
|
||||
nongift = pickle.load(f)
|
||||
with open(gift_pkl_pth, 'rb') as f:
|
||||
gift = pickle.load(f)
|
||||
# showHist(nongift.flatten(), gift.flatten())
|
||||
|
||||
'''
|
||||
一分位均值
|
||||
'''
|
||||
nongift_mean = [np.mean(top_25_percent_mean(items)) for items in nongift]
|
||||
gift_mean = [np.mean(top_25_percent_mean(items)) for items in gift]
|
||||
'''
|
||||
中位数
|
||||
'''
|
||||
# nongift_mean = [np.median(items) for items in nongift]
|
||||
# gift_mean = [np.median(items) for items in gift] # 平均值
|
||||
|
||||
'''
|
||||
全部结果
|
||||
'''
|
||||
# nongifts = [items for items in nongift]
|
||||
# gifts = [items for items in gift]
|
||||
# showHist(nongifts, gifts)
|
||||
|
||||
'''
|
||||
平均值
|
||||
'''
|
||||
# nongift_mean = [np.mean(items) for items in nongift]
|
||||
# gift_mean = [np.mean(items) for items in gift]
|
||||
|
||||
showHist(np.array(nongift_mean), np.array(gift_mean)) # 最大值
|
||||
precision, recall = calculate_precision_recall(np.array(nongift_mean),
|
||||
np.array(gift_mean),
|
||||
points)
|
||||
showgrid(precision, recall, points)
|
||||
|
||||
|
||||
def get_comprehensive_dicts(data_pth):
|
||||
gift_pth = r'../checkpoints/gift_model/action2/best.pth'
|
||||
g_model = Net3(pretrained=True, num_classes=1)
|
||||
g_model.load_state_dict(torch.load(gift_pth))
|
||||
g_model.eval()
|
||||
result = []
|
||||
file_name = ['0_tracking_output.data',
|
||||
'1_tracking_output.data']
|
||||
for root, dirs, files in os.walk(data_pth):
|
||||
if not len(dirs):
|
||||
for file in files:
|
||||
if file in file_name:
|
||||
print(os.path.join(root, file))
|
||||
result += get_tracking_data(os.path.join(root, file))
|
||||
result = torch.from_numpy(np.array(result))
|
||||
input = result.view(result.shape[0], 64, 1, 1)
|
||||
input = input.to('cpu')
|
||||
input = input.to(torch.float32)
|
||||
ji = g_model(input)
|
||||
print(ji)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\赠品\\images'
|
||||
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\没有赠品的商品\\images'
|
||||
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\同样的商品没有捆绑赠品\\images'
|
||||
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241213赠品测试数据\\赠品'
|
||||
# pth = r'C:\Users\HP\Desktop\zengpin\1227'
|
||||
# get_gift_files(pth)
|
||||
|
||||
# 根据子图分析结果
|
||||
pth = r'D:\Project\contrast_nettest\data\gift_test'
|
||||
assess_gift_subimg(pth)
|
||||
|
||||
# 根据完整数据集分析结果
|
||||
# pth = r'C:\Users\HP\Desktop\zengpin\1231'
|
||||
# get_comprehensive_dicts(pth)
|
||||
|
||||
# 删除退购视频
|
||||
# pth = r'C:\Users\HP\Desktop\gift_test\20241213\非赠品'
|
||||
# clean_reurn_data(pth)
|
92
tools/gift_data_pretreatment.py
Normal file
92
tools/gift_data_pretreatment.py
Normal file
@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from config import config as conf
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def convert_rgba_to_rgb(image_path, output_path=None):
|
||||
"""
|
||||
将给定路径的4通道PNG图像转换为3通道,并保存到指定输出路径。
|
||||
|
||||
:param image_path: 输入图像的路径
|
||||
:param output_path: 转换后的图像保存路径
|
||||
"""
|
||||
# 打开图像
|
||||
img = Image.open(image_path)
|
||||
# 转换图像模式从RGBA到RGB
|
||||
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||||
if img.mode == 'RGBA':
|
||||
# 转换为RGB模式
|
||||
img_rgb = img.convert('RGB')
|
||||
# 保存转换后的图像
|
||||
img_rgb.save(image_path)
|
||||
# print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||||
# else:
|
||||
# # 如果已经是RGB或其他模式,直接保存
|
||||
# img.save(image_path)
|
||||
# print(f"Image already in {img.mode} mode, saved to {image_path}")
|
||||
|
||||
|
||||
def test_preprocess(images: list, actionModel=False) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
try:
|
||||
# print(img)
|
||||
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
|
||||
res.append(im)
|
||||
except:
|
||||
continue
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
|
||||
def inference(images, model, actionModel=False):
|
||||
data = test_preprocess(images, actionModel)
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(conf.device)
|
||||
features = model(data)
|
||||
return features
|
||||
|
||||
|
||||
def group_image(images, batch=64) -> list:
|
||||
"""Group image paths by batch size"""
|
||||
size = len(images)
|
||||
res = []
|
||||
for i in range(0, size, batch):
|
||||
end = min(batch + i, size)
|
||||
res.append(images[i:end])
|
||||
return res
|
||||
|
||||
def normalize(queFeatList):
|
||||
for num1 in range(len(queFeatList)):
|
||||
for num2 in range(len(queFeatList[num1])):
|
||||
queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2])
|
||||
return queFeatList
|
||||
|
||||
def getFeatureList(barList, imgList, model):
|
||||
# featList = [[] for i in range(len(barList))]
|
||||
# for index, feat in enumerate(imgList):
|
||||
fe_nps = []
|
||||
groups = group_image(imgList)
|
||||
for group in groups:
|
||||
feat_tensor = inference(group, model)
|
||||
# for fe in feat_tensor:
|
||||
if feat_tensor.device == 'cpu':
|
||||
fe_np = feat_tensor.squeeze().detach().numpy()
|
||||
# fe_np = fe_np[:, 256:]
|
||||
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
|
||||
else:
|
||||
fe_np = feat_tensor.squeeze().detach().cpu().numpy()
|
||||
# fe_np = fe_np[:, 256:]
|
||||
# fe_np = fe_np[256:]
|
||||
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
|
||||
# fe_np = fe_np.reshape(1, fe_np.shape[0], 1, 1)
|
||||
# print(fe_np)
|
||||
|
||||
fe_nps.append(fe_np)
|
||||
# if fe_nps:
|
||||
# merged_fe_np = np.concatenate(fe_nps, axis=0)
|
||||
# else:
|
||||
# merged_fe_np = np.array([]) #
|
||||
# fe_list = normalize(fe_nps)
|
||||
return fe_nps
|
118
tools/json_contrast.py
Normal file
118
tools/json_contrast.py
Normal file
@ -0,0 +1,118 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
def showHist(same, cross):
|
||||
Same = np.array(same)
|
||||
Cross = np.array(cross)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(Same, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([-0.1, 1])
|
||||
axs[0].set_title('Same Barcode')
|
||||
|
||||
axs[1].hist(Cross, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([-0.1, 1])
|
||||
axs[1].set_title('Cross Barcode')
|
||||
# plt.savefig('plot.png')
|
||||
plt.show()
|
||||
|
||||
|
||||
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
||||
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||||
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
|
||||
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
|
||||
plt.plot(x, Correct, color='m', label='Correct:(TN+TP)/(TPFN+TNFP)')
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
plt.savefig('grid.png')
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
def compute_accuracy_recall(score, labels):
|
||||
th = 0.1
|
||||
squence = np.linspace(-1, 1, num=50)
|
||||
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
|
||||
Same = score[:len(score) // 2]
|
||||
Cross = score[len(score) // 2:]
|
||||
for th in squence:
|
||||
t_score = (score > th)
|
||||
t_labels = (labels == 1)
|
||||
TP = np.sum(np.logical_and(t_score, t_labels))
|
||||
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
|
||||
f_score = (score < th)
|
||||
f_labels = (labels == 0)
|
||||
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||
|
||||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||||
|
||||
showHist(Same, Cross)
|
||||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||||
|
||||
|
||||
def get_similarity(features1, features2, n, m):
|
||||
features1 = np.array(features1)
|
||||
features2 = np.array(features2)
|
||||
all_similarity = []
|
||||
for feature1 in features1:
|
||||
for feature2 in features2:
|
||||
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||
all_similarity.append(similarity)
|
||||
test_similarity = np.array(all_similarity)
|
||||
np_all_array = np.array(all_similarity).reshape(len(features1), len(features2))
|
||||
if n == 5 and m == 5:
|
||||
print(all_similarity)
|
||||
return np.mean(np_all_array), all_similarity
|
||||
# return sum(all_similarity)/len(all_similarity), all_similarity
|
||||
# return max(all_similarity), all_similarity
|
||||
|
||||
|
||||
def deal_similarity(dicts):
|
||||
all_similarity = []
|
||||
similarity = []
|
||||
same_barcode, diff_barcode = [], []
|
||||
for n, (key1, value1) in enumerate(dicts.items()):
|
||||
print('key1 >> {}'.format(key1))
|
||||
for m, (key2, value2) in enumerate(dicts.items()):
|
||||
print('key1 >> {} key2 >> {} peidui {}{}'.format(key1, key2, n, m))
|
||||
max_similarity, some_similarity = get_similarity(value1, value2, n, m)
|
||||
similarity.append(max_similarity)
|
||||
if key1 == key2:
|
||||
same_barcode += some_similarity
|
||||
else:
|
||||
diff_barcode += some_similarity
|
||||
all_similarity.append(similarity)
|
||||
similarity = []
|
||||
all_similarity = np.array(all_similarity)
|
||||
random.shuffle(diff_barcode)
|
||||
same_list = [1] * len(same_barcode)
|
||||
diff_list = [0] * len(same_barcode)
|
||||
all_list = same_list + diff_list
|
||||
all_score = same_barcode + diff_barcode[:len(same_barcode)]
|
||||
compute_accuracy_recall(np.array(all_score), np.array(all_list))
|
||||
print(all_similarity.shape)
|
||||
|
||||
|
||||
with open('../search_library/data_zhanting.json', 'r') as file:
|
||||
data = json.load(file)
|
||||
dicts = {}
|
||||
for dict in data['total']:
|
||||
key = dict['key']
|
||||
value = dict['value']
|
||||
dicts[key] = value
|
||||
deal_similarity(dicts)
|
63
tools/model_onnx_transform.py
Normal file
63
tools/model_onnx_transform.py
Normal file
@ -0,0 +1,63 @@
|
||||
import pdb
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model import resnet18
|
||||
from config import config as conf
|
||||
from collections import OrderedDict
|
||||
import cv2
|
||||
|
||||
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
|
||||
# 定义模型
|
||||
if model_name == 'resnet18':
|
||||
model = resnet18(scale=0.75)
|
||||
|
||||
print('model_name >>> {}'.format(model_name))
|
||||
if conf.multiple_cards:
|
||||
model = model.to(torch.device('cpu'))
|
||||
checkpoint = torch.load(pretrained_weights)
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint.items():
|
||||
name = k[7:] # remove "module."
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||
# try:
|
||||
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
|
||||
# model = nn.DataParallel(model).to(conf.device)
|
||||
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
|
||||
|
||||
|
||||
# 转换为ONNX
|
||||
if model_name == 'gift_type2':
|
||||
input_shape = [1, 64, 13, 13]
|
||||
elif model_name == 'gift_type3':
|
||||
input_shape = [1, 3, 224, 224]
|
||||
else:
|
||||
# 假设输入数据的大小是通道数*高度*宽度,例如3*224*224
|
||||
input_shape = [1, 3, 224, 224]
|
||||
|
||||
img = cv2.imread('./dog_224x224.jpg')
|
||||
|
||||
output_file = pretrained_weights.replace('pth', 'onnx')
|
||||
|
||||
# 导出模型
|
||||
torch.onnx.export(model,
|
||||
torch.randn(input_shape),
|
||||
output_file,
|
||||
verbose=True,
|
||||
input_names=['input'],
|
||||
output_names=['output']) ##, optset_version=12
|
||||
|
||||
model.eval()
|
||||
trace_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
|
||||
trace_model.save(output_file.replace('.onnx', '.pt'))
|
||||
print(f"Model exported to {output_file}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理
|
||||
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
|
186
tools/model_rknn_transform.py
Normal file
186
tools/model_rknn_transform.py
Normal file
@ -0,0 +1,186 @@
|
||||
import os
|
||||
import pdb
|
||||
import urllib
|
||||
import traceback
|
||||
import time
|
||||
import sys
|
||||
import numpy as np
|
||||
import cv2
|
||||
from config import config as conf
|
||||
from rknn.api import RKNN
|
||||
|
||||
import config
|
||||
|
||||
# ONNX_MODEL = 'resnet50v2.onnx'
|
||||
# RKNN_MODEL = 'resnet50v2.rknn'
|
||||
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
|
||||
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
|
||||
|
||||
|
||||
# ONNX_MODEL = 'v3_small_0424.onnx'
|
||||
# RKNN_MODEL = 'v3_small_0424.rknn'
|
||||
|
||||
def show_outputs(outputs):
|
||||
# print('***************outputs', outputs)
|
||||
output = outputs[0][0]
|
||||
# print('len(outputs)',len(output), output)
|
||||
output_sorted = sorted(output, reverse=True)
|
||||
top5_str = 'resnet50v2\n-----TOP 5-----\n'
|
||||
for i in range(5):
|
||||
value = output_sorted[i]
|
||||
index = np.where(output == value)
|
||||
for j in range(len(index)):
|
||||
if (i + j) >= 5:
|
||||
break
|
||||
if value > 0:
|
||||
topi = '{}: {}\n'.format(index[j], value)
|
||||
else:
|
||||
topi = '-1: 0.0\n'
|
||||
top5_str += topi
|
||||
# pdb.set_trace()
|
||||
print(top5_str)
|
||||
|
||||
|
||||
def readable_speed(speed):
|
||||
speed_bytes = float(speed)
|
||||
speed_kbytes = speed_bytes / 1024
|
||||
if speed_kbytes > 1024:
|
||||
speed_mbytes = speed_kbytes / 1024
|
||||
if speed_mbytes > 1024:
|
||||
speed_gbytes = speed_mbytes / 1024
|
||||
return "{:.2f} GB/s".format(speed_gbytes)
|
||||
else:
|
||||
return "{:.2f} MB/s".format(speed_mbytes)
|
||||
else:
|
||||
return "{:.2f} KB/s".format(speed_kbytes)
|
||||
|
||||
|
||||
def show_progress(blocknum, blocksize, totalsize):
|
||||
speed = (blocknum * blocksize) / (time.time() - start_time)
|
||||
speed_str = " Speed: {}".format(readable_speed(speed))
|
||||
recv_size = blocknum * blocksize
|
||||
|
||||
f = sys.stdout
|
||||
progress = (recv_size / totalsize)
|
||||
progress_str = "{:.2f}%".format(progress * 100)
|
||||
n = round(progress * 50)
|
||||
s = ('#' * n).ljust(50, '-')
|
||||
f.write(progress_str.ljust(8, ' ') + '[' + s + ']' + speed_str)
|
||||
f.flush()
|
||||
f.write('\r\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Create RKNN object
|
||||
rknn = RKNN(verbose=True)
|
||||
|
||||
# If resnet50v2 does not exist, download it.
|
||||
# Download address:
|
||||
# https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx
|
||||
if not os.path.exists(ONNX_MODEL):
|
||||
print('--> Download {}'.format(ONNX_MODEL))
|
||||
url = 'https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx'
|
||||
download_file = ONNX_MODEL
|
||||
try:
|
||||
start_time = time.time()
|
||||
urllib.request.urlretrieve(url, download_file, show_progress)
|
||||
except:
|
||||
print('Download {} failed.'.format(download_file))
|
||||
print(traceback.format_exc())
|
||||
exit(-1)
|
||||
print('done')
|
||||
|
||||
# pre-process config
|
||||
print('--> config model')
|
||||
# rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.82, 58.82, 58.82])
|
||||
rknn.config(
|
||||
mean_values=[[127.5, 127.5, 127.5]],
|
||||
std_values=[[127.5, 127.5, 127.5]],
|
||||
target_platform='rk3588',
|
||||
model_pruning=False,
|
||||
compress_weight=False,
|
||||
single_core_mode=True)
|
||||
# rknn.config(
|
||||
# mean_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
||||
# std_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
||||
# target_platform='rk3588', # 设置目标平台
|
||||
# # quantize_dtype='int8',
|
||||
# # quantize_algo='normal',
|
||||
# # output_optimize=False,
|
||||
# # output_format='rknnb'
|
||||
# )
|
||||
print('done')
|
||||
|
||||
# Load model
|
||||
print('--> Loading model')
|
||||
ret = rknn.load_onnx(model=ONNX_MODEL)
|
||||
if ret != 0:
|
||||
print('Load model failed!')
|
||||
exit(ret)
|
||||
print('done')
|
||||
|
||||
# Build model
|
||||
print('--> Building model')
|
||||
ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
|
||||
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
|
||||
if ret != 0:
|
||||
print('Build model failed!')
|
||||
exit(ret)
|
||||
print('done')
|
||||
|
||||
# Export rknn model
|
||||
print('--> Export rknn model')
|
||||
ret = rknn.export_rknn(RKNN_MODEL)
|
||||
if ret != 0:
|
||||
print('Export rknn model failed!')
|
||||
exit(ret)
|
||||
print('done')
|
||||
|
||||
# Set inputs
|
||||
img = cv2.imread('./dog_224x224.jpg')
|
||||
# img = cv2.imread('./data/gift_test/Havegift/20241213-161415-cb8e0762-f376-45d1-8f36-7dc070990fa5/subimg/cam1_9_tid2_fid(18, 33250169482).png')
|
||||
# print('img', img)
|
||||
# with open('pixel_values.txt', 'w') as file:
|
||||
|
||||
# for y in range(img.shape[0]):
|
||||
# for x in range(img.shape[1]):
|
||||
# b, g, r = img[y, x]
|
||||
# file.write(f'{r},{g},{b}\n')
|
||||
|
||||
# img = cv2.imread('./810115161912_810115161912_20240131-145622_0da14e4d-a3da-499f-b512-2d4168ab1c87_front_addGood_70f75407b7ae_29_01.jpg')
|
||||
img = cv2.resize(img, (224, 224))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# img = conf.test_transform(img)
|
||||
# img = img.numpy()
|
||||
# img = img.transpose(1, 2, 0)
|
||||
|
||||
# Init runtime environment
|
||||
print('--> Init runtime environment')
|
||||
ret = rknn.init_runtime()
|
||||
# ret = rknn.init_runtime('rk3588')
|
||||
if ret != 0:
|
||||
print('Init runtime environment failed!')
|
||||
exit(ret)
|
||||
print('done')
|
||||
|
||||
# Inference
|
||||
print('--> Running model')
|
||||
T1 = time.time()
|
||||
outputs = rknn.inference(inputs=[img])
|
||||
# outputs = rknn.inference(inputs=img)
|
||||
T2 = time.time()
|
||||
print('消耗时间 >>> {}'.format(T2 - T1))
|
||||
with open('result_0415_128.txt', 'a') as f:
|
||||
f.write(str(outputs))
|
||||
# pdb.set_trace()
|
||||
print('***outputs', outputs)
|
||||
np.save('./onnx_resnet50v2_0.npy', outputs[0])
|
||||
x = outputs[0]
|
||||
output = np.exp(x) / np.sum(np.exp(x))
|
||||
outputs = [output]
|
||||
show_outputs(outputs)
|
||||
print('done')
|
||||
|
||||
rknn.release()
|
233
tools/operate_usearch.py
Normal file
233
tools/operate_usearch.py
Normal file
@ -0,0 +1,233 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from usearch.index import Index
|
||||
import json
|
||||
import struct
|
||||
|
||||
|
||||
def create_index():
|
||||
index = Index(
|
||||
ndim=256,
|
||||
metric='cos',
|
||||
# dtype='f32',
|
||||
dtype='f16',
|
||||
connectivity=32,
|
||||
expansion_add=40, # 128,
|
||||
expansion_search=10, # 64,
|
||||
multi=True
|
||||
)
|
||||
return index
|
||||
|
||||
|
||||
def compare_feature(features1, features2, model='1'):
|
||||
"""
|
||||
:param model 比对策略
|
||||
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
|
||||
'1':带对比的所有相似度的均值
|
||||
'2':比对1:1的最大值
|
||||
:param feature1:
|
||||
:param feature2:
|
||||
:return:
|
||||
"""
|
||||
similarity_group, similarity_groups = [], []
|
||||
if model == '0':
|
||||
for feature1 in features1:
|
||||
for feature2 in features2[0]:
|
||||
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||
similarity_group.append(similarity)
|
||||
similarity_groups.append(max(similarity_group))
|
||||
similarity_group = []
|
||||
return sum(similarity_groups) / len(similarity_groups)
|
||||
|
||||
elif model == '1':
|
||||
feature2 = features2[0]
|
||||
for feature1 in features1:
|
||||
for num in range(len(feature2)):
|
||||
similarity = np.dot(feature1, feature2[num]) / (
|
||||
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||
similarity_group.append(similarity)
|
||||
similarity_groups.append(sum(similarity_group) / len(similarity_group))
|
||||
similarity_group = []
|
||||
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
|
||||
if len(similarity_groups) == 0:
|
||||
return -1
|
||||
return sum(similarity_groups) / len(similarity_groups)
|
||||
elif model == '2':
|
||||
feature2 = features2[0]
|
||||
for feature1 in features1:
|
||||
for num in range(len(feature2)):
|
||||
similarity = np.dot(feature1, feature2[num]) / (
|
||||
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||
similarity_group.append(similarity)
|
||||
return max(similarity_group)
|
||||
|
||||
def get_barcode_feature(data):
|
||||
barcode = data['key']
|
||||
features = data['value']
|
||||
return [barcode] * len(features), features
|
||||
|
||||
|
||||
def analysis_file(file_path):
|
||||
"""
|
||||
:param file_path:
|
||||
:return:
|
||||
"""
|
||||
barcodes, features = [], []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
for dic in data['total']:
|
||||
barcode, feature = get_barcode_feature(dic)
|
||||
barcodes.append(barcode)
|
||||
features.append(feature)
|
||||
return barcodes, features
|
||||
|
||||
|
||||
def create_base_index(index_file_pth=None,
|
||||
barcodes=None,
|
||||
features=None,
|
||||
save_index_name=None):
|
||||
index = create_index()
|
||||
if index_file_pth is not None:
|
||||
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
|
||||
save_index_name = index_file_pth.split('json')[0] + 'data'
|
||||
barcodes, features = analysis_file(index_file_pth)
|
||||
else:
|
||||
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
|
||||
for barcode, feature in zip(barcodes, features):
|
||||
try:
|
||||
index.add(np.array(barcode), np.array(feature))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
index.save(save_index_name)
|
||||
|
||||
|
||||
def get_feature_index(index_file_pth=None,
|
||||
barcodes=None):
|
||||
assert index_file_pth is not None, 'index_file_pth must be not None'
|
||||
index = Index.restore(index_file_pth, view=True)
|
||||
feature_lists = index.get(np.array(barcodes))
|
||||
print("memory {} size {}".format(index.memory_usage, index.size))
|
||||
print("feature_lists {}".format(feature_lists))
|
||||
return feature_lists
|
||||
|
||||
|
||||
def search_in_index(query=None,
|
||||
barcode=None, # barcode -> int or np.ndarray
|
||||
index_name=None,
|
||||
temp_index=False, # 是否为临时库
|
||||
model='0',
|
||||
):
|
||||
if temp_index:
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
if barcode is not None: # 1:1对比测试
|
||||
feature_lists = index.get(np.array(barcode))
|
||||
results = compare_feature(query, feature_lists)
|
||||
else:
|
||||
results = index.search(query, count=5)
|
||||
return results
|
||||
else: # 标准库
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
if barcode is not None: # 1:1对比测试
|
||||
feature_lists = index.get(np.array(barcode))
|
||||
results = compare_feature(query, feature_lists, model)
|
||||
else:
|
||||
results = index.search(query, count=10)
|
||||
return results
|
||||
|
||||
|
||||
def delete_index(index_name=None, key=None, index=None):
|
||||
assert key is not None, 'key must be not None'
|
||||
if index is None:
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
index.remove(index_name)
|
||||
else:
|
||||
index.remove(key)
|
||||
|
||||
from scipy.spatial.distance import cdist
|
||||
def compute_similarity_matrix(featurelists1, featurelists2):
|
||||
"""计算图片之间的余弦相似度矩阵"""
|
||||
# 计算所有向量对之间的余弦相似度
|
||||
cosine_similarities = 1 - cdist(featurelists1, featurelists2, metric='cosine')
|
||||
cosine_similarities = np.around(cosine_similarities, decimals=3)
|
||||
return cosine_similarities
|
||||
|
||||
def check_usearch_json_diff(index_file_pth, json_file_pth):
|
||||
json_features = None
|
||||
feature_lists = get_feature_index(index_file_pth, ['6923644272159'])
|
||||
with open(json_file_pth, 'r') as json_file:
|
||||
json_data = json.load(json_file)
|
||||
for data in json_data['total']:
|
||||
if data['key'] == '6923644272159':
|
||||
json_features = data['value']
|
||||
json_features = np.array(json_features)
|
||||
feature_lists = np.array(feature_lists[0])
|
||||
compute_similarity_matrix(json_features, feature_lists)
|
||||
|
||||
|
||||
def write_binary_file(filename, datas):
|
||||
with open(filename, 'wb') as f:
|
||||
# 先写入数据中的key数量(为C++读取提供便利)
|
||||
key_count = len(datas)
|
||||
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
|
||||
|
||||
for data in datas:
|
||||
key = data['key']
|
||||
feats = data['value']
|
||||
key_bytes = key.encode('utf-8')
|
||||
key_len = len(key)
|
||||
length_byte = struct.pack('<B', key_len)
|
||||
f.write(length_byte)
|
||||
# f.write(struct.pack('Q', len(key_bytes)))
|
||||
f.write(key_bytes)
|
||||
value_count = len(feats)
|
||||
f.write(struct.pack('I', (value_count * 256)))
|
||||
# 遍历字典,写入每个key及其对应的浮点数值列表
|
||||
for values in feats:
|
||||
# 写入每个浮点数值(保留小数点后六位)
|
||||
for value in values:
|
||||
# 使用'f'格式(单精度浮点,4字节),并四舍五入保留六位小数
|
||||
value_half = np.float16(value)
|
||||
# print(value_half.tobytes())
|
||||
f.write(value_half.tobytes())
|
||||
def create_binary_file(json_path, flag=True):
|
||||
# 1. 打开JSON文件
|
||||
with open(json_path, 'r', encoding='utf-8') as file:
|
||||
# 2. 读取并解析JSON文件内容
|
||||
data = json.load(file)
|
||||
if flag:
|
||||
for flag, values in data.items():
|
||||
# 逐个写入values中的每个值,保留小数点后六位,每个值占一行
|
||||
write_binary_file(index_file_pth.replace('json', 'bin'), values)
|
||||
else:
|
||||
write_binary_file(json_path.replace('.json', '.bin'), [data])
|
||||
|
||||
def create_binary_files(index_file_pth):
|
||||
if os.path.isfile(index_file_pth):
|
||||
create_binary_file(index_file_pth)
|
||||
else:
|
||||
for name in os.listdir(index_file_pth):
|
||||
jsonpth = os.sep.join([index_file_pth, name])
|
||||
create_binary_file(jsonpth, False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# index_file_pth = '../data/feature_json' # 生成二进制文件 多文件
|
||||
index_file_pth = '../search_library/yunhedian_30-04.json'
|
||||
# create_base_index(index_file_pth) # 生成usearch文件
|
||||
create_binary_files(index_file_pth) # 生成二进制文件 多文件
|
||||
|
||||
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
|
||||
# # index_file_pth = '../search_library/data_10_normal_0718.index'
|
||||
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
|
||||
|
||||
# # check index data file
|
||||
# index_file_pth = '../search_library/data_zhanting.data'
|
||||
# # # get_feature_index(index_file_pth, ['6901070602818'])
|
||||
# get_feature_index(index_file_pth, ['6923644272159'])
|
||||
|
||||
# index_file_pth = '../search_library/data_zhanting.data'
|
||||
# json_file_pth = '../search_library/data_zhanting.json'
|
||||
# check_usearch_json_diff(index_file_pth, json_file_pth)
|
84
tools/threshold_partition.py
Normal file
84
tools/threshold_partition.py
Normal file
@ -0,0 +1,84 @@
|
||||
'''
|
||||
现场1:N测试,确定阈值
|
||||
'''
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def showHist(filtered_data):
|
||||
Same = filtered_data[:, 1].astype(np.float32)
|
||||
Cross = filtered_data[:, 2].astype(np.float32)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(Same, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([-0.1, 1])
|
||||
axs[0].set_title('first')
|
||||
|
||||
axs[1].hist(Cross, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([-0.1, 1])
|
||||
axs[1].set_title('second')
|
||||
# plt.savefig('plot.png')
|
||||
plt.show()
|
||||
|
||||
|
||||
def get_tartget_list(nested_list):
|
||||
filtered_list = np.array(list(filter(lambda x: len(x) >= 2, nested_list))) # 去除无轨迹的数据
|
||||
filtered_correct = filtered_list[filtered_list[:, 0] != 'wrong'] # 获取比对正确的时项
|
||||
filtered_wrong = filtered_list[filtered_list[:, 0] == 'wrong'] # 获取比对错误的时项
|
||||
showHist(filtered_correct)
|
||||
# showHist(filtered_wrong)
|
||||
print(filtered_list)
|
||||
|
||||
|
||||
def deal_process(file_pth):
|
||||
flag = False
|
||||
event = file_pth.split('\\')[-2]
|
||||
target_barcode = file_pth.split('\\')[-2].split('_')[-1]
|
||||
temp_list = []
|
||||
|
||||
with open(file_pth, 'r') as f:
|
||||
for line in f:
|
||||
if 'oneToOne' in line:
|
||||
flag = True
|
||||
continue
|
||||
if flag:
|
||||
line = line.replace('\n', '')
|
||||
comparison_data = line.split(',')
|
||||
forecast_barcode = comparison_data[0]
|
||||
value = comparison_data[-1].split(':')[-1]
|
||||
if value == '':
|
||||
break
|
||||
if len(temp_list) == 0:
|
||||
if forecast_barcode == target_barcode:
|
||||
temp_list.append('correct')
|
||||
else:
|
||||
temp_list.append('wrong')
|
||||
temp_list.append(float(value))
|
||||
temp_list.append(event)
|
||||
return temp_list
|
||||
|
||||
|
||||
def anaylze_scratch(scratch_pth):
|
||||
purchase, back = [], []
|
||||
for root, dirs, files in os.walk(scratch_pth):
|
||||
if len(root) > 0:
|
||||
if len(root.split('_')) == 4: # 加购
|
||||
process = os.path.join(root, 'process.data')
|
||||
if not os.path.exists(process):
|
||||
continue
|
||||
purchase.append(deal_process(process))
|
||||
elif len(root.split('_')) == 3:
|
||||
process = os.path.join(root, 'process.data')
|
||||
if not os.path.exists(process):
|
||||
continue
|
||||
back.append(deal_process(process))
|
||||
# get_tartget_list(purchase)
|
||||
get_tartget_list(back)
|
||||
print(purchase)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# scratch_pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\展厅测试\\1108_展厅模型v800测试\\'
|
||||
scratch_pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\展厅测试\\1120_展厅模型v801测试\\扫A放A\\'
|
||||
anaylze_scratch(scratch_pth)
|
411
tools/write_feature_json.py
Normal file
411
tools/write_feature_json.py
Normal file
@ -0,0 +1,411 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tools.dataset import get_transform
|
||||
from model import resnet18
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import yaml
|
||||
import shutil
|
||||
import struct
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FeatureExtractor:
|
||||
def __init__(self, conf):
|
||||
self.conf = conf
|
||||
self.model = self.initModel()
|
||||
_, self.test_transform = get_transform(self.conf)
|
||||
pass
|
||||
|
||||
def initModel(self, inference_model: Optional[str] = None) -> torch.nn.Module:
|
||||
"""
|
||||
Initialize and load the ResNet18 model for inference.
|
||||
|
||||
Args:
|
||||
inference_model: Optional path to model weights. Uses conf.test_model if None.
|
||||
|
||||
Returns:
|
||||
Loaded and configured PyTorch model in evaluation mode.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If model weights file is not found
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
model_path = inference_model if inference_model else self.conf['models']['checkpoints']
|
||||
|
||||
try:
|
||||
# Verify model file exists
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||
|
||||
# Initialize model
|
||||
model = resnet18().to(self.conf['base']['device'])
|
||||
|
||||
# Handle multi-GPU case
|
||||
if conf['base']['distributed']:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Load weights
|
||||
state_dict = torch.load(model_path, map_location=conf['base']['device'])
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model.eval()
|
||||
logger.info(f"Successfully loaded model from {model_path}")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {str(e)}")
|
||||
raise
|
||||
|
||||
def convert_rgba_to_rgb(self, image_path):
|
||||
# 打开图像
|
||||
img = Image.open(image_path)
|
||||
# 转换图像模式从RGBA到RGB
|
||||
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||||
if img.mode == 'RGBA':
|
||||
# 转换为RGB模式
|
||||
img_rgb = img.convert('RGB')
|
||||
# 保存转换后的图像
|
||||
img_rgb.save(image_path)
|
||||
print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||||
|
||||
def test_preprocess(self, images: list, actionModel=False) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
try:
|
||||
im = self.test_transform(img) if actionModel else self.test_transform(Image.open(img))
|
||||
res.append(im)
|
||||
except:
|
||||
continue
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
def inference(self, images, model, actionModel=False):
|
||||
data = self.test_preprocess(images, actionModel)
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(conf['base']['device'])
|
||||
features = model(data)
|
||||
if conf['data']['half']:
|
||||
features = features.half()
|
||||
return features
|
||||
|
||||
def group_image(self, images, batch=64) -> list:
|
||||
"""Group image paths by batch size"""
|
||||
size = len(images)
|
||||
res = []
|
||||
for i in range(0, size, batch):
|
||||
end = min(batch + i, size)
|
||||
res.append(images[i:end])
|
||||
return res
|
||||
|
||||
def getFeatureList(self, barList, imgList):
|
||||
featList = [[] for _ in range(len(barList))]
|
||||
|
||||
for index, image_paths in enumerate(imgList):
|
||||
try:
|
||||
# Process images in batches
|
||||
for batch in self.group_image(image_paths):
|
||||
# Get features for batch
|
||||
features = self.inference(batch, self.model)
|
||||
|
||||
# Process each feature in batch
|
||||
for feat in features:
|
||||
# Move to CPU and convert to numpy
|
||||
feat_np = feat.squeeze().detach().cpu().numpy()
|
||||
|
||||
# Normalize first 256 dimensions
|
||||
normalized = self.normalize_256(feat_np[:256])
|
||||
|
||||
# Combine with remaining dimensions
|
||||
combined = np.concatenate([normalized, feat_np[256:]], axis=0)
|
||||
|
||||
featList[index].append(combined)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing images for index {index}: {str(e)}")
|
||||
continue
|
||||
return featList
|
||||
|
||||
def get_files(
|
||||
self,
|
||||
folder: str,
|
||||
filter: Optional[List[str]] = None,
|
||||
create_single_json: bool = False
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Recursively collect image files from directory structure.
|
||||
|
||||
Args:
|
||||
folder: Root directory to scan
|
||||
filter: Optional list of barcodes to include
|
||||
create_single_json: Whether to create individual JSON files per barcode
|
||||
|
||||
Returns:
|
||||
Dictionary mapping barcode names to lists of image paths
|
||||
|
||||
Example:
|
||||
{
|
||||
"barcode1": ["path/to/img1.jpg", "path/to/img2.jpg"],
|
||||
"barcode2": ["path/to/img3.jpg"]
|
||||
}
|
||||
"""
|
||||
file_dicts = {}
|
||||
total_files = 0
|
||||
feature_counts = []
|
||||
barcode_count = 0
|
||||
subclass = [str(i) for i in range(100)]
|
||||
# Validate input directory
|
||||
if not os.path.isdir(folder):
|
||||
raise ValueError(f"Invalid directory: {folder}")
|
||||
|
||||
# Process each barcode directory
|
||||
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
|
||||
if not dirs: # Leaf directory (contains images)
|
||||
basename = os.path.basename(root)
|
||||
if basename in subclass:
|
||||
ori_barcode = root.split('/')[-2]
|
||||
barcode = root.split('/')[-2] + '_' + basename
|
||||
else:
|
||||
ori_barcode = basename
|
||||
barcode = basename
|
||||
# Apply filter if provided
|
||||
if filter and ori_barcode not in filter:
|
||||
continue
|
||||
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
|
||||
logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||||
with open(conf['save']['error_barcodes'], 'a') as f:
|
||||
f.write(ori_barcode + '\n')
|
||||
f.close()
|
||||
continue
|
||||
|
||||
# Process image files
|
||||
if files:
|
||||
image_paths = self._process_image_files(root, files)
|
||||
if not image_paths:
|
||||
continue
|
||||
|
||||
# Update counters
|
||||
barcode_count += 1
|
||||
file_count = len(image_paths)
|
||||
total_files += file_count
|
||||
feature_counts.append(file_count)
|
||||
|
||||
# Handle output mode
|
||||
if create_single_json:
|
||||
self._process_single_barcode(barcode, image_paths)
|
||||
else:
|
||||
if barcode.split('_')[-1] == '0':
|
||||
barcode = barcode.split('_')[0]
|
||||
file_dicts[barcode] = image_paths
|
||||
|
||||
# # Log summary
|
||||
# logger.info(f"Processed {barcode_count} barcodes with {total_files} total images")
|
||||
# logger.debug(f"Image counts per barcode: {feature_counts}")
|
||||
|
||||
# Batch process if not creating individual JSONs
|
||||
if not create_single_json and file_dicts:
|
||||
self.createFeatureDict(
|
||||
file_dicts,
|
||||
create_single_json=False,
|
||||
)
|
||||
return file_dicts
|
||||
|
||||
def _process_image_files(self, root: str, files: List[str]) -> List[str]:
|
||||
"""Process and validate image files in a directory."""
|
||||
valid_paths = []
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
try:
|
||||
# Convert RGBA to RGB if needed
|
||||
self.convert_rgba_to_rgb(file_path)
|
||||
valid_paths.append(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping invalid image {file_path}: {str(e)}")
|
||||
return valid_paths
|
||||
|
||||
def _process_single_barcode(self, barcode: str, image_paths: List[str]):
|
||||
"""Process a single barcode and create individual JSON file."""
|
||||
temp_dict = {barcode: image_paths}
|
||||
self.createFeatureDict(
|
||||
temp_dict,
|
||||
create_single_json=True,
|
||||
)
|
||||
|
||||
def normalize_256(self, queFeatList):
|
||||
queFeatList = queFeatList / np.linalg.norm(queFeatList)
|
||||
return queFeatList
|
||||
|
||||
def img2feature(
|
||||
self,
|
||||
imgs_dict: Dict[str, List[str]]
|
||||
) -> Tuple[List[str], List[List[np.ndarray]]]:
|
||||
"""
|
||||
Extract features for all images in the dictionary.
|
||||
|
||||
Args:
|
||||
imgs_dict: Dictionary mapping barcodes to image paths
|
||||
model: Pretrained feature extraction model
|
||||
barcode_flag: Whether to include barcode info (unused)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of barcode IDs
|
||||
- List of feature lists (one per barcode)
|
||||
|
||||
Raises:
|
||||
ValueError: If input dictionary is empty
|
||||
RuntimeError: If feature extraction fails
|
||||
"""
|
||||
if not imgs_dict:
|
||||
raise ValueError("No images provided for feature extraction")
|
||||
|
||||
try:
|
||||
barcode_list = list(imgs_dict.keys())
|
||||
image_list = list(imgs_dict.values())
|
||||
feature_list = self.getFeatureList(barcode_list, image_list)
|
||||
|
||||
logger.info(f"Successfully extracted features for {len(barcode_list)} barcodes")
|
||||
return barcode_list, feature_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Feature extraction failed: {str(e)}")
|
||||
raise RuntimeError(f"Feature extraction failed: {str(e)}")
|
||||
|
||||
def createFeatureDict(self, imgs_dict,
|
||||
create_single_json=False): # imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||||
dicts_all = {}
|
||||
value_list = []
|
||||
barcode_list, imgs_list = self.img2feature(imgs_dict)
|
||||
for i in range(len(barcode_list)):
|
||||
dicts = {}
|
||||
|
||||
imgs_list_ = []
|
||||
for j in range(len(imgs_list[i])):
|
||||
imgs_list_.append(imgs_list[i][j].tolist())
|
||||
|
||||
dicts['key'] = barcode_list[i]
|
||||
truncated_imgs_list = [subarray[:256] for subarray in imgs_list_]
|
||||
dicts['value'] = truncated_imgs_list
|
||||
if create_single_json:
|
||||
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
|
||||
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
|
||||
with open(json_path, 'w') as json_file:
|
||||
json.dump(dicts, json_file)
|
||||
else:
|
||||
value_list.append(dicts)
|
||||
if not create_single_json:
|
||||
dicts_all['total'] = value_list
|
||||
with open(self.conf['save']['json_bin'], 'w') as json_file:
|
||||
json.dump(dicts_all, json_file)
|
||||
self.create_binary_files(self.conf['save']['json_bin'])
|
||||
|
||||
def statisticsBarcodes(self, pth, filter=None):
|
||||
feature_num = 0
|
||||
feature_num_lists = []
|
||||
nn = 0
|
||||
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
|
||||
for barcode in os.listdir(pth):
|
||||
print("barcode length >> {}".format(len(barcode)))
|
||||
if len(barcode) > 13 or len(barcode) < 8:
|
||||
continue
|
||||
if filter is not None:
|
||||
f.writelines(barcode + '\n')
|
||||
if barcode in filter:
|
||||
print(barcode)
|
||||
feature_num += len(os.listdir(os.path.join(pth, barcode)))
|
||||
nn += 1
|
||||
else:
|
||||
print('barcode name >>{}'.format(barcode))
|
||||
f.writelines(barcode + '\n')
|
||||
feature_num += len(os.listdir(os.path.join(pth, barcode)))
|
||||
feature_num_lists.append(feature_num)
|
||||
print("特征总量: {}".format(feature_num))
|
||||
print("barcode总量: {}".format(nn))
|
||||
f.close()
|
||||
|
||||
def get_shop_barcodes(self, file_path):
|
||||
if file_path:
|
||||
df = pd.read_excel(file_path)
|
||||
column_values = list(df.iloc[:, 6].values)
|
||||
column_values = list(map(str, column_values))
|
||||
return column_values
|
||||
else:
|
||||
return None
|
||||
|
||||
def del_base_dir(self, pth):
|
||||
for root, dirs, files in os.walk(pth):
|
||||
if len(dirs) == 1:
|
||||
if dirs[0] == 'base':
|
||||
shutil.rmtree(os.path.join(root, dirs[0]))
|
||||
|
||||
def write_binary_file(self, filename, datas):
|
||||
with open(filename, 'wb') as f:
|
||||
# 先写入数据中的key数量(为C++读取提供便利)
|
||||
key_count = len(datas)
|
||||
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
|
||||
for data in datas:
|
||||
key = data['key']
|
||||
feats = data['value']
|
||||
key_bytes = key.encode('utf-8')
|
||||
key_len = len(key)
|
||||
length_byte = struct.pack('<B', key_len)
|
||||
f.write(length_byte)
|
||||
# f.write(struct.pack('Q', len(key_bytes)))
|
||||
f.write(key_bytes)
|
||||
value_count = len(feats)
|
||||
f.write(struct.pack('I', (value_count * 256)))
|
||||
# 遍历字典,写入每个key及其对应的浮点数值列表
|
||||
for values in feats:
|
||||
# 写入每个浮点数值(保留小数点后六位)
|
||||
for value in values:
|
||||
# 使用'f'格式(单精度浮点,4字节),并四舍五入保留六位小数
|
||||
value_half = np.float16(value)
|
||||
# print(value_half.tobytes())
|
||||
f.write(value_half.tobytes())
|
||||
|
||||
def create_binary_file(self, json_path, flag=True):
|
||||
# 1. 打开JSON文件
|
||||
with open(json_path, 'r', encoding='utf-8') as file:
|
||||
# 2. 读取并解析JSON文件内容
|
||||
data = json.load(file)
|
||||
if flag:
|
||||
for flag, values in data.items():
|
||||
# 逐个写入values中的每个值,保留小数点后六位,每个值占一行
|
||||
self.write_binary_file(self.conf['save']['json_bin'].replace('json', 'bin'), values)
|
||||
else:
|
||||
self.write_binary_file(json_path.replace('.json', '.bin'), [data])
|
||||
|
||||
def create_binary_files(self, index_file_pth):
|
||||
if os.path.isfile(index_file_pth):
|
||||
self.create_binary_file(index_file_pth)
|
||||
else:
|
||||
for name in os.listdir(index_file_pth):
|
||||
jsonpth = os.sep.join([index_file_pth, name])
|
||||
self.create_binary_file(jsonpth, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open('../configs/write_feature.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
###将图片名称和模型推理特征向量字典存为json文件
|
||||
# xlsx_pth = './shop_xlsx/曹家桥门店在售商品表.xlsx'
|
||||
# xlsx_pth = None
|
||||
# del_base_dir(mg_path)
|
||||
|
||||
extractor = FeatureExtractor(conf)
|
||||
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
|
||||
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
|
||||
filter=column_values,
|
||||
create_single_json=False) # False
|
||||
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)
|
Reference in New Issue
Block a user