# -*- coding: utf-8 -*- """ Created on Sun Nov 3 12:05:19 2024 @author: ym """ import os import time import torch import pickle import numpy as np from config import config as conf from model import resnet18 as resnet18 from feat_inference import inference_image IMG_FORMAT = ['.bmp', '.jpg', '.jpeg', '.png'] '''======= 0. 配置特征提取模型地址 =======''' model_path = conf.test_model model_path = r"D:\exhibition\ckpt\zhanting.pth" ##============ load resnet mdoel model = resnet18().to(conf.device) # model = nn.DataParallel(model).to(conf.device) model.load_state_dict(torch.load(model_path, map_location=conf.device)) model.eval() print('load model {} '.format(conf.testbackbone)) def get_std_barcodeDict(bcdpath, savepath): ''' inputs: bcdpath: 已清洗的barcode样本图像,如果barcode下有'base'文件夹,只选用该文件夹下图像 (default = r'\\192.168.1.28\share\已标注数据备份\对比数据\barcode\barcode_1771') 功能: 生成并保存只有一个key值的字典 {barcode: [imgpath1, imgpath1, ...]}, savepath: 字典存储地址,文件名格式:barcode.pickle ''' # savepath = r'\\192.168.1.28\share\测试_202406\contrast\std_barcodes' '''读取数据集中 barcode 列表''' stdBarcodeList = [] for filename in os.listdir(bcdpath): filepath = os.path.join(bcdpath, filename) # if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8: # continue stdBarcodeList.append(filename) bcdPaths = [(barcode, os.path.join(bcdpath, barcode)) for barcode in stdBarcodeList] '''遍历数据集,针对每一个barcode,生成并保存字典{barcode: [imgpath1, imgpath1, ...]}''' k = 0 errbarcodes = [] for barcode, bpath in bcdPaths: pickpath = os.path.join(savepath, f"{barcode}.pickle") if os.path.isfile(pickpath): continue stdBarcodeDict = {} stdBarcodeDict[barcode] = [] for root, dirs, files in os.walk(bpath): imgpaths = [] if "base" in dirs: broot = os.path.join(root, "base") for imgname in os.listdir(broot): imgpath = os.path.join(broot, imgname) file, ext = os.path.splitext(imgpath) if ext not in IMG_FORMAT: continue imgpaths.append(imgpath) stdBarcodeDict[barcode].extend(imgpaths) break else: for imgname in files: imgpath = os.path.join(root, imgname) _, ext = os.path.splitext(imgpath) if ext not in IMG_FORMAT: continue imgpaths.append(imgpath) stdBarcodeDict[barcode].extend(imgpaths) pickpath = os.path.join(savepath, f"{barcode}.pickle") with open(pickpath, 'wb') as f: pickle.dump(stdBarcodeDict, f) print(f"Barcode: {barcode}") # k += 1 # if k == 10: # break print(f"Len of errbarcodes: {len(errbarcodes)}") return def stdfeat_infer(imgPath, featPath, bcdSet=None): ''' inputs: imgPath: 该文件夹下的 pickle 文件格式 {barcode: [imgpath1, imgpath1, ...]} featPath: imgPath图像对应特征的存储地址 功能: 对 imgPath中图像进行特征提取,生成只有一个key值的字典, {barcode: features},features.shape=(nsample, 256),并保存至 featPath 中 ''' # imgPath = r"\\192.168.1.28\share\测试_202406\contrast\std_barcodes" # featPath = r"\\192.168.1.28\share\测试_202406\contrast\std_features" stdBarcodeDict = {} stdBarcodeDict_ft16 = {} '''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名、该pickle文件中字典的key值''' k = 0 for filename in os.listdir(imgPath): bcd, ext = os.path.splitext(filename) pkpath = os.path.join(featPath, f"{bcd}.pickle") if os.path.isfile(pkpath): continue if bcdSet is not None and bcd not in bcdSet: continue filepath = os.path.join(imgPath, filename) stdbDict = {} stdbDict_ft16 = {} stdbDict_uint8 = {} t1 = time.time() try: with open(filepath, 'rb') as f: bpDict = pickle.load(f) for barcode, imgpaths in bpDict.items(): # feature = batch_inference(imgpaths, 8) #from vit distilled model of LiChen feature = inference_image(imgpaths, conf.test_transform, model, conf.device) feature /= np.linalg.norm(feature, axis=1)[:, None] # float16 feature_ft16 = feature.astype(np.float16) feature_ft16 /= np.linalg.norm(feature_ft16, axis=1)[:, None] # uint8, 两种策略,1) 精度损失小, 2) 计算复杂度小 # feature_uint8, _ = ft16_to_uint8(feature_ft16) feature_uint8 = (feature_ft16*128).astype(np.int8) except Exception as e: print(f"Error accured at: {filename}, with Exception is: {e}") '''================ 保存单个barcode特征 ================''' ##================== float32 stdbDict["barcode"] = barcode stdbDict["imgpaths"] = imgpaths stdbDict["feats_ft32"] = feature stdbDict["feats_ft16"] = feature_ft16 stdbDict["feats_uint8"] = feature_uint8 with open(pkpath, 'wb') as f: pickle.dump(stdbDict, f) stdBarcodeDict[barcode] = feature stdBarcodeDict_ft16[barcode] = feature_ft16 t2 = time.time() print(f"Barcode: {barcode}, need time: {t2-t1:.1f} secs") # k += 1 # if k == 10: # break ##================== float32 # pickpath = os.path.join(featPath, f"barcode_features_{k}.pickle") # with open(pickpath, 'wb') as f: # pickle.dump(stdBarcodeDict, f) ##================== float16 # pickpath_ft16 = os.path.join(featPath, f"barcode_features_ft16_{k}.pickle") # with open(pickpath_ft16, 'wb') as f: # pickle.dump(stdBarcodeDict_ft16, f) return def genfeatures(imgpath, bcdpath, featpath): get_std_barcodeDict(imgpath, bcdpath) stdfeat_infer(bcdpath, featpath, bcdSet=None) print(f"Features have generated, saved in: {featpath}") def main(): imgpath = r"\\192.168.1.28\share\展厅barcode数据\整理\zhantingBase" bcdpath = r"D:\exhibition\dataset\bcdpath" featpath = r"D:\exhibition\dataset\feats" genfeatures(imgpath, bcdpath, featpath) if __name__ == '__main__': main()