This commit is contained in:
lee
2025-06-11 15:23:50 +08:00
commit 37ecef40f7
79 changed files with 26981 additions and 0 deletions

369
tools/gift_assessment.py Normal file
View File

@ -0,0 +1,369 @@
import os
import pdb
import shutil
import sys
sys.path.append('../model')
import matplotlib.pyplot as plt
import numpy as np
from model.mlp import Net2, Net3, Net4
from model import resnet18
import torch
from gift_data_pretreatment import getFeatureList
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def init_model(pkl_flag):
res_pth = r"../checkpoints/resnet18_1009/best.pth"
if pkl_flag:
gift_pth = r'../checkpoints/gift_model/action2/gift_v11.pth'
gift_model = Net3(pretrained=True, num_classes=1)
gift_model.load_state_dict(torch.load(gift_pth))
else:
gift_pth = r'../checkpoints/gift_model/action3/best.pth'
gift_model = Net4('resnet18', True, True) # 预训练模型
try:
print('>>multiple_cards load pre model <<')
gift_model.load_state_dict({k.replace('module.', ''): v for k, v in
torch.load(gift_pth,
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')).items()})
except Exception as e:
print('>> load pre model <<')
gift_model.load_state_dict(torch.load(gift_pth,
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
res_model = resnet18()
res_model.load_state_dict({k.replace('module.', ''): v for k, v in
torch.load(res_pth, map_location=torch.device(device)).items()})
return res_model, gift_model
def showHist(nongifts, gifts):
# Same = filtered_data[:, 1].astype(np.float32)
# Cross = filtered_data[:, 2].astype(np.float32)
fig, axs = plt.subplots(2, 1)
axs[0].hist(nongifts, bins=50, edgecolor='blue')
axs[0].set_xlim([-0.1, 1])
axs[0].set_title('nongifts')
axs[1].hist(gifts, bins=50, edgecolor='green')
axs[1].set_xlim([-0.1, 1])
axs[1].set_title('gifts')
# plt.savefig('plot.png')
plt.show()
def calculate_precision_recall(nongift, gift, points):
precision, recall = [], []
for point in points:
TP = np.sum(gift > point)
FN = np.sum(gift < point)
FP = np.sum(nongift > point)
TN = np.sum(nongift < point)
if TP == 0:
precision.append(0)
recall.append(0)
else:
precision.append(TP / (TP + FP))
recall.append(TP / (TP + FN))
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
if point == 0.5:
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
return precision, recall
def showgrid(all_prec, all_recall, points):
plt.figure(figsize=(10, 6))
plt.plot(points[:-1], all_prec[:-1], color='blue', label='precision')
plt.plot(points[:-1], all_recall[:-1], color='red', label='recall')
plt.legend()
plt.xlabel('threshold')
# plt.ylabel('Similarity')
plt.grid(True, linestyle='--', alpha=0.5)
# plt.savefig('grid.png')
plt.show()
plt.close()
pass
def discriminate_action(roots): # 判断加购还是退购
pth = os.sep.join([roots, 'process.data'])
with open(pth, 'r') as f:
lines = f.readlines()
for line in lines:
content = line.strip()
if 'weightValue' in content:
# print(content.split(":")[-1].split(',')[0])
if int(content.split(":")[-1].split(',')[0]) > 0:
return 'add'
else:
return 'return'
def median(lst):
sorted_lst = sorted(lst)
n = len(sorted_lst)
if n % 2 == 1:
# 如果列表长度是奇数,中位数是中间的那个元素
return sorted_lst[n // 2]
else:
# 如果列表长度是偶数,中位数是中间两个元素的平均值
mid1 = sorted_lst[(n // 2) - 1]
mid2 = sorted_lst[n // 2]
return (mid1 + mid2) / 2
def get_special_data(data, p):
# print(data)
length = len(data)
if length > 5:
if p == 'max':
return max(data[:round(length * 0.5)])
elif p == 'average':
return sum(data[:round(length * 0.5)]) / len(data[:round(length * 0.5)])
elif p == 'median':
return median(data[:round(length * 0.5)])
else:
return sum(data) / len(data)
def read_data_file(pth):
result = []
with open(pth, 'r') as data_file:
lines = data_file.readlines()
for line in lines:
if line.split(':')[0] == 'free_gift__result':
if '0_tracking_output.data' in pth:
result = line.split(':')[1].split(',')[:-1]
else:
result = line.split(':')[1].split(',')[:-2]
result = [float(i) for i in result]
return result
def get_tracking_data(pth):
result = []
with open(pth, 'r') as data_file:
lines = data_file.readlines()
for line in lines:
if len(line.split(',')) == 65:
result.append([float(item) for item in line.split(',')[:-1]])
return result
def clean_reurn_data(pth):
for roots, dirs, files in os.walk(pth):
# print(roots, dirs, files)
if len(dirs) == 0:
flag = discriminate_action(roots)
if flag == 'return':
shutil.rmtree(roots)
def get_gift_files(pth): # 测试后直接分析测试结果文件
add_special_output_0, return_special_output_0, return_special_output_1, add_special_output_1 = [], [], [], []
add_tracking_output_0, return_tracking_output_0, add_tracking_output_1, return_tracking_output_1 = [], [], [], []
for roots, dirs, files in os.walk(pth):
# print(roots, dirs, files)
if len(dirs) == 0:
flag = discriminate_action(roots)
for file in files:
if file == '0_tracking_output.data':
result = read_data_file(os.path.join(roots, file))
if not len(result) == 0:
if flag == 'add':
add_special_output_0.append(get_special_data(result, 'average')) # 加购后摄
else:
return_special_output_0.append(get_special_data(result, 'average')) # 退购后摄
if flag == 'add':
add_tracking_output_0 += read_data_file(os.path.join(roots, file))
else:
return_tracking_output_0 += read_data_file(os.path.join(roots, file))
elif file == '1_tracking_output.data':
result = read_data_file(os.path.join(roots, file))
if not len(result) == 0:
if flag == 'add':
add_special_output_1.append(get_special_data(result, 'average')) # 加购前摄
else:
return_special_output_1.append(get_special_data(result, 'average')) # 退购前摄
if flag == 'add':
add_tracking_output_1 += read_data_file(os.path.join(roots, file))
else:
return_tracking_output_1 += read_data_file(os.path.join(roots, file))
comprehensive_dicts = {"add_special_output_0": add_special_output_0,
"return_special_output_0": return_special_output_0,
"add_tracking_output_0": add_tracking_output_0,
"return_tracking_output_0": return_tracking_output_0,
"add_special_output_1": add_special_output_1,
"return_special_output_1": return_special_output_1,
"add_tracking_output_1": add_tracking_output_1,
"return_tracking_output_1": return_tracking_output_1,
}
# print(tracking_output_0, tracking_output_1)
showHist(np.array(comprehensive_dicts['add_tracking_output_0']),
np.array(comprehensive_dicts['add_tracking_output_1']))
# showHist(np.array(comprehensive_dicts['add_special_output_0']),
# np.array(comprehensive_dicts['add_special_output_1']))
return comprehensive_dicts
def get_feature_array(img_pth_lists, res_model, gift_model, pkl_flag=True):
features_np = []
if pkl_flag:
for img_lists in img_pth_lists:
# print(img_lists)
fe_nps = getFeatureList(None, img_lists, res_model)
# fe_nps.squeeze()
try:
fe_nps = fe_nps[0][:, 256:]
except Exception as e:
print(e)
continue
fe_nps = torch.from_numpy(fe_nps)
fe_nps = fe_nps.view(fe_nps.shape[0], 64, 13, 13)
if len(fe_nps):
fe_np = gift_model(fe_nps)
fe_np = np.squeeze(fe_np.detach().numpy())
features_np.append(fe_np)
else:
for img_lists in img_pth_lists:
fe_nps = getFeatureList(None, img_lists, gift_model)
if len(fe_nps) > 0:
fe_nps = np.concatenate(fe_nps)
features_np.append(fe_nps)
return features_np
import pickle
def create_gift_subimg_np(data_pth, pkl_flag):
gift_array_pth = os.path.join(data_pth, 'gift.pkl')
nongift_array_pth = os.path.join(data_pth, 'nongift.pkl')
res_model, gift_model = init_model(pkl_flag)
res_model = res_model.eval()
gift_model = gift_model.eval()
gift_img_pth_list, gift_lists, nongift_img_pth_list, nongift_lists = [], [], [], []
for root, dirs, files in os.walk(data_pth):
if ('commodity' in root and 'subimg' in root):
print("commodity >> {}".format(root))
for file in files:
nongift_img_pth_list.append(os.sep.join([root, file]))
nongift_lists.append(nongift_img_pth_list)
nongift_img_pth_list = []
elif ('Havegift' in root and 'subimg' in root):
print("Havegift >> {}".format(root))
for file in files:
gift_img_pth_list.append(os.sep.join([root, file]))
gift_lists.append(gift_img_pth_list)
gift_img_pth_list = []
nongift = get_feature_array(nongift_lists, res_model, gift_model, pkl_flag)
gift = get_feature_array(gift_lists, res_model, gift_model, pkl_flag)
with open(nongift_array_pth, 'wb') as file:
pickle.dump(nongift, file)
with open(gift_array_pth, 'wb') as file:
pickle.dump(gift, file)
def top_25_percent_mean(arr):
# 1. 对数组进行从高到低排序
sorted_arr = np.sort(arr)[::-1]
# 2. 计算数组长度的25%
top_25_percent_length = int(len(sorted_arr) * 0.25)
# 3. 取排序后数组的前25%元素
top_25_percent = sorted_arr[:top_25_percent_length]
# 4. 计算这些元素的平均值
mean_value = np.mean(top_25_percent)
return top_25_percent
def assess_gift_subimg(data_pth, pkl_flag=False): # 分析分割后子图,
points = (np.linspace(1, 100, 100)) / 100
gift_pkl_pth = os.path.join(data_pth, 'gift.pkl')
nongift_pkl_pth = os.path.join(data_pth, 'nongift.pkl')
if not os.path.exists(gift_pkl_pth):
create_gift_subimg_np(data_pth, pkl_flag)
with open(nongift_pkl_pth, 'rb') as f:
nongift = pickle.load(f)
with open(gift_pkl_pth, 'rb') as f:
gift = pickle.load(f)
# showHist(nongift.flatten(), gift.flatten())
'''
一分位均值
'''
nongift_mean = [np.mean(top_25_percent_mean(items)) for items in nongift]
gift_mean = [np.mean(top_25_percent_mean(items)) for items in gift]
'''
中位数
'''
# nongift_mean = [np.median(items) for items in nongift]
# gift_mean = [np.median(items) for items in gift] # 平均值
'''
全部结果
'''
# nongifts = [items for items in nongift]
# gifts = [items for items in gift]
# showHist(nongifts, gifts)
'''
平均值
'''
# nongift_mean = [np.mean(items) for items in nongift]
# gift_mean = [np.mean(items) for items in gift]
showHist(np.array(nongift_mean), np.array(gift_mean)) # 最大值
precision, recall = calculate_precision_recall(np.array(nongift_mean),
np.array(gift_mean),
points)
showgrid(precision, recall, points)
def get_comprehensive_dicts(data_pth):
gift_pth = r'../checkpoints/gift_model/action2/best.pth'
g_model = Net3(pretrained=True, num_classes=1)
g_model.load_state_dict(torch.load(gift_pth))
g_model.eval()
result = []
file_name = ['0_tracking_output.data',
'1_tracking_output.data']
for root, dirs, files in os.walk(data_pth):
if not len(dirs):
for file in files:
if file in file_name:
print(os.path.join(root, file))
result += get_tracking_data(os.path.join(root, file))
result = torch.from_numpy(np.array(result))
input = result.view(result.shape[0], 64, 1, 1)
input = input.to('cpu')
input = input.to(torch.float32)
ji = g_model(input)
print(ji)
if __name__ == '__main__':
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\赠品\\images'
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\没有赠品的商品\\images'
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\同样的商品没有捆绑赠品\\images'
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241213赠品测试数据\\赠品'
# pth = r'C:\Users\HP\Desktop\zengpin\1227'
# get_gift_files(pth)
# 根据子图分析结果
pth = r'D:\Project\contrast_nettest\data\gift_test'
assess_gift_subimg(pth)
# 根据完整数据集分析结果
# pth = r'C:\Users\HP\Desktop\zengpin\1231'
# get_comprehensive_dicts(pth)
# 删除退购视频
# pth = r'C:\Users\HP\Desktop\gift_test\20241213\非赠品'
# clean_reurn_data(pth)