mirror of
https://gitee.com/nanjing-yimao-information/ieemoo-ai-gift.git
synced 2025-08-18 21:30:25 +00:00
增加测试维度
This commit is contained in:
109
ultralytics/utils/show_pr.py
Normal file
109
ultralytics/utils/show_pr.py
Normal file
@ -0,0 +1,109 @@
|
||||
import os.path
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
# [tag, bandage, null]
|
||||
# [0, 1, 2]
|
||||
class ShowPR:
|
||||
def __init__(self, tags, prec_value, title_name=None):
|
||||
self.tags = tags
|
||||
self.prec_value = prec_value
|
||||
self.thres = [i * 0.01 for i in range(101)]
|
||||
self.title_name = title_name
|
||||
|
||||
def change_precValue(self, thre=0.5):
|
||||
values = []
|
||||
for i in range(len(self.prec_value)):
|
||||
value = []
|
||||
for j in range(len(self.prec_value[i])):
|
||||
if self.prec_value[i][j] > thre:
|
||||
value.append(1)
|
||||
else:
|
||||
value.append(0)
|
||||
values.append(value)
|
||||
return values
|
||||
|
||||
def _calculate_pr(self, prec_value):
|
||||
FN, FP, TN, TP = 0, 0, 0, 0
|
||||
for output, target in zip(prec_value, self.tags):
|
||||
# print("output >> {} , target >> {}".format(output, target))
|
||||
if output != target:
|
||||
if target == 0:
|
||||
FP += 1
|
||||
elif target == 1:
|
||||
FN += 1
|
||||
else:
|
||||
if target == 0:
|
||||
TN += 1
|
||||
elif target == 1:
|
||||
TP += 1
|
||||
if TP == 0:
|
||||
prec, recall = 0, 0
|
||||
else:
|
||||
prec = TP / (TP + FP)
|
||||
recall = TP / (TP + FN)
|
||||
if TN == 0:
|
||||
tn_prec, tn_recall = 0, 0
|
||||
else:
|
||||
tn_prec = TN / (TN + FN)
|
||||
tn_recall = TN / (TN + FP)
|
||||
# print("TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(TP, FP, TN, FN))
|
||||
return prec, recall, tn_prec, tn_recall
|
||||
|
||||
def calculate_multiple(self):
|
||||
recall, recall_TN, PrecisePos, PreciseNeg = [], [], [], []
|
||||
for thre in self.thres:
|
||||
# prec_value = []
|
||||
# if self.prec_value >= thre:
|
||||
# prec_value.append(1)
|
||||
# else:
|
||||
# prec_value.append(0)
|
||||
prec_value = [1 if num >= thre else 0 for num in self.prec_value]
|
||||
prec, recall_pos, tn_prec, tn_recall = self._calculate_pr(prec_value)
|
||||
print(
|
||||
f"thre>>{thre:.2f}, recall>>{recall_pos:.4f}, precise_pos>>{prec:.4f}, recall_tn>>{tn_recall:.4f}, precise_neg>>{tn_prec:4f}")
|
||||
PrecisePos.append(prec)
|
||||
recall.append(recall_pos)
|
||||
PreciseNeg.append(tn_prec)
|
||||
recall_TN.append(tn_recall)
|
||||
return recall, recall_TN, PrecisePos, PreciseNeg
|
||||
|
||||
def write_results_to_file(self, recall, recall_TN, PrecisePos, PreciseNeg):
|
||||
file_path = os.sep.join(['./ckpts/tracePR', self.title_name + '.txt'])
|
||||
with open(file_path, 'w') as file:
|
||||
file.write("threshold, recall, recall_TN, PrecisePos, PreciseNeg\n")
|
||||
for thre, rec, rec_tn, prec_pos, prec_neg in zip(self.thres, recall, recall_TN, PrecisePos, PreciseNeg):
|
||||
file.write(
|
||||
f"thre>>{thre:.2f}, recall>>{rec:.4f}, precise_pos>>{prec_pos:.4f}, recall_tn>>{rec_tn:.4f}, precise_neg>>{prec_neg:4f}\n")
|
||||
|
||||
def show_pr(self, recall, recall_TN, PrecisePos, PreciseNeg):
|
||||
# self.calculate_multiple()
|
||||
x = self.thres
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||||
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
|
||||
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
# if self.title_name is not None:
|
||||
# plt.title(f"PrecisePos & Recall ratio:{ratio:.2f}", fontdict={'fontsize': 12, 'fontweight': 'black'})
|
||||
# plt.grid(True, linestyle='--', alpha=0.5)
|
||||
# 启用次刻度
|
||||
# plt.minorticks_on()
|
||||
|
||||
# 设置主刻度的网格线
|
||||
plt.grid(which='major', linestyle='-', alpha=0.5, color='gray')
|
||||
|
||||
# 设置次刻度的网格线
|
||||
plt.grid(which='minor', linestyle=':', alpha=0.3, color='gray')
|
||||
plt.savefig(os.sep.join(['./ckpts/tracePR', self.title_name + '.png']))
|
||||
plt.show()
|
||||
plt.close()
|
||||
self.write_results_to_file(recall, recall_TN, PrecisePos, PreciseNeg)
|
||||
|
||||
def get_pic(self):
|
||||
recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple()
|
||||
self.show_pr(recall, recall_TN, PrecisePos, PreciseNeg)
|
@ -110,7 +110,7 @@ class ShowPR:
|
||||
def get_pic(self):
|
||||
for ratio in self.ratios:
|
||||
# ratio = 0.5
|
||||
if ratio < 0.2 or ratio > 0.95:
|
||||
if ratio < 0.1 or ratio > 0.95:
|
||||
continue
|
||||
recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple(ratio)
|
||||
self.show_pr(recall, recall_TN, PrecisePos, PreciseNeg, ratio)
|
||||
|
Reference in New Issue
Block a user