72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
# coding=utf-8
|
|
# /usr/bin/env pythpn
|
|
|
|
import sys
|
|
sys.path.append('..')
|
|
import os
|
|
from PIL import Image
|
|
from cirtorch.networks.imageretrievalnet import extract_vectors, extract_vectors_o
|
|
from utils.config import cfg
|
|
from utils.monitor import Moniting
|
|
import cv2 as cv
|
|
# setting up the visible GPU
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
|
|
|
class ImageProcess():
|
|
def __init__(self, img_dir):
|
|
self.img_dir = img_dir
|
|
|
|
def process(self, uuid_barcode):
|
|
imgs = list()
|
|
nu = 0
|
|
for root, dirs, files in os.walk(self.img_dir):
|
|
for file in files:
|
|
img_path = os.path.join(root + os.sep, file)
|
|
try:
|
|
image = Image.open(img_path)
|
|
if max(image.size) / min(image.size) < 5:
|
|
if uuid_barcode == None:
|
|
imgs.append(img_path)
|
|
print('\r>>>> {}/{} Train done...'.format((nu + 1),
|
|
len(os.listdir(self.img_dir))),
|
|
end='')
|
|
nu+=1
|
|
else:
|
|
if uuid_barcode in img_path:
|
|
imgs.append(img_path)
|
|
except:
|
|
print("image height/width ratio is small")
|
|
return imgs
|
|
|
|
class AntiFraudFeatureDataset():
|
|
def __init__(self, uuid_barcode=None, test_img_dir=cfg.TEST_IMG_DIR):#, model='work'):
|
|
self.uuid_barcode = uuid_barcode
|
|
self.TestImgDir = test_img_dir
|
|
#self.model = model
|
|
|
|
def extractFeature_o(self, net, image, transform, ms):
|
|
size = cfg.RESIZE
|
|
#image = cv.resize(image, (size, size))
|
|
vecs = extract_vectors_o(net, image, size,transform, ms=ms)
|
|
feature_dict = list(vecs.detach().cpu().numpy().T)
|
|
return feature_dict
|
|
|
|
def extractFeature(self, net, transform, ms):
|
|
# extract database and query vectors
|
|
print('>> database images...')
|
|
images = ImageProcess(self.TestImgDir).process(self.uuid_barcode)
|
|
#print('ori', images)
|
|
vecs, img_paths = extract_vectors( net, images, cfg.RESIZE, transform, ms=ms)
|
|
feature_dict = list(vecs.detach().cpu().numpy().T)
|
|
return feature_dict
|
|
|
|
if __name__ == '__main__':
|
|
from utils.tools import createNet
|
|
net, transform, ms = createNet()
|
|
path = '../data/imgs/1.jpg'
|
|
image = cv.imread(path)
|
|
affd = AntiFraudFeatureDataset()
|
|
feature = affd.extractFeature_o(net, image, transform, ms)
|
|
print(len(feature))
|
|
|