499 lines
18 KiB
Python
499 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
Created on Sat Jul ** 14:07:25 2024
|
||
|
||
现场测试精度、召回率分析程序,是 feat_select.py 的简化版,
|
||
但支持循环计算,并输出总的pr曲线
|
||
|
||
@author: ym
|
||
"""
|
||
|
||
|
||
|
||
import os.path
|
||
import shutil
|
||
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import cv2
|
||
from pathlib import Path
|
||
import sys
|
||
sys.path.append(r"D:\DetectTracking")
|
||
from tracking.utils.plotting import Annotator, colors
|
||
from tracking.utils.read_data import extract_data, read_deletedBarcode_file, read_tracking_output, read_returnGoods_file
|
||
from tracking.utils.plotting import draw_tracking_boxes, get_subimgs
|
||
from contrast.utils.tools import showHist, show_recall_prec, compute_recall_precision
|
||
|
||
|
||
# =============================================================================
|
||
# def read_tracking_output(filepath):
|
||
# boxes = []
|
||
# feats = []
|
||
# with open(filepath, 'r', encoding='utf-8') as file:
|
||
# for line in file:
|
||
# line = line.strip() # 去除行尾的换行符和可能的空白字符
|
||
#
|
||
# if not line:
|
||
# continue
|
||
#
|
||
# if line.endswith(','):
|
||
# line = line[:-1]
|
||
#
|
||
# data = np.array([float(x) for x in line.split(",")])
|
||
# if data.size == 9:
|
||
# boxes.append(data)
|
||
# if data.size == 256:
|
||
# feats.append(data)
|
||
#
|
||
# return np.array(boxes), np.array(feats)
|
||
# =============================================================================
|
||
|
||
def read_tracking_imgs(imgspath):
|
||
'''
|
||
input:
|
||
imgspath:该路径中的图像为Yolo算法的输入图像,640x512
|
||
output:
|
||
imgs_0:后摄图像,根据 frameId 进行了排序
|
||
imgs_1:前摄图像,根据 frameId 进行了排序
|
||
'''
|
||
imgs_0, frmIDs_0, imgs_1, frmIDs_1 = [], [], [], []
|
||
|
||
for filename in os.listdir(imgspath):
|
||
file, ext = os.path.splitext(filename)
|
||
flist = file.split('_')
|
||
if len(flist)==4 and ext==".jpg":
|
||
camID, frmID = flist[0], int(flist[-1])
|
||
imgpath = os.path.join(imgspath, filename)
|
||
img = cv2.imread(imgpath)
|
||
|
||
if camID=='0':
|
||
imgs_0.append(img)
|
||
frmIDs_0.append(frmID)
|
||
if camID=='1':
|
||
imgs_1.append(img)
|
||
frmIDs_1.append(frmID)
|
||
|
||
if len(frmIDs_0):
|
||
indice = np.argsort(np.array(frmIDs_0))
|
||
imgs_0 = [imgs_0[i] for i in indice ]
|
||
if len(frmIDs_1):
|
||
indice = np.argsort(np.array(frmIDs_1))
|
||
imgs_1 = [imgs_1[i] for i in indice ]
|
||
|
||
return imgs_0, imgs_1
|
||
|
||
|
||
# =============================================================================
|
||
# def draw_tracking_boxes(imgs, tracks):
|
||
# '''tracks: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
|
||
# 0 1 2 3 4 5 6 7 8
|
||
# 关键:imgs中的次序和 track 中的 fid 对应
|
||
# '''
|
||
# subimgs = []
|
||
# for *xyxy, tid, conf, cls, fid, bid in tracks:
|
||
# label = f'id:{int(tid)}_{int(cls)}_{conf:.2f}'
|
||
#
|
||
# annotator = Annotator(imgs[int(fid-1)].copy())
|
||
# if cls==0:
|
||
# color = colors(int(cls), True)
|
||
# elif tid>0 and cls!=0:
|
||
# color = colors(int(tid), True)
|
||
# else:
|
||
# color = colors(19, True) # 19为调色板的最后一个元素
|
||
#
|
||
# pt2 = [p/2 for p in xyxy]
|
||
# annotator.box_label(pt2, label, color=color)
|
||
# img0 = annotator.result()
|
||
#
|
||
# subimgs.append(img0)
|
||
#
|
||
# return subimgs
|
||
# =============================================================================
|
||
|
||
def get_contrast_paths(pair, basepath):
|
||
assert(len(pair)==2 or len(pair)==3), "pair: seqdir, delete, barcodes"
|
||
|
||
getout_fold = pair[0] # 取出操作对应的文件夹
|
||
relvt_barcode = pair[1] # 取出操作对应放入操作的 Barcode
|
||
if len(pair)==3:
|
||
error_match = pair[2] # 取出操作错误匹配的 Barcode
|
||
else:
|
||
error_match = ''
|
||
|
||
|
||
getoutpath, inputpath, errorpath = '', '', ''
|
||
|
||
day, hms = getout_fold.strip('_').split('-')
|
||
|
||
input_folds, times = [], []
|
||
errmatch_folds, errmatch_times = [], []
|
||
for pathname in os.listdir(basepath):
|
||
if pathname.endswith('_'): continue
|
||
if os.path.isfile(os.path.join(basepath, pathname)):continue
|
||
infold = pathname.split('_')
|
||
if len(infold)!=2: continue
|
||
|
||
day1, hms1 = infold[0].split('-')
|
||
|
||
if day1==day and infold[1]==relvt_barcode and int(hms1)<int(hms):
|
||
input_folds.append(pathname)
|
||
times.append(int(hms1))
|
||
|
||
if day1==day and len(error_match) and infold[1]==error_match and int(hms1)<int(hms):
|
||
errmatch_folds.append(pathname)
|
||
errmatch_times.append(int(hms1))
|
||
|
||
''' 根据时间排序,选择离取出操作最近时间的文件夹,
|
||
作为取出操作应正确匹配的放入操作所对应的文件夹 '''
|
||
if len(input_folds):
|
||
indice = np.argsort(np.array(times))
|
||
input_fold = input_folds[indice[-1]]
|
||
|
||
inputpath = os.path.join(basepath, input_fold)
|
||
|
||
|
||
'''取出操作错误匹配的放入操作对应的文件夹'''
|
||
if len(errmatch_folds):
|
||
indice = np.argsort(np.array(errmatch_times))
|
||
errmatch_fold = errmatch_folds[indice[-1]]
|
||
|
||
errorpath = os.path.join(basepath, errmatch_fold)
|
||
|
||
|
||
|
||
'''放入事件文件夹地址、取出事件文件夹地址'''
|
||
getoutpath = os.path.join(basepath, getout_fold)
|
||
|
||
|
||
return getoutpath, inputpath, errorpath
|
||
|
||
|
||
def save_tracking_imgpairs(pairs, savepath):
|
||
'''
|
||
pairs: 匹配事件对
|
||
savepath: 保存的目标文件夹
|
||
'''
|
||
def get_event_path(evtpath):
|
||
basepath, eventname = os.path.split(evtpath)
|
||
evt_path = ''
|
||
for filename in os.listdir(basepath):
|
||
if filename.find(eventname)==0:
|
||
evt_path = os.path.join(basepath, filename)
|
||
break
|
||
return evt_path
|
||
|
||
getoutpath = get_event_path(pairs[0])
|
||
inputpath = get_event_path(pairs[1])
|
||
|
||
if len(pairs) == 3:
|
||
errorpath = get_event_path(pairs[2])
|
||
else:
|
||
errorpath = ''
|
||
|
||
''' 1. 读取放入、取出事件对应的 Yolo输入的前后摄图像,0:后摄,1:前摄
|
||
2. 读取放入、取出事件对应的 tracking 输出:boxes, feats
|
||
3. boxes绘制并保存图像序列
|
||
4. 截取并保存轨迹子图
|
||
'''
|
||
if len(getoutpath):
|
||
imgs_getout_0, imgs_getout_1 = read_tracking_imgs(getoutpath)
|
||
|
||
getout_data_0 = os.path.join(getoutpath, '0_tracking_output.data')
|
||
getout_data_1 = os.path.join(getoutpath, '1_tracking_output.data')
|
||
boxes_output_0, feats_output_0 = read_tracking_output(getout_data_0)
|
||
boxes_output_1, feats_output_1 = read_tracking_output(getout_data_1)
|
||
ImgsGetout_0 = draw_tracking_boxes(imgs_getout_0, boxes_output_0)
|
||
ImgsGetout_1 = draw_tracking_boxes(imgs_getout_1, boxes_output_1)
|
||
|
||
SubimgsGetout_0 = get_subimgs(imgs_getout_0, boxes_output_0)
|
||
SubimgsGetout_1 = get_subimgs(imgs_getout_1, boxes_output_1)
|
||
|
||
savedir = os.path.basename(getoutpath)
|
||
|
||
if len(inputpath):
|
||
imgs_input_0, imgs_input_1 = read_tracking_imgs(inputpath)
|
||
|
||
input_data_0 = os.path.join(inputpath, '0_tracking_output.data')
|
||
input_data_1 = os.path.join(inputpath, '1_tracking_output.data')
|
||
boxes_input_0, feats_input_0 = read_tracking_output(input_data_0)
|
||
boxes_input_1, feats_input_1 = read_tracking_output(input_data_1)
|
||
ImgsInput_0 = draw_tracking_boxes(imgs_input_0, boxes_input_0)
|
||
ImgsInput_1 = draw_tracking_boxes(imgs_input_1, boxes_input_1)
|
||
|
||
SubimgsInput_0 = get_subimgs(imgs_input_0, boxes_input_0)
|
||
SubimgsInput_1 = get_subimgs(imgs_input_1, boxes_input_1)
|
||
|
||
savedir = savedir + '+' + os.path.basename(inputpath)
|
||
|
||
if len(errorpath):
|
||
imgs_error_0, imgs_error_1 = read_tracking_imgs(errorpath)
|
||
|
||
error_data_0 = os.path.join(errorpath, '0_tracking_output.data')
|
||
error_data_1 = os.path.join(errorpath, '1_tracking_output.data')
|
||
boxes_error_0, feats_error_0 = read_tracking_output(error_data_0)
|
||
boxes_error_1, feats_error_1 = read_tracking_output(error_data_1)
|
||
ImgsError_0 = draw_tracking_boxes(imgs_error_0, boxes_error_0)
|
||
ImgsError_1 = draw_tracking_boxes(imgs_error_1, boxes_error_1)
|
||
|
||
SubimgsError_0 = get_subimgs(imgs_error_0, boxes_error_0)
|
||
SubimgsError_1 = get_subimgs(imgs_error_0, boxes_error_0)
|
||
|
||
savedir = savedir + '+' + os.path.basename(errorpath)
|
||
|
||
''' savepath\pairs\savedir\eventpairs\保存画框后的图像序列 '''
|
||
entpairs = os.path.join(savepath, 'pairs', savedir, 'eventpairs')
|
||
if not os.path.exists(entpairs):
|
||
os.makedirs(entpairs)
|
||
for fid, img in ImgsInput_0:
|
||
imgpath = os.path.join(entpairs, f'input_0_{fid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
for fid, img in ImgsInput_1:
|
||
imgpath = os.path.join(entpairs, f'input_1_{fid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
for fid, img in ImgsGetout_0:
|
||
imgpath = os.path.join(entpairs, f'getout_0_{fid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
for fid, img in ImgsGetout_1:
|
||
imgpath = os.path.join(entpairs, f'getout_1_{fid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
if 'ImgsError_0' in vars() and 'ImgsError_1' in vars():
|
||
for fid, img in ImgsError_0:
|
||
imgpath = os.path.join(entpairs, f'errMatch_0_{fid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
for fid, img in ImgsError_1:
|
||
imgpath = os.path.join(entpairs, f'errMatch_1_{fid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
|
||
''' savepath\pairs\savedir\subimgpairs\保存轨迹子图 '''
|
||
subimgpairs = os.path.join(savepath, 'pairs', savedir, 'subimgpairs')
|
||
if not os.path.exists(subimgpairs):
|
||
os.makedirs(subimgpairs)
|
||
for fid, bid, img in SubimgsGetout_0:
|
||
imgpath = os.path.join(subimgpairs, f'getout_0_{fid}_{bid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
for fid, bid, img in SubimgsGetout_1:
|
||
imgpath = os.path.join(subimgpairs, f'getout_1_{fid}_{bid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
for fid, bid, img in SubimgsInput_0:
|
||
imgpath = os.path.join(subimgpairs, f'input_0_{fid}_{bid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
for fid, bid, img in SubimgsInput_1:
|
||
imgpath = os.path.join(subimgpairs, f'input_1_{fid}_{bid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
if 'SubimgsError_0' in vars() and 'SubimgsError_1' in vars():
|
||
for fid, bid, img in SubimgsError_0:
|
||
imgpath = os.path.join(subimgpairs, f'errMatch_0_{fid}_{bid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
for fid, bid, img in SubimgsError_1:
|
||
imgpath = os.path.join(subimgpairs, f'errMatch_1_{fid}_{bid}.png')
|
||
cv2.imwrite(imgpath, img)
|
||
|
||
|
||
def one2n_deleted(all_list):
|
||
corrpairs, errpairs, correct_similarity, err_similarity = [], [], [], []
|
||
for s_list in all_list:
|
||
seqdir = s_list['SeqDir'].strip()
|
||
delete = s_list['Deleted'].strip()
|
||
barcodes = [s.strip() for s in s_list['barcode']]
|
||
|
||
|
||
similarity_comp, similarity_front = [], []
|
||
for simil in s_list['similarity']:
|
||
ss = [float(s.strip()) for s in simil.split(',')]
|
||
|
||
similarity_comp.append(ss[0])
|
||
if len(ss)==3:
|
||
similarity_front.append(ss[2])
|
||
|
||
if len(similarity_front):
|
||
similarity = [s for s in similarity_front]
|
||
else:
|
||
similarity = [s for s in similarity_comp]
|
||
|
||
|
||
index = similarity.index(max(similarity))
|
||
matched_barcode = barcodes[index]
|
||
if matched_barcode == delete:
|
||
corrpairs.append((seqdir, delete))
|
||
correct_similarity.append(max(similarity))
|
||
else:
|
||
errpairs.append((seqdir, delete, matched_barcode))
|
||
err_similarity.append(max(similarity))
|
||
|
||
|
||
return corrpairs, errpairs, correct_similarity, err_similarity
|
||
|
||
|
||
|
||
def one2n_return(all_list):
|
||
corrpairs, corr_similarity, errpairs, err_similarity = [], [], [], []
|
||
|
||
for s_list in all_list:
|
||
seqdir = s_list['SeqDir'].strip()
|
||
delete = s_list['Deleted'].strip()
|
||
barcodes = [s.strip() for s in s_list['barcode']]
|
||
events = [s.strip() for s in s_list['event']]
|
||
types = [s.strip() for s in s_list['type']]
|
||
|
||
## =================== 读入相似度值
|
||
similarity_comp, similarity_front = [], []
|
||
for simil in s_list['similarity']:
|
||
ss = [float(s.strip()) for s in simil.split(',')]
|
||
|
||
similarity_comp.append(ss[0])
|
||
if len(ss)==3:
|
||
similarity_front.append(ss[2])
|
||
|
||
if len(similarity_front):
|
||
similarity = [s for s in similarity_front]
|
||
else:
|
||
similarity = [s for s in similarity_comp]
|
||
|
||
|
||
index = similarity.index(max(similarity))
|
||
matched_barcode = barcodes[index]
|
||
if matched_barcode == delete:
|
||
corrpairs.append((seqdir, events[index]))
|
||
corr_similarity.append(max(similarity))
|
||
else:
|
||
idx = [i for i, name in enumerate(events) if name.split('_')[-1] == delete]
|
||
idxmax, simimax = -1, -1
|
||
# idxmax, simimax = k, similarity[k] for k in idx if similarity[k] > simimax
|
||
for k in idx:
|
||
if similarity[k] > simimax:
|
||
idxmax = k
|
||
simimax = similarity[k]
|
||
if idxmax>-1:
|
||
input_event = events[idxmax]
|
||
else:
|
||
input_event = ''
|
||
|
||
errpairs.append((seqdir, input_event, events[index]))
|
||
err_similarity.append(max(similarity))
|
||
|
||
return corrpairs, errpairs, corr_similarity, err_similarity
|
||
|
||
|
||
def test_rpath_deleted():
|
||
'''deletedBarcode.txt 格式的 1:n 数据结果文件, returnGoods.txt格式数据文件不需要调用该函数'''
|
||
|
||
del_bfile = r'\\192.168.1.28\share\测试_202406\709\deletedBarcode.txt'
|
||
basepath = r'\\192.168.1.28\share\测试_202406\709'
|
||
savepath = r'D:\DetectTracking\contrast\result'
|
||
saveimgs = True
|
||
|
||
|
||
|
||
relative_paths = []
|
||
|
||
'''1. 读取 deletedBarcode 文件 '''
|
||
all_list = read_deletedBarcode_file(del_bfile)
|
||
|
||
'''2. 算法性能评估,并输出 (取出,删除, 错误匹配) 对 '''
|
||
corrpairs, errpairs, _, _ = one2n_deleted(all_list)
|
||
|
||
'''3. 构造事件组合(取出,放入并删除, 错误匹配) 对应路径 '''
|
||
for errpair in errpairs:
|
||
GetoutPath, InputPath, ErrorPath = get_contrast_paths(errpair, basepath)
|
||
|
||
pairs = (GetoutPath, InputPath, ErrorPath)
|
||
relative_paths.append(pairs)
|
||
|
||
print(InputPath)
|
||
'''3. 获取 (取出,放入并删除, 错误匹配) 对应路径,保存相应轨迹图像'''
|
||
if saveimgs:
|
||
save_tracking_imgpairs(pairs, savepath)
|
||
|
||
def test_rpath_return():
|
||
return_bfile = r'\\192.168.1.28\share\测试_202406\1101\images\returnGoods.txt'
|
||
basepath = r'\\192.168.1.28\share\测试_202406\1101\images'
|
||
savepath = r'D:\DetectTracking\contrast\result'
|
||
|
||
all_list = read_returnGoods_file(return_bfile)
|
||
corrpairs, errpairs, _, _ = one2n_return(all_list)
|
||
for corrpair in corrpairs:
|
||
GetoutPath = os.path.join(basepath, corrpair[0])
|
||
InputPath = os.path.join(basepath, corrpair[1])
|
||
|
||
pairs = (GetoutPath, InputPath)
|
||
save_tracking_imgpairs(pairs, savepath)
|
||
|
||
for errpair in errpairs:
|
||
GetoutPath = os.path.join(basepath, errpair[0])
|
||
InputPath = os.path.join(basepath, errpair[1])
|
||
ErrorPath = os.path.join(basepath, errpair[2])
|
||
|
||
pairs = (GetoutPath, InputPath, ErrorPath)
|
||
save_tracking_imgpairs(pairs, savepath)
|
||
|
||
|
||
def test_one2n():
|
||
'''
|
||
1:n 性能测试
|
||
兼容 2 种 txt 文件格式:returnGoods.txt, deletedBarcode.txt
|
||
fpath: 文件路径、或文件夹,其中包含多个 txt 文件
|
||
savepath: pr曲线保存路径
|
||
'''
|
||
# fpath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\other' # deletedBarcode.txt
|
||
fpath = r'\\192.168.1.28\share\测试_202406\1108_展厅模型v800测试' # returnGoods.txt
|
||
savepath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\illustration'
|
||
|
||
if os.path.isdir(fpath):
|
||
filepaths = [os.path.join(fpath, f) for f in os.listdir(fpath)
|
||
if f.find('.txt')>0
|
||
and (f.find('deletedBarcode')>=0 or f.find('returnGoods')>=0)]
|
||
elif os.path.isfile(fpath):
|
||
filepaths = [fpath]
|
||
else:
|
||
return
|
||
|
||
if not os.path.exists(savepath):
|
||
os.mkdir(savepath)
|
||
|
||
BarLists, blists = {}, []
|
||
for pth in filepaths:
|
||
file = str(Path(pth).stem)
|
||
if file.find('deletedBarcode')>=0:
|
||
blist = read_deletedBarcode_file(pth)
|
||
if file.find('returnGoods')>=0:
|
||
blist = read_returnGoods_file(pth)
|
||
|
||
BarLists.update({file: blist})
|
||
blists.extend(blist)
|
||
|
||
if len(blists): BarLists.update({"Total": blists})
|
||
for file, blist in BarLists.items():
|
||
if all(b['filetype']=="deletedBarcode" for b in blist):
|
||
_, _, correct_similarity, err_similarity = one2n_deleted(blist)
|
||
if all(b['filetype']=="returnGoods" for b in blists):
|
||
_, _, correct_similarity, err_similarity = one2n_return(blist)
|
||
|
||
recall, prec, ths = compute_recall_precision(err_similarity, correct_similarity)
|
||
|
||
plt1 = show_recall_prec(recall, prec, ths)
|
||
# plt1.show()
|
||
plt1.xlabel(f'threshold, Num: {len(blist)}')
|
||
plt1.savefig(os.path.join(savepath, file+'_pr.png'))
|
||
# plt1.close()
|
||
|
||
plt2 = showHist(err_similarity, correct_similarity)
|
||
plt2.show()
|
||
plt2.savefig(os.path.join(savepath, file+'_hist.png'))
|
||
# plt.close()
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# test_one2n()
|
||
test_rpath_return() # returnGoods.txt
|
||
# test_rpath_deleted() # deleteBarcode.txt
|
||
|
||
|
||
# try:
|
||
# test_rpath_return()
|
||
# test_rpath_deleted()
|
||
# except Exception as e:
|
||
# print(e)
|
||
|
||
|