228 lines
7.8 KiB
Python
228 lines
7.8 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
Created on Sun Nov 3 12:05:19 2024
|
||
|
||
@author: ym
|
||
"""
|
||
import os
|
||
import time
|
||
# import torch
|
||
import pickle
|
||
# import json
|
||
import numpy as np
|
||
from PIL import Image
|
||
from feat_extract.config import config as conf
|
||
# from model import resnet18 as resnet18
|
||
from feat_extract.inference import FeatsInterface #, inference_image
|
||
|
||
|
||
IMG_FORMAT = ['.bmp', '.jpg', '.jpeg', '.png']
|
||
|
||
# def model_init(conf, mpath=None):
|
||
# '''======= 0. 配置特征提取模型地址 ======='''
|
||
# if mpath is None:
|
||
# model_path = conf.test_model
|
||
# else:
|
||
# model_path = mpath
|
||
|
||
# ##============ load resnet mdoel
|
||
# model = resnet18().to(conf.device)
|
||
# # model = nn.DataParallel(model).to(conf.device)
|
||
# model.load_state_dict(torch.load(model_path, map_location=conf.device))
|
||
# model.eval()
|
||
# print('load model {} '.format(conf.testbackbone))
|
||
|
||
# return model
|
||
|
||
def get_std_barcodeDict(bcdpath, savepath, bcdSet):
|
||
'''
|
||
inputs:
|
||
bcdpath: 已清洗的barcode样本图像,如果barcode下有'base'文件夹,只选用该文件夹下图像
|
||
(default = r'\\\\192.168.1.28\\share\\已标注数据备份\\对比数据\\barcode\\barcode_1771')
|
||
功能:
|
||
生成并保存只有一个key值的字典 {barcode: [imgpath1, imgpath1, ...]},
|
||
savepath: 字典存储地址,文件名格式:barcode.pickle
|
||
'''
|
||
|
||
# savepath = r'\\\\192.168.1.28\\share\\测试_202406\\contrast\\std_barcodes'
|
||
|
||
'''读取数据集中 barcode 列表'''
|
||
stdBarcodeList = []
|
||
for filename in os.listdir(bcdpath):
|
||
filepath = os.path.join(bcdpath, filename)
|
||
if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
|
||
continue
|
||
if bcdSet is None:
|
||
stdBarcodeList.append(filename)
|
||
elif filename in bcdSet:
|
||
stdBarcodeList.append(filename)
|
||
|
||
|
||
bcdPaths = [(barcode, os.path.join(bcdpath, barcode)) for barcode in stdBarcodeList]
|
||
|
||
'''遍历数据集,针对每一个barcode,生成并保存字典{barcode: [imgpath1, imgpath1, ...]}'''
|
||
k = 0
|
||
errbarcodes = []
|
||
for barcode, bpath in bcdPaths:
|
||
pickpath = os.path.join(savepath, f"{barcode}.pickle")
|
||
if os.path.isfile(pickpath):
|
||
continue
|
||
|
||
stdBarcodeDict = {}
|
||
stdBarcodeDict[barcode] = []
|
||
for root, dirs, files in os.walk(bpath):
|
||
imgpaths = []
|
||
if "base" in dirs:
|
||
broot = os.path.join(root, "base")
|
||
for imgname in os.listdir(broot):
|
||
imgpath = os.path.join(broot, imgname)
|
||
file, ext = os.path.splitext(imgpath)
|
||
|
||
if ext not in IMG_FORMAT:
|
||
continue
|
||
imgpaths.append(imgpath)
|
||
|
||
stdBarcodeDict[barcode].extend(imgpaths)
|
||
break
|
||
|
||
else:
|
||
for imgname in files:
|
||
imgpath = os.path.join(root, imgname)
|
||
_, ext = os.path.splitext(imgpath)
|
||
if ext not in IMG_FORMAT: continue
|
||
imgpaths.append(imgpath)
|
||
stdBarcodeDict[barcode].extend(imgpaths)
|
||
|
||
pickpath = os.path.join(savepath, f"{barcode}.pickle")
|
||
with open(pickpath, 'wb') as f:
|
||
pickle.dump(stdBarcodeDict, f)
|
||
|
||
print(f"Barcode: {barcode}")
|
||
|
||
# k += 1
|
||
# if k == 10:
|
||
# break
|
||
print(f"Len of errbarcodes: {len(errbarcodes)}")
|
||
return
|
||
|
||
|
||
|
||
def stdfeat_infer(imgPath, featPath, bcdSet=None):
|
||
'''
|
||
inputs:
|
||
imgPath: 该文件夹下的 pickle 文件格式 {barcode: [imgpath1, imgpath1, ...]}
|
||
featPath: imgPath图像对应特征的存储地址
|
||
功能:
|
||
对 imgPath中图像进行特征提取,生成只有一个key值的字典,
|
||
{barcode: features},features.shape=(nsample, 256),并保存至 featPath 中
|
||
|
||
'''
|
||
|
||
# imgPath = r"\\192.168.1.28\share\测试_202406\contrast\std_barcodes"
|
||
# featPath = r"\\192.168.1.28\share\测试_202406\contrast\std_features"
|
||
|
||
|
||
Encoder = FeatsInterface(conf)
|
||
|
||
'''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名;
|
||
(3)该pickle文件中字典的key值; (4)特征向量字典中的一个key值'''
|
||
k = 0
|
||
for filename in os.listdir(imgPath):
|
||
bcd, ext = os.path.splitext(filename)
|
||
filepath = os.path.join(imgPath, filename)
|
||
if ext != ".pickle": continue
|
||
if bcdSet is not None and bcd not in bcdSet:
|
||
continue
|
||
|
||
featpath = os.path.join(featPath, f"{bcd}.pickle")
|
||
if os.path.isfile(featpath):
|
||
continue
|
||
|
||
stdbDict = {}
|
||
t1 = time.time()
|
||
try:
|
||
with open(filepath, 'rb') as f:
|
||
bpDict = pickle.load(f)
|
||
|
||
for barcode, imgpaths in bpDict.items():
|
||
# feature = batch_inference(imgpaths, 8) #from vit distilled model of LiChen
|
||
# feature = inference_image(imgpaths, conf.test_transform, model, conf.device)
|
||
|
||
imgs = []
|
||
for d, imgpath in enumerate(imgpaths):
|
||
img = Image.open(imgpath)
|
||
imgs.append(img)
|
||
|
||
feature = Encoder.inference(imgs)
|
||
|
||
feature /= np.linalg.norm(feature, axis=1)[:, None]
|
||
|
||
feature_ft32 = feature.astype(np.float32)
|
||
|
||
# float16
|
||
feature_ft16 = feature.astype(np.float16)
|
||
feature_ft16 /= np.linalg.norm(feature_ft16, axis=1)[:, None]
|
||
|
||
# uint8, 两种策略,1) 精度损失小, 2) 计算复杂度小
|
||
# feature_uint8, _ = ft16_to_uint8(feature_ft16)
|
||
feature_uint8 = (feature_ft16*128).astype(np.int8)
|
||
|
||
'''================ 保存单个barcode特征 ================'''
|
||
##================== float32
|
||
stdbDict["barcode"] = barcode
|
||
stdbDict["imgpaths"] = imgpaths
|
||
stdbDict["feats_ft32"] = feature_ft32
|
||
stdbDict["feats_ft16"] = feature_ft16
|
||
stdbDict["feats_uint8"] = feature_uint8
|
||
|
||
with open(featpath, 'wb') as f:
|
||
pickle.dump(stdbDict, f)
|
||
|
||
except Exception as e:
|
||
print(f"Error accured at: {filename}, with Exception is: {e}")
|
||
|
||
|
||
t2 = time.time()
|
||
print(f"Barcode: {barcode}, need time: {t2-t1:.1f} secs")
|
||
# k += 1
|
||
# if k == 10:
|
||
# break
|
||
|
||
return
|
||
|
||
|
||
def gen_bcd_features(imgpath, bcdpath, featpath, eventSourcePath):
|
||
''' 生成标准特征集 '''
|
||
'''1. 提取 imgpath 中样本地址,生成字典{barcode: [imgpath1, imgpath1, ...]}
|
||
并存储于: bcdpath, 格式为 barcode.pickle'''
|
||
|
||
bcdList = []
|
||
for evtname in os.listdir(eventSourcePath):
|
||
bname, ext = os.path.splitext(evtname)
|
||
evt = bname.split('_')
|
||
if len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10:
|
||
bcdList.append(evt[-1])
|
||
|
||
bcdSet = set(bcdList)
|
||
get_std_barcodeDict(imgpath, bcdpath, bcdSet)
|
||
|
||
'''2. 特征提取,并保存至文件夹 featpath 中,也根据 bcdSet 交集执行'''
|
||
stdfeat_infer(bcdpath, featpath, bcdSet)
|
||
|
||
def main():
|
||
imgpath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v2.0_abroad\比对数据\all_base_二筛"
|
||
bcdpath = r"D:\exhibition\dataset\bcdpath_abroad"
|
||
featpath = r"D:\exhibition\dataset\feats_abroad"
|
||
if not os.path.exists(bcdpath):
|
||
os.makedirs(bcdpath)
|
||
if not os.path.exists(featpath):
|
||
os.makedirs(featpath)
|
||
|
||
|
||
gen_bcd_features(imgpath, bcdpath, featpath)
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|