update
This commit is contained in:
56
contrast/utils/tools.py
Normal file
56
contrast/utils/tools.py
Normal file
@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Thu Oct 31 15:17:01 2024
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
||||
def showHist(err, correct):
|
||||
err = np.array(err)
|
||||
correct = np.array(correct)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(err, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([0, 1])
|
||||
axs[0].set_title('err')
|
||||
|
||||
axs[1].hist(correct, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([0, 1])
|
||||
axs[1].set_title('correct')
|
||||
# plt.show()
|
||||
|
||||
return plt
|
||||
|
||||
def show_recall_prec(recall, prec, ths):
|
||||
# x = np.linspace(start=-0, stop=1, num=11, endpoint=True).tolist()
|
||||
fig = plt.figure(figsize=(10, 6))
|
||||
plt.plot(ths, recall, color='red', label='recall')
|
||||
plt.plot(ths, prec, color='blue', label='PrecisePos')
|
||||
plt.legend()
|
||||
plt.xlabel(f'threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
# plt.savefig('accuracy_recall_grid.png')
|
||||
# plt.show()
|
||||
# plt.close()
|
||||
|
||||
return plt
|
||||
|
||||
|
||||
def compute_recall_precision(err_similarity, correct_similarity):
|
||||
ths = np.linspace(0, 1, 51)
|
||||
recall, prec = [], []
|
||||
for th in ths:
|
||||
TP = len([num for num in correct_similarity if num >= th])
|
||||
FP = len([num for num in err_similarity if num >= th])
|
||||
if (TP+FP) == 0:
|
||||
prec.append(1)
|
||||
recall.append(0)
|
||||
else:
|
||||
prec.append(TP / (TP + FP))
|
||||
recall.append(TP / (len(err_similarity) + len(correct_similarity)))
|
||||
return recall, prec, ths
|
Reference in New Issue
Block a user