import os import pdb import shutil import sys sys.path.append('../model') import matplotlib.pyplot as plt import numpy as np from model.mlp import Net2, Net3, Net4 from model import resnet18 import torch from gift_data_pretreatment import getFeatureList device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def init_model(pkl_flag): res_pth = r"../checkpoints/resnet18_1009/best.pth" if pkl_flag: gift_pth = r'../checkpoints/gift_model/action2/gift_v11.pth' gift_model = Net3(pretrained=True, num_classes=1) gift_model.load_state_dict(torch.load(gift_pth)) else: gift_pth = r'../checkpoints/gift_model/action3/best.pth' gift_model = Net4('resnet18', True, True) # 预训练模型 try: print('>>multiple_cards load pre model <<') gift_model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(gift_pth, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')).items()}) except Exception as e: print('>> load pre model <<') gift_model.load_state_dict(torch.load(gift_pth, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))) res_model = resnet18() res_model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(res_pth, map_location=torch.device(device)).items()}) return res_model, gift_model def showHist(nongifts, gifts): # Same = filtered_data[:, 1].astype(np.float32) # Cross = filtered_data[:, 2].astype(np.float32) fig, axs = plt.subplots(2, 1) axs[0].hist(nongifts, bins=50, edgecolor='blue') axs[0].set_xlim([-0.1, 1]) axs[0].set_title('nongifts') axs[1].hist(gifts, bins=50, edgecolor='green') axs[1].set_xlim([-0.1, 1]) axs[1].set_title('gifts') # plt.savefig('plot.png') plt.show() def calculate_precision_recall(nongift, gift, points): precision, recall = [], [] for point in points: TP = np.sum(gift > point) FN = np.sum(gift < point) FP = np.sum(nongift > point) TN = np.sum(nongift < point) if TP == 0: precision.append(0) recall.append(0) else: precision.append(TP / (TP + FP)) recall.append(TP / (TP + FN)) print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN)) if point == 0.5: print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN)) return precision, recall def showgrid(all_prec, all_recall, points): plt.figure(figsize=(10, 6)) plt.plot(points[:-1], all_prec[:-1], color='blue', label='precision') plt.plot(points[:-1], all_recall[:-1], color='red', label='recall') plt.legend() plt.xlabel('threshold') # plt.ylabel('Similarity') plt.grid(True, linestyle='--', alpha=0.5) # plt.savefig('grid.png') plt.show() plt.close() pass def discriminate_action(roots): # 判断加购还是退购 pth = os.sep.join([roots, 'process.data']) with open(pth, 'r') as f: lines = f.readlines() for line in lines: content = line.strip() if 'weightValue' in content: # print(content.split(":")[-1].split(',')[0]) if int(content.split(":")[-1].split(',')[0]) > 0: return 'add' else: return 'return' def median(lst): sorted_lst = sorted(lst) n = len(sorted_lst) if n % 2 == 1: # 如果列表长度是奇数,中位数是中间的那个元素 return sorted_lst[n // 2] else: # 如果列表长度是偶数,中位数是中间两个元素的平均值 mid1 = sorted_lst[(n // 2) - 1] mid2 = sorted_lst[n // 2] return (mid1 + mid2) / 2 def get_special_data(data, p): # print(data) length = len(data) if length > 5: if p == 'max': return max(data[:round(length * 0.5)]) elif p == 'average': return sum(data[:round(length * 0.5)]) / len(data[:round(length * 0.5)]) elif p == 'median': return median(data[:round(length * 0.5)]) else: return sum(data) / len(data) def read_data_file(pth): result = [] with open(pth, 'r') as data_file: lines = data_file.readlines() for line in lines: if line.split(':')[0] == 'free_gift__result': if '0_tracking_output.data' in pth: result = line.split(':')[1].split(',')[:-1] else: result = line.split(':')[1].split(',')[:-2] result = [float(i) for i in result] return result def get_tracking_data(pth): result = [] with open(pth, 'r') as data_file: lines = data_file.readlines() for line in lines: if len(line.split(',')) == 65: result.append([float(item) for item in line.split(',')[:-1]]) return result def clean_reurn_data(pth): for roots, dirs, files in os.walk(pth): # print(roots, dirs, files) if len(dirs) == 0: flag = discriminate_action(roots) if flag == 'return': shutil.rmtree(roots) def get_gift_files(pth): # 测试后直接分析测试结果文件 add_special_output_0, return_special_output_0, return_special_output_1, add_special_output_1 = [], [], [], [] add_tracking_output_0, return_tracking_output_0, add_tracking_output_1, return_tracking_output_1 = [], [], [], [] for roots, dirs, files in os.walk(pth): # print(roots, dirs, files) if len(dirs) == 0: flag = discriminate_action(roots) for file in files: if file == '0_tracking_output.data': result = read_data_file(os.path.join(roots, file)) if not len(result) == 0: if flag == 'add': add_special_output_0.append(get_special_data(result, 'average')) # 加购后摄 else: return_special_output_0.append(get_special_data(result, 'average')) # 退购后摄 if flag == 'add': add_tracking_output_0 += read_data_file(os.path.join(roots, file)) else: return_tracking_output_0 += read_data_file(os.path.join(roots, file)) elif file == '1_tracking_output.data': result = read_data_file(os.path.join(roots, file)) if not len(result) == 0: if flag == 'add': add_special_output_1.append(get_special_data(result, 'average')) # 加购前摄 else: return_special_output_1.append(get_special_data(result, 'average')) # 退购前摄 if flag == 'add': add_tracking_output_1 += read_data_file(os.path.join(roots, file)) else: return_tracking_output_1 += read_data_file(os.path.join(roots, file)) comprehensive_dicts = {"add_special_output_0": add_special_output_0, "return_special_output_0": return_special_output_0, "add_tracking_output_0": add_tracking_output_0, "return_tracking_output_0": return_tracking_output_0, "add_special_output_1": add_special_output_1, "return_special_output_1": return_special_output_1, "add_tracking_output_1": add_tracking_output_1, "return_tracking_output_1": return_tracking_output_1, } # print(tracking_output_0, tracking_output_1) showHist(np.array(comprehensive_dicts['add_tracking_output_0']), np.array(comprehensive_dicts['add_tracking_output_1'])) # showHist(np.array(comprehensive_dicts['add_special_output_0']), # np.array(comprehensive_dicts['add_special_output_1'])) return comprehensive_dicts def get_feature_array(img_pth_lists, res_model, gift_model, pkl_flag=True): features_np = [] if pkl_flag: for img_lists in img_pth_lists: # print(img_lists) fe_nps = getFeatureList(None, img_lists, res_model) # fe_nps.squeeze() try: fe_nps = fe_nps[0][:, 256:] except Exception as e: print(e) continue fe_nps = torch.from_numpy(fe_nps) fe_nps = fe_nps.view(fe_nps.shape[0], 64, 13, 13) if len(fe_nps): fe_np = gift_model(fe_nps) fe_np = np.squeeze(fe_np.detach().numpy()) features_np.append(fe_np) else: for img_lists in img_pth_lists: fe_nps = getFeatureList(None, img_lists, gift_model) if len(fe_nps) > 0: fe_nps = np.concatenate(fe_nps) features_np.append(fe_nps) return features_np import pickle def create_gift_subimg_np(data_pth, pkl_flag): gift_array_pth = os.path.join(data_pth, 'gift.pkl') nongift_array_pth = os.path.join(data_pth, 'nongift.pkl') res_model, gift_model = init_model(pkl_flag) res_model = res_model.eval() gift_model = gift_model.eval() gift_img_pth_list, gift_lists, nongift_img_pth_list, nongift_lists = [], [], [], [] for root, dirs, files in os.walk(data_pth): if ('commodity' in root and 'subimg' in root): print("commodity >> {}".format(root)) for file in files: nongift_img_pth_list.append(os.sep.join([root, file])) nongift_lists.append(nongift_img_pth_list) nongift_img_pth_list = [] elif ('Havegift' in root and 'subimg' in root): print("Havegift >> {}".format(root)) for file in files: gift_img_pth_list.append(os.sep.join([root, file])) gift_lists.append(gift_img_pth_list) gift_img_pth_list = [] nongift = get_feature_array(nongift_lists, res_model, gift_model, pkl_flag) gift = get_feature_array(gift_lists, res_model, gift_model, pkl_flag) with open(nongift_array_pth, 'wb') as file: pickle.dump(nongift, file) with open(gift_array_pth, 'wb') as file: pickle.dump(gift, file) def top_25_percent_mean(arr): # 1. 对数组进行从高到低排序 sorted_arr = np.sort(arr)[::-1] # 2. 计算数组长度的25% top_25_percent_length = int(len(sorted_arr) * 0.25) # 3. 取排序后数组的前25%元素 top_25_percent = sorted_arr[:top_25_percent_length] # 4. 计算这些元素的平均值 mean_value = np.mean(top_25_percent) return top_25_percent def assess_gift_subimg(data_pth, pkl_flag=False): # 分析分割后子图, points = (np.linspace(1, 100, 100)) / 100 gift_pkl_pth = os.path.join(data_pth, 'gift.pkl') nongift_pkl_pth = os.path.join(data_pth, 'nongift.pkl') if not os.path.exists(gift_pkl_pth): create_gift_subimg_np(data_pth, pkl_flag) with open(nongift_pkl_pth, 'rb') as f: nongift = pickle.load(f) with open(gift_pkl_pth, 'rb') as f: gift = pickle.load(f) # showHist(nongift.flatten(), gift.flatten()) ''' 一分位均值 ''' nongift_mean = [np.mean(top_25_percent_mean(items)) for items in nongift] gift_mean = [np.mean(top_25_percent_mean(items)) for items in gift] ''' 中位数 ''' # nongift_mean = [np.median(items) for items in nongift] # gift_mean = [np.median(items) for items in gift] # 平均值 ''' 全部结果 ''' # nongifts = [items for items in nongift] # gifts = [items for items in gift] # showHist(nongifts, gifts) ''' 平均值 ''' # nongift_mean = [np.mean(items) for items in nongift] # gift_mean = [np.mean(items) for items in gift] showHist(np.array(nongift_mean), np.array(gift_mean)) # 最大值 precision, recall = calculate_precision_recall(np.array(nongift_mean), np.array(gift_mean), points) showgrid(precision, recall, points) def get_comprehensive_dicts(data_pth): gift_pth = r'../checkpoints/gift_model/action2/best.pth' g_model = Net3(pretrained=True, num_classes=1) g_model.load_state_dict(torch.load(gift_pth)) g_model.eval() result = [] file_name = ['0_tracking_output.data', '1_tracking_output.data'] for root, dirs, files in os.walk(data_pth): if not len(dirs): for file in files: if file in file_name: print(os.path.join(root, file)) result += get_tracking_data(os.path.join(root, file)) result = torch.from_numpy(np.array(result)) input = result.view(result.shape[0], 64, 1, 1) input = input.to('cpu') input = input.to(torch.float32) ji = g_model(input) print(ji) if __name__ == '__main__': # pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\赠品\\images' # pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\没有赠品的商品\\images' # pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\同样的商品没有捆绑赠品\\images' # pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241213赠品测试数据\\赠品' # pth = r'C:\Users\HP\Desktop\zengpin\1227' # get_gift_files(pth) # 根据子图分析结果 pth = r'D:\Project\contrast_nettest\data\gift_test' assess_gift_subimg(pth) # 根据完整数据集分析结果 # pth = r'C:\Users\HP\Desktop\zengpin\1231' # get_comprehensive_dicts(pth) # 删除退购视频 # pth = r'C:\Users\HP\Desktop\gift_test\20241213\非赠品' # clean_reurn_data(pth)