119 lines
4.2 KiB
Python
119 lines
4.2 KiB
Python
import json
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import random
|
||
|
||
|
||
def showHist(same, cross):
|
||
Same = np.array(same)
|
||
Cross = np.array(cross)
|
||
|
||
fig, axs = plt.subplots(2, 1)
|
||
axs[0].hist(Same, bins=50, edgecolor='black')
|
||
axs[0].set_xlim([-0.1, 1])
|
||
axs[0].set_title('Same Barcode')
|
||
|
||
axs[1].hist(Cross, bins=50, edgecolor='black')
|
||
axs[1].set_xlim([-0.1, 1])
|
||
axs[1].set_title('Cross Barcode')
|
||
# plt.savefig('plot.png')
|
||
plt.show()
|
||
|
||
|
||
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
||
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
|
||
plt.figure(figsize=(10, 6))
|
||
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
|
||
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
|
||
plt.plot(x, Correct, color='m', label='Correct:(TN+TP)/(TPFN+TNFP)')
|
||
plt.legend()
|
||
plt.xlabel('threshold')
|
||
# plt.ylabel('Similarity')
|
||
plt.grid(True, linestyle='--', alpha=0.5)
|
||
plt.savefig('grid.png')
|
||
plt.show()
|
||
plt.close()
|
||
|
||
|
||
def compute_accuracy_recall(score, labels):
|
||
th = 0.1
|
||
squence = np.linspace(-1, 1, num=50)
|
||
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
|
||
Same = score[:len(score) // 2]
|
||
Cross = score[len(score) // 2:]
|
||
for th in squence:
|
||
t_score = (score > th)
|
||
t_labels = (labels == 1)
|
||
TP = np.sum(np.logical_and(t_score, t_labels))
|
||
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
|
||
f_score = (score < th)
|
||
f_labels = (labels == 0)
|
||
TN = np.sum(np.logical_and(f_score, f_labels))
|
||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||
|
||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||
|
||
showHist(Same, Cross)
|
||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||
|
||
|
||
def get_similarity(features1, features2, n, m):
|
||
features1 = np.array(features1)
|
||
features2 = np.array(features2)
|
||
all_similarity = []
|
||
for feature1 in features1:
|
||
for feature2 in features2:
|
||
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||
all_similarity.append(similarity)
|
||
test_similarity = np.array(all_similarity)
|
||
np_all_array = np.array(all_similarity).reshape(len(features1), len(features2))
|
||
if n == 5 and m == 5:
|
||
print(all_similarity)
|
||
return np.mean(np_all_array), all_similarity
|
||
# return sum(all_similarity)/len(all_similarity), all_similarity
|
||
# return max(all_similarity), all_similarity
|
||
|
||
|
||
def deal_similarity(dicts):
|
||
all_similarity = []
|
||
similarity = []
|
||
same_barcode, diff_barcode = [], []
|
||
for n, (key1, value1) in enumerate(dicts.items()):
|
||
print('key1 >> {}'.format(key1))
|
||
for m, (key2, value2) in enumerate(dicts.items()):
|
||
print('key1 >> {} key2 >> {} peidui {}{}'.format(key1, key2, n, m))
|
||
max_similarity, some_similarity = get_similarity(value1, value2, n, m)
|
||
similarity.append(max_similarity)
|
||
if key1 == key2:
|
||
same_barcode += some_similarity
|
||
else:
|
||
diff_barcode += some_similarity
|
||
all_similarity.append(similarity)
|
||
similarity = []
|
||
all_similarity = np.array(all_similarity)
|
||
random.shuffle(diff_barcode)
|
||
same_list = [1] * len(same_barcode)
|
||
diff_list = [0] * len(same_barcode)
|
||
all_list = same_list + diff_list
|
||
all_score = same_barcode + diff_barcode[:len(same_barcode)]
|
||
compute_accuracy_recall(np.array(all_score), np.array(all_list))
|
||
print(all_similarity.shape)
|
||
|
||
|
||
with open('../search_library/data_zhanting.json', 'r') as file:
|
||
data = json.load(file)
|
||
dicts = {}
|
||
for dict in data['total']:
|
||
key = dict['key']
|
||
value = dict['value']
|
||
dicts[key] = value
|
||
deal_similarity(dicts)
|