This commit is contained in:
王庆刚
2024-11-04 18:06:52 +08:00
parent dfb2272a15
commit 5ecc1285d4
41 changed files with 2552 additions and 440 deletions

209
contrast/genfeats.py Normal file
View File

@ -0,0 +1,209 @@
# -*- coding: utf-8 -*-
"""
Created on Sun Nov 3 12:05:19 2024
@author: ym
"""
import os
import time
import torch
import pickle
import numpy as np
from config import config as conf
from model import resnet18 as resnet18
from feat_inference import inference_image
IMG_FORMAT = ['.bmp', '.jpg', '.jpeg', '.png']
'''======= 0. 配置特征提取模型地址 ======='''
model_path = conf.test_model
model_path = r"D:\exhibition\ckpt\zhanting.pth"
##============ load resnet mdoel
model = resnet18().to(conf.device)
# model = nn.DataParallel(model).to(conf.device)
model.load_state_dict(torch.load(model_path, map_location=conf.device))
model.eval()
print('load model {} '.format(conf.testbackbone))
def get_std_barcodeDict(bcdpath, savepath):
'''
inputs:
bcdpath: 已清洗的barcode样本图像如果barcode下有'base'文件夹,只选用该文件夹下图像
(default = r'\\192.168.1.28\share\已标注数据备份\对比数据\barcode\barcode_1771')
功能:
生成并保存只有一个key值的字典 {barcode: [imgpath1, imgpath1, ...]}
savepath: 字典存储地址文件名格式barcode.pickle
'''
# savepath = r'\\192.168.1.28\share\测试_202406\contrast\std_barcodes'
'''读取数据集中 barcode 列表'''
stdBarcodeList = []
for filename in os.listdir(bcdpath):
filepath = os.path.join(bcdpath, filename)
# if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
# continue
stdBarcodeList.append(filename)
bcdPaths = [(barcode, os.path.join(bcdpath, barcode)) for barcode in stdBarcodeList]
'''遍历数据集针对每一个barcode生成并保存字典{barcode: [imgpath1, imgpath1, ...]}'''
k = 0
errbarcodes = []
for barcode, bpath in bcdPaths:
pickpath = os.path.join(savepath, f"{barcode}.pickle")
if os.path.isfile(pickpath):
continue
stdBarcodeDict = {}
stdBarcodeDict[barcode] = []
for root, dirs, files in os.walk(bpath):
imgpaths = []
if "base" in dirs:
broot = os.path.join(root, "base")
for imgname in os.listdir(broot):
imgpath = os.path.join(broot, imgname)
file, ext = os.path.splitext(imgpath)
if ext not in IMG_FORMAT:
continue
imgpaths.append(imgpath)
stdBarcodeDict[barcode].extend(imgpaths)
break
else:
for imgname in files:
imgpath = os.path.join(root, imgname)
_, ext = os.path.splitext(imgpath)
if ext not in IMG_FORMAT: continue
imgpaths.append(imgpath)
stdBarcodeDict[barcode].extend(imgpaths)
pickpath = os.path.join(savepath, f"{barcode}.pickle")
with open(pickpath, 'wb') as f:
pickle.dump(stdBarcodeDict, f)
print(f"Barcode: {barcode}")
# k += 1
# if k == 10:
# break
print(f"Len of errbarcodes: {len(errbarcodes)}")
return
def stdfeat_infer(imgPath, featPath, bcdSet=None):
'''
inputs:
imgPath: 该文件夹下的 pickle 文件格式 {barcode: [imgpath1, imgpath1, ...]}
featPath: imgPath图像对应特征的存储地址
功能:
对 imgPath中图像进行特征提取生成只有一个key值的字典
{barcode: features}features.shape=(nsample, 256),并保存至 featPath 中
'''
# imgPath = r"\\192.168.1.28\share\测试_202406\contrast\std_barcodes"
# featPath = r"\\192.168.1.28\share\测试_202406\contrast\std_features"
stdBarcodeDict = {}
stdBarcodeDict_ft16 = {}
'''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名、该pickle文件中字典的key值'''
k = 0
for filename in os.listdir(imgPath):
bcd, ext = os.path.splitext(filename)
pkpath = os.path.join(featPath, f"{bcd}.pickle")
if os.path.isfile(pkpath): continue
if bcdSet is not None and bcd not in bcdSet:
continue
filepath = os.path.join(imgPath, filename)
stdbDict = {}
stdbDict_ft16 = {}
stdbDict_uint8 = {}
t1 = time.time()
try:
with open(filepath, 'rb') as f:
bpDict = pickle.load(f)
for barcode, imgpaths in bpDict.items():
# feature = batch_inference(imgpaths, 8) #from vit distilled model of LiChen
feature = inference_image(imgpaths, conf.test_transform, model, conf.device)
feature /= np.linalg.norm(feature, axis=1)[:, None]
# float16
feature_ft16 = feature.astype(np.float16)
feature_ft16 /= np.linalg.norm(feature_ft16, axis=1)[:, None]
# uint8, 两种策略1) 精度损失小, 2) 计算复杂度小
# feature_uint8, _ = ft16_to_uint8(feature_ft16)
feature_uint8 = (feature_ft16*128).astype(np.int8)
except Exception as e:
print(f"Error accured at: {filename}, with Exception is: {e}")
'''================ 保存单个barcode特征 ================'''
##================== float32
stdbDict["barcode"] = barcode
stdbDict["imgpaths"] = imgpaths
stdbDict["feats_ft32"] = feature
stdbDict["feats_ft16"] = feature_ft16
stdbDict["feats_uint8"] = feature_uint8
with open(pkpath, 'wb') as f:
pickle.dump(stdbDict, f)
stdBarcodeDict[barcode] = feature
stdBarcodeDict_ft16[barcode] = feature_ft16
t2 = time.time()
print(f"Barcode: {barcode}, need time: {t2-t1:.1f} secs")
# k += 1
# if k == 10:
# break
##================== float32
# pickpath = os.path.join(featPath, f"barcode_features_{k}.pickle")
# with open(pickpath, 'wb') as f:
# pickle.dump(stdBarcodeDict, f)
##================== float16
# pickpath_ft16 = os.path.join(featPath, f"barcode_features_ft16_{k}.pickle")
# with open(pickpath_ft16, 'wb') as f:
# pickle.dump(stdBarcodeDict_ft16, f)
return
def genfeatures(imgpath, bcdpath, featpath):
get_std_barcodeDict(imgpath, bcdpath)
stdfeat_infer(bcdpath, featpath, bcdSet=None)
print(f"Features have generated, saved in: {featpath}")
def main():
imgpath = r"\\192.168.1.28\share\展厅barcode数据\整理\zhantingBase"
bcdpath = r"D:\exhibition\dataset\bcdpath"
featpath = r"D:\exhibition\dataset\feats"
genfeatures(imgpath, bcdpath, featpath)
if __name__ == '__main__':
main()