回传数据解析,兼容v5和v10
This commit is contained in:
452
contrast/one2one_contrast.py
Normal file
452
contrast/one2one_contrast.py
Normal file
@ -0,0 +1,452 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Fri Aug 30 17:53:03 2024
|
||||
功能:1:1比对性能测试程序
|
||||
1. 基于标准特征集所对应的原始图像样本,生成标准特征集并保存。
|
||||
func: generate_event_and_stdfeatures():
|
||||
(1) get_std_barcodeDict(stdSamplePath, stdBarcodePath)
|
||||
提取 stdSamplePath 中样本地址,生成字典{barcode: [imgpath1, imgpath1, ...]}
|
||||
并存储为 pickle 文件,barcode.pickle'''
|
||||
(2) stdfeat_infer(stdBarcodePath, stdFeaturePath, bcdSet=None)
|
||||
标准特征提取,并保存至文件夹 stdFeaturePath 中,
|
||||
也可在运行过程中根据与购物事件集合 barcodes 交集执行
|
||||
2. 1:1 比对性能测试,
|
||||
func: one2one_simi()
|
||||
(1) 求购物事件和标准特征级 Barcode 交集,构造 evtDict、stdDict
|
||||
(2) 构造扫 A 放 A、扫 A 放 B 组合,mergePairs = AA_list + AB_list
|
||||
(3) 循环计算 mergePairs 中元素 "(A, A) 或 (A, B)" 相似度;
|
||||
对于未保存的轨迹图像或标准 barcode 图像,保存图像
|
||||
(4) 保存计算结果
|
||||
|
||||
|
||||
3. precise、recall等指标计算
|
||||
func: compute_one2one_pr(pickpath)
|
||||
|
||||
|
||||
@author: ym
|
||||
|
||||
"""
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import pickle
|
||||
import json
|
||||
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[1] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT))
|
||||
|
||||
from utils.calsimi import calsimi_vs_stdfeat, calsimi_vs_stdfeat_new
|
||||
from utils.tools import get_evtList, init_eventDict
|
||||
from utils.databits import data_precision_compare
|
||||
from genfeats import gen_bcd_features
|
||||
|
||||
|
||||
def build_std_evt_dict():
|
||||
'''
|
||||
stdFeaturePath: 标准特征集地址
|
||||
eventDataPath: Event对象地址
|
||||
'''
|
||||
|
||||
stdBarcode = [p.stem for p in Path(stdFeaturePath).iterdir() if p.is_file() and (p.suffix=='.json' or p.suffix=='.pickle')]
|
||||
|
||||
'''======1. 购物事件列表,该列表中的 Barcode 存在于标准的 stdBarcode 内 ==='''
|
||||
evtList = [(p.stem, p.stem.split('_')[-1]) for p in Path(eventDataPath).iterdir()
|
||||
if p.is_file()
|
||||
and p.suffix=='.pickle'
|
||||
and (len(p.stem.split('_'))==2 or len(p.stem.split('_'))==3)
|
||||
and p.stem.split('_')[-1].isdigit()
|
||||
and p.stem.split('_')[-1] in stdBarcode
|
||||
]
|
||||
barcodes = set([bcd for _, bcd in evtList])
|
||||
|
||||
'''======2. 构建用于比对的标准特征字典 ============='''
|
||||
stdDict = {}
|
||||
for stdfile in os.listdir(stdFeaturePath):
|
||||
barcode, ext = os.path.splitext(stdfile)
|
||||
if barcode not in barcodes:
|
||||
continue
|
||||
stdpath = os.path.join(stdFeaturePath, stdfile)
|
||||
|
||||
if ext == ".json":
|
||||
with open(stdpath, 'r', encoding='utf-8') as f:
|
||||
stddata = json.load(f)
|
||||
feat = np.array(stddata["value"])
|
||||
stdDict[barcode] = feat
|
||||
if ext == ".pickle":
|
||||
with open(stdpath, 'rb') as f:
|
||||
stddata = pickle.load(f)
|
||||
feat = stddata["feats_ft32"]
|
||||
stdDict[barcode] = feat
|
||||
|
||||
'''*********** USearch ***********'''
|
||||
# stdDict = {}
|
||||
# for barcode in barcodes:
|
||||
# stdDict[barcode] = stdlib[barcode]
|
||||
|
||||
'''======3. 构建用于比对的操作事件字典 ============='''
|
||||
evtDict = {}
|
||||
for evtname, barcode in evtList:
|
||||
evtpath = os.path.join(eventDataPath, evtname+'.pickle')
|
||||
try:
|
||||
with open(evtpath, 'rb') as f:
|
||||
evtdata = pickle.load(f)
|
||||
except Exception as e:
|
||||
print(evtname)
|
||||
|
||||
evtDict[evtname] = evtdata
|
||||
|
||||
return evtList, evtDict, stdDict
|
||||
|
||||
def one2SN_pr(evtList, evtDict, stdDict, simType="simple"):
|
||||
|
||||
std_barcodes = set([bcd for _, bcd in evtList])
|
||||
|
||||
|
||||
tp_events, fn_events, fp_events, tn_events = [], [], [], []
|
||||
tp_simi, fn_simi, tn_simi, fp_simi = [], [], [], []
|
||||
errorFile_one2SN = []
|
||||
|
||||
SN = 9
|
||||
for evtname, barcode in evtList:
|
||||
bcd_selected = [barcode]
|
||||
|
||||
dset = list(std_barcodes - set([barcode]))
|
||||
if len(dset) > SN:
|
||||
random.shuffle(dset)
|
||||
bcd_selected.extend(dset[:SN])
|
||||
else:
|
||||
bcd_selected.extend(dset)
|
||||
|
||||
event = evtDict[evtname]
|
||||
## 无轨迹判断
|
||||
if len(event.front_feats)+len(event.back_feats)==0:
|
||||
errorFile_one2SN.append(evtname)
|
||||
print(f"No trajectory: {evtname}")
|
||||
continue
|
||||
|
||||
barcodes, similars = [], []
|
||||
for stdbcd in bcd_selected:
|
||||
stdfeat = stdDict[stdbcd]
|
||||
|
||||
if simType=="typea":
|
||||
simi_mean, simi_max, simi_mfeat = calsimi_vs_stdfeat(event, stdfeat)
|
||||
elif simType=="typeb":
|
||||
pass
|
||||
else:
|
||||
simi_mean, simi_1, simi_2 = calsimi_vs_stdfeat_new(event, stdfeat)
|
||||
|
||||
|
||||
## 在event.front_feats和event.back_feats同时为空时,此处不需要保护
|
||||
# if simi_mean==None:
|
||||
# continue
|
||||
|
||||
barcodes.append(stdbcd)
|
||||
similars.append(simi_mean)
|
||||
|
||||
## 此处不需要保护
|
||||
# if len(similars)==0:
|
||||
# print(evtname)
|
||||
# continue
|
||||
|
||||
max_idx = similars.index(max(similars))
|
||||
max_sim = similars[max_idx]
|
||||
for i in range(len(barcodes)):
|
||||
bcd, simi = barcodes[i], similars[i]
|
||||
if bcd==barcode and simi==max_sim:
|
||||
tp_simi.append(simi)
|
||||
tp_events.append(evtname)
|
||||
elif bcd==barcode and simi!=max_sim:
|
||||
fn_simi.append(simi)
|
||||
fn_events.append(evtname)
|
||||
elif bcd!=barcode and simi!=max_sim:
|
||||
tn_simi.append(simi)
|
||||
tn_events.append(evtname)
|
||||
elif bcd!=barcode and simi==max_sim and barcode in barcodes:
|
||||
fp_simi.append(simi)
|
||||
fp_events.append(evtname)
|
||||
else:
|
||||
errorFile_one2SN.append(evtname)
|
||||
|
||||
PPreciseX, PRecallX = [], []
|
||||
NPreciseX, NRecallX = [], []
|
||||
Thresh = np.linspace(-0.2, 1, 100)
|
||||
for th in Thresh:
|
||||
'''适用于 (Precise, Recall) 计算方式:多个相似度计算并排序,barcode相等且排名第一为 TP '''
|
||||
'''===================================== 1:SN '''
|
||||
TPX = sum(np.array(tp_simi) >= th)
|
||||
FPX = sum(np.array(fp_simi) >= th)
|
||||
FNX = sum(np.array(fn_simi) < th)
|
||||
TNX = sum(np.array(tn_simi) < th)
|
||||
PPreciseX.append(TPX/(TPX+FPX+1e-6))
|
||||
PRecallX.append(TPX/(TPX+FNX+1e-6))
|
||||
|
||||
NPreciseX.append(TNX/(TNX+FNX+1e-6))
|
||||
NRecallX.append(TNX/(TNX+FPX+1e-6))
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(Thresh, PPreciseX, 'r', label='Precise_Pos: TP/TPFP')
|
||||
ax.plot(Thresh, PRecallX, 'b', label='Recall_Pos: TP/TPFN')
|
||||
ax.plot(Thresh, NPreciseX, 'g', label='Precise_Neg: TN/TNFP')
|
||||
ax.plot(Thresh, NRecallX, 'c', label='Recall_Neg: TN/TNFN')
|
||||
ax.set_xlim([0, 1])
|
||||
ax.set_ylim([0, 1])
|
||||
ax.set_xticks(np.arange(0, 1, 0.1))
|
||||
ax.set_yticks(np.arange(0, 1, 0.1))
|
||||
ax.grid(True, linestyle='--')
|
||||
ax.set_title('1:SN Precise & Recall')
|
||||
ax.set_xlabel(f"Event Num: {len(tp_events) + len(fn_events)}")
|
||||
ax.legend()
|
||||
plt.show()
|
||||
|
||||
rltpath = os.path.join(similPath, f'pr_1toSN_{simType}.png')
|
||||
plt.savefig(rltpath)
|
||||
|
||||
## ============================= 1:N 展厅 直方图'''
|
||||
fig, axes = plt.subplots(2, 2)
|
||||
axes[0, 0].hist(tp_simi, bins=60, range=(-0.2, 1), edgecolor='black')
|
||||
axes[0, 0].set_xlim([-0.2, 1])
|
||||
axes[0, 0].set_title(f'TP({len(tp_simi)})')
|
||||
axes[0, 1].hist(fp_simi, bins=60, range=(-0.2, 1), edgecolor='black')
|
||||
axes[0, 1].set_xlim([-0.2, 1])
|
||||
axes[0, 1].set_title(f'FP({len(fp_simi)})')
|
||||
axes[1, 0].hist(tn_simi, bins=60, range=(-0.2, 1), edgecolor='black')
|
||||
axes[1, 0].set_xlim([-0.2, 1])
|
||||
axes[1, 0].set_title(f'TN({len(tn_simi)})')
|
||||
axes[1, 1].hist(fn_simi, bins=60, range=(-0.2, 1), edgecolor='black')
|
||||
axes[1, 1].set_xlim([-0.2, 1])
|
||||
axes[1, 1].set_title(f'FN({len(fn_simi)})')
|
||||
plt.show()
|
||||
|
||||
rltpath = os.path.join(similPath, f'hist_1toSN_{simType}.png')
|
||||
plt.savefig(rltpath)
|
||||
|
||||
|
||||
|
||||
|
||||
def one2one_simi(evtList, evtDict, stdDict, simType):
|
||||
|
||||
barcodes = set([bcd for _, bcd in evtList])
|
||||
'''======1 构造 3 个事件对: 扫 A 放 A, 扫 A 放 B, 合并 ===================='''
|
||||
AA_list = [(evtname, barcode, "same") for evtname, barcode in evtList]
|
||||
AB_list = []
|
||||
for evtname, barcode in evtList:
|
||||
dset = list(barcodes.symmetric_difference(set([barcode])))
|
||||
if len(dset):
|
||||
idx = random.randint(0, len(dset)-1)
|
||||
AB_list.append((evtname, dset[idx], "diff"))
|
||||
|
||||
mergePairs = AA_list + AB_list
|
||||
|
||||
'''======2 计算事件、标准特征集相似度 =================='''
|
||||
rltdata = []
|
||||
errorFile_one2one = []
|
||||
for i in range(len(mergePairs)):
|
||||
evtname, stdbcd, label = mergePairs[i]
|
||||
event = evtDict[evtname]
|
||||
if len(event.feats_compose)==0:
|
||||
errorFile_one2one.append(evtname)
|
||||
|
||||
continue
|
||||
|
||||
stdfeat = stdDict[stdbcd] # float32
|
||||
|
||||
if simType=="typea":
|
||||
simi_mean, simi_1, simi_2 = calsimi_vs_stdfeat_new(event, stdfeat)
|
||||
elif simType=="typeb":
|
||||
pass
|
||||
else:
|
||||
simi_mean, simi_1, simi_2 = calsimi_vs_stdfeat(event, stdfeat)
|
||||
|
||||
if simi_mean is None:
|
||||
continue
|
||||
|
||||
rltdata.append((label, stdbcd, evtname, simi_mean, simi_1, simi_2))
|
||||
|
||||
'''================ float32、16、int8 精度比较与存储 ============='''
|
||||
# data_precision_compare(stdfeat, evtfeat, mergePairs[i], similPath, save=True)
|
||||
|
||||
errorFile_one2one = list(set(errorFile_one2one))
|
||||
|
||||
return rltdata, errorFile_one2one
|
||||
|
||||
|
||||
def one2one_pr(evtList, evtDict, stdDict, simType="simple"):
|
||||
|
||||
rltdata, errorFile_one2one = one2one_simi(evtList, evtDict, stdDict, simType)
|
||||
|
||||
Same, Cross = [], []
|
||||
|
||||
for label, stdbcd, evtname, simi_mean, simi_max, simi_mft in rltdata:
|
||||
if simType=="simple" and label == "same":
|
||||
Same.append(simi_max)
|
||||
if simType=="simple" and label == "diff":
|
||||
Cross.append(simi_max)
|
||||
|
||||
if simType=="typea" and label == "same":
|
||||
Same.append(simi_mean)
|
||||
if simType=="typea" and label == "diff":
|
||||
Cross.append(simi_mean)
|
||||
|
||||
|
||||
# for label, stdbcd, evtname, simi_mean, simi_max, simi_mft in rltdata:
|
||||
# if label == "same":
|
||||
# Same.append(simi_mean)
|
||||
# if label == "diff":
|
||||
# Cross.append(simi_mean)
|
||||
|
||||
Same = np.array(Same)
|
||||
Cross = np.array(Cross)
|
||||
TPFN = len(Same)
|
||||
TNFP = len(Cross)
|
||||
|
||||
# fig, axs = plt.subplots(2, 1)
|
||||
# axs[0].hist(Same, bins=60, range=(-0.2, 1), edgecolor='black')
|
||||
# axs[0].set_xlim([-0.2, 1])
|
||||
# axs[0].set_title(f'Same Barcode, Num: {TPFN}')
|
||||
|
||||
# axs[1].hist(Cross, bins=60, range=(-0.2, 1), edgecolor='black')
|
||||
# axs[1].set_xlim([-0.2, 1])
|
||||
# axs[1].set_title(f'Cross Barcode, Num: {TNFP}')
|
||||
# plt.savefig(f'./result/{file}_hist.png') # svg, png, pdf
|
||||
|
||||
|
||||
Recall_Pos, Recall_Neg = [], []
|
||||
Precision_Pos, Precision_Neg = [], []
|
||||
Correct = []
|
||||
Thresh = np.linspace(-0.2, 1, 100)
|
||||
for th in Thresh:
|
||||
TP = np.sum(Same >= th)
|
||||
FN = np.sum(Same < th)
|
||||
# FN = TPFN - TP
|
||||
|
||||
TN = np.sum(Cross < th)
|
||||
FP = np.sum(Cross >= th)
|
||||
# FP = TNFP - TN
|
||||
|
||||
|
||||
Precision_Pos.append(TP/(TP+FP+1e-6))
|
||||
Precision_Neg.append(TN/(TN+FN+1e-6))
|
||||
Recall_Pos.append(TP/(TP+FN+1e-6))
|
||||
Recall_Neg.append(TN/(TN+FP+1e-6))
|
||||
|
||||
# Recall_Pos.append(TP/TPFN)
|
||||
# Recall_Neg.append(TN/TNFP)
|
||||
|
||||
|
||||
Correct.append((TN+TP)/(TPFN+TNFP))
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.plot(Thresh, Precision_Pos, 'r', label='Precision_Pos: TP/(TP+FP)')
|
||||
ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN')
|
||||
ax.plot(Thresh, Recall_Neg, 'g', label='Recall_Neg: TN/TNFP')
|
||||
ax.plot(Thresh, Correct, 'c', label='Correct: (TN+TP)/(TPFN+TNFP)')
|
||||
ax.plot(Thresh, Precision_Neg, 'm', label='Precision_Neg: TN/(TN+FN)')
|
||||
|
||||
ax.set_xlim([0, 1])
|
||||
ax.set_ylim([0, 1])
|
||||
|
||||
ax.set_xticks(np.arange(0, 1, 0.1))
|
||||
ax.set_yticks(np.arange(0, 1, 0.1))
|
||||
ax.grid(True, linestyle='--')
|
||||
|
||||
ax.set_title('PrecisePos & PreciseNeg')
|
||||
ax.set_xlabel(f"Same Num: {TPFN}, Cross Num: {TNFP}")
|
||||
ax.legend()
|
||||
plt.show()
|
||||
|
||||
rltpath = os.path.join(similPath, f'pr_1to1_{simType}.png')
|
||||
plt.savefig(rltpath) # svg, png, pdf
|
||||
|
||||
|
||||
fig, axes = plt.subplots(2,1)
|
||||
axes[0].hist(Same, bins=60, range=(-0.2, 1), edgecolor='black')
|
||||
axes[0].set_xlim([-0.2, 1])
|
||||
axes[0].set_title(f'TP({len(Same)})')
|
||||
|
||||
axes[1].hist(Cross, bins=60, range=(-0.2, 1), edgecolor='black')
|
||||
axes[1].set_xlim([-0.2, 1])
|
||||
axes[1].set_title(f'TN({len(Cross)})')
|
||||
|
||||
rltpath = os.path.join(similPath, f'hist_1to1_{simType}.png')
|
||||
plt.savefig(rltpath)
|
||||
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def test_one2one_one2SN(simType):
|
||||
'''1:1性能评估'''
|
||||
|
||||
# evtpaths, bcdSet = get_evtList(eventSourcePath)
|
||||
|
||||
'''=== 1. 只需运行一次,生成事件对应的标准特征库字典,如已生成,无需运行 ===='''
|
||||
# gen_bcd_features(stdSamplePath, stdBarcodePath, stdFeaturePath, eventSourcePath)
|
||||
|
||||
'''==== 2. 生成事件字典, 只需运行一次 ===================='''
|
||||
# init_eventDict(eventSourcePath, eventDataPath, source_type)
|
||||
|
||||
'''==== 3. 基于事件barcode集和标准库barcode交集构造事件集合 ========='''
|
||||
evtList, evtDict, stdDict = build_std_evt_dict()
|
||||
|
||||
one2one_pr(evtList, evtDict, stdDict, simType)
|
||||
|
||||
one2SN_pr(evtList, evtDict, stdDict, simType)
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
共7个地址:
|
||||
(1) stdSamplePath: 用于生成比对标准特征集的原始图像地址
|
||||
(2) stdBarcodePath: 比对标准特征集原始图像地址的pickle文件存储,{barcode: [imgpath1, imgpath1, ...]}
|
||||
(3) stdFeaturePath: 比对标准特征集特征存储地址
|
||||
(4) eventSourcePath: 事件地址, 包含data文件的文件夹或 Yolo-Resnet-Tracker输出的Pickle文件父文件夹
|
||||
(5) resultPath: 结果存储地址
|
||||
(6) eventDataPath: 用于1:1比对的购物事件存储地址,在resultPath下
|
||||
(7) similPath: 1:1比对结果存储地址(事件级),在resultPath下
|
||||
'''
|
||||
|
||||
stdSamplePath = "/home/wqg/dataset/total_barcode/totalBarcode"
|
||||
stdBarcodePath = "/home/wqg/dataset/total_barcode/bcdpath"
|
||||
stdFeaturePath = "/home/wqg/dataset/test_dataset/total_barcode/features_json/v11_barcode_0304/"
|
||||
|
||||
if not os.path.exists(stdBarcodePath):
|
||||
os.makedirs(stdBarcodePath)
|
||||
if not os.path.exists(stdFeaturePath):
|
||||
os.makedirs(stdFeaturePath)
|
||||
|
||||
'''source_type:
|
||||
"source": eventSourcePath 为 Yolo-Resnet-Tracker 输出的 pickle 文件
|
||||
"data": 基于事件切分的原 data 文件版本
|
||||
"realtime": 全实时生成的 data 文件
|
||||
'''
|
||||
source_type = 'source' # 'source', 'data', 'realtime'
|
||||
simType = "typea" # "simple", "typea", "typeb"
|
||||
|
||||
evttype = "single_event_V10"
|
||||
# evttype = "single_event_V5"
|
||||
# evttype = "performence_V10"
|
||||
# evttype = "performence_V5"
|
||||
eventSourcePath = "/home/wqg/dataset/pipeline/yrt/{}/shopping_pkl".format(evttype)
|
||||
|
||||
resultPath = "/home/wqg/dataset/pipeline/contrast/{}".format(evttype)
|
||||
eventDataPath = os.path.join(resultPath, "evtobjs")
|
||||
similPath = os.path.join(resultPath, "simidata")
|
||||
if not os.path.exists(eventDataPath):
|
||||
os.makedirs(eventDataPath)
|
||||
if not os.path.exists(similPath):
|
||||
os.makedirs(similPath)
|
||||
|
||||
test_one2one_one2SN(simType)
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user