Files
detecttracking/contrast/one2one_contrast.py
2025-04-11 17:02:39 +08:00

452 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Fri Aug 30 17:53:03 2024
功能1:1比对性能测试程序
1. 基于标准特征集所对应的原始图像样本,生成标准特征集并保存。
func: generate_event_and_stdfeatures():
(1) get_std_barcodeDict(stdSamplePath, stdBarcodePath)
提取 stdSamplePath 中样本地址,生成字典{barcode: [imgpath1, imgpath1, ...]}
并存储为 pickle 文件barcode.pickle'''
(2) stdfeat_infer(stdBarcodePath, stdFeaturePath, bcdSet=None)
标准特征提取,并保存至文件夹 stdFeaturePath 中,
也可在运行过程中根据与购物事件集合 barcodes 交集执行
2. 1:1 比对性能测试,
func: one2one_simi()
(1) 求购物事件和标准特征级 Barcode 交集,构造 evtDict、stdDict
(2) 构造扫 A 放 A、扫 A 放 B 组合mergePairs = AA_list + AB_list
(3) 循环计算 mergePairs 中元素 "(A, A) 或 (A, B)" 相似度;
对于未保存的轨迹图像或标准 barcode 图像,保存图像
(4) 保存计算结果
3. precise、recall等指标计算
func: compute_one2one_pr(pickpath)
@author: ym
"""
import numpy as np
import os
import sys
import random
import pickle
import json
from pathlib import Path
import matplotlib.pyplot as plt
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from utils.calsimi import calsimi_vs_stdfeat, calsimi_vs_stdfeat_new
from utils.tools import get_evtList, init_eventDict
from utils.databits import data_precision_compare
from genfeats import gen_bcd_features
def build_std_evt_dict():
'''
stdFeaturePath: 标准特征集地址
eventDataPath: Event对象地址
'''
stdBarcode = [p.stem for p in Path(stdFeaturePath).iterdir() if p.is_file() and (p.suffix=='.json' or p.suffix=='.pickle')]
'''======1. 购物事件列表,该列表中的 Barcode 存在于标准的 stdBarcode 内 ==='''
evtList = [(p.stem, p.stem.split('_')[-1]) for p in Path(eventDataPath).iterdir()
if p.is_file()
and p.suffix=='.pickle'
and (len(p.stem.split('_'))==2 or len(p.stem.split('_'))==3)
and p.stem.split('_')[-1].isdigit()
and p.stem.split('_')[-1] in stdBarcode
]
barcodes = set([bcd for _, bcd in evtList])
'''======2. 构建用于比对的标准特征字典 ============='''
stdDict = {}
for stdfile in os.listdir(stdFeaturePath):
barcode, ext = os.path.splitext(stdfile)
if barcode not in barcodes:
continue
stdpath = os.path.join(stdFeaturePath, stdfile)
if ext == ".json":
with open(stdpath, 'r', encoding='utf-8') as f:
stddata = json.load(f)
feat = np.array(stddata["value"])
stdDict[barcode] = feat
if ext == ".pickle":
with open(stdpath, 'rb') as f:
stddata = pickle.load(f)
feat = stddata["feats_ft32"]
stdDict[barcode] = feat
'''*********** USearch ***********'''
# stdDict = {}
# for barcode in barcodes:
# stdDict[barcode] = stdlib[barcode]
'''======3. 构建用于比对的操作事件字典 ============='''
evtDict = {}
for evtname, barcode in evtList:
evtpath = os.path.join(eventDataPath, evtname+'.pickle')
try:
with open(evtpath, 'rb') as f:
evtdata = pickle.load(f)
except Exception as e:
print(evtname)
evtDict[evtname] = evtdata
return evtList, evtDict, stdDict
def one2SN_pr(evtList, evtDict, stdDict, simType="simple"):
std_barcodes = set([bcd for _, bcd in evtList])
tp_events, fn_events, fp_events, tn_events = [], [], [], []
tp_simi, fn_simi, tn_simi, fp_simi = [], [], [], []
errorFile_one2SN = []
SN = 9
for evtname, barcode in evtList:
bcd_selected = [barcode]
dset = list(std_barcodes - set([barcode]))
if len(dset) > SN:
random.shuffle(dset)
bcd_selected.extend(dset[:SN])
else:
bcd_selected.extend(dset)
event = evtDict[evtname]
## 无轨迹判断
if len(event.front_feats)+len(event.back_feats)==0:
errorFile_one2SN.append(evtname)
print(f"No trajectory: {evtname}")
continue
barcodes, similars = [], []
for stdbcd in bcd_selected:
stdfeat = stdDict[stdbcd]
if simType=="typea":
simi_mean, simi_max, simi_mfeat = calsimi_vs_stdfeat(event, stdfeat)
elif simType=="typeb":
pass
else:
simi_mean, simi_1, simi_2 = calsimi_vs_stdfeat_new(event, stdfeat)
## 在event.front_feats和event.back_feats同时为空时此处不需要保护
# if simi_mean==None:
# continue
barcodes.append(stdbcd)
similars.append(simi_mean)
## 此处不需要保护
# if len(similars)==0:
# print(evtname)
# continue
max_idx = similars.index(max(similars))
max_sim = similars[max_idx]
for i in range(len(barcodes)):
bcd, simi = barcodes[i], similars[i]
if bcd==barcode and simi==max_sim:
tp_simi.append(simi)
tp_events.append(evtname)
elif bcd==barcode and simi!=max_sim:
fn_simi.append(simi)
fn_events.append(evtname)
elif bcd!=barcode and simi!=max_sim:
tn_simi.append(simi)
tn_events.append(evtname)
elif bcd!=barcode and simi==max_sim and barcode in barcodes:
fp_simi.append(simi)
fp_events.append(evtname)
else:
errorFile_one2SN.append(evtname)
PPreciseX, PRecallX = [], []
NPreciseX, NRecallX = [], []
Thresh = np.linspace(-0.2, 1, 100)
for th in Thresh:
'''适用于 (Precise, Recall) 计算方式多个相似度计算并排序barcode相等且排名第一为 TP '''
'''===================================== 1:SN '''
TPX = sum(np.array(tp_simi) >= th)
FPX = sum(np.array(fp_simi) >= th)
FNX = sum(np.array(fn_simi) < th)
TNX = sum(np.array(tn_simi) < th)
PPreciseX.append(TPX/(TPX+FPX+1e-6))
PRecallX.append(TPX/(TPX+FNX+1e-6))
NPreciseX.append(TNX/(TNX+FNX+1e-6))
NRecallX.append(TNX/(TNX+FPX+1e-6))
fig, ax = plt.subplots()
ax.plot(Thresh, PPreciseX, 'r', label='Precise_Pos: TP/TPFP')
ax.plot(Thresh, PRecallX, 'b', label='Recall_Pos: TP/TPFN')
ax.plot(Thresh, NPreciseX, 'g', label='Precise_Neg: TN/TNFP')
ax.plot(Thresh, NRecallX, 'c', label='Recall_Neg: TN/TNFN')
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_xticks(np.arange(0, 1, 0.1))
ax.set_yticks(np.arange(0, 1, 0.1))
ax.grid(True, linestyle='--')
ax.set_title('1:SN Precise & Recall')
ax.set_xlabel(f"Event Num: {len(tp_events) + len(fn_events)}")
ax.legend()
plt.show()
rltpath = os.path.join(similPath, f'pr_1toSN_{simType}.png')
plt.savefig(rltpath)
## ============================= 1:N 展厅 直方图'''
fig, axes = plt.subplots(2, 2)
axes[0, 0].hist(tp_simi, bins=60, range=(-0.2, 1), edgecolor='black')
axes[0, 0].set_xlim([-0.2, 1])
axes[0, 0].set_title(f'TP({len(tp_simi)})')
axes[0, 1].hist(fp_simi, bins=60, range=(-0.2, 1), edgecolor='black')
axes[0, 1].set_xlim([-0.2, 1])
axes[0, 1].set_title(f'FP({len(fp_simi)})')
axes[1, 0].hist(tn_simi, bins=60, range=(-0.2, 1), edgecolor='black')
axes[1, 0].set_xlim([-0.2, 1])
axes[1, 0].set_title(f'TN({len(tn_simi)})')
axes[1, 1].hist(fn_simi, bins=60, range=(-0.2, 1), edgecolor='black')
axes[1, 1].set_xlim([-0.2, 1])
axes[1, 1].set_title(f'FN({len(fn_simi)})')
plt.show()
rltpath = os.path.join(similPath, f'hist_1toSN_{simType}.png')
plt.savefig(rltpath)
def one2one_simi(evtList, evtDict, stdDict, simType):
barcodes = set([bcd for _, bcd in evtList])
'''======1 构造 3 个事件对: 扫 A 放 A, 扫 A 放 B, 合并 ===================='''
AA_list = [(evtname, barcode, "same") for evtname, barcode in evtList]
AB_list = []
for evtname, barcode in evtList:
dset = list(barcodes.symmetric_difference(set([barcode])))
if len(dset):
idx = random.randint(0, len(dset)-1)
AB_list.append((evtname, dset[idx], "diff"))
mergePairs = AA_list + AB_list
'''======2 计算事件、标准特征集相似度 =================='''
rltdata = []
errorFile_one2one = []
for i in range(len(mergePairs)):
evtname, stdbcd, label = mergePairs[i]
event = evtDict[evtname]
if len(event.feats_compose)==0:
errorFile_one2one.append(evtname)
continue
stdfeat = stdDict[stdbcd] # float32
if simType=="typea":
simi_mean, simi_1, simi_2 = calsimi_vs_stdfeat_new(event, stdfeat)
elif simType=="typeb":
pass
else:
simi_mean, simi_1, simi_2 = calsimi_vs_stdfeat(event, stdfeat)
if simi_mean is None:
continue
rltdata.append((label, stdbcd, evtname, simi_mean, simi_1, simi_2))
'''================ float32、16、int8 精度比较与存储 ============='''
# data_precision_compare(stdfeat, evtfeat, mergePairs[i], similPath, save=True)
errorFile_one2one = list(set(errorFile_one2one))
return rltdata, errorFile_one2one
def one2one_pr(evtList, evtDict, stdDict, simType="simple"):
rltdata, errorFile_one2one = one2one_simi(evtList, evtDict, stdDict, simType)
Same, Cross = [], []
for label, stdbcd, evtname, simi_mean, simi_max, simi_mft in rltdata:
if simType=="simple" and label == "same":
Same.append(simi_max)
if simType=="simple" and label == "diff":
Cross.append(simi_max)
if simType=="typea" and label == "same":
Same.append(simi_mean)
if simType=="typea" and label == "diff":
Cross.append(simi_mean)
# for label, stdbcd, evtname, simi_mean, simi_max, simi_mft in rltdata:
# if label == "same":
# Same.append(simi_mean)
# if label == "diff":
# Cross.append(simi_mean)
Same = np.array(Same)
Cross = np.array(Cross)
TPFN = len(Same)
TNFP = len(Cross)
# fig, axs = plt.subplots(2, 1)
# axs[0].hist(Same, bins=60, range=(-0.2, 1), edgecolor='black')
# axs[0].set_xlim([-0.2, 1])
# axs[0].set_title(f'Same Barcode, Num: {TPFN}')
# axs[1].hist(Cross, bins=60, range=(-0.2, 1), edgecolor='black')
# axs[1].set_xlim([-0.2, 1])
# axs[1].set_title(f'Cross Barcode, Num: {TNFP}')
# plt.savefig(f'./result/{file}_hist.png') # svg, png, pdf
Recall_Pos, Recall_Neg = [], []
Precision_Pos, Precision_Neg = [], []
Correct = []
Thresh = np.linspace(-0.2, 1, 100)
for th in Thresh:
TP = np.sum(Same >= th)
FN = np.sum(Same < th)
# FN = TPFN - TP
TN = np.sum(Cross < th)
FP = np.sum(Cross >= th)
# FP = TNFP - TN
Precision_Pos.append(TP/(TP+FP+1e-6))
Precision_Neg.append(TN/(TN+FN+1e-6))
Recall_Pos.append(TP/(TP+FN+1e-6))
Recall_Neg.append(TN/(TN+FP+1e-6))
# Recall_Pos.append(TP/TPFN)
# Recall_Neg.append(TN/TNFP)
Correct.append((TN+TP)/(TPFN+TNFP))
fig, ax = plt.subplots()
ax.plot(Thresh, Precision_Pos, 'r', label='Precision_Pos: TP/(TP+FP)')
ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN')
ax.plot(Thresh, Recall_Neg, 'g', label='Recall_Neg: TN/TNFP')
ax.plot(Thresh, Correct, 'c', label='Correct: (TN+TP)/(TPFN+TNFP)')
ax.plot(Thresh, Precision_Neg, 'm', label='Precision_Neg: TN/(TN+FN)')
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_xticks(np.arange(0, 1, 0.1))
ax.set_yticks(np.arange(0, 1, 0.1))
ax.grid(True, linestyle='--')
ax.set_title('PrecisePos & PreciseNeg')
ax.set_xlabel(f"Same Num: {TPFN}, Cross Num: {TNFP}")
ax.legend()
plt.show()
rltpath = os.path.join(similPath, f'pr_1to1_{simType}.png')
plt.savefig(rltpath) # svg, png, pdf
fig, axes = plt.subplots(2,1)
axes[0].hist(Same, bins=60, range=(-0.2, 1), edgecolor='black')
axes[0].set_xlim([-0.2, 1])
axes[0].set_title(f'TP({len(Same)})')
axes[1].hist(Cross, bins=60, range=(-0.2, 1), edgecolor='black')
axes[1].set_xlim([-0.2, 1])
axes[1].set_title(f'TN({len(Cross)})')
rltpath = os.path.join(similPath, f'hist_1to1_{simType}.png')
plt.savefig(rltpath)
plt.show()
def test_one2one_one2SN(simType):
'''1:1性能评估'''
# evtpaths, bcdSet = get_evtList(eventSourcePath)
'''=== 1. 只需运行一次,生成事件对应的标准特征库字典,如已生成,无需运行 ===='''
# gen_bcd_features(stdSamplePath, stdBarcodePath, stdFeaturePath, eventSourcePath)
'''==== 2. 生成事件字典, 只需运行一次 ===================='''
# init_eventDict(eventSourcePath, eventDataPath, source_type)
'''==== 3. 基于事件barcode集和标准库barcode交集构造事件集合 ========='''
evtList, evtDict, stdDict = build_std_evt_dict()
one2one_pr(evtList, evtDict, stdDict, simType)
one2SN_pr(evtList, evtDict, stdDict, simType)
if __name__ == '__main__':
'''
共7个地址
(1) stdSamplePath: 用于生成比对标准特征集的原始图像地址
(2) stdBarcodePath: 比对标准特征集原始图像地址的pickle文件存储{barcode: [imgpath1, imgpath1, ...]}
(3) stdFeaturePath: 比对标准特征集特征存储地址
(4) eventSourcePath: 事件地址, 包含data文件的文件夹或 Yolo-Resnet-Tracker输出的Pickle文件父文件夹
(5) resultPath: 结果存储地址
(6) eventDataPath: 用于1:1比对的购物事件存储地址在resultPath下
(7) similPath: 1:1比对结果存储地址(事件级)在resultPath下
'''
stdSamplePath = "/home/wqg/dataset/total_barcode/totalBarcode"
stdBarcodePath = "/home/wqg/dataset/total_barcode/bcdpath"
stdFeaturePath = "/home/wqg/dataset/test_dataset/total_barcode/features_json/v11_barcode_0304/"
if not os.path.exists(stdBarcodePath):
os.makedirs(stdBarcodePath)
if not os.path.exists(stdFeaturePath):
os.makedirs(stdFeaturePath)
'''source_type:
"source": eventSourcePath 为 Yolo-Resnet-Tracker 输出的 pickle 文件
"data": 基于事件切分的原 data 文件版本
"realtime": 全实时生成的 data 文件
'''
source_type = 'source' # 'source', 'data', 'realtime'
simType = "typea" # "simple", "typea", "typeb"
evttype = "single_event_V10"
# evttype = "single_event_V5"
# evttype = "performence_V10"
# evttype = "performence_V5"
eventSourcePath = "/home/wqg/dataset/pipeline/yrt/{}/shopping_pkl".format(evttype)
resultPath = "/home/wqg/dataset/pipeline/contrast/{}".format(evttype)
eventDataPath = os.path.join(resultPath, "evtobjs")
similPath = os.path.join(resultPath, "simidata")
if not os.path.exists(eventDataPath):
os.makedirs(eventDataPath)
if not os.path.exists(similPath):
os.makedirs(similPath)
test_one2one_one2SN(simType)