Files
detecttracking/contrast/one2one_contrast.py
2024-12-10 19:01:54 +08:00

541 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Fri Aug 30 17:53:03 2024
功能1:1比对性能测试程序
1. 基于标准特征集所对应的原始图像样本,生成标准特征集并保存。
func: generate_event_and_stdfeatures():
(1) get_std_barcodeDict(stdSamplePath, stdBarcodePath)
提取 stdSamplePath 中样本地址,生成字典{barcode: [imgpath1, imgpath1, ...]}
并存储为 pickle 文件barcode.pickle'''
(2) stdfeat_infer(stdBarcodePath, stdFeaturePath, bcdSet=None)
标准特征提取,并保存至文件夹 stdFeaturePath 中,
也可在运行过程中根据与购物事件集合 barcodes 交集执行
2. 1:1 比对性能测试,
func: one2one_eval(similPath)
(1) 求购物事件和标准特征级 Barcode 交集,构造 evtDict、stdDict
(2) 构造扫 A 放 A、扫 A 放 B 组合mergePairs = AA_list + AB_list
(3) 循环计算 mergePairs 中元素 "(A, A) 或 (A, B)" 相似度;
对于未保存的轨迹图像或标准 barcode 图像,保存图像
(4) 保存计算结果
3. precise、recall等指标计算
func: compute_precise_recall(pickpath)
@author: ym
"""
import numpy as np
import cv2
import os
import sys
import random
import pickle
import json
# import torch
import time
# import json
from pathlib import Path
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import shutil
from datetime import datetime
# from openpyxl import load_workbook, Workbook
# from config import config as conf
# from model import resnet18 as resnet18
# from feat_inference import inference_image
sys.path.append(r"D:\DetectTracking")
from tracking.utils.read_data import extract_data, read_tracking_output, read_similar, read_deletedBarcode_file
from tracking.utils.plotting import Annotator, colors
from feat_extract.config import config as conf
from feat_extract.inference import FeatsInterface
from utils.event import ShoppingEvent
from genfeats import gen_bcd_features
def int8_to_ft16(arr_uint8, amin, amax):
arr_ft16 = (arr_uint8 / 255 * (amax-amin) + amin).astype(np.float16)
return arr_ft16
def ft16_to_uint8(arr_ft16):
# pickpath = r"\\192.168.1.28\share\测试_202406\contrast\std_features_ft32vsft16\6902265587712_ft16.pickle"
# with open(pickpath, 'rb') as f:
# edict = pickle.load(f)
# arr_ft16 = edict['feats']
amin = np.min(arr_ft16)
amax = np.max(arr_ft16)
arr_ft255 = (arr_ft16 - amin) * 255 / (amax-amin)
arr_uint8 = arr_ft255.astype(np.uint8)
arr_ft16_ = int8_to_ft16(arr_uint8, amin, amax)
arrDistNorm = np.linalg.norm(arr_ft16_ - arr_ft16) / arr_ft16_.size
return arr_uint8, arr_ft16_
def plot_save_image(event, savepath):
cameras = ('front', 'back')
for camera in cameras:
if camera == 'front':
boxes = event.front_trackerboxes
imgpaths = event.front_imgpaths
else:
boxes = event.back_trackerboxes
imgpaths = event.back_imgpaths
def array2list(bboxes):
'''[x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]'''
frame_ids = bboxes[:, 7].astype(int)
fID = np.unique(bboxes[:, 7].astype(int))
fboxes = []
for f_id in fID:
idx = np.where(frame_ids==f_id)[0]
box = bboxes[idx, :]
fboxes.append((f_id, box))
return fboxes
fboxes = array2list(boxes)
for fid, fbox in fboxes:
imgpath = imgpaths[int(fid-1)]
image = cv2.imread(imgpath)
annotator = Annotator(image.copy(), line_width=2)
for i, *xyxy, tid, score, cls, fid, bid in enumerate(fbox):
label = f'{int(id), int(cls)}'
if tid >=0 and cls==0:
color = colors(int(cls), True)
elif tid >=0 and cls!=0:
color = colors(int(id), True)
else:
color = colors(19, True) # 19为调色板的最后一个元素
annotator.box_label(xyxy, label, color=color)
im0 = annotator.result()
spath = os.path.join(savepath, Path(imgpath).name)
cv2.imwrite(spath, im0)
def save_event_subimg(event, savepath):
'''
功能: 保存一次购物事件的轨迹子图
9 items: barcode, type, filepath, back_imgpaths, front_imgpaths,
back_boxes, front_boxes, back_feats, front_feats,
feats_compose, feats_select
子图保存次序:先前摄、后后摄,以 k 为编号,和 "feats_compose" 中次序相同
'''
cameras = ('front', 'back')
for camera in cameras:
if camera == 'front':
boxes = event.front_boxes
imgpaths = event.front_imgpaths
else:
boxes = event.back_boxes
imgpaths = event.back_imgpaths
for i, box in enumerate(boxes):
x1, y1, x2, y2, tid, score, cls, fid, bid = box
imgpath = imgpaths[int(fid-1)]
image = cv2.imread(imgpath)
subimg = image[int(y1/2):int(y2/2), int(x1/2):int(x2/2), :]
camerType, timeTamp, _, frameID = os.path.basename(imgpath).split('.')[0].split('_')
subimgName = f"cam{camerType}_{i}_tid{int(tid)}_fid({int(fid)}, {frameID}).png"
spath = os.path.join(savepath, subimgName)
cv2.imwrite(spath, subimg)
# basename = os.path.basename(event['filepath'])
print(f"Image saved: {os.path.basename(event.eventpath)}")
def data_precision_compare(stdfeat, evtfeat, evtMessage, save=True):
evt, stdbcd, label = evtMessage
rltdata, rltdata_ft16, rltdata_ft16_ = [], [], []
matrix = 1 - cdist(stdfeat, evtfeat, 'cosine')
simi_mean = np.mean(matrix)
simi_max = np.max(matrix)
stdfeatm = np.mean(stdfeat, axis=0, keepdims=True)
evtfeatm = np.mean(evtfeat, axis=0, keepdims=True)
simi_mfeat = 1- np.maximum(0.0, cdist(stdfeatm, evtfeatm, 'cosine'))
rltdata = [label, stdbcd, evt, simi_mean, simi_max, simi_mfeat[0,0]]
##================================================================= float16
stdfeat_ft16 = stdfeat.astype(np.float16)
evtfeat_ft16 = evtfeat.astype(np.float16)
stdfeat_ft16 /= np.linalg.norm(stdfeat_ft16, axis=1)[:, None]
evtfeat_ft16 /= np.linalg.norm(evtfeat_ft16, axis=1)[:, None]
matrix_ft16 = 1 - cdist(stdfeat_ft16, evtfeat_ft16, 'cosine')
simi_mean_ft16 = np.mean(matrix_ft16)
simi_max_ft16 = np.max(matrix_ft16)
stdfeatm_ft16 = np.mean(stdfeat_ft16, axis=0, keepdims=True)
evtfeatm_ft16 = np.mean(evtfeat_ft16, axis=0, keepdims=True)
simi_mfeat_ft16 = 1- np.maximum(0.0, cdist(stdfeatm_ft16, evtfeatm_ft16, 'cosine'))
rltdata_ft16 = [label, stdbcd, evt, simi_mean_ft16, simi_max_ft16, simi_mfeat_ft16[0,0]]
'''****************** uint8 is ok!!!!!! ******************'''
##=================================================================== uint8
# stdfeat_uint8, stdfeat_ft16_ = ft16_to_uint8(stdfeat_ft16)
# evtfeat_uint8, evtfeat_ft16_ = ft16_to_uint8(evtfeat_ft16)
stdfeat_uint8 = (stdfeat_ft16*128).astype(np.int8)
evtfeat_uint8 = (evtfeat_ft16*128).astype(np.int8)
stdfeat_ft16_ = stdfeat_uint8.astype(np.float16)/128
evtfeat_ft16_ = evtfeat_uint8.astype(np.float16)/128
absdiff = np.linalg.norm(stdfeat_ft16_ - stdfeat) / stdfeat.size
matrix_ft16_ = 1 - cdist(stdfeat_ft16_, evtfeat_ft16_, 'cosine')
simi_mean_ft16_ = np.mean(matrix_ft16_)
simi_max_ft16_ = np.max(matrix_ft16_)
stdfeatm_ft16_ = np.mean(stdfeat_ft16_, axis=0, keepdims=True)
evtfeatm_ft16_ = np.mean(evtfeat_ft16_, axis=0, keepdims=True)
simi_mfeat_ft16_ = 1- np.maximum(0.0, cdist(stdfeatm_ft16_, evtfeatm_ft16_, 'cosine'))
rltdata_ft16_ = [label, stdbcd, evt, simi_mean_ft16_, simi_max_ft16_, simi_mfeat_ft16_[0,0]]
if not save:
return
##========================================================= save as float32
rppath = os.path.join(similPath, f'{evt}_ft32.pickle')
with open(rppath, 'wb') as f:
pickle.dump(rltdata, f)
rtpath = os.path.join(similPath, f'{evt}_ft32.txt')
with open(rtpath, 'w', encoding='utf-8') as f:
for result in rltdata:
part = [f"{x:.3f}" if isinstance(x, float) else str(x) for x in result]
line = ', '.join(part)
f.write(line + '\n')
##========================================================= save as float16
rppath_ft16 = os.path.join(similPath, f'{evt}_ft16.pickle')
with open(rppath_ft16, 'wb') as f:
pickle.dump(rltdata_ft16, f)
rtpath_ft16 = os.path.join(similPath, f'{evt}_ft16.txt')
with open(rtpath_ft16, 'w', encoding='utf-8') as f:
for result in rltdata_ft16:
part = [f"{x:.3f}" if isinstance(x, float) else str(x) for x in result]
line = ', '.join(part)
f.write(line + '\n')
##=========================================================== save as uint8
rppath_uint8 = os.path.join(similPath, f'{evt}_uint8.pickle')
with open(rppath_uint8, 'wb') as f:
pickle.dump(rltdata_ft16_, f)
rtpath_uint8 = os.path.join(similPath, f'{evt}_uint8.txt')
with open(rtpath_uint8, 'w', encoding='utf-8') as f:
for result in rltdata_ft16_:
part = [f"{x:.3f}" if isinstance(x, float) else str(x) for x in result]
line = ', '.join(part)
f.write(line + '\n')
def one2one_simi():
'''
stdFeaturePath: 标准特征集地址
eventDataPath: Event对象地址
'''
stdBarcode = [p.stem for p in Path(stdFeaturePath).iterdir() if p.is_file() and p.suffix=='.pickle']
'''======1. 购物事件列表,该列表中的 Barcode 存在于标准的 stdBarcode 内 ==='''
evtList = [(p.stem, p.stem.split('_')[-1]) for p in Path(eventDataPath).iterdir()
if p.is_file()
and p.suffix=='.pickle'
and (len(p.stem.split('_'))==2 or len(p.stem.split('_'))==3)
and p.stem.split('_')[-1].isdigit()
and p.stem.split('_')[-1] in stdBarcode
]
barcodes = set([bcd for _, bcd in evtList])
'''======2. 构建用于比对的标准特征字典 ============='''
stdDict = {}
for barcode in barcodes:
stdpath = os.path.join(stdFeaturePath, barcode+'.pickle')
with open(stdpath, 'rb') as f:
stddata = pickle.load(f)
stdDict[barcode] = stddata
'''======3. 构建用于比对的操作事件字典 ============='''
evtDict = {}
for evtname, barcode in evtList:
evtpath = os.path.join(eventDataPath, evtname+'.pickle')
with open(evtpath, 'rb') as f:
evtdata = pickle.load(f)
evtDict[evtname] = evtdata
'''======4.1 事件轨迹子图保存 ======================'''
error_event = []
for evtname, event in evtDict.items():
pairpath = os.path.join(subimgPath, f"{evtname}")
if not os.path.exists(pairpath):
os.makedirs(pairpath)
try:
save_event_subimg(event, pairpath)
except Exception as e:
error_event.append(evtname)
img_path = os.path.join(imagePath, f"{evtname}")
if not os.path.exists(img_path):
os.makedirs(img_path)
try:
plot_save_image(event, img_path)
except Exception as e:
error_event.append(evtname)
errfile = os.path.join(subimgPath, f'error_event.txt')
with open(errfile, 'w', encoding='utf-8') as f:
for line in error_event:
f.write(line + '\n')
'''======4.2 barcode 标准图像保存 =================='''
# for stdbcd in barcodes:
# stdImgpath = stdDict[stdbcd]["imgpaths"]
# pstdpath = os.path.join(subimgPath, f"{stdbcd}")
# if not os.path.exists(pstdpath):
# os.makedirs(pstdpath)
# ii = 1
# for filepath in stdImgpath:
# stdpath = os.path.join(pstdpath, f"{stdbcd}_{ii}.png")
# shutil.copy2(filepath, stdpath)
# ii += 1
'''======5 构造 3 个事件对: 扫 A 放 A, 扫 A 放 B, 合并 ===================='''
AA_list = [(evtname, barcode, "same") for evtname, barcode in evtList]
AB_list = []
for evtname, barcode in evtList:
dset = list(barcodes.symmetric_difference(set([barcode])))
if len(dset):
idx = random.randint(0, len(dset)-1)
AB_list.append((evtname, dset[idx], "diff"))
mergePairs = AA_list + AB_list
'''======6 计算事件、标准特征集相似度 =================='''
rltdata = []
for i in range(len(mergePairs)):
evtname, stdbcd, label = mergePairs[i]
event = evtDict[evtname]
##============================================ float32
stdfeat = stdDict[stdbcd]["feats_ft32"]
evtfeat = event.feats_compose
if len(evtfeat)==0: continue
matrix = 1 - cdist(stdfeat, evtfeat, 'cosine')
matrix[matrix < 0] = 0
simi_mean = np.mean(matrix)
simi_max = np.max(matrix)
stdfeatm = np.mean(stdfeat, axis=0, keepdims=True)
evtfeatm = np.mean(evtfeat, axis=0, keepdims=True)
simi_mfeat = 1- np.maximum(0.0, cdist(stdfeatm, evtfeatm, 'cosine'))
rltdata.append((label, stdbcd, evtname, simi_mean, simi_max, simi_mfeat[0,0]))
'''================ float32、16、int8 精度比较与存储 ============='''
# data_precision_compare(stdfeat, evtfeat, mergePairs[i], save=True)
print("func: one2one_eval(), have finished!")
return rltdata
def compute_precise_recall(rltdata):
Same, Cross = [], []
for label, stdbcd, evtname, simi_mean, simi_max, simi_mft in rltdata:
if label == "same":
Same.append(simi_mean)
if label == "diff":
Cross.append(simi_mean)
Same = np.array(Same)
Cross = np.array(Cross)
TPFN = len(Same)
TNFP = len(Cross)
# fig, axs = plt.subplots(2, 1)
# axs[0].hist(Same, bins=60, edgecolor='black')
# axs[0].set_xlim([-0.2, 1])
# axs[0].set_title(f'Same Barcode, Num: {TPFN}')
# axs[1].hist(Cross, bins=60, edgecolor='black')
# axs[1].set_xlim([-0.2, 1])
# axs[1].set_title(f'Cross Barcode, Num: {TNFP}')
# plt.savefig(f'./result/{file}_hist.png') # svg, png, pdf
Recall_Pos, Recall_Neg = [], []
Precision_Pos, Precision_Neg = [], []
Correct = []
Thresh = np.linspace(-0.2, 1, 100)
for th in Thresh:
TP = np.sum(Same > th)
FN = TPFN - TP
TN = np.sum(Cross < th)
FP = TNFP - TN
Recall_Pos.append(TP/TPFN)
Recall_Neg.append(TN/TNFP)
Precision_Pos.append(TP/(TP+FP+1e-6))
Precision_Neg.append(TN/(TN+FN+1e-6))
Correct.append((TN+TP)/(TPFN+TNFP))
fig, ax = plt.subplots()
ax.plot(Thresh, Correct, 'r', label='Correct: (TN+TP)/(TPFN+TNFP)')
ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN')
ax.plot(Thresh, Recall_Neg, 'g', label='Recall_Neg: TN/TNFP')
ax.plot(Thresh, Precision_Pos, 'c', label='Precision_Pos: TP/(TP+FP)')
ax.plot(Thresh, Precision_Neg, 'm', label='Precision_Neg: TN/(TN+FN)')
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.grid(True)
ax.set_title('PrecisePos & PreciseNeg')
ax.set_xlabel(f"Same Num: {TPFN}, Cross Num: {TNFP}")
ax.legend()
plt.show()
rltpath = os.path.join(similPath, 'pr.png')
plt.savefig(rltpath) # svg, png, pdf
def gen_eventdict(sourcePath, saveimg=True):
eventList = []
errEvents = []
k = 0
for source_path in sourcePath:
bname = os.path.basename(source_path)
pickpath = os.path.join(eventDataPath, f"{bname}.pickle")
if os.path.isfile(pickpath): continue
try:
event = ShoppingEvent(source_path, stype="data")
eventList.append(event)
with open(pickpath, 'wb') as f:
pickle.dump(event, f)
print(bname)
except Exception as e:
errEvents.append(source_path)
print(e)
# k += 1
# if k==10:
# break
errfile = os.path.join(eventDataPath, f'error_events.txt')
with open(errfile, 'w', encoding='utf-8') as f:
for line in errEvents:
f.write(line + '\n')
def test_one2one():
bcdList, event_spath = [], []
for evtpath in eventSourcePath:
for evtname in os.listdir(evtpath):
evt = evtname.split('_')
dirpath = os.path.join(evtpath, evtname)
if os.path.isfile(dirpath): continue
if len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10:
bcdList.append(evt[-1])
event_spath.append(os.path.join(evtpath, evtname))
bcdSet = set(bcdList)
'''==== 1. 生成标准特征集, 只需运行一次, 在 genfeats.py 中实现 ==========='''
# gen_bcd_features(stdSamplePath, stdBarcodePath, stdFeaturePath, bcdSet)
print("stdFeats have generated and saved!")
'''==== 2. 生成事件字典, 只需运行一次 ==============='''
gen_eventdict(event_spath)
print("eventList have generated and saved!")
'''==== 3. 1:1性能评估 ==============='''
rltdata = one2one_simi()
compute_precise_recall(rltdata)
if __name__ == '__main__':
'''
共7个地址
(1) stdSamplePath: 用于生成比对标准特征集的原始图像地址
(2) stdBarcodePath: 比对标准特征集原始图像地址的pickle文件存储{barcode: [imgpath1, imgpath1, ...]}
(3) stdFeaturePath: 比对标准特征集特征存储地址
(4) eventSourcePath: 事件地址
(5) resultPath: 结果存储地址
(6) eventDataPath: 用于1:1比对的购物事件特征存储地址、对应子图存储地址
(7) subimgPath: 1:1比对购物事件轨迹、标准barcode所对应的 subimgs 存储地址
(8) similPath: 1:1比对结果存储地址(事件级)
'''
# stdSamplePath = r"\\192.168.1.28\share\已标注数据备份\对比数据\barcode\barcode_500_1979_已清洗"
# stdBarcodePath = r"\\192.168.1.28\share\测试_202406\contrast\std_barcodes_2192"
# stdFeaturePath = r"\\192.168.1.28\share\测试_202406\contrast\std_features_ft32"
# eventDataPath = r"\\192.168.1.28\share\测试_202406\contrast\events"
# subimgPath = r'\\192.168.1.28\share\测试_202406\contrast\subimgs'
# similPath = r"D:\DetectTracking\contrast\result\pickle"
# eventSourcePath = [r'\\192.168.1.28\share\测试_202406\1101\images']
stdSamplePath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v1.0\比对数据\整理\zhantingBase"
stdBarcodePath = r"D:\exhibition\dataset\bcdpath"
stdFeaturePath = r"D:\exhibition\dataset\feats"
resultPath = r"D:\exhibition\result\events"
# eventSourcePath = [r'D:\exhibition\images\20241202']
# eventSourcePath = [r"\\192.168.1.28\share\测试视频数据以及日志\各模块测试记录\展厅测试\1129_展厅模型v801测试组测试"]
eventSourcePath = [r"\\192.168.1.28\share\测试视频数据以及日志\各模块测试记录\展厅测试\1126_展厅模型v801测试"]
'''定义当前事件存储地址及生成相应文件件'''
eventDataPath = os.path.join(resultPath, "1126", "evtobjs")
subimgPath = os.path.join(resultPath, "1126", "subimgs")
imagePath = os.path.join(resultPath, "1126", "image")
similPath = os.path.join(resultPath, "1126", "simidata")
if not os.path.exists(eventDataPath):
os.makedirs(eventDataPath)
if not os.path.exists(subimgPath):
os.makedirs(subimgPath)
if not os.path.exists(imagePath):
os.makedirs(imagePath)
if not os.path.exists(similPath):
os.makedirs(similPath)
test_one2one()