# -*- coding: utf-8 -*- """ Created on Fri Aug 30 17:53:03 2024 功能:1:1比对性能测试程序 1. 基于标准特征集所对应的原始图像样本,生成标准特征集并保存。 func: generate_event_and_stdfeatures(): (1) get_std_barcodeDict(stdSamplePath, stdBarcodePath) 提取 stdSamplePath 中样本地址,生成字典{barcode: [imgpath1, imgpath1, ...]} 并存储为 pickle 文件,barcode.pickle''' (2) stdfeat_infer(stdBarcodePath, stdFeaturePath, bcdSet=None) 标准特征提取,并保存至文件夹 stdFeaturePath 中, 也可在运行过程中根据与购物事件集合 barcodes 交集执行 2. 1:1 比对性能测试, func: one2one_simi() (1) 求购物事件和标准特征级 Barcode 交集,构造 evtDict、stdDict (2) 构造扫 A 放 A、扫 A 放 B 组合,mergePairs = AA_list + AB_list (3) 循环计算 mergePairs 中元素 "(A, A) 或 (A, B)" 相似度; 对于未保存的轨迹图像或标准 barcode 图像,保存图像 (4) 保存计算结果 3. precise、recall等指标计算 func: compute_one2one_pr(pickpath) @author: ym """ import numpy as np import os import sys import random import pickle import json from pathlib import Path import matplotlib.pyplot as plt FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) from utils.calsimi import calsimi_vs_stdfeat, calsimi_vs_stdfeat_new from utils.tools import get_evtList, init_eventDict from utils.databits import data_precision_compare from genfeats import gen_bcd_features def build_std_evt_dict(): ''' stdFeaturePath: 标准特征集地址 eventDataPath: Event对象地址 ''' stdBarcode = [p.stem for p in Path(stdFeaturePath).iterdir() if p.is_file() and (p.suffix=='.json' or p.suffix=='.pickle')] '''======1. 购物事件列表,该列表中的 Barcode 存在于标准的 stdBarcode 内 ===''' evtList = [(p.stem, p.stem.split('_')[-1]) for p in Path(eventDataPath).iterdir() if p.is_file() and p.suffix=='.pickle' and (len(p.stem.split('_'))==2 or len(p.stem.split('_'))==3) and p.stem.split('_')[-1].isdigit() and p.stem.split('_')[-1] in stdBarcode ] barcodes = set([bcd for _, bcd in evtList]) '''======2. 构建用于比对的标准特征字典 =============''' stdDict = {} for stdfile in os.listdir(stdFeaturePath): barcode, ext = os.path.splitext(stdfile) if barcode not in barcodes: continue stdpath = os.path.join(stdFeaturePath, stdfile) if ext == ".json": with open(stdpath, 'r', encoding='utf-8') as f: stddata = json.load(f) feat = np.array(stddata["value"]) stdDict[barcode] = feat if ext == ".pickle": with open(stdpath, 'rb') as f: stddata = pickle.load(f) feat = stddata["feats_ft32"] stdDict[barcode] = feat '''*********** USearch ***********''' # stdDict = {} # for barcode in barcodes: # stdDict[barcode] = stdlib[barcode] '''======3. 构建用于比对的操作事件字典 =============''' evtDict = {} for evtname, barcode in evtList: evtpath = os.path.join(eventDataPath, evtname+'.pickle') try: with open(evtpath, 'rb') as f: evtdata = pickle.load(f) except Exception as e: print(evtname) evtDict[evtname] = evtdata return evtList, evtDict, stdDict def one2SN_pr(evtList, evtDict, stdDict, simType="simple"): std_barcodes = set([bcd for _, bcd in evtList]) tp_events, fn_events, fp_events, tn_events = [], [], [], [] tp_simi, fn_simi, tn_simi, fp_simi = [], [], [], [] errorFile_one2SN = [] SN = 9 for evtname, barcode in evtList: bcd_selected = [barcode] dset = list(std_barcodes - set([barcode])) if len(dset) > SN: random.shuffle(dset) bcd_selected.extend(dset[:SN]) else: bcd_selected.extend(dset) event = evtDict[evtname] ## 无轨迹判断 if len(event.front_feats)+len(event.back_feats)==0: errorFile_one2SN.append(evtname) print(f"No trajectory: {evtname}") continue barcodes, similars = [], [] for stdbcd in bcd_selected: stdfeat = stdDict[stdbcd] if simType=="typea": simi_mean, simi_max, simi_mfeat = calsimi_vs_stdfeat(event, stdfeat) elif simType=="typeb": pass else: simi_mean, simi_1, simi_2 = calsimi_vs_stdfeat_new(event, stdfeat) ## 在event.front_feats和event.back_feats同时为空时,此处不需要保护 # if simi_mean==None: # continue barcodes.append(stdbcd) similars.append(simi_mean) ## 此处不需要保护 # if len(similars)==0: # print(evtname) # continue max_idx = similars.index(max(similars)) max_sim = similars[max_idx] for i in range(len(barcodes)): bcd, simi = barcodes[i], similars[i] if bcd==barcode and simi==max_sim: tp_simi.append(simi) tp_events.append(evtname) elif bcd==barcode and simi!=max_sim: fn_simi.append(simi) fn_events.append(evtname) elif bcd!=barcode and simi!=max_sim: tn_simi.append(simi) tn_events.append(evtname) elif bcd!=barcode and simi==max_sim and barcode in barcodes: fp_simi.append(simi) fp_events.append(evtname) else: errorFile_one2SN.append(evtname) PPreciseX, PRecallX = [], [] NPreciseX, NRecallX = [], [] Thresh = np.linspace(-0.2, 1, 100) for th in Thresh: '''适用于 (Precise, Recall) 计算方式:多个相似度计算并排序,barcode相等且排名第一为 TP ''' '''===================================== 1:SN ''' TPX = sum(np.array(tp_simi) >= th) FPX = sum(np.array(fp_simi) >= th) FNX = sum(np.array(fn_simi) < th) TNX = sum(np.array(tn_simi) < th) PPreciseX.append(TPX/(TPX+FPX+1e-6)) PRecallX.append(TPX/(TPX+FNX+1e-6)) NPreciseX.append(TNX/(TNX+FNX+1e-6)) NRecallX.append(TNX/(TNX+FPX+1e-6)) fig, ax = plt.subplots() ax.plot(Thresh, PPreciseX, 'r', label='Precise_Pos: TP/TPFP') ax.plot(Thresh, PRecallX, 'b', label='Recall_Pos: TP/TPFN') ax.plot(Thresh, NPreciseX, 'g', label='Precise_Neg: TN/TNFP') ax.plot(Thresh, NRecallX, 'c', label='Recall_Neg: TN/TNFN') ax.set_xlim([0, 1]) ax.set_ylim([0, 1]) ax.set_xticks(np.arange(0, 1, 0.1)) ax.set_yticks(np.arange(0, 1, 0.1)) ax.grid(True, linestyle='--') ax.set_title('1:SN Precise & Recall') ax.set_xlabel(f"Event Num: {len(tp_events) + len(fn_events)}") ax.legend() plt.show() rltpath = os.path.join(similPath, f'pr_1toSN_{simType}.png') plt.savefig(rltpath) ## ============================= 1:N 展厅 直方图''' fig, axes = plt.subplots(2, 2) axes[0, 0].hist(tp_simi, bins=60, range=(-0.2, 1), edgecolor='black') axes[0, 0].set_xlim([-0.2, 1]) axes[0, 0].set_title(f'TP({len(tp_simi)})') axes[0, 1].hist(fp_simi, bins=60, range=(-0.2, 1), edgecolor='black') axes[0, 1].set_xlim([-0.2, 1]) axes[0, 1].set_title(f'FP({len(fp_simi)})') axes[1, 0].hist(tn_simi, bins=60, range=(-0.2, 1), edgecolor='black') axes[1, 0].set_xlim([-0.2, 1]) axes[1, 0].set_title(f'TN({len(tn_simi)})') axes[1, 1].hist(fn_simi, bins=60, range=(-0.2, 1), edgecolor='black') axes[1, 1].set_xlim([-0.2, 1]) axes[1, 1].set_title(f'FN({len(fn_simi)})') plt.show() rltpath = os.path.join(similPath, f'hist_1toSN_{simType}.png') plt.savefig(rltpath) def one2one_simi(evtList, evtDict, stdDict, simType): barcodes = set([bcd for _, bcd in evtList]) '''======1 构造 3 个事件对: 扫 A 放 A, 扫 A 放 B, 合并 ====================''' AA_list = [(evtname, barcode, "same") for evtname, barcode in evtList] AB_list = [] for evtname, barcode in evtList: dset = list(barcodes.symmetric_difference(set([barcode]))) if len(dset): idx = random.randint(0, len(dset)-1) AB_list.append((evtname, dset[idx], "diff")) mergePairs = AA_list + AB_list '''======2 计算事件、标准特征集相似度 ==================''' rltdata = [] errorFile_one2one = [] for i in range(len(mergePairs)): evtname, stdbcd, label = mergePairs[i] event = evtDict[evtname] if len(event.feats_compose)==0: errorFile_one2one.append(evtname) continue stdfeat = stdDict[stdbcd] # float32 if simType=="typea": simi_mean, simi_1, simi_2 = calsimi_vs_stdfeat_new(event, stdfeat) elif simType=="typeb": pass else: simi_mean, simi_1, simi_2 = calsimi_vs_stdfeat(event, stdfeat) if simi_mean is None: continue rltdata.append((label, stdbcd, evtname, simi_mean, simi_1, simi_2)) '''================ float32、16、int8 精度比较与存储 =============''' # data_precision_compare(stdfeat, evtfeat, mergePairs[i], similPath, save=True) errorFile_one2one = list(set(errorFile_one2one)) return rltdata, errorFile_one2one def one2one_pr(evtList, evtDict, stdDict, simType="simple"): rltdata, errorFile_one2one = one2one_simi(evtList, evtDict, stdDict, simType) Same, Cross = [], [] for label, stdbcd, evtname, simi_mean, simi_max, simi_mft in rltdata: if simType=="simple" and label == "same": Same.append(simi_max) if simType=="simple" and label == "diff": Cross.append(simi_max) if simType=="typea" and label == "same": Same.append(simi_mean) if simType=="typea" and label == "diff": Cross.append(simi_mean) # for label, stdbcd, evtname, simi_mean, simi_max, simi_mft in rltdata: # if label == "same": # Same.append(simi_mean) # if label == "diff": # Cross.append(simi_mean) Same = np.array(Same) Cross = np.array(Cross) TPFN = len(Same) TNFP = len(Cross) # fig, axs = plt.subplots(2, 1) # axs[0].hist(Same, bins=60, range=(-0.2, 1), edgecolor='black') # axs[0].set_xlim([-0.2, 1]) # axs[0].set_title(f'Same Barcode, Num: {TPFN}') # axs[1].hist(Cross, bins=60, range=(-0.2, 1), edgecolor='black') # axs[1].set_xlim([-0.2, 1]) # axs[1].set_title(f'Cross Barcode, Num: {TNFP}') # plt.savefig(f'./result/{file}_hist.png') # svg, png, pdf Recall_Pos, Recall_Neg = [], [] Precision_Pos, Precision_Neg = [], [] Correct = [] Thresh = np.linspace(-0.2, 1, 100) for th in Thresh: TP = np.sum(Same >= th) FN = np.sum(Same < th) # FN = TPFN - TP TN = np.sum(Cross < th) FP = np.sum(Cross >= th) # FP = TNFP - TN Precision_Pos.append(TP/(TP+FP+1e-6)) Precision_Neg.append(TN/(TN+FN+1e-6)) Recall_Pos.append(TP/(TP+FN+1e-6)) Recall_Neg.append(TN/(TN+FP+1e-6)) # Recall_Pos.append(TP/TPFN) # Recall_Neg.append(TN/TNFP) Correct.append((TN+TP)/(TPFN+TNFP)) fig, ax = plt.subplots() ax.plot(Thresh, Precision_Pos, 'r', label='Precision_Pos: TP/(TP+FP)') ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN') ax.plot(Thresh, Recall_Neg, 'g', label='Recall_Neg: TN/TNFP') ax.plot(Thresh, Correct, 'c', label='Correct: (TN+TP)/(TPFN+TNFP)') ax.plot(Thresh, Precision_Neg, 'm', label='Precision_Neg: TN/(TN+FN)') ax.set_xlim([0, 1]) ax.set_ylim([0, 1]) ax.set_xticks(np.arange(0, 1, 0.1)) ax.set_yticks(np.arange(0, 1, 0.1)) ax.grid(True, linestyle='--') ax.set_title('PrecisePos & PreciseNeg') ax.set_xlabel(f"Same Num: {TPFN}, Cross Num: {TNFP}") ax.legend() plt.show() rltpath = os.path.join(similPath, f'pr_1to1_{simType}.png') plt.savefig(rltpath) # svg, png, pdf fig, axes = plt.subplots(2,1) axes[0].hist(Same, bins=60, range=(-0.2, 1), edgecolor='black') axes[0].set_xlim([-0.2, 1]) axes[0].set_title(f'TP({len(Same)})') axes[1].hist(Cross, bins=60, range=(-0.2, 1), edgecolor='black') axes[1].set_xlim([-0.2, 1]) axes[1].set_title(f'TN({len(Cross)})') rltpath = os.path.join(similPath, f'hist_1to1_{simType}.png') plt.savefig(rltpath) plt.show() def test_one2one_one2SN(simType): '''1:1性能评估''' # evtpaths, bcdSet = get_evtList(eventSourcePath) '''=== 1. 只需运行一次,生成事件对应的标准特征库字典,如已生成,无需运行 ====''' # gen_bcd_features(stdSamplePath, stdBarcodePath, stdFeaturePath, eventSourcePath) '''==== 2. 生成事件字典, 只需运行一次 ====================''' # init_eventDict(eventSourcePath, eventDataPath, source_type) '''==== 3. 基于事件barcode集和标准库barcode交集构造事件集合 =========''' evtList, evtDict, stdDict = build_std_evt_dict() one2one_pr(evtList, evtDict, stdDict, simType) one2SN_pr(evtList, evtDict, stdDict, simType) if __name__ == '__main__': ''' 共7个地址: (1) stdSamplePath: 用于生成比对标准特征集的原始图像地址 (2) stdBarcodePath: 比对标准特征集原始图像地址的pickle文件存储,{barcode: [imgpath1, imgpath1, ...]} (3) stdFeaturePath: 比对标准特征集特征存储地址 (4) eventSourcePath: 事件地址, 包含data文件的文件夹或 Yolo-Resnet-Tracker输出的Pickle文件父文件夹 (5) resultPath: 结果存储地址 (6) eventDataPath: 用于1:1比对的购物事件存储地址,在resultPath下 (7) similPath: 1:1比对结果存储地址(事件级),在resultPath下 ''' stdSamplePath = "/home/wqg/dataset/total_barcode/totalBarcode" stdBarcodePath = "/home/wqg/dataset/total_barcode/bcdpath" stdFeaturePath = "/home/wqg/dataset/test_dataset/total_barcode/features_json/v11_barcode_0304/" if not os.path.exists(stdBarcodePath): os.makedirs(stdBarcodePath) if not os.path.exists(stdFeaturePath): os.makedirs(stdFeaturePath) '''source_type: "source": eventSourcePath 为 Yolo-Resnet-Tracker 输出的 pickle 文件 "data": 基于事件切分的原 data 文件版本 "realtime": 全实时生成的 data 文件 ''' source_type = 'source' # 'source', 'data', 'realtime' simType = "typea" # "simple", "typea", "typeb" evttype = "single_event_V10" # evttype = "single_event_V5" # evttype = "performence_V10" # evttype = "performence_V5" eventSourcePath = "/home/wqg/dataset/pipeline/yrt/{}/shopping_pkl".format(evttype) resultPath = "/home/wqg/dataset/pipeline/contrast/{}".format(evttype) eventDataPath = os.path.join(resultPath, "evtobjs") similPath = os.path.join(resultPath, "simidata") if not os.path.exists(eventDataPath): os.makedirs(eventDataPath) if not os.path.exists(similPath): os.makedirs(similPath) test_one2one_one2SN(simType)