Files
ieemoo-ai-searchv2/utils/retrieval_feature.py
2022-11-22 15:32:06 +08:00

72 lines
2.5 KiB
Python

# coding=utf-8
# /usr/bin/env pythpn
import sys
sys.path.append('..')
import os
from PIL import Image
from cirtorch.networks.imageretrievalnet import extract_vectors, extract_vectors_o
from utils.config import cfg
from utils.monitor import Moniting
import cv2 as cv
# setting up the visible GPU
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
class ImageProcess():
def __init__(self, img_dir):
self.img_dir = img_dir
def process(self, uuid_barcode):
imgs = list()
nu = 0
for root, dirs, files in os.walk(self.img_dir):
for file in files:
img_path = os.path.join(root + os.sep, file)
try:
image = Image.open(img_path)
if max(image.size) / min(image.size) < 5:
if uuid_barcode == None:
imgs.append(img_path)
print('\r>>>> {}/{} Train done...'.format((nu + 1),
len(os.listdir(self.img_dir))),
end='')
nu+=1
else:
if uuid_barcode in img_path:
imgs.append(img_path)
except:
print("image height/width ratio is small")
return imgs
class AntiFraudFeatureDataset():
def __init__(self, uuid_barcode=None, test_img_dir=cfg.TEST_IMG_DIR):#, model='work'):
self.uuid_barcode = uuid_barcode
self.TestImgDir = test_img_dir
#self.model = model
def extractFeature_o(self, net, image, transform, ms):
size = cfg.RESIZE
#image = cv.resize(image, (size, size))
vecs = extract_vectors_o(net, image, size,transform, ms=ms)
feature_dict = list(vecs.detach().cpu().numpy().T)
return feature_dict
def extractFeature(self, net, transform, ms):
# extract database and query vectors
print('>> database images...')
images = ImageProcess(self.TestImgDir).process(self.uuid_barcode)
#print('ori', images)
vecs, img_paths = extract_vectors( net, images, cfg.RESIZE, transform, ms=ms)
feature_dict = list(vecs.detach().cpu().numpy().T)
return feature_dict
if __name__ == '__main__':
from utils.tools import createNet
net, transform, ms = createNet()
path = '../data/imgs/1.jpg'
image = cv.imread(path)
affd = AntiFraudFeatureDataset()
feature = affd.extractFeature_o(net, image, transform, ms)
print(len(feature))