mirror of
https://gitee.com/nanjing-yimao-information/ieemoo-ai-gift.git
synced 2025-08-23 23:50:25 +00:00
update
This commit is contained in:
144
pipline_test/display_result.py
Normal file
144
pipline_test/display_result.py
Normal file
@ -0,0 +1,144 @@
|
||||
from sklearn.preprocessing import label_binarize
|
||||
from sklearn.metrics import precision_recall_curve, average_precision_score
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
# from utils.config import conf
|
||||
import os
|
||||
|
||||
|
||||
# def show_multiclass_pr(y_true, y_scores): # 多分类
|
||||
# # 将真实标签二值化
|
||||
# y_true_binarized = label_binarize(y_true, classes=np.arange(conf.n_classes))
|
||||
#
|
||||
# # 计算每个类别的PR曲线和平均精度
|
||||
# precision = dict()
|
||||
# recall = dict()
|
||||
# average_precision = dict()
|
||||
#
|
||||
# for i in range(conf.n_classes):
|
||||
# precision[i], recall[i], _ = precision_recall_curve(y_true_binarized[:, i], y_scores[:, i])
|
||||
# average_precision[i] = average_precision_score(y_true_binarized[:, i], y_scores[:, i])
|
||||
#
|
||||
# # 计算微平均PR曲线和平均精度
|
||||
# precision["micro"], recall["micro"], _ = precision_recall_curve(y_true_binarized.ravel(), y_scores.ravel())
|
||||
# average_precision["micro"] = average_precision_score(y_true_binarized, y_scores, average="micro")
|
||||
#
|
||||
# # 计算宏平均PR曲线和平均精度
|
||||
# precision["macro"] = np.mean([precision[i] for i in range(conf.n_classes)], axis=0)
|
||||
# recall["macro"] = np.mean([recall[i] for i in range(conf.n_classes)], axis=0)
|
||||
# average_precision["macro"] = np.mean([average_precision[i] for i in range(conf.n_classes)])
|
||||
#
|
||||
# # 绘制微平均PR曲线
|
||||
# plt.figure(figsize=(10, 7))
|
||||
# plt.plot(recall["micro"], precision["micro"], color='gold', lw=2,
|
||||
# label='micro-average Precision-recall curve (area = {0:0.2f})'
|
||||
# ''.format(average_precision["micro"]))
|
||||
#
|
||||
# # 绘制宏平均PR曲线
|
||||
# plt.plot(recall["macro"], precision["macro"], color='navy', lw=2,
|
||||
# label='macro-average Precision-recall curve (area = {0:0.2f})'
|
||||
# ''.format(average_precision["macro"]))
|
||||
#
|
||||
# # 绘制每个类别的PR曲线
|
||||
# colors = plt.cm.tab20(np.linspace(0, 1, conf.n_classes))
|
||||
# for i, color in zip(range(conf.n_classes), colors):
|
||||
# plt.plot(recall[i], precision[i], color=color, lw=2,
|
||||
# label='Precision-recall curve of class {0} (area = {1:0.2f})'
|
||||
# ''.format(i, average_precision[i]))
|
||||
#
|
||||
# plt.xlabel('Recall')
|
||||
# plt.ylabel('Precision')
|
||||
# plt.title('Extension of Precision-Recall curve to multi-class')
|
||||
# plt.legend(loc="best")
|
||||
# plt.show()
|
||||
# pass
|
||||
|
||||
|
||||
def calculate_similarity(outputs, targets, threshold):
|
||||
FN, FP, TN, TP = 0, 0, 0, 0
|
||||
for output, target in zip(outputs, targets):
|
||||
# print("output >> {} , target >> {}".format(output, target))
|
||||
if output != target:
|
||||
if target == 0:
|
||||
FP += 1
|
||||
elif target == 1:
|
||||
FN += 1
|
||||
else:
|
||||
if target == 0:
|
||||
TN += 1
|
||||
elif target == 1:
|
||||
TP += 1
|
||||
if TP == 0:
|
||||
prec, recall = 0, 0
|
||||
else:
|
||||
prec = TP / (TP + FP)
|
||||
recall = TP / (TP + FN)
|
||||
if TN == 0:
|
||||
tn_prec, tn_recall = 0, 0
|
||||
else:
|
||||
tn_prec = TN / (TN + FN)
|
||||
tn_recall = TN / (TN + FP)
|
||||
# print("TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(TP, FP, TN, FN))
|
||||
if threshold == 0.1:
|
||||
print("TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(TP, FP, TN, FN))
|
||||
return prec, recall, tn_prec, tn_recall
|
||||
|
||||
|
||||
def show_pr(prec, recall, tn_prec, tn_recall, thres, title_name):
|
||||
x = thres
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||
plt.plot(x, tn_recall, color='black', label='recall_TN:TN/TNFP')
|
||||
plt.plot(x, prec, color='blue', label='PrecisePos:TP/TPFN')
|
||||
plt.plot(x, tn_prec, color='green', label='PreciseNeg:TN/TNFP')
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
# if self.title_name is not None:
|
||||
plt.title(title_name)
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
plt.savefig(os.sep.join(['../pr_test', 'yolo_pr.png']))
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
def write_results_to_file(title_name, thres, recall, recall_TN, PrecisePos, PreciseNeg):
|
||||
file_path = os.sep.join(['../pr_test', title_name + '.txt'])
|
||||
with open(file_path, 'w') as file:
|
||||
file.write("threshold, recall, recall_TN, PrecisePos, PreciseNeg\n")
|
||||
for thre, rec, rec_tn, prec_pos, prec_neg in zip(thres, recall, recall_TN, PrecisePos, PreciseNeg):
|
||||
file.write(
|
||||
f"thre>>{thre:.2f}, recall>>{rec:.4f}, precise_pos>>{prec_pos:.4f}, recall_tn>>{rec_tn:.4f}, precise_neg>>{prec_neg:4f}\n")
|
||||
|
||||
|
||||
def show_to_pr(y_true, y_prob, type): # 二分类
|
||||
thres = [i * 0.01 for i in range(101)]
|
||||
title_name = 'yolo_pr'
|
||||
recall, recall_TN, PrecisePos, PreciseNeg = [], [], [], []
|
||||
for threshold in thres:
|
||||
y_scores_adjusted = []
|
||||
if type == 0:
|
||||
y_scores_adjusted = np.where(y_prob < threshold, 0, 1)
|
||||
elif type == 1:
|
||||
for yp in y_prob:
|
||||
if yp != 0:
|
||||
yp = np.array(yp)
|
||||
yp_num = np.sum(np.array(yp) > threshold)
|
||||
if yp_num / len(yp) > 0.1:
|
||||
y_scores_adjusted.append(1)
|
||||
else:
|
||||
y_scores_adjusted.append(0)
|
||||
else:
|
||||
y_scores_adjusted.append(0)
|
||||
prec, pos_recall, tn_prec, tn_recall = calculate_similarity(y_scores_adjusted, y_true, threshold)
|
||||
recall.append(pos_recall)
|
||||
recall_TN.append(tn_recall)
|
||||
PrecisePos.append(prec)
|
||||
PreciseNeg.append(tn_prec)
|
||||
# print(" prec>>{} recall>>{} tn_prec>>{} tn_recall>>{} threshold>>{}\n".format(prec, pos_recall, tn_prec,
|
||||
# tn_recall, threshold))
|
||||
show_pr(PrecisePos, recall, PreciseNeg, recall_TN, thres, title_name)
|
||||
write_results_to_file(title_name, thres, recall, recall_TN, PrecisePos, PreciseNeg)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
Reference in New Issue
Block a user