first push
This commit is contained in:
71
utils/retrieval_feature.py
Normal file
71
utils/retrieval_feature.py
Normal file
@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# /usr/bin/env pythpn
|
||||
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
import os
|
||||
from PIL import Image
|
||||
from cirtorch.networks.imageretrievalnet import extract_vectors, extract_vectors_o
|
||||
from utils.config import cfg
|
||||
from utils.monitor import Moniting
|
||||
import cv2 as cv
|
||||
# setting up the visible GPU
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
||||
|
||||
class ImageProcess():
|
||||
def __init__(self, img_dir):
|
||||
self.img_dir = img_dir
|
||||
|
||||
def process(self, uuid_barcode):
|
||||
imgs = list()
|
||||
nu = 0
|
||||
for root, dirs, files in os.walk(self.img_dir):
|
||||
for file in files:
|
||||
img_path = os.path.join(root + os.sep, file)
|
||||
try:
|
||||
image = Image.open(img_path)
|
||||
if max(image.size) / min(image.size) < 5:
|
||||
if uuid_barcode == None:
|
||||
imgs.append(img_path)
|
||||
print('\r>>>> {}/{} Train done...'.format((nu + 1),
|
||||
len(os.listdir(self.img_dir))),
|
||||
end='')
|
||||
nu+=1
|
||||
else:
|
||||
if uuid_barcode in img_path:
|
||||
imgs.append(img_path)
|
||||
except:
|
||||
print("image height/width ratio is small")
|
||||
return imgs
|
||||
|
||||
class AntiFraudFeatureDataset():
|
||||
def __init__(self, uuid_barcode=None, test_img_dir=cfg.TEST_IMG_DIR):#, model='work'):
|
||||
self.uuid_barcode = uuid_barcode
|
||||
self.TestImgDir = test_img_dir
|
||||
#self.model = model
|
||||
|
||||
def extractFeature_o(self, net, image, transform, ms):
|
||||
size = cfg.RESIZE
|
||||
#image = cv.resize(image, (size, size))
|
||||
vecs = extract_vectors_o(net, image, size,transform, ms=ms)
|
||||
feature_dict = list(vecs.detach().cpu().numpy().T)
|
||||
return feature_dict
|
||||
|
||||
def extractFeature(self, net, transform, ms):
|
||||
# extract database and query vectors
|
||||
print('>> database images...')
|
||||
images = ImageProcess(self.TestImgDir).process(self.uuid_barcode)
|
||||
#print('ori', images)
|
||||
vecs, img_paths = extract_vectors( net, images, cfg.RESIZE, transform, ms=ms)
|
||||
feature_dict = list(vecs.detach().cpu().numpy().T)
|
||||
return feature_dict
|
||||
|
||||
if __name__ == '__main__':
|
||||
from utils.tools import createNet
|
||||
net, transform, ms = createNet()
|
||||
path = '../data/imgs/1.jpg'
|
||||
image = cv.imread(path)
|
||||
affd = AntiFraudFeatureDataset()
|
||||
feature = affd.extractFeature_o(net, image, transform, ms)
|
||||
print(len(feature))
|
||||
|
Reference in New Issue
Block a user