mirror of
https://gitee.com/nanjing-yimao-information/ieemoo-ai-gift.git
synced 2025-08-18 13:20:25 +00:00
145 lines
5.8 KiB
Python
145 lines
5.8 KiB
Python
from sklearn.preprocessing import label_binarize
|
|
from sklearn.metrics import precision_recall_curve, average_precision_score
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
# from utils.config import conf
|
|
import os
|
|
|
|
|
|
# def show_multiclass_pr(y_true, y_scores): # 多分类
|
|
# # 将真实标签二值化
|
|
# y_true_binarized = label_binarize(y_true, classes=np.arange(conf.n_classes))
|
|
#
|
|
# # 计算每个类别的PR曲线和平均精度
|
|
# precision = dict()
|
|
# recall = dict()
|
|
# average_precision = dict()
|
|
#
|
|
# for i in range(conf.n_classes):
|
|
# precision[i], recall[i], _ = precision_recall_curve(y_true_binarized[:, i], y_scores[:, i])
|
|
# average_precision[i] = average_precision_score(y_true_binarized[:, i], y_scores[:, i])
|
|
#
|
|
# # 计算微平均PR曲线和平均精度
|
|
# precision["micro"], recall["micro"], _ = precision_recall_curve(y_true_binarized.ravel(), y_scores.ravel())
|
|
# average_precision["micro"] = average_precision_score(y_true_binarized, y_scores, average="micro")
|
|
#
|
|
# # 计算宏平均PR曲线和平均精度
|
|
# precision["macro"] = np.mean([precision[i] for i in range(conf.n_classes)], axis=0)
|
|
# recall["macro"] = np.mean([recall[i] for i in range(conf.n_classes)], axis=0)
|
|
# average_precision["macro"] = np.mean([average_precision[i] for i in range(conf.n_classes)])
|
|
#
|
|
# # 绘制微平均PR曲线
|
|
# plt.figure(figsize=(10, 7))
|
|
# plt.plot(recall["micro"], precision["micro"], color='gold', lw=2,
|
|
# label='micro-average Precision-recall curve (area = {0:0.2f})'
|
|
# ''.format(average_precision["micro"]))
|
|
#
|
|
# # 绘制宏平均PR曲线
|
|
# plt.plot(recall["macro"], precision["macro"], color='navy', lw=2,
|
|
# label='macro-average Precision-recall curve (area = {0:0.2f})'
|
|
# ''.format(average_precision["macro"]))
|
|
#
|
|
# # 绘制每个类别的PR曲线
|
|
# colors = plt.cm.tab20(np.linspace(0, 1, conf.n_classes))
|
|
# for i, color in zip(range(conf.n_classes), colors):
|
|
# plt.plot(recall[i], precision[i], color=color, lw=2,
|
|
# label='Precision-recall curve of class {0} (area = {1:0.2f})'
|
|
# ''.format(i, average_precision[i]))
|
|
#
|
|
# plt.xlabel('Recall')
|
|
# plt.ylabel('Precision')
|
|
# plt.title('Extension of Precision-Recall curve to multi-class')
|
|
# plt.legend(loc="best")
|
|
# plt.show()
|
|
# pass
|
|
|
|
|
|
def calculate_similarity(outputs, targets, threshold):
|
|
FN, FP, TN, TP = 0, 0, 0, 0
|
|
for output, target in zip(outputs, targets):
|
|
# print("output >> {} , target >> {}".format(output, target))
|
|
if output != target:
|
|
if target == 0:
|
|
FP += 1
|
|
elif target == 1:
|
|
FN += 1
|
|
else:
|
|
if target == 0:
|
|
TN += 1
|
|
elif target == 1:
|
|
TP += 1
|
|
if TP == 0:
|
|
prec, recall = 0, 0
|
|
else:
|
|
prec = TP / (TP + FP)
|
|
recall = TP / (TP + FN)
|
|
if TN == 0:
|
|
tn_prec, tn_recall = 0, 0
|
|
else:
|
|
tn_prec = TN / (TN + FN)
|
|
tn_recall = TN / (TN + FP)
|
|
# print("TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(TP, FP, TN, FN))
|
|
if threshold == 0.1:
|
|
print("TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(TP, FP, TN, FN))
|
|
return prec, recall, tn_prec, tn_recall
|
|
|
|
|
|
def show_pr(prec, recall, tn_prec, tn_recall, thres, title_name):
|
|
x = thres
|
|
plt.figure(figsize=(10, 6))
|
|
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
|
plt.plot(x, tn_recall, color='black', label='recall_TN:TN/TNFP')
|
|
plt.plot(x, prec, color='blue', label='PrecisePos:TP/TPFN')
|
|
plt.plot(x, tn_prec, color='green', label='PreciseNeg:TN/TNFP')
|
|
plt.legend()
|
|
plt.xlabel('threshold')
|
|
# if self.title_name is not None:
|
|
plt.title(title_name)
|
|
plt.grid(True, linestyle='--', alpha=0.5)
|
|
plt.savefig(os.sep.join(['../pr_test', 'yolo_pr.png']))
|
|
plt.show()
|
|
plt.close()
|
|
|
|
|
|
def write_results_to_file(title_name, thres, recall, recall_TN, PrecisePos, PreciseNeg):
|
|
file_path = os.sep.join(['../pr_test', title_name + '.txt'])
|
|
with open(file_path, 'w') as file:
|
|
file.write("threshold, recall, recall_TN, PrecisePos, PreciseNeg\n")
|
|
for thre, rec, rec_tn, prec_pos, prec_neg in zip(thres, recall, recall_TN, PrecisePos, PreciseNeg):
|
|
file.write(
|
|
f"thre>>{thre:.2f}, recall>>{rec:.4f}, precise_pos>>{prec_pos:.4f}, recall_tn>>{rec_tn:.4f}, precise_neg>>{prec_neg:4f}\n")
|
|
|
|
|
|
def show_to_pr(y_true, y_prob, type): # 二分类
|
|
thres = [i * 0.01 for i in range(101)]
|
|
title_name = 'yolo_pr'
|
|
recall, recall_TN, PrecisePos, PreciseNeg = [], [], [], []
|
|
for threshold in thres:
|
|
y_scores_adjusted = []
|
|
if type == 0:
|
|
y_scores_adjusted = np.where(y_prob < threshold, 0, 1)
|
|
elif type == 1:
|
|
for yp in y_prob:
|
|
if yp != 0:
|
|
yp = np.array(yp)
|
|
yp_num = np.sum(np.array(yp) > threshold)
|
|
if yp_num / len(yp) > 0.1:
|
|
y_scores_adjusted.append(1)
|
|
else:
|
|
y_scores_adjusted.append(0)
|
|
else:
|
|
y_scores_adjusted.append(0)
|
|
prec, pos_recall, tn_prec, tn_recall = calculate_similarity(y_scores_adjusted, y_true, threshold)
|
|
recall.append(pos_recall)
|
|
recall_TN.append(tn_recall)
|
|
PrecisePos.append(prec)
|
|
PreciseNeg.append(tn_prec)
|
|
# print(" prec>>{} recall>>{} tn_prec>>{} tn_recall>>{} threshold>>{}\n".format(prec, pos_recall, tn_prec,
|
|
# tn_recall, threshold))
|
|
show_pr(PrecisePos, recall, PreciseNeg, recall_TN, thres, title_name)
|
|
write_results_to_file(title_name, thres, recall, recall_TN, PrecisePos, PreciseNeg)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pass
|