370 lines
14 KiB
Python
370 lines
14 KiB
Python
import os
|
|
import pdb
|
|
import shutil
|
|
import sys
|
|
|
|
sys.path.append('../model')
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from model.mlp import Net2, Net3, Net4
|
|
from model import resnet18
|
|
import torch
|
|
from gift_data_pretreatment import getFeatureList
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
def init_model(pkl_flag):
|
|
res_pth = r"../checkpoints/resnet18_1009/best.pth"
|
|
if pkl_flag:
|
|
gift_pth = r'../checkpoints/gift_model/action2/gift_v11.pth'
|
|
gift_model = Net3(pretrained=True, num_classes=1)
|
|
gift_model.load_state_dict(torch.load(gift_pth))
|
|
else:
|
|
gift_pth = r'../checkpoints/gift_model/action3/best.pth'
|
|
gift_model = Net4('resnet18', True, True) # 预训练模型
|
|
try:
|
|
print('>>multiple_cards load pre model <<')
|
|
gift_model.load_state_dict({k.replace('module.', ''): v for k, v in
|
|
torch.load(gift_pth,
|
|
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')).items()})
|
|
except Exception as e:
|
|
print('>> load pre model <<')
|
|
gift_model.load_state_dict(torch.load(gift_pth,
|
|
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
|
|
res_model = resnet18()
|
|
res_model.load_state_dict({k.replace('module.', ''): v for k, v in
|
|
torch.load(res_pth, map_location=torch.device(device)).items()})
|
|
return res_model, gift_model
|
|
|
|
|
|
def showHist(nongifts, gifts):
|
|
# Same = filtered_data[:, 1].astype(np.float32)
|
|
# Cross = filtered_data[:, 2].astype(np.float32)
|
|
|
|
fig, axs = plt.subplots(2, 1)
|
|
axs[0].hist(nongifts, bins=50, edgecolor='blue')
|
|
axs[0].set_xlim([-0.1, 1])
|
|
axs[0].set_title('nongifts')
|
|
|
|
axs[1].hist(gifts, bins=50, edgecolor='green')
|
|
axs[1].set_xlim([-0.1, 1])
|
|
axs[1].set_title('gifts')
|
|
# plt.savefig('plot.png')
|
|
plt.show()
|
|
|
|
|
|
def calculate_precision_recall(nongift, gift, points):
|
|
precision, recall = [], []
|
|
for point in points:
|
|
TP = np.sum(gift > point)
|
|
FN = np.sum(gift < point)
|
|
FP = np.sum(nongift > point)
|
|
TN = np.sum(nongift < point)
|
|
if TP == 0:
|
|
precision.append(0)
|
|
recall.append(0)
|
|
else:
|
|
precision.append(TP / (TP + FP))
|
|
recall.append(TP / (TP + FN))
|
|
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
|
|
if point == 0.5:
|
|
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
|
|
return precision, recall
|
|
|
|
|
|
def showgrid(all_prec, all_recall, points):
|
|
plt.figure(figsize=(10, 6))
|
|
plt.plot(points[:-1], all_prec[:-1], color='blue', label='precision')
|
|
plt.plot(points[:-1], all_recall[:-1], color='red', label='recall')
|
|
plt.legend()
|
|
plt.xlabel('threshold')
|
|
# plt.ylabel('Similarity')
|
|
plt.grid(True, linestyle='--', alpha=0.5)
|
|
# plt.savefig('grid.png')
|
|
plt.show()
|
|
plt.close()
|
|
pass
|
|
|
|
|
|
def discriminate_action(roots): # 判断加购还是退购
|
|
pth = os.sep.join([roots, 'process.data'])
|
|
with open(pth, 'r') as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
content = line.strip()
|
|
if 'weightValue' in content:
|
|
# print(content.split(":")[-1].split(',')[0])
|
|
if int(content.split(":")[-1].split(',')[0]) > 0:
|
|
return 'add'
|
|
else:
|
|
return 'return'
|
|
|
|
|
|
def median(lst):
|
|
sorted_lst = sorted(lst)
|
|
n = len(sorted_lst)
|
|
if n % 2 == 1:
|
|
# 如果列表长度是奇数,中位数是中间的那个元素
|
|
return sorted_lst[n // 2]
|
|
else:
|
|
# 如果列表长度是偶数,中位数是中间两个元素的平均值
|
|
mid1 = sorted_lst[(n // 2) - 1]
|
|
mid2 = sorted_lst[n // 2]
|
|
return (mid1 + mid2) / 2
|
|
|
|
|
|
def get_special_data(data, p):
|
|
# print(data)
|
|
length = len(data)
|
|
if length > 5:
|
|
if p == 'max':
|
|
return max(data[:round(length * 0.5)])
|
|
elif p == 'average':
|
|
return sum(data[:round(length * 0.5)]) / len(data[:round(length * 0.5)])
|
|
elif p == 'median':
|
|
return median(data[:round(length * 0.5)])
|
|
else:
|
|
return sum(data) / len(data)
|
|
|
|
|
|
def read_data_file(pth):
|
|
result = []
|
|
with open(pth, 'r') as data_file:
|
|
lines = data_file.readlines()
|
|
for line in lines:
|
|
if line.split(':')[0] == 'free_gift__result':
|
|
if '0_tracking_output.data' in pth:
|
|
result = line.split(':')[1].split(',')[:-1]
|
|
else:
|
|
result = line.split(':')[1].split(',')[:-2]
|
|
result = [float(i) for i in result]
|
|
return result
|
|
|
|
|
|
def get_tracking_data(pth):
|
|
result = []
|
|
with open(pth, 'r') as data_file:
|
|
lines = data_file.readlines()
|
|
for line in lines:
|
|
if len(line.split(',')) == 65:
|
|
result.append([float(item) for item in line.split(',')[:-1]])
|
|
return result
|
|
|
|
|
|
def clean_reurn_data(pth):
|
|
for roots, dirs, files in os.walk(pth):
|
|
# print(roots, dirs, files)
|
|
if len(dirs) == 0:
|
|
flag = discriminate_action(roots)
|
|
if flag == 'return':
|
|
shutil.rmtree(roots)
|
|
|
|
|
|
def get_gift_files(pth): # 测试后直接分析测试结果文件
|
|
add_special_output_0, return_special_output_0, return_special_output_1, add_special_output_1 = [], [], [], []
|
|
add_tracking_output_0, return_tracking_output_0, add_tracking_output_1, return_tracking_output_1 = [], [], [], []
|
|
for roots, dirs, files in os.walk(pth):
|
|
# print(roots, dirs, files)
|
|
if len(dirs) == 0:
|
|
flag = discriminate_action(roots)
|
|
for file in files:
|
|
if file == '0_tracking_output.data':
|
|
result = read_data_file(os.path.join(roots, file))
|
|
if not len(result) == 0:
|
|
if flag == 'add':
|
|
add_special_output_0.append(get_special_data(result, 'average')) # 加购后摄
|
|
else:
|
|
return_special_output_0.append(get_special_data(result, 'average')) # 退购后摄
|
|
if flag == 'add':
|
|
add_tracking_output_0 += read_data_file(os.path.join(roots, file))
|
|
else:
|
|
return_tracking_output_0 += read_data_file(os.path.join(roots, file))
|
|
elif file == '1_tracking_output.data':
|
|
result = read_data_file(os.path.join(roots, file))
|
|
if not len(result) == 0:
|
|
if flag == 'add':
|
|
add_special_output_1.append(get_special_data(result, 'average')) # 加购前摄
|
|
else:
|
|
return_special_output_1.append(get_special_data(result, 'average')) # 退购前摄
|
|
if flag == 'add':
|
|
add_tracking_output_1 += read_data_file(os.path.join(roots, file))
|
|
else:
|
|
return_tracking_output_1 += read_data_file(os.path.join(roots, file))
|
|
comprehensive_dicts = {"add_special_output_0": add_special_output_0,
|
|
"return_special_output_0": return_special_output_0,
|
|
"add_tracking_output_0": add_tracking_output_0,
|
|
"return_tracking_output_0": return_tracking_output_0,
|
|
"add_special_output_1": add_special_output_1,
|
|
"return_special_output_1": return_special_output_1,
|
|
"add_tracking_output_1": add_tracking_output_1,
|
|
"return_tracking_output_1": return_tracking_output_1,
|
|
}
|
|
# print(tracking_output_0, tracking_output_1)
|
|
showHist(np.array(comprehensive_dicts['add_tracking_output_0']),
|
|
np.array(comprehensive_dicts['add_tracking_output_1']))
|
|
# showHist(np.array(comprehensive_dicts['add_special_output_0']),
|
|
# np.array(comprehensive_dicts['add_special_output_1']))
|
|
return comprehensive_dicts
|
|
|
|
|
|
def get_feature_array(img_pth_lists, res_model, gift_model, pkl_flag=True):
|
|
features_np = []
|
|
if pkl_flag:
|
|
for img_lists in img_pth_lists:
|
|
# print(img_lists)
|
|
fe_nps = getFeatureList(None, img_lists, res_model)
|
|
# fe_nps.squeeze()
|
|
try:
|
|
fe_nps = fe_nps[0][:, 256:]
|
|
except Exception as e:
|
|
print(e)
|
|
continue
|
|
fe_nps = torch.from_numpy(fe_nps)
|
|
fe_nps = fe_nps.view(fe_nps.shape[0], 64, 13, 13)
|
|
if len(fe_nps):
|
|
fe_np = gift_model(fe_nps)
|
|
fe_np = np.squeeze(fe_np.detach().numpy())
|
|
features_np.append(fe_np)
|
|
else:
|
|
for img_lists in img_pth_lists:
|
|
fe_nps = getFeatureList(None, img_lists, gift_model)
|
|
if len(fe_nps) > 0:
|
|
fe_nps = np.concatenate(fe_nps)
|
|
features_np.append(fe_nps)
|
|
return features_np
|
|
|
|
|
|
import pickle
|
|
|
|
|
|
def create_gift_subimg_np(data_pth, pkl_flag):
|
|
gift_array_pth = os.path.join(data_pth, 'gift.pkl')
|
|
nongift_array_pth = os.path.join(data_pth, 'nongift.pkl')
|
|
res_model, gift_model = init_model(pkl_flag)
|
|
res_model = res_model.eval()
|
|
gift_model = gift_model.eval()
|
|
gift_img_pth_list, gift_lists, nongift_img_pth_list, nongift_lists = [], [], [], []
|
|
|
|
for root, dirs, files in os.walk(data_pth):
|
|
if ('commodity' in root and 'subimg' in root):
|
|
print("commodity >> {}".format(root))
|
|
for file in files:
|
|
nongift_img_pth_list.append(os.sep.join([root, file]))
|
|
nongift_lists.append(nongift_img_pth_list)
|
|
nongift_img_pth_list = []
|
|
elif ('Havegift' in root and 'subimg' in root):
|
|
print("Havegift >> {}".format(root))
|
|
for file in files:
|
|
gift_img_pth_list.append(os.sep.join([root, file]))
|
|
gift_lists.append(gift_img_pth_list)
|
|
gift_img_pth_list = []
|
|
nongift = get_feature_array(nongift_lists, res_model, gift_model, pkl_flag)
|
|
gift = get_feature_array(gift_lists, res_model, gift_model, pkl_flag)
|
|
with open(nongift_array_pth, 'wb') as file:
|
|
pickle.dump(nongift, file)
|
|
with open(gift_array_pth, 'wb') as file:
|
|
pickle.dump(gift, file)
|
|
|
|
|
|
def top_25_percent_mean(arr):
|
|
# 1. 对数组进行从高到低排序
|
|
sorted_arr = np.sort(arr)[::-1]
|
|
|
|
# 2. 计算数组长度的25%
|
|
top_25_percent_length = int(len(sorted_arr) * 0.25)
|
|
|
|
# 3. 取排序后数组的前25%元素
|
|
top_25_percent = sorted_arr[:top_25_percent_length]
|
|
|
|
# 4. 计算这些元素的平均值
|
|
mean_value = np.mean(top_25_percent)
|
|
|
|
return top_25_percent
|
|
|
|
|
|
def assess_gift_subimg(data_pth, pkl_flag=False): # 分析分割后子图,
|
|
points = (np.linspace(1, 100, 100)) / 100
|
|
gift_pkl_pth = os.path.join(data_pth, 'gift.pkl')
|
|
nongift_pkl_pth = os.path.join(data_pth, 'nongift.pkl')
|
|
if not os.path.exists(gift_pkl_pth):
|
|
create_gift_subimg_np(data_pth, pkl_flag)
|
|
with open(nongift_pkl_pth, 'rb') as f:
|
|
nongift = pickle.load(f)
|
|
with open(gift_pkl_pth, 'rb') as f:
|
|
gift = pickle.load(f)
|
|
# showHist(nongift.flatten(), gift.flatten())
|
|
|
|
'''
|
|
一分位均值
|
|
'''
|
|
nongift_mean = [np.mean(top_25_percent_mean(items)) for items in nongift]
|
|
gift_mean = [np.mean(top_25_percent_mean(items)) for items in gift]
|
|
'''
|
|
中位数
|
|
'''
|
|
# nongift_mean = [np.median(items) for items in nongift]
|
|
# gift_mean = [np.median(items) for items in gift] # 平均值
|
|
|
|
'''
|
|
全部结果
|
|
'''
|
|
# nongifts = [items for items in nongift]
|
|
# gifts = [items for items in gift]
|
|
# showHist(nongifts, gifts)
|
|
|
|
'''
|
|
平均值
|
|
'''
|
|
# nongift_mean = [np.mean(items) for items in nongift]
|
|
# gift_mean = [np.mean(items) for items in gift]
|
|
|
|
showHist(np.array(nongift_mean), np.array(gift_mean)) # 最大值
|
|
precision, recall = calculate_precision_recall(np.array(nongift_mean),
|
|
np.array(gift_mean),
|
|
points)
|
|
showgrid(precision, recall, points)
|
|
|
|
|
|
def get_comprehensive_dicts(data_pth):
|
|
gift_pth = r'../checkpoints/gift_model/action2/best.pth'
|
|
g_model = Net3(pretrained=True, num_classes=1)
|
|
g_model.load_state_dict(torch.load(gift_pth))
|
|
g_model.eval()
|
|
result = []
|
|
file_name = ['0_tracking_output.data',
|
|
'1_tracking_output.data']
|
|
for root, dirs, files in os.walk(data_pth):
|
|
if not len(dirs):
|
|
for file in files:
|
|
if file in file_name:
|
|
print(os.path.join(root, file))
|
|
result += get_tracking_data(os.path.join(root, file))
|
|
result = torch.from_numpy(np.array(result))
|
|
input = result.view(result.shape[0], 64, 1, 1)
|
|
input = input.to('cpu')
|
|
input = input.to(torch.float32)
|
|
ji = g_model(input)
|
|
print(ji)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\赠品\\images'
|
|
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\没有赠品的商品\\images'
|
|
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\同样的商品没有捆绑赠品\\images'
|
|
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241213赠品测试数据\\赠品'
|
|
# pth = r'C:\Users\HP\Desktop\zengpin\1227'
|
|
# get_gift_files(pth)
|
|
|
|
# 根据子图分析结果
|
|
pth = r'D:\Project\contrast_nettest\data\gift_test'
|
|
assess_gift_subimg(pth)
|
|
|
|
# 根据完整数据集分析结果
|
|
# pth = r'C:\Users\HP\Desktop\zengpin\1231'
|
|
# get_comprehensive_dicts(pth)
|
|
|
|
# 删除退购视频
|
|
# pth = r'C:\Users\HP\Desktop\gift_test\20241213\非赠品'
|
|
# clean_reurn_data(pth)
|