Files
ieemoo-ai-contrast/tools/json_contrast.py
2025-06-11 15:23:50 +08:00

119 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import random
def showHist(same, cross):
Same = np.array(same)
Cross = np.array(cross)
fig, axs = plt.subplots(2, 1)
axs[0].hist(Same, bins=50, edgecolor='black')
axs[0].set_xlim([-0.1, 1])
axs[0].set_title('Same Barcode')
axs[1].hist(Cross, bins=50, edgecolor='black')
axs[1].set_xlim([-0.1, 1])
axs[1].set_title('Cross Barcode')
# plt.savefig('plot.png')
plt.show()
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
plt.figure(figsize=(10, 6))
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
plt.plot(x, Correct, color='m', label='Correct(TN+TP)/(TPFN+TNFP)')
plt.legend()
plt.xlabel('threshold')
# plt.ylabel('Similarity')
plt.grid(True, linestyle='--', alpha=0.5)
plt.savefig('grid.png')
plt.show()
plt.close()
def compute_accuracy_recall(score, labels):
th = 0.1
squence = np.linspace(-1, 1, num=50)
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
Same = score[:len(score) // 2]
Cross = score[len(score) // 2:]
for th in squence:
t_score = (score > th)
t_labels = (labels == 1)
TP = np.sum(np.logical_and(t_score, t_labels))
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
f_score = (score < th)
f_labels = (labels == 0)
TN = np.sum(np.logical_and(f_score, f_labels))
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
recall.append(0 if TP == 0 else TP / (TP + FN))
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
showHist(Same, Cross)
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
def get_similarity(features1, features2, n, m):
features1 = np.array(features1)
features2 = np.array(features2)
all_similarity = []
for feature1 in features1:
for feature2 in features2:
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
all_similarity.append(similarity)
test_similarity = np.array(all_similarity)
np_all_array = np.array(all_similarity).reshape(len(features1), len(features2))
if n == 5 and m == 5:
print(all_similarity)
return np.mean(np_all_array), all_similarity
# return sum(all_similarity)/len(all_similarity), all_similarity
# return max(all_similarity), all_similarity
def deal_similarity(dicts):
all_similarity = []
similarity = []
same_barcode, diff_barcode = [], []
for n, (key1, value1) in enumerate(dicts.items()):
print('key1 >> {}'.format(key1))
for m, (key2, value2) in enumerate(dicts.items()):
print('key1 >> {} key2 >> {} peidui {}{}'.format(key1, key2, n, m))
max_similarity, some_similarity = get_similarity(value1, value2, n, m)
similarity.append(max_similarity)
if key1 == key2:
same_barcode += some_similarity
else:
diff_barcode += some_similarity
all_similarity.append(similarity)
similarity = []
all_similarity = np.array(all_similarity)
random.shuffle(diff_barcode)
same_list = [1] * len(same_barcode)
diff_list = [0] * len(same_barcode)
all_list = same_list + diff_list
all_score = same_barcode + diff_barcode[:len(same_barcode)]
compute_accuracy_recall(np.array(all_score), np.array(all_list))
print(all_similarity.shape)
with open('../search_library/data_zhanting.json', 'r') as file:
data = json.load(file)
dicts = {}
for dict in data['total']:
key = dict['key']
value = dict['value']
dicts[key] = value
deal_similarity(dicts)