Files
detecttracking/contrast/one2n_contrast.py
王庆刚 8bbee310ba bakeup
2024-11-25 18:05:08 +08:00

499 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Sat Jul ** 14:07:25 2024
现场测试精度、召回率分析程序,是 feat_select.py 的简化版,
但支持循环计算并输出总的pr曲线
@author: ym
"""
import os.path
import shutil
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import sys
sys.path.append(r"D:\DetectTracking")
from tracking.utils.plotting import Annotator, colors
from tracking.utils.read_data import extract_data, read_deletedBarcode_file, read_tracking_output, read_returnGoods_file
from tracking.utils.plotting import draw_tracking_boxes, get_subimgs
from contrast.utils.tools import showHist, show_recall_prec, compute_recall_precision
# =============================================================================
# def read_tracking_output(filepath):
# boxes = []
# feats = []
# with open(filepath, 'r', encoding='utf-8') as file:
# for line in file:
# line = line.strip() # 去除行尾的换行符和可能的空白字符
#
# if not line:
# continue
#
# if line.endswith(','):
# line = line[:-1]
#
# data = np.array([float(x) for x in line.split(",")])
# if data.size == 9:
# boxes.append(data)
# if data.size == 256:
# feats.append(data)
#
# return np.array(boxes), np.array(feats)
# =============================================================================
def read_tracking_imgs(imgspath):
'''
input:
imgspath该路径中的图像为Yolo算法的输入图像640x512
output
imgs_0后摄图像根据 frameId 进行了排序
imgs_1前摄图像根据 frameId 进行了排序
'''
imgs_0, frmIDs_0, imgs_1, frmIDs_1 = [], [], [], []
for filename in os.listdir(imgspath):
file, ext = os.path.splitext(filename)
flist = file.split('_')
if len(flist)==4 and ext==".jpg":
camID, frmID = flist[0], int(flist[-1])
imgpath = os.path.join(imgspath, filename)
img = cv2.imread(imgpath)
if camID=='0':
imgs_0.append(img)
frmIDs_0.append(frmID)
if camID=='1':
imgs_1.append(img)
frmIDs_1.append(frmID)
if len(frmIDs_0):
indice = np.argsort(np.array(frmIDs_0))
imgs_0 = [imgs_0[i] for i in indice ]
if len(frmIDs_1):
indice = np.argsort(np.array(frmIDs_1))
imgs_1 = [imgs_1[i] for i in indice ]
return imgs_0, imgs_1
# =============================================================================
# def draw_tracking_boxes(imgs, tracks):
# '''tracks: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
# 0 1 2 3 4 5 6 7 8
# 关键imgs中的次序和 track 中的 fid 对应
# '''
# subimgs = []
# for *xyxy, tid, conf, cls, fid, bid in tracks:
# label = f'id:{int(tid)}_{int(cls)}_{conf:.2f}'
#
# annotator = Annotator(imgs[int(fid-1)].copy())
# if cls==0:
# color = colors(int(cls), True)
# elif tid>0 and cls!=0:
# color = colors(int(tid), True)
# else:
# color = colors(19, True) # 19为调色板的最后一个元素
#
# pt2 = [p/2 for p in xyxy]
# annotator.box_label(pt2, label, color=color)
# img0 = annotator.result()
#
# subimgs.append(img0)
#
# return subimgs
# =============================================================================
def get_contrast_paths(pair, basepath):
assert(len(pair)==2 or len(pair)==3), "pair: seqdir, delete, barcodes"
getout_fold = pair[0] # 取出操作对应的文件夹
relvt_barcode = pair[1] # 取出操作对应放入操作的 Barcode
if len(pair)==3:
error_match = pair[2] # 取出操作错误匹配的 Barcode
else:
error_match = ''
getoutpath, inputpath, errorpath = '', '', ''
day, hms = getout_fold.strip('_').split('-')
input_folds, times = [], []
errmatch_folds, errmatch_times = [], []
for pathname in os.listdir(basepath):
if pathname.endswith('_'): continue
if os.path.isfile(os.path.join(basepath, pathname)):continue
infold = pathname.split('_')
if len(infold)!=2: continue
day1, hms1 = infold[0].split('-')
if day1==day and infold[1]==relvt_barcode and int(hms1)<int(hms):
input_folds.append(pathname)
times.append(int(hms1))
if day1==day and len(error_match) and infold[1]==error_match and int(hms1)<int(hms):
errmatch_folds.append(pathname)
errmatch_times.append(int(hms1))
''' 根据时间排序,选择离取出操作最近时间的文件夹,
作为取出操作应正确匹配的放入操作所对应的文件夹 '''
if len(input_folds):
indice = np.argsort(np.array(times))
input_fold = input_folds[indice[-1]]
inputpath = os.path.join(basepath, input_fold)
'''取出操作错误匹配的放入操作对应的文件夹'''
if len(errmatch_folds):
indice = np.argsort(np.array(errmatch_times))
errmatch_fold = errmatch_folds[indice[-1]]
errorpath = os.path.join(basepath, errmatch_fold)
'''放入事件文件夹地址、取出事件文件夹地址'''
getoutpath = os.path.join(basepath, getout_fold)
return getoutpath, inputpath, errorpath
def save_tracking_imgpairs(pairs, savepath):
'''
pairs: 匹配事件对
savepath: 保存的目标文件夹
'''
def get_event_path(evtpath):
basepath, eventname = os.path.split(evtpath)
evt_path = ''
for filename in os.listdir(basepath):
if filename.find(eventname)==0:
evt_path = os.path.join(basepath, filename)
break
return evt_path
getoutpath = get_event_path(pairs[0])
inputpath = get_event_path(pairs[1])
if len(pairs) == 3:
errorpath = get_event_path(pairs[2])
else:
errorpath = ''
''' 1. 读取放入、取出事件对应的 Yolo输入的前后摄图像0后摄1前摄
2. 读取放入、取出事件对应的 tracking 输出boxes, feats
3. boxes绘制并保存图像序列
4. 截取并保存轨迹子图
'''
if len(getoutpath):
imgs_getout_0, imgs_getout_1 = read_tracking_imgs(getoutpath)
getout_data_0 = os.path.join(getoutpath, '0_tracking_output.data')
getout_data_1 = os.path.join(getoutpath, '1_tracking_output.data')
boxes_output_0, feats_output_0 = read_tracking_output(getout_data_0)
boxes_output_1, feats_output_1 = read_tracking_output(getout_data_1)
ImgsGetout_0 = draw_tracking_boxes(imgs_getout_0, boxes_output_0)
ImgsGetout_1 = draw_tracking_boxes(imgs_getout_1, boxes_output_1)
SubimgsGetout_0 = get_subimgs(imgs_getout_0, boxes_output_0)
SubimgsGetout_1 = get_subimgs(imgs_getout_1, boxes_output_1)
savedir = os.path.basename(getoutpath)
if len(inputpath):
imgs_input_0, imgs_input_1 = read_tracking_imgs(inputpath)
input_data_0 = os.path.join(inputpath, '0_tracking_output.data')
input_data_1 = os.path.join(inputpath, '1_tracking_output.data')
boxes_input_0, feats_input_0 = read_tracking_output(input_data_0)
boxes_input_1, feats_input_1 = read_tracking_output(input_data_1)
ImgsInput_0 = draw_tracking_boxes(imgs_input_0, boxes_input_0)
ImgsInput_1 = draw_tracking_boxes(imgs_input_1, boxes_input_1)
SubimgsInput_0 = get_subimgs(imgs_input_0, boxes_input_0)
SubimgsInput_1 = get_subimgs(imgs_input_1, boxes_input_1)
savedir = savedir + '+' + os.path.basename(inputpath)
if len(errorpath):
imgs_error_0, imgs_error_1 = read_tracking_imgs(errorpath)
error_data_0 = os.path.join(errorpath, '0_tracking_output.data')
error_data_1 = os.path.join(errorpath, '1_tracking_output.data')
boxes_error_0, feats_error_0 = read_tracking_output(error_data_0)
boxes_error_1, feats_error_1 = read_tracking_output(error_data_1)
ImgsError_0 = draw_tracking_boxes(imgs_error_0, boxes_error_0)
ImgsError_1 = draw_tracking_boxes(imgs_error_1, boxes_error_1)
SubimgsError_0 = get_subimgs(imgs_error_0, boxes_error_0)
SubimgsError_1 = get_subimgs(imgs_error_0, boxes_error_0)
savedir = savedir + '+' + os.path.basename(errorpath)
''' savepath\pairs\savedir\eventpairs\保存画框后的图像序列 '''
entpairs = os.path.join(savepath, 'pairs', savedir, 'eventpairs')
if not os.path.exists(entpairs):
os.makedirs(entpairs)
for fid, img in ImgsInput_0:
imgpath = os.path.join(entpairs, f'input_0_{fid}.png')
cv2.imwrite(imgpath, img)
for fid, img in ImgsInput_1:
imgpath = os.path.join(entpairs, f'input_1_{fid}.png')
cv2.imwrite(imgpath, img)
for fid, img in ImgsGetout_0:
imgpath = os.path.join(entpairs, f'getout_0_{fid}.png')
cv2.imwrite(imgpath, img)
for fid, img in ImgsGetout_1:
imgpath = os.path.join(entpairs, f'getout_1_{fid}.png')
cv2.imwrite(imgpath, img)
if 'ImgsError_0' in vars() and 'ImgsError_1' in vars():
for fid, img in ImgsError_0:
imgpath = os.path.join(entpairs, f'errMatch_0_{fid}.png')
cv2.imwrite(imgpath, img)
for fid, img in ImgsError_1:
imgpath = os.path.join(entpairs, f'errMatch_1_{fid}.png')
cv2.imwrite(imgpath, img)
''' savepath\pairs\savedir\subimgpairs\保存轨迹子图 '''
subimgpairs = os.path.join(savepath, 'pairs', savedir, 'subimgpairs')
if not os.path.exists(subimgpairs):
os.makedirs(subimgpairs)
for fid, bid, img in SubimgsGetout_0:
imgpath = os.path.join(subimgpairs, f'getout_0_{fid}_{bid}.png')
cv2.imwrite(imgpath, img)
for fid, bid, img in SubimgsGetout_1:
imgpath = os.path.join(subimgpairs, f'getout_1_{fid}_{bid}.png')
cv2.imwrite(imgpath, img)
for fid, bid, img in SubimgsInput_0:
imgpath = os.path.join(subimgpairs, f'input_0_{fid}_{bid}.png')
cv2.imwrite(imgpath, img)
for fid, bid, img in SubimgsInput_1:
imgpath = os.path.join(subimgpairs, f'input_1_{fid}_{bid}.png')
cv2.imwrite(imgpath, img)
if 'SubimgsError_0' in vars() and 'SubimgsError_1' in vars():
for fid, bid, img in SubimgsError_0:
imgpath = os.path.join(subimgpairs, f'errMatch_0_{fid}_{bid}.png')
cv2.imwrite(imgpath, img)
for fid, bid, img in SubimgsError_1:
imgpath = os.path.join(subimgpairs, f'errMatch_1_{fid}_{bid}.png')
cv2.imwrite(imgpath, img)
def one2n_deleted(all_list):
corrpairs, errpairs, correct_similarity, err_similarity = [], [], [], []
for s_list in all_list:
seqdir = s_list['SeqDir'].strip()
delete = s_list['Deleted'].strip()
barcodes = [s.strip() for s in s_list['barcode']]
similarity_comp, similarity_front = [], []
for simil in s_list['similarity']:
ss = [float(s.strip()) for s in simil.split(',')]
similarity_comp.append(ss[0])
if len(ss)==3:
similarity_front.append(ss[2])
if len(similarity_front):
similarity = [s for s in similarity_front]
else:
similarity = [s for s in similarity_comp]
index = similarity.index(max(similarity))
matched_barcode = barcodes[index]
if matched_barcode == delete:
corrpairs.append((seqdir, delete))
correct_similarity.append(max(similarity))
else:
errpairs.append((seqdir, delete, matched_barcode))
err_similarity.append(max(similarity))
return corrpairs, errpairs, correct_similarity, err_similarity
def one2n_return(all_list):
corrpairs, corr_similarity, errpairs, err_similarity = [], [], [], []
for s_list in all_list:
seqdir = s_list['SeqDir'].strip()
delete = s_list['Deleted'].strip()
barcodes = [s.strip() for s in s_list['barcode']]
events = [s.strip() for s in s_list['event']]
types = [s.strip() for s in s_list['type']]
## =================== 读入相似度值
similarity_comp, similarity_front = [], []
for simil in s_list['similarity']:
ss = [float(s.strip()) for s in simil.split(',')]
similarity_comp.append(ss[0])
if len(ss)==3:
similarity_front.append(ss[2])
if len(similarity_front):
similarity = [s for s in similarity_front]
else:
similarity = [s for s in similarity_comp]
index = similarity.index(max(similarity))
matched_barcode = barcodes[index]
if matched_barcode == delete:
corrpairs.append((seqdir, events[index]))
corr_similarity.append(max(similarity))
else:
idx = [i for i, name in enumerate(events) if name.split('_')[-1] == delete]
idxmax, simimax = -1, -1
# idxmax, simimax = k, similarity[k] for k in idx if similarity[k] > simimax
for k in idx:
if similarity[k] > simimax:
idxmax = k
simimax = similarity[k]
if idxmax>-1:
input_event = events[idxmax]
else:
input_event = ''
errpairs.append((seqdir, input_event, events[index]))
err_similarity.append(max(similarity))
return corrpairs, errpairs, corr_similarity, err_similarity
def test_rpath_deleted():
'''deletedBarcode.txt 格式的 1:n 数据结果文件, returnGoods.txt格式数据文件不需要调用该函数'''
del_bfile = r'\\192.168.1.28\share\测试_202406\709\deletedBarcode.txt'
basepath = r'\\192.168.1.28\share\测试_202406\709'
savepath = r'D:\DetectTracking\contrast\result'
saveimgs = True
relative_paths = []
'''1. 读取 deletedBarcode 文件 '''
all_list = read_deletedBarcode_file(del_bfile)
'''2. 算法性能评估,并输出 (取出,删除, 错误匹配) 对 '''
corrpairs, errpairs, _, _ = one2n_deleted(all_list)
'''3. 构造事件组合(取出,放入并删除, 错误匹配) 对应路径 '''
for errpair in errpairs:
GetoutPath, InputPath, ErrorPath = get_contrast_paths(errpair, basepath)
pairs = (GetoutPath, InputPath, ErrorPath)
relative_paths.append(pairs)
print(InputPath)
'''3. 获取 (取出,放入并删除, 错误匹配) 对应路径,保存相应轨迹图像'''
if saveimgs:
save_tracking_imgpairs(pairs, savepath)
def test_rpath_return():
return_bfile = r'\\192.168.1.28\share\测试_202406\1101\images\returnGoods.txt'
basepath = r'\\192.168.1.28\share\测试_202406\1101\images'
savepath = r'D:\DetectTracking\contrast\result'
all_list = read_returnGoods_file(return_bfile)
corrpairs, errpairs, _, _ = one2n_return(all_list)
for corrpair in corrpairs:
GetoutPath = os.path.join(basepath, corrpair[0])
InputPath = os.path.join(basepath, corrpair[1])
pairs = (GetoutPath, InputPath)
save_tracking_imgpairs(pairs, savepath)
for errpair in errpairs:
GetoutPath = os.path.join(basepath, errpair[0])
InputPath = os.path.join(basepath, errpair[1])
ErrorPath = os.path.join(basepath, errpair[2])
pairs = (GetoutPath, InputPath, ErrorPath)
save_tracking_imgpairs(pairs, savepath)
def test_one2n():
'''
1:n 性能测试
兼容 2 种 txt 文件格式returnGoods.txt, deletedBarcode.txt
fpath: 文件路径、或文件夹,其中包含多个 txt 文件
savepath: pr曲线保存路径
'''
# fpath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\other' # deletedBarcode.txt
fpath = r'\\192.168.1.28\share\测试_202406\1108_展厅模型v800测试' # returnGoods.txt
savepath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\illustration'
if os.path.isdir(fpath):
filepaths = [os.path.join(fpath, f) for f in os.listdir(fpath)
if f.find('.txt')>0
and (f.find('deletedBarcode')>=0 or f.find('returnGoods')>=0)]
elif os.path.isfile(fpath):
filepaths = [fpath]
else:
return
if not os.path.exists(savepath):
os.mkdir(savepath)
BarLists, blists = {}, []
for pth in filepaths:
file = str(Path(pth).stem)
if file.find('deletedBarcode')>=0:
blist = read_deletedBarcode_file(pth)
if file.find('returnGoods')>=0:
blist = read_returnGoods_file(pth)
BarLists.update({file: blist})
blists.extend(blist)
if len(blists): BarLists.update({"Total": blists})
for file, blist in BarLists.items():
if all(b['filetype']=="deletedBarcode" for b in blist):
_, _, correct_similarity, err_similarity = one2n_deleted(blist)
if all(b['filetype']=="returnGoods" for b in blists):
_, _, correct_similarity, err_similarity = one2n_return(blist)
recall, prec, ths = compute_recall_precision(err_similarity, correct_similarity)
plt1 = show_recall_prec(recall, prec, ths)
# plt1.show()
plt1.xlabel(f'threshold, Num: {len(blist)}')
plt1.savefig(os.path.join(savepath, file+'_pr.png'))
# plt1.close()
plt2 = showHist(err_similarity, correct_similarity)
plt2.show()
plt2.savefig(os.path.join(savepath, file+'_hist.png'))
# plt.close()
if __name__ == '__main__':
# test_one2n()
test_rpath_return() # returnGoods.txt
# test_rpath_deleted() # deleteBarcode.txt
# try:
# test_rpath_return()
# test_rpath_deleted()
# except Exception as e:
# print(e)