guoqing bakeup
This commit is contained in:
454
contrast/one2n_contrast.py
Normal file
454
contrast/one2n_contrast.py
Normal file
@ -0,0 +1,454 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Sat Jul ** 14:07:25 2024
|
||||
|
||||
现场测试精度、召回率分析程序,是 feat_select.py 的简化版,
|
||||
但支持循环计算,并输出总的pr曲线
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
|
||||
|
||||
|
||||
import os.path
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
|
||||
import sys
|
||||
sys.path.append(r"D:\DetectTracking")
|
||||
from tracking.utils.plotting import Annotator, colors
|
||||
from tracking.utils.read_data import extract_data, read_deletedBarcode_file, read_tracking_output
|
||||
from tracking.utils.plotting import draw_tracking_boxes
|
||||
|
||||
|
||||
|
||||
def showHist(err, correct):
|
||||
err = np.array(err)
|
||||
correct = np.array(correct)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(err, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([0, 1])
|
||||
axs[0].set_title('err')
|
||||
|
||||
axs[1].hist(correct, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([0, 1])
|
||||
axs[1].set_title('correct')
|
||||
# plt.show()
|
||||
|
||||
return plt
|
||||
|
||||
def show_recall_prec(recall, prec, ths):
|
||||
# x = np.linspace(start=-0, stop=1, num=11, endpoint=True).tolist()
|
||||
fig = plt.figure(figsize=(10, 6))
|
||||
plt.plot(ths, recall, color='red', label='recall')
|
||||
plt.plot(ths, prec, color='blue', label='PrecisePos')
|
||||
plt.legend()
|
||||
plt.xlabel(f'threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
# plt.savefig('accuracy_recall_grid.png')
|
||||
# plt.show()
|
||||
# plt.close()
|
||||
|
||||
return plt
|
||||
|
||||
|
||||
def compute_recall_precision(err_similarity, correct_similarity):
|
||||
ths = np.linspace(0, 1, 51)
|
||||
recall, prec = [], []
|
||||
for th in ths:
|
||||
TP = len([num for num in correct_similarity if num >= th])
|
||||
FP = len([num for num in err_similarity if num >= th])
|
||||
if (TP+FP) == 0:
|
||||
prec.append(1)
|
||||
recall.append(0)
|
||||
else:
|
||||
prec.append(TP / (TP + FP))
|
||||
recall.append(TP / (len(err_similarity) + len(correct_similarity)))
|
||||
return recall, prec, ths
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# def read_tracking_output(filepath):
|
||||
# boxes = []
|
||||
# feats = []
|
||||
# with open(filepath, 'r', encoding='utf-8') as file:
|
||||
# for line in file:
|
||||
# line = line.strip() # 去除行尾的换行符和可能的空白字符
|
||||
#
|
||||
# if not line:
|
||||
# continue
|
||||
#
|
||||
# if line.endswith(','):
|
||||
# line = line[:-1]
|
||||
#
|
||||
# data = np.array([float(x) for x in line.split(",")])
|
||||
# if data.size == 9:
|
||||
# boxes.append(data)
|
||||
# if data.size == 256:
|
||||
# feats.append(data)
|
||||
#
|
||||
# return np.array(boxes), np.array(feats)
|
||||
# =============================================================================
|
||||
|
||||
def read_tracking_imgs(imgspath):
|
||||
'''
|
||||
input:
|
||||
imgspath:该路径中的图像为Yolo算法的输入图像,640x512
|
||||
output:
|
||||
imgs_0:后摄图像,根据 frameId 进行了排序
|
||||
imgs_1:前摄图像,根据 frameId 进行了排序
|
||||
'''
|
||||
imgs_0, frmIDs_0, imgs_1, frmIDs_1 = [], [], [], []
|
||||
|
||||
for filename in os.listdir(imgspath):
|
||||
file, ext = os.path.splitext(filename)
|
||||
flist = file.split('_')
|
||||
if len(flist)==4 and ext==".jpg":
|
||||
camID, frmID = flist[0], int(flist[-1])
|
||||
imgpath = os.path.join(imgspath, filename)
|
||||
img = cv2.imread(imgpath)
|
||||
|
||||
if camID=='0':
|
||||
imgs_0.append(img)
|
||||
frmIDs_0.append(frmID)
|
||||
if camID=='1':
|
||||
imgs_1.append(img)
|
||||
frmIDs_1.append(frmID)
|
||||
|
||||
if len(frmIDs_0):
|
||||
indice = np.argsort(np.array(frmIDs_0))
|
||||
imgs_0 = [imgs_0[i] for i in indice ]
|
||||
if len(frmIDs_1):
|
||||
indice = np.argsort(np.array(frmIDs_1))
|
||||
imgs_1 = [imgs_1[i] for i in indice ]
|
||||
|
||||
return imgs_0, imgs_1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# def draw_tracking_boxes(imgs, tracks):
|
||||
# '''tracks: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
|
||||
# 0 1 2 3 4 5 6 7 8
|
||||
# 关键:imgs中的次序和 track 中的 fid 对应
|
||||
# '''
|
||||
# subimgs = []
|
||||
# for *xyxy, tid, conf, cls, fid, bid in tracks:
|
||||
# label = f'id:{int(tid)}_{int(cls)}_{conf:.2f}'
|
||||
#
|
||||
# annotator = Annotator(imgs[int(fid-1)].copy())
|
||||
# if cls==0:
|
||||
# color = colors(int(cls), True)
|
||||
# elif tid>0 and cls!=0:
|
||||
# color = colors(int(tid), True)
|
||||
# else:
|
||||
# color = colors(19, True) # 19为调色板的最后一个元素
|
||||
#
|
||||
# pt2 = [p/2 for p in xyxy]
|
||||
# annotator.box_label(pt2, label, color=color)
|
||||
# img0 = annotator.result()
|
||||
#
|
||||
# subimgs.append(img0)
|
||||
#
|
||||
# return subimgs
|
||||
# =============================================================================
|
||||
|
||||
def get_contrast_paths(pair, basepath):
|
||||
assert(len(pair)==2 or len(pair)==3), "pair: seqdir, delete, barcodes"
|
||||
|
||||
getout_fold = pair[0] # 取出操作对应的文件夹
|
||||
relvt_barcode = pair[1] # 取出操作对应放入操作的 Barcode
|
||||
if len(pair)==3:
|
||||
error_match = pair[2] # 取出操作错误匹配的 Barcode
|
||||
else:
|
||||
error_match = ''
|
||||
|
||||
|
||||
getoutpath, inputpath, errorpath = '', '', ''
|
||||
|
||||
day, hms = getout_fold.strip('_').split('-')
|
||||
|
||||
input_folds, times = [], []
|
||||
errmatch_folds, errmatch_times = [], []
|
||||
for pathname in os.listdir(basepath):
|
||||
if pathname.endswith('_'): continue
|
||||
if os.path.isfile(os.path.join(basepath, pathname)):continue
|
||||
infold = pathname.split('_')
|
||||
if len(infold)!=2: continue
|
||||
|
||||
day1, hms1 = infold[0].split('-')
|
||||
|
||||
if day1==day and infold[1]==relvt_barcode and int(hms1)<int(hms):
|
||||
input_folds.append(pathname)
|
||||
times.append(int(hms1))
|
||||
|
||||
if day1==day and len(error_match) and infold[1]==error_match and int(hms1)<int(hms):
|
||||
errmatch_folds.append(pathname)
|
||||
errmatch_times.append(int(hms1))
|
||||
|
||||
''' 根据时间排序,选择离取出操作最近时间的文件夹,
|
||||
作为取出操作应正确匹配的放入操作所对应的文件夹 '''
|
||||
if len(input_folds):
|
||||
indice = np.argsort(np.array(times))
|
||||
input_fold = input_folds[indice[-1]]
|
||||
inputpath = os.path.join(basepath, input_fold)
|
||||
|
||||
'''取出操作错误匹配的放入操作对应的文件夹'''
|
||||
if len(errmatch_folds):
|
||||
indice = np.argsort(np.array(errmatch_times))
|
||||
errmatch_fold = errmatch_folds[indice[-1]]
|
||||
errorpath = os.path.join(basepath, errmatch_fold)
|
||||
|
||||
'''放入事件文件夹地址、取出事件文件夹地址'''
|
||||
getoutpath = os.path.join(basepath, getout_fold)
|
||||
|
||||
|
||||
return getoutpath, inputpath, errorpath
|
||||
|
||||
|
||||
def save_tracking_imgpairs(pair, basepath, savepath):
|
||||
'''
|
||||
basepath: 原始测试数据文件夹的路径
|
||||
savepath: 保存的目标文件夹
|
||||
'''
|
||||
|
||||
getoutpath, inputpath, errorpath = get_contrast_paths(pair, basepath)
|
||||
|
||||
if len(inputpath)==0:
|
||||
return
|
||||
|
||||
|
||||
'''==== 读取放入、取出事件对应的 Yolo输入的前后摄图像,0:后摄,1:前摄 ===='''
|
||||
|
||||
|
||||
'''==== 读取放入、取出事件对应的 tracking 输出:boxes, feats ===='''
|
||||
|
||||
if len(inputpath):
|
||||
imgs_input_0, imgs_input_1 = read_tracking_imgs(inputpath)
|
||||
|
||||
input_data_0 = os.path.join(inputpath, '0_tracking_output.data')
|
||||
input_data_1 = os.path.join(inputpath, '1_tracking_output.data')
|
||||
boxes_input_0, feats_input_0 = read_tracking_output(input_data_0)
|
||||
boxes_input_1, feats_input_1 = read_tracking_output(input_data_1)
|
||||
ImgsInput_0 = draw_tracking_boxes(imgs_input_0, boxes_input_0)
|
||||
ImgsInput_1 = draw_tracking_boxes(imgs_input_1, boxes_input_1)
|
||||
|
||||
if len(getoutpath):
|
||||
imgs_getout_0, imgs_getout_1 = read_tracking_imgs(getoutpath)
|
||||
|
||||
getout_data_0 = os.path.join(getoutpath, '0_tracking_output.data')
|
||||
getout_data_1 = os.path.join(getoutpath, '1_tracking_output.data')
|
||||
boxes_output_0, feats_output_0 = read_tracking_output(getout_data_0)
|
||||
boxes_output_1, feats_output_1 = read_tracking_output(getout_data_1)
|
||||
ImgsGetout_0 = draw_tracking_boxes(imgs_getout_0, boxes_output_0)
|
||||
ImgsGetout_1 = draw_tracking_boxes(imgs_getout_1, boxes_output_1)
|
||||
|
||||
if len(errorpath):
|
||||
|
||||
imgs_error_0, imgs_error_1 = read_tracking_imgs(errorpath)
|
||||
|
||||
error_data_0 = os.path.join(errorpath, '0_tracking_output.data')
|
||||
error_data_1 = os.path.join(errorpath, '1_tracking_output.data')
|
||||
boxes_error_0, feats_error_0 = read_tracking_output(error_data_0)
|
||||
boxes_error_1, feats_error_1 = read_tracking_output(error_data_1)
|
||||
ImgsError_0 = draw_tracking_boxes(imgs_error_0, boxes_error_0)
|
||||
ImgsError_1 = draw_tracking_boxes(imgs_error_1, boxes_error_1)
|
||||
|
||||
|
||||
savedir = pair[0] + pair[1]
|
||||
if len(errorpath):
|
||||
savedir = savedir + '_' + errorpath.split('_')[-1]
|
||||
foldname = os.path.join(savepath, 'imgpairs', savedir)
|
||||
if not os.path.exists(foldname):
|
||||
os.makedirs(foldname)
|
||||
|
||||
for i, img in enumerate(ImgsInput_0):
|
||||
imgpath = os.path.join(foldname, f'input_0_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
for i, img in enumerate(ImgsInput_1):
|
||||
imgpath = os.path.join(foldname, f'input_1_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
for i, img in enumerate(ImgsGetout_0):
|
||||
imgpath = os.path.join(foldname, f'getout_0_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
for i, img in enumerate(ImgsGetout_1):
|
||||
imgpath = os.path.join(foldname, f'getout_1_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
|
||||
for i, img in enumerate(ImgsError_0):
|
||||
imgpath = os.path.join(foldname, f'errMatch_0_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
for i, img in enumerate(ImgsError_1):
|
||||
imgpath = os.path.join(foldname, f'errMatch_1_{i}.png')
|
||||
cv2.imwrite(imgpath, img)
|
||||
|
||||
|
||||
# def performance_evaluate(all_list, isshow=False):
|
||||
|
||||
# corrpairs, correct_barcode_list, correct_similarity, errpairs, err_barcode_list, err_similarity = [], [], [], [], [], []
|
||||
# for s_list in all_list:
|
||||
# seqdir = s_list['SeqDir'].strip()
|
||||
# delete = s_list['Deleted'].strip()
|
||||
# barcodes = [s.strip() for s in s_list['barcode']]
|
||||
# similarity = [float(s.strip()) for s in s_list['similarity']]
|
||||
|
||||
# if delete in barcodes[:1]:
|
||||
# corrpairs.append((seqdir, delete))
|
||||
# correct_barcode_list.append(delete)
|
||||
# correct_similarity.append(similarity[0])
|
||||
# else:
|
||||
# errpairs.append((seqdir, delete, barcodes[0]))
|
||||
# err_barcode_list.append(delete)
|
||||
# err_similarity.append(similarity[0])
|
||||
|
||||
def performance_evaluate(all_list, isshow=False):
|
||||
|
||||
corrpairs, correct_barcode_list, correct_similarity, errpairs, err_barcode_list, err_similarity = [], [], [], [], [], []
|
||||
for s_list in all_list:
|
||||
seqdir = s_list['SeqDir'].strip()
|
||||
delete = s_list['Deleted'].strip()
|
||||
barcodes = [s.strip() for s in s_list['barcode']]
|
||||
|
||||
|
||||
similarity_comp, similarity_front = [], []
|
||||
for simil in s_list['similarity']:
|
||||
ss = [float(s.strip()) for s in simil.split(',')]
|
||||
|
||||
similarity_comp.append(ss[0])
|
||||
if len(ss)==3:
|
||||
similarity_front.append(ss[2])
|
||||
|
||||
if len(similarity_front):
|
||||
similarity = [s for s in similarity_front]
|
||||
else:
|
||||
similarity = [s for s in similarity_comp]
|
||||
|
||||
|
||||
index = similarity.index(max(similarity))
|
||||
matched_barcode = barcodes[index]
|
||||
if matched_barcode == delete:
|
||||
corrpairs.append((seqdir, delete))
|
||||
correct_barcode_list.append(delete)
|
||||
correct_similarity.append(max(similarity))
|
||||
else:
|
||||
errpairs.append((seqdir, delete, matched_barcode))
|
||||
err_barcode_list.append(delete)
|
||||
err_similarity.append(max(similarity))
|
||||
|
||||
'''3. 计算比对性能 '''
|
||||
if isshow:
|
||||
recall, prec, ths = compute_recall_precision(err_similarity, correct_similarity)
|
||||
show_recall_prec(recall, prec, ths)
|
||||
showHist(err_similarity, correct_similarity)
|
||||
|
||||
return errpairs, corrpairs, err_similarity, correct_similarity
|
||||
|
||||
|
||||
|
||||
def contrast_analysis(del_barcode_file, basepath, savepath, saveimgs=False):
|
||||
'''
|
||||
del_barcode_file: 测试数据文件,利用该文件进行算法性能分析
|
||||
|
||||
'''
|
||||
|
||||
'''1. 读取 deletedBarcode 文件 '''
|
||||
all_list = read_deletedBarcode_file(del_barcode_file)
|
||||
|
||||
'''2. 算法性能评估,并输出 (取出,删除, 错误匹配) 对 '''
|
||||
errpairs, corrpairs, _, _ = performance_evaluate(all_list)
|
||||
|
||||
'''3. 获取 (取出,删除, 错误匹配) 对应路径,保存相应轨迹图像'''
|
||||
relative_paths = []
|
||||
for errpair in errpairs:
|
||||
GetoutPath, InputPath, ErrorPath = get_contrast_paths(errpair, basepath)
|
||||
relative_paths.append((GetoutPath, InputPath, ErrorPath))
|
||||
|
||||
if saveimgs:
|
||||
save_tracking_imgpairs(errpair, basepath, savepath)
|
||||
|
||||
return relative_paths
|
||||
|
||||
|
||||
def contrast_loop(fpath):
|
||||
savepath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\illustration'
|
||||
# savepath = r'D:\contrast\dataset\1_to_n\illustration'
|
||||
if not os.path.exists(savepath):
|
||||
os.mkdir(savepath)
|
||||
|
||||
if os.path.isfile(fpath):
|
||||
fpath, filename = os.path.split(fpath)
|
||||
|
||||
BarLists, blists = {}, []
|
||||
for filename in os.listdir(fpath):
|
||||
file = os.path.splitext(filename)[0][15:]
|
||||
|
||||
filepath = os.path.join(fpath, filename)
|
||||
blist = read_deletedBarcode_file(filepath)
|
||||
|
||||
BarLists.update({file: blist})
|
||||
blists.extend(blist)
|
||||
|
||||
BarLists.update({file: blist})
|
||||
BarLists.update({"Total": blists})
|
||||
for file, blist in BarLists.items():
|
||||
errpairs, corrpairs, err_similarity, correct_similarity = performance_evaluate(blist)
|
||||
|
||||
recall, prec, ths = compute_recall_precision(err_similarity, correct_similarity)
|
||||
|
||||
plt1 = show_recall_prec(recall, prec, ths)
|
||||
# plt1.show()
|
||||
plt1.xlabel(f'threshold, Num: {len(blist)}')
|
||||
plt1.savefig(os.path.join(savepath, file+'_pr.png'))
|
||||
# plt1.close()
|
||||
|
||||
# plt2 = showHist(err_similarity, correct_similarity)
|
||||
# plt2.show()
|
||||
# plt2.savefig(os.path.join(savepath, file+'_hist.png'))
|
||||
# plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
fpath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\other'
|
||||
contrast_loop(fpath)
|
||||
|
||||
def main1():
|
||||
del_barcode_file = r'\\192.168.1.28\share\测试_202406\709\deletedBarcode.txt'
|
||||
basepath = r'\\192.168.1.28\share\测试_202406\709'
|
||||
savepath = r'D:\contrast\dataset\result'
|
||||
|
||||
try:
|
||||
relative_path = contrast_analysis(del_barcode_file, basepath, savepath)
|
||||
except Exception as e:
|
||||
print(f'Error Type: {e}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
# main1()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user