Files
ieemoo-ai-conpurchase/utils/embedding.py
2023-06-25 13:55:22 +08:00

64 lines
2.4 KiB
Python

from network.createNet import initnet
import cv2, torch
import numpy as np
class DataProcessing():
def __init__(self, backbone, model_path, device):
model = initnet(backbone)
model.load_state_dict(torch.load(model_path))
model.to(torch.device(device))
model.eval()
self.model = model
self.device = device
def cosin_metric(self, x1, x2):
if not len(x1)==len(x2):
return 100
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
def load_image(self, image):
#image = cv2.imread(image)
if image is None:
return None
image = cv2.resize(image, (256, 256))
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :, :, :]
image = image.astype(np.float32, copy=False)
return image
def getFeatures(self, imgs): #<< input type is np
images = None
features = []
assert (type(imgs) is list), 'Err input need list'
for i, img in enumerate(imgs):
#print('imgs >>> {}{}'.format(type(img), type(img)))
image = self.load_image(img)
if image is None:
print('read {} error'.format(img_path))
else:
data = torch.from_numpy(image)
data = data.to(torch.device(self.device))
output = self.model(data)
output = output.data.cpu().numpy()
features.append(output)
return features # >>>>>>> return type is list
def cal_cosine(self, t_features, m_features): # Calculate the cosine angular distance
if not (type(m_features) is list or np.ndarray):
return 'Err m_features need list or ndarray'
elif (type(t_features) is list or np.ndarray):
cosin_re = []
for tf in t_features:
for mf in m_features:
#print('tf >> {} tf>>{} mf>>{} mf>>{}'.format(tf, type(tf), len(mf), type(mf)))
if type(mf) is list:
cosin_re.append(self.cosin_metric(tf.reshape(-1), mf))
else:
cosin_re.append(self.cosin_metric(tf.reshape(-1), mf.reshape(-1)))
else:
cosin_re = []
for mf in m_features:
cosin_re.append(self.cosin_metric(t_features.reshape(-1), mf.reshape(-1)))
return cosin_re