import os.path import os import matplotlib.pyplot as plt import numpy as np # [tag, bandage, null] # [0, 1, 2] class ShowPR: def __init__(self, tags, prec_value, title_name=None): self.tags = tags self.prec_value = prec_value self.thres = [i * 0.01 for i in range(101)] self.ratios = [i * 0.05 for i in range(21)] self.title_name = title_name def change_precValue(self, thre=0.5): values = [] for i in range(len(self.prec_value)): value = [] for j in range(len(self.prec_value[i])): if self.prec_value[i][j] > thre: value.append(1) else: value.append(0) values.append(value) return values def calculate_mena(self, ratio=0.5): values = [] for data in self.prec_value: thres_num = int(len(data)*ratio) sorted_data = sorted(data, reverse=True) value = sorted_data[:thres_num] if len(value) == 0: value = sorted_data[:1] values.append(sum(value)/len(value)) return values def _calculate_pr(self, prec_value): FN, FP, TN, TP = 0, 0, 0, 0 for output, target in zip(prec_value, self.tags): # print("output >> {} , target >> {}".format(output, target)) if output != target: if target == 0: FP += 1 elif target == 1: FN += 1 else: if target == 0: TN += 1 elif target == 1: TP += 1 if TP == 0: prec, recall = 0, 0 else: prec = TP / (TP + FP) recall = TP / (TP + FN) if TN == 0: tn_prec, tn_recall = 0, 0 else: tn_prec = TN / (TN + FN) tn_recall = TN / (TN + FP) # print("TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(TP, FP, TN, FN)) return prec, recall, tn_prec, tn_recall def calculate_multiple_1(self, ratio=0.2): # 方案1 计算满足阈值判断的占比(ratio) recall, recall_TN, PrecisePos, PreciseNeg = [], [], [], [] for thre in self.thres: prec_value = [] value = self.change_precValue(thre) for num in range(len(value)): proportion = value[num].count(1) / len(value[num]) if proportion >= ratio: prec_value.append(1) else: prec_value.append(0) prec, recall_pos, tn_prec, tn_recall = self._calculate_pr(prec_value) print( f"thre>>{ratio:.2f}, recall>>{recall_pos:.4f}, precise_pos>>{prec:.4f}, recall_tn>>{tn_recall:.4f}, precise_neg>>{tn_prec:4f}") PrecisePos.append(prec) recall.append(recall_pos) PreciseNeg.append(tn_prec) recall_TN.append(tn_recall) return recall, recall_TN, PrecisePos, PreciseNeg def calculate_multiple_2(self, ratio=0.2): # 方案2 计算前ratio的预测试值的平均值大于thre为赠品小于为非赠品 recall, recall_TN, PrecisePos, PreciseNeg = [], [], [], [] event_value = self.calculate_mena(ratio) for thre in self.thres: prec_value = [1 if num >= thre else 0 for num in event_value] prec, recall_pos, tn_prec, tn_recall = self._calculate_pr(prec_value) print( f"thre>>{ratio:.2f}, recall>>{recall_pos:.4f}, precise_pos>>{prec:.4f}, recall_tn>>{tn_recall:.4f}, precise_neg>>{tn_prec:4f}") PrecisePos.append(prec) recall.append(recall_pos) PreciseNeg.append(tn_prec) recall_TN.append(tn_recall) return recall, recall_TN, PrecisePos, PreciseNeg def write_results_to_file(self, recall, recall_TN, PrecisePos, PreciseNeg, ratio): file_path = os.sep.join(['./ckpts/tracePR', self.title_name + f"_{ratio:.2f}" + '.txt']) with open(file_path, 'w') as file: file.write("threshold, recall, recall_TN, PrecisePos, PreciseNeg\n") for thre, rec, rec_tn, prec_pos, prec_neg in zip(self.thres, recall, recall_TN, PrecisePos, PreciseNeg): file.write( f"thre>>{thre:.2f}, recall>>{rec:.4f}, precise_pos>>{prec_pos:.4f}, recall_tn>>{rec_tn:.4f}, precise_neg>>{prec_neg:4f}\n") def show_pr(self, recall, recall_TN, PrecisePos, PreciseNeg, ratio): # self.calculate_multiple() x = self.thres plt.figure(figsize=(10, 6)) plt.plot(x, recall, color='red', label='recall:TP/TPFN') plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP') plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN') plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP') plt.legend() plt.xlabel('threshold') if self.title_name is not None: plt.title(f"PrecisePos & Recall ratio:{ratio:.2f}", fontdict={'fontsize': 12, 'fontweight': 'black'}) # plt.grid(True, linestyle='--', alpha=0.5) # 启用次刻度 plt.minorticks_on() # 设置主刻度的网格线 plt.grid(which='major', linestyle='-', alpha=0.5, color='gray') # 设置次刻度的网格线 plt.grid(which='minor', linestyle=':', alpha=0.3, color='gray') plt.savefig(os.sep.join(['./ckpts/tracePR', self.title_name + f"_{ratio:.2f}" + '.png'])) plt.show() plt.close() self.write_results_to_file(recall, recall_TN, PrecisePos, PreciseNeg, ratio) def get_pic(self): for ratio in self.ratios: # ratio = 0.5 if ratio < 0.1 or ratio > 0.95: continue recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple_1(ratio) # recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple_2(ratio) self.show_pr(recall, recall_TN, PrecisePos, PreciseNeg, ratio)