modified for site test
This commit is contained in:
365
tracking/contrast_analysis.py
Normal file
365
tracking/contrast_analysis.py
Normal file
@ -0,0 +1,365 @@
|
||||
import os.path
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
from utils.plotting import Annotator, colors
|
||||
import sys
|
||||
sys.path.append(r"D:\DetectTracking")
|
||||
from tracking.utils.read_data import extract_data, read_deletedBarcode_file, read_tracking_output
|
||||
from tracking.utils.plotting import draw_tracking_boxes
|
||||
|
||||
def showHist(err, correct):
|
||||
err = np.array(err)
|
||||
correct = np.array(correct)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(err, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([0, 1])
|
||||
axs[0].set_title('err')
|
||||
|
||||
axs[1].hist(correct, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([0, 1])
|
||||
axs[1].set_title('correct')
|
||||
plt.show()
|
||||
|
||||
def showgrid(recall, prec, ths):
|
||||
# x = np.linspace(start=-0, stop=1, num=11, endpoint=True).tolist()
|
||||
fig = plt.figure(figsize=(10, 6))
|
||||
plt.plot(ths, recall, color='red', label='recall')
|
||||
plt.plot(ths, prec, color='blue', label='PrecisePos')
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
plt.savefig('accuracy_recall_grid.png')
|
||||
plt.show()
|
||||
# plt.close()
|
||||
|
||||
|
||||
def compute_recall_precision(err_similarity, correct_similarity):
|
||||
ths = np.linspace(0, 1, 11)
|
||||
recall, prec = [], []
|
||||
for th in ths:
|
||||
TP = len([num for num in correct_similarity if num >= th])
|
||||
FP = len([num for num in err_similarity if num >= th])
|
||||
if (TP+FP) == 0:
|
||||
prec.append(1)
|
||||
recall.append(0)
|
||||
else:
|
||||
prec.append(TP / (TP + FP))
|
||||
recall.append(TP / (len(err_similarity) + len(correct_similarity)))
|
||||
|
||||
showgrid(recall, prec, ths)
|
||||
return recall, prec
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# def read_tracking_output(filepath):
|
||||
# boxes = []
|
||||
# feats = []
|
||||
# with open(filepath, 'r', encoding='utf-8') as file:
|
||||
# for line in file:
|
||||
# line = line.strip() # 去除行尾的换行符和可能的空白字符
|
||||
#
|
||||
# if not line:
|
||||
# continue
|
||||
#
|
||||
# if line.endswith(','):
|
||||
# line = line[:-1]
|
||||
#
|
||||
# data = np.array([float(x) for x in line.split(",")])
|
||||
# if data.size == 9:
|
||||
# boxes.append(data)
|
||||
# if data.size == 256:
|
||||
# feats.append(data)
|
||||
#
|
||||
# return np.array(boxes), np.array(feats)
|
||||
# =============================================================================
|
||||
|
||||
def read_tracking_imgs(imgspath):
|
||||
'''
|
||||
input:
|
||||
imgspath:该路径中的图像为Yolo算法的输入图像,640x512
|
||||
output:
|
||||
imgs_0:后摄图像,根据 frameId 进行了排序
|
||||
imgs_1:前摄图像,根据 frameId 进行了排序
|
||||
'''
|
||||
imgs_0, frmIDs_0, imgs_1, frmIDs_1 = [], [], [], []
|
||||
|
||||
for filename in os.listdir(imgspath):
|
||||
file, ext = os.path.splitext(filename)
|
||||
flist = file.split('_')
|
||||
if len(flist)==4 and ext==".jpg":
|
||||
camID, frmID = flist[0], int(flist[-1])
|
||||
imgpath = os.path.join(imgspath, filename)
|
||||
img = cv2.imread(imgpath)
|
||||
|
||||
if camID=='0':
|
||||
imgs_0.append(img)
|
||||
frmIDs_0.append(frmID)
|
||||
if camID=='1':
|
||||
imgs_1.append(img)
|
||||
frmIDs_1.append(frmID)
|
||||
|
||||
if len(frmIDs_0):
|
||||
indice = np.argsort(np.array(frmIDs_0))
|
||||
imgs_0 = [imgs_0[i] for i in indice ]
|
||||
if len(frmIDs_1):
|
||||
indice = np.argsort(np.array(frmIDs_1))
|
||||
imgs_1 = [imgs_1[i] for i in indice ]
|
||||
|
||||
return imgs_0, imgs_1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# def draw_tracking_boxes(imgs, tracks):
|
||||
# '''tracks: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
|
||||
# 0 1 2 3 4 5 6 7 8
|
||||
# 关键:imgs中的次序和 track 中的 fid 对应
|
||||
# '''
|
||||
# subimgs = []
|
||||
# for *xyxy, tid, conf, cls, fid, bid in tracks:
|
||||
# label = f'id:{int(tid)}_{int(cls)}_{conf:.2f}'
|
||||
#
|
||||
# annotator = Annotator(imgs[int(fid-1)].copy())
|
||||
# if cls==0:
|
||||
# color = colors(int(cls), True)
|
||||
# elif tid>0 and cls!=0:
|
||||
# color = colors(int(tid), True)
|
||||
# else:
|
||||
# color = colors(19, True) # 19为调色板的最后一个元素
|
||||
#
|
||||
# pt2 = [p/2 for p in xyxy]
|
||||
# annotator.box_label(pt2, label, color=color)
|
||||
# img0 = annotator.result()
|
||||
#
|
||||
# subimgs.append(img0)
|
||||
#
|
||||
# return subimgs
|
||||
# =============================================================================
|
||||
|
||||
def get_contrast_paths(pair, basepath):
|
||||
assert(len(pair)==2 or len(pair)==3), "pair: seqdir, delete, barcodes"
|
||||
|
||||
getout_fold = pair[0] # 取出操作对应的文件夹
|
||||
relvt_barcode = pair[1] # 取出操作对应放入操作的 Barcode
|
||||
if len(pair)==3:
|
||||
error_match = pair[2] # 取出操作错误匹配的 Barcode
|
||||
else:
|
||||
error_match = ''
|
||||
|
||||
|
||||
getoutpath, inputpath, errorpath = '', '', ''
|
||||
|
||||
day, hms = getout_fold.strip('_').split('-')
|
||||
|
||||
input_folds, times = [], []
|
||||
errmatch_folds, errmatch_times = [], []
|
||||
for pathname in os.listdir(basepath):
|
||||
if pathname.endswith('_'): continue
|
||||
if os.path.isfile(os.path.join(basepath, pathname)):continue
|
||||
infold = pathname.split('_')
|
||||
if len(infold)!=2: continue
|
||||
|
||||
day1, hms1 = infold[0].split('-')
|
||||
|
||||
if day1==day and infold[1]==relvt_barcode and int(hms1)<int(hms):
|
||||
input_folds.append(pathname)
|
||||
times.append(int(hms1))
|
||||
|
||||
if day1==day and len(error_match) and infold[1]==error_match and int(hms1)<int(hms):
|
||||
errmatch_folds.append(pathname)
|
||||
errmatch_times.append(int(hms1))
|
||||
|
||||
''' 根据时间排序,选择离取出操作最近时间的文件夹,
|
||||
作为取出操作应正确匹配的放入操作所对应的文件夹 '''
|
||||
if len(input_folds):
|
||||
indice = np.argsort(np.array(times))
|
||||
input_fold = input_folds[indice[-1]]
|
||||
inputpath = os.path.join(basepath, input_fold)
|
||||
|
||||
'''取出操作错误匹配的放入操作对应的文件夹'''
|
||||
if len(errmatch_folds):
|
||||
indice = np.argsort(np.array(errmatch_times))
|
||||
errmatch_fold = errmatch_folds[indice[-1]]
|
||||
errorpath = os.path.join(basepath, errmatch_fold)
|
||||
|
||||
'''放入事件文件夹地址、取出事件文件夹地址'''
|
||||
getoutpath = os.path.join(basepath, getout_fold)
|
||||
|
||||
|
||||
return getoutpath, inputpath, errorpath
|
||||
|
||||
|
||||
def save_tracking_imgpairs(pair, basepath, savepath):
|
||||
'''
|
||||
basepath: 原始测试数据文件夹的路径
|
||||
savepath: 保存的目标文件夹
|
||||
'''
|
||||
|
||||
getoutpath, inputpath, errorpath = get_contrast_paths(pair, basepath)
|
||||
|
||||
if len(inputpath)==0:
|
||||
return
|
||||
|
||||
|
||||
'''==== 读取放入、取出事件对应的 Yolo输入的前后摄图像,0:后摄,1:前摄 ===='''
|
||||
|
||||
|
||||
'''==== 读取放入、取出事件对应的 tracking 输出:boxes, feats ===='''
|
||||
|
||||
if len(inputpath):
|
||||
imgs_input_0, imgs_input_1 = read_tracking_imgs(inputpath)
|
||||
|
||||
input_data_0 = os.path.join(inputpath, '0_tracking_output.data')
|
||||
input_data_1 = os.path.join(inputpath, '1_tracking_output.data')
|
||||
boxes_input_0, feats_input_0 = read_tracking_output(input_data_0)
|
||||
boxes_input_1, feats_input_1 = read_tracking_output(input_data_1)
|
||||
ImgsInput_0 = draw_tracking_boxes(imgs_input_0, boxes_input_0)
|
||||
ImgsInput_1 = draw_tracking_boxes(imgs_input_1, boxes_input_1)
|
||||
|
||||
if len(getoutpath):
|
||||
imgs_getout_0, imgs_getout_1 = read_tracking_imgs(getoutpath)
|
||||
|
||||
getout_data_0 = os.path.join(getoutpath, '0_tracking_output.data')
|
||||
getout_data_1 = os.path.join(getoutpath, '1_tracking_output.data')
|
||||
boxes_output_0, feats_output_0 = read_tracking_output(getout_data_0)
|
||||
boxes_output_1, feats_output_1 = read_tracking_output(getout_data_1)
|
||||
ImgsGetout_0 = draw_tracking_boxes(imgs_getout_0, boxes_output_0)
|
||||
ImgsGetout_1 = draw_tracking_boxes(imgs_getout_1, boxes_output_1)
|
||||
|
||||
if len(errorpath):
|
||||
|
||||
imgs_error_0, imgs_error_1 = read_tracking_imgs(errorpath)
|
||||
|
||||
error_data_0 = os.path.join(errorpath, '0_tracking_output.data')
|
||||
error_data_1 = os.path.join(errorpath, '1_tracking_output.data')
|
||||
boxes_error_0, feats_error_0 = read_tracking_output(error_data_0)
|
||||
boxes_error_1, feats_error_1 = read_tracking_output(error_data_1)
|
||||
ImgsError_0 = draw_tracking_boxes(imgs_error_0, boxes_error_0)
|
||||
ImgsError_1 = draw_tracking_boxes(imgs_error_1, boxes_error_1)
|
||||
|
||||
|
||||
savedir = pair[0] + pair[1]
|
||||
if len(errorpath):
|
||||
savedir = savedir + '_' + errorpath.split('_')[-1]
|
||||
foldname = os.path.join(savepath, 'imgpairs', savedir)
|
||||
if not os.path.exists(foldname):
|
||||
os.makedirs(foldname)
|
||||
|
||||
for i, img in enumerate(ImgsInput_0):
|
||||
imgpath = os.path.join(foldname, f'input_0_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
for i, img in enumerate(ImgsInput_1):
|
||||
imgpath = os.path.join(foldname, f'input_1_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
for i, img in enumerate(ImgsGetout_0):
|
||||
imgpath = os.path.join(foldname, f'getout_0_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
for i, img in enumerate(ImgsGetout_1):
|
||||
imgpath = os.path.join(foldname, f'getout_1_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
|
||||
for i, img in enumerate(ImgsError_0):
|
||||
imgpath = os.path.join(foldname, f'errMatch_0_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
for i, img in enumerate(ImgsError_1):
|
||||
imgpath = os.path.join(foldname, f'errMatch_1_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
|
||||
|
||||
def performance_evaluate(all_list, isshow=False):
|
||||
|
||||
corrpairs, correct_barcode_list, correct_similarity, errpairs, err_barcode_list, err_similarity = [], [], [], [], [], []
|
||||
for s_list in all_list:
|
||||
seqdir = s_list['SeqDir'].strip()
|
||||
delete = s_list['Deleted'].strip()
|
||||
barcodes = [s.strip() for s in s_list['barcode']]
|
||||
similarity = [float(s.strip()) for s in s_list['similarity']]
|
||||
|
||||
if delete in barcodes[:1]:
|
||||
corrpairs.append((seqdir, delete))
|
||||
correct_barcode_list.append(delete)
|
||||
correct_similarity.append(similarity[0])
|
||||
else:
|
||||
errpairs.append((seqdir, delete, barcodes[0]))
|
||||
err_barcode_list.append(delete)
|
||||
err_similarity.append(similarity[0])
|
||||
|
||||
|
||||
'''3. 计算比对性能 '''
|
||||
if isshow:
|
||||
compute_recall_precision(err_similarity, correct_similarity)
|
||||
showHist(err_similarity, correct_similarity)
|
||||
|
||||
return errpairs, corrpairs, err_similarity, correct_similarity
|
||||
|
||||
|
||||
|
||||
def contrast_analysis(del_barcode_file, basepath, savepath, saveimgs=False):
|
||||
'''
|
||||
del_barcode_file: 测试数据文件,利用该文件进行算法性能分析
|
||||
|
||||
'''
|
||||
|
||||
'''1. 读取 deletedBarcode 文件 '''
|
||||
all_list = read_deletedBarcode_file(del_barcode_file)
|
||||
|
||||
|
||||
'''2. 算法性能评估,并输出 (取出,删除, 错误匹配) 对 '''
|
||||
errpairs, corrpairs, _, _ = performance_evaluate(all_list)
|
||||
|
||||
'''3. 获取 (取出,删除, 错误匹配) 对应路径,保存相应轨迹图像'''
|
||||
relative_paths = []
|
||||
for errpair in errpairs:
|
||||
GetoutPath, InputPath, ErrorPath = get_contrast_paths(errpair, basepath)
|
||||
relative_paths.append((GetoutPath, InputPath, ErrorPath))
|
||||
|
||||
if saveimgs:
|
||||
save_tracking_imgpairs(errpair, basepath, savepath)
|
||||
|
||||
|
||||
|
||||
|
||||
return relative_paths
|
||||
|
||||
def main():
|
||||
del_barcode_file = 'D:/contrast/dataset/compairsonResult/deletedBarcode_20240709_pm.txt'
|
||||
basepath = r'D:\contrast\dataset\1_to_n\709'
|
||||
savepath = r'D:\contrast\dataset\result'
|
||||
|
||||
try:
|
||||
relative_path = contrast_analysis(del_barcode_file, basepath, savepath)
|
||||
except Exception as e:
|
||||
print(f'Error Type: {e}')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user