113 lines
4.4 KiB
Python
113 lines
4.4 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from test_ori import group_image, init_model, featurize
|
|
from config import config as conf
|
|
import json
|
|
import os.path as osp
|
|
|
|
def compare_fp16_fp32(values_pf16, values_pf32, dataTest):
|
|
if dataTest:
|
|
norm_values_pf16 = torch.norm(values_pf16, p=2)
|
|
norm_values_pf32 = torch.norm(values_pf32, p=2)
|
|
euclidean_distance = torch.norm(norm_values_pf16 - norm_values_pf32, p=2)
|
|
print(f"欧几里得距离: {euclidean_distance}")
|
|
cosine_sim = torch.dot(values_pf16.float(), values_pf32) / (norm_values_pf16 * norm_values_pf32)
|
|
print(f"余弦相似度: {cosine_sim}")
|
|
else:
|
|
|
|
pass
|
|
def cosin_metric(x1, x2, fp32=True):
|
|
if fp32:
|
|
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
|
else:
|
|
x1_fp16 = x1.astype(np.float16)
|
|
x2_fp16 = x2.astype(np.float16)
|
|
# print(type(x1))
|
|
# pdb.set_trace()
|
|
return np.dot(x1_fp16, x2_fp16) / (np.linalg.norm(x1_fp16) * np.linalg.norm(x2_fp16))
|
|
def deal_group_pair(pairList1, pairList2):
|
|
one_similarity_fp16, one_similarity_fp32, allsimilarity_fp32, allsimilarity_fp16 = [], [], [], []
|
|
for pair1 in pairList1:
|
|
for pair2 in pairList2:
|
|
# similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
|
one_similarity_fp32.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), True))
|
|
one_similarity_fp16.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), False))
|
|
allsimilarity_fp32.append(one_similarity_fp32)
|
|
allsimilarity_fp16.append(one_similarity_fp16)
|
|
one_similarity_fp16, one_similarity_fp32 = [], []
|
|
return np.array(allsimilarity_fp32), np.array(allsimilarity_fp16)
|
|
|
|
def compute_group_accuracy(content_list_read, model):
|
|
allSimilarity, allLabel = [], []
|
|
Same, Cross = [], []
|
|
flag_same = True
|
|
flag_diff = True
|
|
for data_loaded in content_list_read:
|
|
one_group_list = []
|
|
try:
|
|
if (flag_same and str(data_loaded[-1]) == '1') or (flag_diff and str(data_loaded[-1]) == '0'):
|
|
for i in range(2):
|
|
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
|
group = group_image(images, conf.test_batch_size)
|
|
d = featurize(group[0], conf.test_transform, model, conf.device)
|
|
one_group_list.append(d.values())
|
|
if str(data_loaded[-1]) == '1':
|
|
flag_same = False
|
|
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
|
|
print('fp32 same-- >', allsimilarity_fp32)
|
|
print('fp16 same-- >', allsimilarity_fp16)
|
|
else:
|
|
flag_diff = False
|
|
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
|
|
print('fp32 diff-- >', allsimilarity_fp32)
|
|
print('fp16 diff-- >', allsimilarity_fp16)
|
|
except Exception as e:
|
|
continue
|
|
# print(allSimilarity)
|
|
# print(allLabel)
|
|
return allSimilarity, allLabel
|
|
def get_feature_list(imgPth):
|
|
imgs = get_files(imgPth)
|
|
group = group_image(imgs, conf.test_batch_size)
|
|
model = init_model()
|
|
model.eval()
|
|
fe = featurize(group[0], conf.test_transform, model, conf.device)
|
|
return fe
|
|
|
|
|
|
def get_files(imgPth):
|
|
imgsList = []
|
|
for img in os.walk(imgPth):
|
|
for img_name in img[2]:
|
|
img_path = os.sep.join([img[0], img_name])
|
|
imgsList.append(img_path)
|
|
return imgsList
|
|
import pdb
|
|
|
|
def compare(imgPth, group=False):
|
|
model = init_model()
|
|
model.eval()
|
|
if not group:
|
|
values_pf16, values_pf32 = [], []
|
|
fe = get_feature_list(imgPth)
|
|
# pdb.set_trace()
|
|
values_pf32 += [value.cpu() for value in fe.values()]
|
|
values_pf16 += [value.cpu().half() for value in fe.values()]
|
|
for value_pf16, value_pf32 in zip(values_pf16, values_pf32):
|
|
compare_fp16_fp32(value_pf16, value_pf32, dataTest=True)
|
|
else:
|
|
filename = conf.test_group_json
|
|
with open(filename, 'r', encoding='utf-8') as file:
|
|
content_list_read = json.load(file)
|
|
compute_group_accuracy(content_list_read, model)
|
|
pass
|
|
|
|
|
|
if __name__ == '__main__':
|
|
imgPth = './data/test/inner/3701375401900'
|
|
compare(imgPth)
|