bakeup
This commit is contained in:
@ -6,26 +6,33 @@ Created on Sun Nov 3 12:05:19 2024
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
# import torch
|
||||
import pickle
|
||||
# import json
|
||||
import numpy as np
|
||||
from config import config as conf
|
||||
from model import resnet18 as resnet18
|
||||
from feat_inference import inference_image
|
||||
from PIL import Image
|
||||
from feat_extract.config import config as conf
|
||||
# from model import resnet18 as resnet18
|
||||
from feat_extract.inference import FeatsInterface #, inference_image
|
||||
|
||||
|
||||
IMG_FORMAT = ['.bmp', '.jpg', '.jpeg', '.png']
|
||||
|
||||
'''======= 0. 配置特征提取模型地址 ======='''
|
||||
model_path = conf.test_model
|
||||
model_path = r"D:\exhibition\ckpt\zhanting.pth"
|
||||
# def model_init(conf, mpath=None):
|
||||
# '''======= 0. 配置特征提取模型地址 ======='''
|
||||
# if mpath is None:
|
||||
# model_path = conf.test_model
|
||||
# else:
|
||||
# model_path = mpath
|
||||
|
||||
##============ load resnet mdoel
|
||||
model = resnet18().to(conf.device)
|
||||
# model = nn.DataParallel(model).to(conf.device)
|
||||
model.load_state_dict(torch.load(model_path, map_location=conf.device))
|
||||
model.eval()
|
||||
print('load model {} '.format(conf.testbackbone))
|
||||
# ##============ load resnet mdoel
|
||||
# model = resnet18().to(conf.device)
|
||||
# # model = nn.DataParallel(model).to(conf.device)
|
||||
# model.load_state_dict(torch.load(model_path, map_location=conf.device))
|
||||
# model.eval()
|
||||
# print('load model {} '.format(conf.testbackbone))
|
||||
|
||||
# return model
|
||||
|
||||
def get_std_barcodeDict(bcdpath, savepath, bcdSet):
|
||||
'''
|
||||
@ -42,9 +49,9 @@ def get_std_barcodeDict(bcdpath, savepath, bcdSet):
|
||||
'''读取数据集中 barcode 列表'''
|
||||
stdBarcodeList = []
|
||||
for filename in os.listdir(bcdpath):
|
||||
# filepath = os.path.join(bcdpath, filename)
|
||||
# if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
|
||||
# continue
|
||||
filepath = os.path.join(bcdpath, filename)
|
||||
if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
|
||||
continue
|
||||
if bcdSet is None:
|
||||
stdBarcodeList.append(filename)
|
||||
elif filename in bcdSet:
|
||||
@ -59,7 +66,7 @@ def get_std_barcodeDict(bcdpath, savepath, bcdSet):
|
||||
for barcode, bpath in bcdPaths:
|
||||
pickpath = os.path.join(savepath, f"{barcode}.pickle")
|
||||
if os.path.isfile(pickpath):
|
||||
continue
|
||||
continue
|
||||
|
||||
stdBarcodeDict = {}
|
||||
stdBarcodeDict[barcode] = []
|
||||
@ -89,6 +96,7 @@ def get_std_barcodeDict(bcdpath, savepath, bcdSet):
|
||||
pickpath = os.path.join(savepath, f"{barcode}.pickle")
|
||||
with open(pickpath, 'wb') as f:
|
||||
pickle.dump(stdBarcodeDict, f)
|
||||
|
||||
print(f"Barcode: {barcode}")
|
||||
|
||||
# k += 1
|
||||
@ -115,32 +123,37 @@ def stdfeat_infer(imgPath, featPath, bcdSet=None):
|
||||
stdBarcodeDict = {}
|
||||
stdBarcodeDict_ft16 = {}
|
||||
|
||||
Encoder = FeatsInterface(conf)
|
||||
|
||||
'''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名、该pickle文件中字典的key值'''
|
||||
|
||||
'''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名;
|
||||
(3)该pickle文件中字典的key值; (4)特征向量字典中的一个key值'''
|
||||
k = 0
|
||||
for filename in os.listdir(imgPath):
|
||||
bcd, ext = os.path.splitext(filename)
|
||||
pkpath = os.path.join(featPath, f"{bcd}.pickle")
|
||||
|
||||
if os.path.isfile(pkpath): continue
|
||||
filepath = os.path.join(imgPath, filename)
|
||||
if ext != ".pickle": continue
|
||||
if bcdSet is not None and bcd not in bcdSet:
|
||||
continue
|
||||
|
||||
filepath = os.path.join(imgPath, filename)
|
||||
featpath = os.path.join(featPath, f"{bcd}.pickle")
|
||||
|
||||
stdbDict = {}
|
||||
stdbDict_ft16 = {}
|
||||
stdbDict_uint8 = {}
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
try:
|
||||
with open(filepath, 'rb') as f:
|
||||
bpDict = pickle.load(f)
|
||||
bpDict = pickle.load(f)
|
||||
|
||||
for barcode, imgpaths in bpDict.items():
|
||||
# feature = batch_inference(imgpaths, 8) #from vit distilled model of LiChen
|
||||
feature = inference_image(imgpaths, conf.test_transform, model, conf.device)
|
||||
# feature = inference_image(imgpaths, conf.test_transform, model, conf.device)
|
||||
|
||||
imgs = []
|
||||
for d, imgpath in enumerate(imgpaths):
|
||||
img = Image.open(imgpath)
|
||||
imgs.append(img)
|
||||
|
||||
feature = Encoder.inference(imgs)
|
||||
|
||||
feature /= np.linalg.norm(feature, axis=1)[:, None]
|
||||
|
||||
# float16
|
||||
@ -162,7 +175,7 @@ def stdfeat_infer(imgPath, featPath, bcdSet=None):
|
||||
stdbDict["feats_ft16"] = feature_ft16
|
||||
stdbDict["feats_uint8"] = feature_uint8
|
||||
|
||||
with open(pkpath, 'wb') as f:
|
||||
with open(featpath, 'wb') as f:
|
||||
pickle.dump(stdbDict, f)
|
||||
|
||||
stdBarcodeDict[barcode] = feature
|
||||
@ -174,21 +187,10 @@ def stdfeat_infer(imgPath, featPath, bcdSet=None):
|
||||
# if k == 10:
|
||||
# break
|
||||
|
||||
##================== float32
|
||||
# pickpath = os.path.join(featPath, f"barcode_features_{k}.pickle")
|
||||
# with open(pickpath, 'wb') as f:
|
||||
# pickle.dump(stdBarcodeDict, f)
|
||||
|
||||
##================== float16
|
||||
# pickpath_ft16 = os.path.join(featPath, f"barcode_features_ft16_{k}.pickle")
|
||||
# with open(pickpath_ft16, 'wb') as f:
|
||||
# pickle.dump(stdBarcodeDict_ft16, f)
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
def genfeatures(imgpath, bcdpath, featpath, bcdSet=None):
|
||||
def gen_bcd_features(imgpath, bcdpath, featpath, bcdSet=None):
|
||||
''' 生成标准特征集 '''
|
||||
'''1. 提取 imgpath 中样本地址,生成字典{barcode: [imgpath1, imgpath1, ...]}
|
||||
并存储于: bcdpath, 格式为 barcode.pickle'''
|
||||
@ -198,11 +200,12 @@ def genfeatures(imgpath, bcdpath, featpath, bcdSet=None):
|
||||
stdfeat_infer(bcdpath, featpath, bcdSet)
|
||||
|
||||
def main():
|
||||
imgpath = r"\\192.168.1.28\share\展厅barcode数据\整理\zhantingBase"
|
||||
imgpath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v1.0\比对数据\整理\zhantingBase"
|
||||
bcdpath = r"D:\exhibition\dataset\bcdpath"
|
||||
featpath = r"D:\exhibition\dataset\feats"
|
||||
|
||||
genfeatures(imgpath, bcdpath, featpath)
|
||||
|
||||
|
||||
gen_bcd_features(imgpath, bcdpath, featpath)
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user