56 lines
1.5 KiB
Python
56 lines
1.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on Thu Oct 31 15:17:01 2024
|
|
|
|
@author: ym
|
|
"""
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
def showHist(err, correct):
|
|
err = np.array(err)
|
|
correct = np.array(correct)
|
|
|
|
fig, axs = plt.subplots(2, 1)
|
|
axs[0].hist(err, bins=50, edgecolor='black')
|
|
axs[0].set_xlim([0, 1])
|
|
axs[0].set_title('err')
|
|
|
|
axs[1].hist(correct, bins=50, edgecolor='black')
|
|
axs[1].set_xlim([0, 1])
|
|
axs[1].set_title('correct')
|
|
# plt.show()
|
|
|
|
return plt
|
|
|
|
def show_recall_prec(recall, prec, ths):
|
|
# x = np.linspace(start=-0, stop=1, num=11, endpoint=True).tolist()
|
|
fig = plt.figure(figsize=(10, 6))
|
|
plt.plot(ths, recall, color='red', label='recall')
|
|
plt.plot(ths, prec, color='blue', label='PrecisePos')
|
|
plt.legend()
|
|
plt.xlabel(f'threshold')
|
|
# plt.ylabel('Similarity')
|
|
plt.grid(True, linestyle='--', alpha=0.5)
|
|
# plt.savefig('accuracy_recall_grid.png')
|
|
# plt.show()
|
|
# plt.close()
|
|
|
|
return plt
|
|
|
|
|
|
def compute_recall_precision(err_similarity, correct_similarity):
|
|
ths = np.linspace(0, 1, 51)
|
|
recall, prec = [], []
|
|
for th in ths:
|
|
TP = len([num for num in correct_similarity if num >= th])
|
|
FP = len([num for num in err_similarity if num >= th])
|
|
if (TP+FP) == 0:
|
|
prec.append(1)
|
|
recall.append(0)
|
|
else:
|
|
prec.append(TP / (TP + FP))
|
|
recall.append(TP / (len(err_similarity) + len(correct_similarity)))
|
|
return recall, prec, ths |