# coding=utf-8 # /usr/bin/env pythpn import sys sys.path.append('..') import os from PIL import Image from cirtorch.networks.imageretrievalnet import extract_vectors, extract_vectors_o from utils.config import cfg from utils.monitor import Moniting import cv2 as cv # setting up the visible GPU os.environ['CUDA_VISIBLE_DEVICES'] = "0" class ImageProcess(): def __init__(self, img_dir): self.img_dir = img_dir def process(self, uuid_barcode): imgs = list() nu = 0 for root, dirs, files in os.walk(self.img_dir): for file in files: img_path = os.path.join(root + os.sep, file) try: image = Image.open(img_path) if max(image.size) / min(image.size) < 5: if uuid_barcode == None: imgs.append(img_path) print('\r>>>> {}/{} Train done...'.format((nu + 1), len(os.listdir(self.img_dir))), end='') nu+=1 else: if uuid_barcode in img_path: imgs.append(img_path) except: print("image height/width ratio is small") return imgs class AntiFraudFeatureDataset(): def __init__(self, uuid_barcode=None, test_img_dir=cfg.TEST_IMG_DIR):#, model='work'): self.uuid_barcode = uuid_barcode self.TestImgDir = test_img_dir #self.model = model def extractFeature_o(self, net, image, transform, ms): size = cfg.RESIZE #image = cv.resize(image, (size, size)) vecs = extract_vectors_o(net, image, size,transform, ms=ms) feature_dict = list(vecs.detach().cpu().numpy().T) return feature_dict def extractFeature(self, net, transform, ms): # extract database and query vectors print('>> database images...') images = ImageProcess(self.TestImgDir).process(self.uuid_barcode) #print('ori', images) vecs, img_paths = extract_vectors( net, images, cfg.RESIZE, transform, ms=ms) feature_dict = list(vecs.detach().cpu().numpy().T) return feature_dict if __name__ == '__main__': from utils.tools import createNet net, transform, ms = createNet() path = '../data/imgs/1.jpg' image = cv.imread(path) affd = AntiFraudFeatureDataset() feature = affd.extractFeature_o(net, image, transform, ms) print(len(feature))