This commit is contained in:
王庆刚
2024-11-25 18:05:08 +08:00
parent c47894ddc0
commit 8bbee310ba
109 changed files with 1003 additions and 305 deletions

View File

@ -6,26 +6,33 @@ Created on Sun Nov 3 12:05:19 2024
"""
import os
import time
import torch
# import torch
import pickle
# import json
import numpy as np
from config import config as conf
from model import resnet18 as resnet18
from feat_inference import inference_image
from PIL import Image
from feat_extract.config import config as conf
# from model import resnet18 as resnet18
from feat_extract.inference import FeatsInterface #, inference_image
IMG_FORMAT = ['.bmp', '.jpg', '.jpeg', '.png']
'''======= 0. 配置特征提取模型地址 ======='''
model_path = conf.test_model
model_path = r"D:\exhibition\ckpt\zhanting.pth"
# def model_init(conf, mpath=None):
# '''======= 0. 配置特征提取模型地址 ======='''
# if mpath is None:
# model_path = conf.test_model
# else:
# model_path = mpath
##============ load resnet mdoel
model = resnet18().to(conf.device)
# model = nn.DataParallel(model).to(conf.device)
model.load_state_dict(torch.load(model_path, map_location=conf.device))
model.eval()
print('load model {} '.format(conf.testbackbone))
# ##============ load resnet mdoel
# model = resnet18().to(conf.device)
# # model = nn.DataParallel(model).to(conf.device)
# model.load_state_dict(torch.load(model_path, map_location=conf.device))
# model.eval()
# print('load model {} '.format(conf.testbackbone))
# return model
def get_std_barcodeDict(bcdpath, savepath, bcdSet):
'''
@ -42,9 +49,9 @@ def get_std_barcodeDict(bcdpath, savepath, bcdSet):
'''读取数据集中 barcode 列表'''
stdBarcodeList = []
for filename in os.listdir(bcdpath):
# filepath = os.path.join(bcdpath, filename)
# if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
# continue
filepath = os.path.join(bcdpath, filename)
if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
continue
if bcdSet is None:
stdBarcodeList.append(filename)
elif filename in bcdSet:
@ -59,7 +66,7 @@ def get_std_barcodeDict(bcdpath, savepath, bcdSet):
for barcode, bpath in bcdPaths:
pickpath = os.path.join(savepath, f"{barcode}.pickle")
if os.path.isfile(pickpath):
continue
continue
stdBarcodeDict = {}
stdBarcodeDict[barcode] = []
@ -89,6 +96,7 @@ def get_std_barcodeDict(bcdpath, savepath, bcdSet):
pickpath = os.path.join(savepath, f"{barcode}.pickle")
with open(pickpath, 'wb') as f:
pickle.dump(stdBarcodeDict, f)
print(f"Barcode: {barcode}")
# k += 1
@ -115,32 +123,37 @@ def stdfeat_infer(imgPath, featPath, bcdSet=None):
stdBarcodeDict = {}
stdBarcodeDict_ft16 = {}
Encoder = FeatsInterface(conf)
'''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名、该pickle文件中字典的key值'''
'''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名;
(3)该pickle文件中字典的key值; (4)特征向量字典中的一个key值'''
k = 0
for filename in os.listdir(imgPath):
bcd, ext = os.path.splitext(filename)
pkpath = os.path.join(featPath, f"{bcd}.pickle")
if os.path.isfile(pkpath): continue
filepath = os.path.join(imgPath, filename)
if ext != ".pickle": continue
if bcdSet is not None and bcd not in bcdSet:
continue
filepath = os.path.join(imgPath, filename)
featpath = os.path.join(featPath, f"{bcd}.pickle")
stdbDict = {}
stdbDict_ft16 = {}
stdbDict_uint8 = {}
t1 = time.time()
try:
with open(filepath, 'rb') as f:
bpDict = pickle.load(f)
bpDict = pickle.load(f)
for barcode, imgpaths in bpDict.items():
# feature = batch_inference(imgpaths, 8) #from vit distilled model of LiChen
feature = inference_image(imgpaths, conf.test_transform, model, conf.device)
# feature = inference_image(imgpaths, conf.test_transform, model, conf.device)
imgs = []
for d, imgpath in enumerate(imgpaths):
img = Image.open(imgpath)
imgs.append(img)
feature = Encoder.inference(imgs)
feature /= np.linalg.norm(feature, axis=1)[:, None]
# float16
@ -162,7 +175,7 @@ def stdfeat_infer(imgPath, featPath, bcdSet=None):
stdbDict["feats_ft16"] = feature_ft16
stdbDict["feats_uint8"] = feature_uint8
with open(pkpath, 'wb') as f:
with open(featpath, 'wb') as f:
pickle.dump(stdbDict, f)
stdBarcodeDict[barcode] = feature
@ -174,21 +187,10 @@ def stdfeat_infer(imgPath, featPath, bcdSet=None):
# if k == 10:
# break
##================== float32
# pickpath = os.path.join(featPath, f"barcode_features_{k}.pickle")
# with open(pickpath, 'wb') as f:
# pickle.dump(stdBarcodeDict, f)
##================== float16
# pickpath_ft16 = os.path.join(featPath, f"barcode_features_ft16_{k}.pickle")
# with open(pickpath_ft16, 'wb') as f:
# pickle.dump(stdBarcodeDict_ft16, f)
return
def genfeatures(imgpath, bcdpath, featpath, bcdSet=None):
def gen_bcd_features(imgpath, bcdpath, featpath, bcdSet=None):
''' 生成标准特征集 '''
'''1. 提取 imgpath 中样本地址,生成字典{barcode: [imgpath1, imgpath1, ...]}
并存储于: bcdpath, 格式为 barcode.pickle'''
@ -198,11 +200,12 @@ def genfeatures(imgpath, bcdpath, featpath, bcdSet=None):
stdfeat_infer(bcdpath, featpath, bcdSet)
def main():
imgpath = r"\\192.168.1.28\share\展厅barcode数据\整理\zhantingBase"
imgpath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v1.0\比对数据\整理\zhantingBase"
bcdpath = r"D:\exhibition\dataset\bcdpath"
featpath = r"D:\exhibition\dataset\feats"
genfeatures(imgpath, bcdpath, featpath)
gen_bcd_features(imgpath, bcdpath, featpath)