rebuild
This commit is contained in:
112
tools/fp32comparefp16.py
Normal file
112
tools/fp32comparefp16.py
Normal file
@ -0,0 +1,112 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from test_ori import group_image, init_model, featurize
|
||||
from config import config as conf
|
||||
import json
|
||||
import os.path as osp
|
||||
|
||||
def compare_fp16_fp32(values_pf16, values_pf32, dataTest):
|
||||
if dataTest:
|
||||
norm_values_pf16 = torch.norm(values_pf16, p=2)
|
||||
norm_values_pf32 = torch.norm(values_pf32, p=2)
|
||||
euclidean_distance = torch.norm(norm_values_pf16 - norm_values_pf32, p=2)
|
||||
print(f"欧几里得距离: {euclidean_distance}")
|
||||
cosine_sim = torch.dot(values_pf16.float(), values_pf32) / (norm_values_pf16 * norm_values_pf32)
|
||||
print(f"余弦相似度: {cosine_sim}")
|
||||
else:
|
||||
|
||||
pass
|
||||
def cosin_metric(x1, x2, fp32=True):
|
||||
if fp32:
|
||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
else:
|
||||
x1_fp16 = x1.astype(np.float16)
|
||||
x2_fp16 = x2.astype(np.float16)
|
||||
# print(type(x1))
|
||||
# pdb.set_trace()
|
||||
return np.dot(x1_fp16, x2_fp16) / (np.linalg.norm(x1_fp16) * np.linalg.norm(x2_fp16))
|
||||
def deal_group_pair(pairList1, pairList2):
|
||||
one_similarity_fp16, one_similarity_fp32, allsimilarity_fp32, allsimilarity_fp16 = [], [], [], []
|
||||
for pair1 in pairList1:
|
||||
for pair2 in pairList2:
|
||||
# similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
||||
one_similarity_fp32.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), True))
|
||||
one_similarity_fp16.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), False))
|
||||
allsimilarity_fp32.append(one_similarity_fp32)
|
||||
allsimilarity_fp16.append(one_similarity_fp16)
|
||||
one_similarity_fp16, one_similarity_fp32 = [], []
|
||||
return np.array(allsimilarity_fp32), np.array(allsimilarity_fp16)
|
||||
|
||||
def compute_group_accuracy(content_list_read, model):
|
||||
allSimilarity, allLabel = [], []
|
||||
Same, Cross = [], []
|
||||
flag_same = True
|
||||
flag_diff = True
|
||||
for data_loaded in content_list_read:
|
||||
one_group_list = []
|
||||
try:
|
||||
if (flag_same and str(data_loaded[-1]) == '1') or (flag_diff and str(data_loaded[-1]) == '0'):
|
||||
for i in range(2):
|
||||
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
||||
group = group_image(images, conf.test_batch_size)
|
||||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
one_group_list.append(d.values())
|
||||
if str(data_loaded[-1]) == '1':
|
||||
flag_same = False
|
||||
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
print('fp32 same-- >', allsimilarity_fp32)
|
||||
print('fp16 same-- >', allsimilarity_fp16)
|
||||
else:
|
||||
flag_diff = False
|
||||
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
print('fp32 diff-- >', allsimilarity_fp32)
|
||||
print('fp16 diff-- >', allsimilarity_fp16)
|
||||
except Exception as e:
|
||||
continue
|
||||
# print(allSimilarity)
|
||||
# print(allLabel)
|
||||
return allSimilarity, allLabel
|
||||
def get_feature_list(imgPth):
|
||||
imgs = get_files(imgPth)
|
||||
group = group_image(imgs, conf.test_batch_size)
|
||||
model = init_model()
|
||||
model.eval()
|
||||
fe = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
return fe
|
||||
|
||||
|
||||
def get_files(imgPth):
|
||||
imgsList = []
|
||||
for img in os.walk(imgPth):
|
||||
for img_name in img[2]:
|
||||
img_path = os.sep.join([img[0], img_name])
|
||||
imgsList.append(img_path)
|
||||
return imgsList
|
||||
import pdb
|
||||
|
||||
def compare(imgPth, group=False):
|
||||
model = init_model()
|
||||
model.eval()
|
||||
if not group:
|
||||
values_pf16, values_pf32 = [], []
|
||||
fe = get_feature_list(imgPth)
|
||||
# pdb.set_trace()
|
||||
values_pf32 += [value.cpu() for value in fe.values()]
|
||||
values_pf16 += [value.cpu().half() for value in fe.values()]
|
||||
for value_pf16, value_pf32 in zip(values_pf16, values_pf32):
|
||||
compare_fp16_fp32(value_pf16, value_pf32, dataTest=True)
|
||||
else:
|
||||
filename = conf.test_group_json
|
||||
with open(filename, 'r', encoding='utf-8') as file:
|
||||
content_list_read = json.load(file)
|
||||
compute_group_accuracy(content_list_read, model)
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
imgPth = './data/test/inner/3701375401900'
|
||||
compare(imgPth)
|
Reference in New Issue
Block a user