Files
2025-04-18 14:41:53 +08:00

228 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Sun Nov 3 12:05:19 2024
@author: ym
"""
import os
import time
# import torch
import pickle
# import json
import numpy as np
from PIL import Image
from feat_extract.config import config as conf
# from model import resnet18 as resnet18
from feat_extract.inference import FeatsInterface #, inference_image
IMG_FORMAT = ['.bmp', '.jpg', '.jpeg', '.png']
# def model_init(conf, mpath=None):
# '''======= 0. 配置特征提取模型地址 ======='''
# if mpath is None:
# model_path = conf.test_model
# else:
# model_path = mpath
# ##============ load resnet mdoel
# model = resnet18().to(conf.device)
# # model = nn.DataParallel(model).to(conf.device)
# model.load_state_dict(torch.load(model_path, map_location=conf.device))
# model.eval()
# print('load model {} '.format(conf.testbackbone))
# return model
def get_std_barcodeDict(bcdpath, savepath, bcdSet):
'''
inputs:
bcdpath: 已清洗的barcode样本图像如果barcode下有'base'文件夹,只选用该文件夹下图像
(default = r'\\\\192.168.1.28\\share\\已标注数据备份\\对比数据\\barcode\\barcode_1771')
功能:
生成并保存只有一个key值的字典 {barcode: [imgpath1, imgpath1, ...]}
savepath: 字典存储地址文件名格式barcode.pickle
'''
# savepath = r'\\\\192.168.1.28\\share\\测试_202406\\contrast\\std_barcodes'
'''读取数据集中 barcode 列表'''
stdBarcodeList = []
for filename in os.listdir(bcdpath):
filepath = os.path.join(bcdpath, filename)
if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
continue
if bcdSet is None:
stdBarcodeList.append(filename)
elif filename in bcdSet:
stdBarcodeList.append(filename)
bcdPaths = [(barcode, os.path.join(bcdpath, barcode)) for barcode in stdBarcodeList]
'''遍历数据集针对每一个barcode生成并保存字典{barcode: [imgpath1, imgpath1, ...]}'''
k = 0
errbarcodes = []
for barcode, bpath in bcdPaths:
pickpath = os.path.join(savepath, f"{barcode}.pickle")
if os.path.isfile(pickpath):
continue
stdBarcodeDict = {}
stdBarcodeDict[barcode] = []
for root, dirs, files in os.walk(bpath):
imgpaths = []
if "base" in dirs:
broot = os.path.join(root, "base")
for imgname in os.listdir(broot):
imgpath = os.path.join(broot, imgname)
file, ext = os.path.splitext(imgpath)
if ext not in IMG_FORMAT:
continue
imgpaths.append(imgpath)
stdBarcodeDict[barcode].extend(imgpaths)
break
else:
for imgname in files:
imgpath = os.path.join(root, imgname)
_, ext = os.path.splitext(imgpath)
if ext not in IMG_FORMAT: continue
imgpaths.append(imgpath)
stdBarcodeDict[barcode].extend(imgpaths)
pickpath = os.path.join(savepath, f"{barcode}.pickle")
with open(pickpath, 'wb') as f:
pickle.dump(stdBarcodeDict, f)
print(f"Barcode: {barcode}")
# k += 1
# if k == 10:
# break
print(f"Len of errbarcodes: {len(errbarcodes)}")
return
def stdfeat_infer(imgPath, featPath, bcdSet=None):
'''
inputs:
imgPath: 该文件夹下的 pickle 文件格式 {barcode: [imgpath1, imgpath1, ...]}
featPath: imgPath图像对应特征的存储地址
功能:
对 imgPath中图像进行特征提取生成只有一个key值的字典
{barcode: features}features.shape=(nsample, 256),并保存至 featPath 中
'''
# imgPath = r"\\192.168.1.28\share\测试_202406\contrast\std_barcodes"
# featPath = r"\\192.168.1.28\share\测试_202406\contrast\std_features"
Encoder = FeatsInterface(conf)
'''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名;
(3)该pickle文件中字典的key值; (4)特征向量字典中的一个key值'''
k = 0
for filename in os.listdir(imgPath):
bcd, ext = os.path.splitext(filename)
filepath = os.path.join(imgPath, filename)
if ext != ".pickle": continue
if bcdSet is not None and bcd not in bcdSet:
continue
featpath = os.path.join(featPath, f"{bcd}.pickle")
if os.path.isfile(featpath):
continue
stdbDict = {}
t1 = time.time()
try:
with open(filepath, 'rb') as f:
bpDict = pickle.load(f)
for barcode, imgpaths in bpDict.items():
# feature = batch_inference(imgpaths, 8) #from vit distilled model of LiChen
# feature = inference_image(imgpaths, conf.test_transform, model, conf.device)
imgs = []
for d, imgpath in enumerate(imgpaths):
img = Image.open(imgpath)
imgs.append(img)
feature = Encoder.inference(imgs)
feature /= np.linalg.norm(feature, axis=1)[:, None]
feature_ft32 = feature.astype(np.float32)
# float16
feature_ft16 = feature.astype(np.float16)
feature_ft16 /= np.linalg.norm(feature_ft16, axis=1)[:, None]
# uint8, 两种策略1) 精度损失小, 2) 计算复杂度小
# feature_uint8, _ = ft16_to_uint8(feature_ft16)
feature_uint8 = (feature_ft16*128).astype(np.int8)
'''================ 保存单个barcode特征 ================'''
##================== float32
stdbDict["barcode"] = barcode
stdbDict["imgpaths"] = imgpaths
stdbDict["feats_ft32"] = feature_ft32
stdbDict["feats_ft16"] = feature_ft16
stdbDict["feats_uint8"] = feature_uint8
with open(featpath, 'wb') as f:
pickle.dump(stdbDict, f)
except Exception as e:
print(f"Error accured at: {filename}, with Exception is: {e}")
t2 = time.time()
print(f"Barcode: {barcode}, need time: {t2-t1:.1f} secs")
# k += 1
# if k == 10:
# break
return
def gen_bcd_features(imgpath, bcdpath, featpath, eventSourcePath):
''' 生成标准特征集 '''
'''1. 提取 imgpath 中样本地址,生成字典{barcode: [imgpath1, imgpath1, ...]}
并存储于: bcdpath, 格式为 barcode.pickle'''
bcdList = []
for evtname in os.listdir(eventSourcePath):
bname, ext = os.path.splitext(evtname)
evt = bname.split('_')
if len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10:
bcdList.append(evt[-1])
bcdSet = set(bcdList)
get_std_barcodeDict(imgpath, bcdpath, bcdSet)
'''2. 特征提取,并保存至文件夹 featpath 中,也根据 bcdSet 交集执行'''
stdfeat_infer(bcdpath, featpath, bcdSet)
def main():
imgpath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v2.0_abroad\比对数据\all_base_二筛"
bcdpath = r"D:\exhibition\dataset\bcdpath_abroad"
featpath = r"D:\exhibition\dataset\feats_abroad"
if not os.path.exists(bcdpath):
os.makedirs(bcdpath)
if not os.path.exists(featpath):
os.makedirs(featpath)
gen_bcd_features(imgpath, bcdpath, featpath)
if __name__ == '__main__':
main()