129 lines
4.0 KiB
Python
129 lines
4.0 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
Created on Thu Oct 31 15:17:01 2024
|
||
|
||
@author: ym
|
||
"""
|
||
import os
|
||
import numpy as np
|
||
import pickle
|
||
from pathlib import Path
|
||
import matplotlib.pyplot as plt
|
||
from .event import ShoppingEvent
|
||
|
||
def init_eventDict(sourcePath, eventDataPath, stype="data"):
|
||
'''
|
||
stype: str,
|
||
'source': 由 videos 或 images 生成的 pickle 文件
|
||
'data': 从 data 文件中读取的现场运行数据
|
||
"realtime": 全实时数据,从 data 文件中读取的现场运行数据
|
||
|
||
sourcePath:事件文件夹,事件类型包含2种:
|
||
(1) pipeline生成的 pickle 文件
|
||
(2) 直接采集的事件文件夹
|
||
'''
|
||
k, errEvents = 0, []
|
||
for evtname in os.listdir(sourcePath):
|
||
bname, ext = os.path.splitext(evtname)
|
||
source_path = os.path.join(sourcePath, evtname)
|
||
|
||
if stype=="source" and ext not in ['.pkl', '.pickle']: continue
|
||
if stype=="data" and os.path.isfile(source_path): continue
|
||
if stype=="realtime" and os.path.isfile(source_path): continue
|
||
|
||
evt = bname.split('_')
|
||
condt = len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10
|
||
if not condt: continue
|
||
|
||
pickpath = os.path.join(eventDataPath, f"{bname}.pickle")
|
||
if os.path.isfile(pickpath): continue
|
||
|
||
# event = ShoppingEvent(source_path, stype)
|
||
try:
|
||
event = ShoppingEvent(source_path, stype)
|
||
with open(pickpath, 'wb') as f:
|
||
pickle.dump(event, f)
|
||
print(evtname)
|
||
except Exception as e:
|
||
errEvents.append(source_path)
|
||
print(f"Error: {evtname}, {e}")
|
||
# k += 1
|
||
# if k==1:
|
||
# break
|
||
|
||
errfile = Path(eventDataPath).parent / 'error_events.txt'
|
||
with open(str(errfile), 'a', encoding='utf-8') as f:
|
||
for line in errEvents:
|
||
f.write(line + '\n')
|
||
|
||
|
||
def get_evtList(evtpath):
|
||
'''==== 0. 生成事件列表和对应的 Barcodes 集合 ==========='''
|
||
bcdList, evtpaths = [], []
|
||
for evtname in os.listdir(evtpath):
|
||
bname, ext = os.path.splitext(evtname)
|
||
|
||
## 处理事件的两种情况:文件夹 和 Yolo-Resnet-Tracker 的输出
|
||
fpath = os.path.join(evtpath, evtname)
|
||
if os.path.isfile(fpath) and (ext==".pkl" or ext==".pickle"):
|
||
evt = bname.split('_')
|
||
elif os.path.isdir(fpath):
|
||
evt = evtname.split('_')
|
||
else:
|
||
continue
|
||
|
||
if len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10:
|
||
bcdList.append(evt[-1])
|
||
evtpaths.append(fpath)
|
||
|
||
bcdSet = set(bcdList)
|
||
|
||
return evtpaths, bcdSet
|
||
|
||
|
||
|
||
def showHist(err, correct):
|
||
err = np.array(err)
|
||
correct = np.array(correct)
|
||
|
||
fig, axs = plt.subplots(2, 1)
|
||
axs[0].hist(err, bins=50, edgecolor='black')
|
||
axs[0].set_xlim([0, 1])
|
||
axs[0].set_title('err')
|
||
|
||
axs[1].hist(correct, bins=50, edgecolor='black')
|
||
axs[1].set_xlim([0, 1])
|
||
axs[1].set_title('correct')
|
||
# plt.show()
|
||
|
||
return plt
|
||
|
||
def show_recall_prec(recall, prec, ths):
|
||
# x = np.linspace(start=-0, stop=1, num=11, endpoint=True).tolist()
|
||
fig = plt.figure(figsize=(10, 6))
|
||
plt.plot(ths, recall, color='red', label='recall')
|
||
plt.plot(ths, prec, color='blue', label='PrecisePos')
|
||
plt.legend()
|
||
plt.xlabel(f'threshold')
|
||
# plt.ylabel('Similarity')
|
||
plt.grid(True, linestyle='--', alpha=0.5)
|
||
# plt.savefig('accuracy_recall_grid.png')
|
||
# plt.show()
|
||
# plt.close()
|
||
|
||
return plt
|
||
|
||
|
||
def compute_recall_precision(err_similarity, correct_similarity):
|
||
ths = np.linspace(0, 1, 51)
|
||
recall, prec = [], []
|
||
for th in ths:
|
||
TP = len([num for num in correct_similarity if num >= th])
|
||
FP = len([num for num in err_similarity if num >= th])
|
||
if (TP+FP) == 0:
|
||
prec.append(1)
|
||
recall.append(0)
|
||
else:
|
||
prec.append(TP / (TP + FP))
|
||
recall.append(TP / (len(err_similarity) + len(correct_similarity)))
|
||
return recall, prec, ths |