Files
detecttracking/contrast/one2n_contrast.py
王庆刚 5ecc1285d4 update
2024-11-04 18:06:52 +08:00

464 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Sat Jul ** 14:07:25 2024
现场测试精度、召回率分析程序,是 feat_select.py 的简化版,
但支持循环计算并输出总的pr曲线
@author: ym
"""
import os.path
import shutil
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import sys
sys.path.append(r"D:\DetectTracking")
from tracking.utils.plotting import Annotator, colors
from tracking.utils.read_data import extract_data, read_deletedBarcode_file, read_tracking_output, read_returnGoods_file
from tracking.utils.plotting import draw_tracking_boxes
from contrast.utils.tools import showHist, show_recall_prec, compute_recall_precision
# =============================================================================
# def read_tracking_output(filepath):
# boxes = []
# feats = []
# with open(filepath, 'r', encoding='utf-8') as file:
# for line in file:
# line = line.strip() # 去除行尾的换行符和可能的空白字符
#
# if not line:
# continue
#
# if line.endswith(','):
# line = line[:-1]
#
# data = np.array([float(x) for x in line.split(",")])
# if data.size == 9:
# boxes.append(data)
# if data.size == 256:
# feats.append(data)
#
# return np.array(boxes), np.array(feats)
# =============================================================================
def read_tracking_imgs(imgspath):
'''
input:
imgspath该路径中的图像为Yolo算法的输入图像640x512
output
imgs_0后摄图像根据 frameId 进行了排序
imgs_1前摄图像根据 frameId 进行了排序
'''
imgs_0, frmIDs_0, imgs_1, frmIDs_1 = [], [], [], []
for filename in os.listdir(imgspath):
file, ext = os.path.splitext(filename)
flist = file.split('_')
if len(flist)==4 and ext==".jpg":
camID, frmID = flist[0], int(flist[-1])
imgpath = os.path.join(imgspath, filename)
img = cv2.imread(imgpath)
if camID=='0':
imgs_0.append(img)
frmIDs_0.append(frmID)
if camID=='1':
imgs_1.append(img)
frmIDs_1.append(frmID)
if len(frmIDs_0):
indice = np.argsort(np.array(frmIDs_0))
imgs_0 = [imgs_0[i] for i in indice ]
if len(frmIDs_1):
indice = np.argsort(np.array(frmIDs_1))
imgs_1 = [imgs_1[i] for i in indice ]
return imgs_0, imgs_1
# =============================================================================
# def draw_tracking_boxes(imgs, tracks):
# '''tracks: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
# 0 1 2 3 4 5 6 7 8
# 关键imgs中的次序和 track 中的 fid 对应
# '''
# subimgs = []
# for *xyxy, tid, conf, cls, fid, bid in tracks:
# label = f'id:{int(tid)}_{int(cls)}_{conf:.2f}'
#
# annotator = Annotator(imgs[int(fid-1)].copy())
# if cls==0:
# color = colors(int(cls), True)
# elif tid>0 and cls!=0:
# color = colors(int(tid), True)
# else:
# color = colors(19, True) # 19为调色板的最后一个元素
#
# pt2 = [p/2 for p in xyxy]
# annotator.box_label(pt2, label, color=color)
# img0 = annotator.result()
#
# subimgs.append(img0)
#
# return subimgs
# =============================================================================
def get_contrast_paths(pair, basepath):
assert(len(pair)==2 or len(pair)==3), "pair: seqdir, delete, barcodes"
getout_fold = pair[0] # 取出操作对应的文件夹
relvt_barcode = pair[1] # 取出操作对应放入操作的 Barcode
if len(pair)==3:
error_match = pair[2] # 取出操作错误匹配的 Barcode
else:
error_match = ''
getoutpath, inputpath, errorpath = '', '', ''
day, hms = getout_fold.strip('_').split('-')
input_folds, times = [], []
errmatch_folds, errmatch_times = [], []
for pathname in os.listdir(basepath):
if pathname.endswith('_'): continue
if os.path.isfile(os.path.join(basepath, pathname)):continue
infold = pathname.split('_')
if len(infold)!=2: continue
day1, hms1 = infold[0].split('-')
if day1==day and infold[1]==relvt_barcode and int(hms1)<int(hms):
input_folds.append(pathname)
times.append(int(hms1))
if day1==day and len(error_match) and infold[1]==error_match and int(hms1)<int(hms):
errmatch_folds.append(pathname)
errmatch_times.append(int(hms1))
''' 根据时间排序,选择离取出操作最近时间的文件夹,
作为取出操作应正确匹配的放入操作所对应的文件夹 '''
if len(input_folds):
indice = np.argsort(np.array(times))
input_fold = input_folds[indice[-1]]
inputpath = os.path.join(basepath, input_fold)
'''取出操作错误匹配的放入操作对应的文件夹'''
if len(errmatch_folds):
indice = np.argsort(np.array(errmatch_times))
errmatch_fold = errmatch_folds[indice[-1]]
errorpath = os.path.join(basepath, errmatch_fold)
'''放入事件文件夹地址、取出事件文件夹地址'''
getoutpath = os.path.join(basepath, getout_fold)
return getoutpath, inputpath, errorpath
def save_tracking_imgpairs(pair, basepath, savepath):
'''
basepath: 原始测试数据文件夹的路径
savepath: 保存的目标文件夹
'''
getoutpath, inputpath, errorpath = get_contrast_paths(pair, basepath)
if len(inputpath)==0:
return
'''==== 读取放入、取出事件对应的 Yolo输入的前后摄图像0后摄1前摄 ===='''
'''==== 读取放入、取出事件对应的 tracking 输出boxes, feats ===='''
if len(inputpath):
imgs_input_0, imgs_input_1 = read_tracking_imgs(inputpath)
input_data_0 = os.path.join(inputpath, '0_tracking_output.data')
input_data_1 = os.path.join(inputpath, '1_tracking_output.data')
boxes_input_0, feats_input_0 = read_tracking_output(input_data_0)
boxes_input_1, feats_input_1 = read_tracking_output(input_data_1)
ImgsInput_0 = draw_tracking_boxes(imgs_input_0, boxes_input_0)
ImgsInput_1 = draw_tracking_boxes(imgs_input_1, boxes_input_1)
if len(getoutpath):
imgs_getout_0, imgs_getout_1 = read_tracking_imgs(getoutpath)
getout_data_0 = os.path.join(getoutpath, '0_tracking_output.data')
getout_data_1 = os.path.join(getoutpath, '1_tracking_output.data')
boxes_output_0, feats_output_0 = read_tracking_output(getout_data_0)
boxes_output_1, feats_output_1 = read_tracking_output(getout_data_1)
ImgsGetout_0 = draw_tracking_boxes(imgs_getout_0, boxes_output_0)
ImgsGetout_1 = draw_tracking_boxes(imgs_getout_1, boxes_output_1)
if len(errorpath):
imgs_error_0, imgs_error_1 = read_tracking_imgs(errorpath)
error_data_0 = os.path.join(errorpath, '0_tracking_output.data')
error_data_1 = os.path.join(errorpath, '1_tracking_output.data')
boxes_error_0, feats_error_0 = read_tracking_output(error_data_0)
boxes_error_1, feats_error_1 = read_tracking_output(error_data_1)
ImgsError_0 = draw_tracking_boxes(imgs_error_0, boxes_error_0)
ImgsError_1 = draw_tracking_boxes(imgs_error_1, boxes_error_1)
savedir = pair[0] + pair[1]
if len(errorpath):
savedir = savedir + '_' + errorpath.split('_')[-1]
foldname = os.path.join(savepath, 'imgpairs', savedir)
if not os.path.exists(foldname):
os.makedirs(foldname)
for i, img in enumerate(ImgsInput_0):
imgpath = os.path.join(foldname, f'input_0_{i}.png')
cv2.imwrite(imgpath, img)
for i, img in enumerate(ImgsInput_1):
imgpath = os.path.join(foldname, f'input_1_{i}.png')
cv2.imwrite(imgpath, img)
for i, img in enumerate(ImgsGetout_0):
imgpath = os.path.join(foldname, f'getout_0_{i}.png')
cv2.imwrite(imgpath, img)
for i, img in enumerate(ImgsGetout_1):
imgpath = os.path.join(foldname, f'getout_1_{i}.png')
cv2.imwrite(imgpath, img)
for i, img in enumerate(ImgsError_0):
imgpath = os.path.join(foldname, f'errMatch_0_{i}.png')
cv2.imwrite(imgpath, img)
for i, img in enumerate(ImgsError_1):
imgpath = os.path.join(foldname, f'errMatch_1_{i}.png')
cv2.imwrite(imgpath, img)
def one2n_old(all_list):
corrpairs, errpairs, correct_similarity, err_similarity = [], [], [], []
for s_list in all_list:
seqdir = s_list['SeqDir'].strip()
delete = s_list['Deleted'].strip()
barcodes = [s.strip() for s in s_list['barcode']]
similarity_comp, similarity_front = [], []
for simil in s_list['similarity']:
ss = [float(s.strip()) for s in simil.split(',')]
similarity_comp.append(ss[0])
if len(ss)==3:
similarity_front.append(ss[2])
if len(similarity_front):
similarity = [s for s in similarity_front]
else:
similarity = [s for s in similarity_comp]
index = similarity.index(max(similarity))
matched_barcode = barcodes[index]
if matched_barcode == delete:
corrpairs.append((seqdir, delete))
correct_similarity.append(max(similarity))
else:
errpairs.append((seqdir, delete, matched_barcode))
err_similarity.append(max(similarity))
return corrpairs, errpairs, correct_similarity, err_similarity
def one2n_new(all_list):
corrpairs, correct_similarity, errpairs, err_similarity = [], [], [], []
for s_list in all_list:
seqdir = s_list['SeqDir'].strip()
delete = s_list['Deleted'].strip()
barcodes = [s.strip() for s in s_list['barcode']]
events = [s.strip() for s in s_list['event']]
types = [s.strip() for s in s_list['type']]
## =================== 读入相似度值
similarity_comp, similarity_front = [], []
for simil in s_list['similarity']:
ss = [float(s.strip()) for s in simil.split(',')]
similarity_comp.append(ss[0])
if len(ss)==3:
similarity_front.append(ss[2])
if len(similarity_front):
similarity = [s for s in similarity_front]
else:
similarity = [s for s in similarity_comp]
index = similarity.index(max(similarity))
matched_barcode = barcodes[index]
if matched_barcode == delete:
corrpairs.append((seqdir, events[index]))
correct_similarity.append(max(similarity))
else:
idx = [i for i, name in enumerate(events) if name.split('_')[-1] == delete]
idxmax, simimax = -1, -1
# idxmax, simimax = k, similarity[k] for k in idx if similarity[k] > simimax
for k in idx:
if similarity[k] > simimax:
idxmax = k
simimax = similarity[k]
errpairs.append((seqdir, events[idxmax], events[index]))
err_similarity.append(max(similarity))
return errpairs, corrpairs, err_similarity, correct_similarity
# def contrast_analysis(del_barcode_file, basepath, savepath, saveimgs=False):
def get_relative_paths(del_barcode_file, basepath, savepath, saveimgs=False):
'''
del_barcode_file:
deletedBarcode.txt 格式的 1:n 数据结果文件
returnGoods.txt格式数据文件不需要调用该函数one2n_old() 函数返回的 errpairs
中元素为三元元组(取出,放入, 错误匹配)
'''
relative_paths = []
'''1. 读取 deletedBarcode 文件 '''
all_list = read_deletedBarcode_file(del_barcode_file)
'''2. 算法性能评估,并输出 (取出,删除, 错误匹配) 对 '''
errpairs, corrpairs, _, _ = one2n_old(all_list)
'''3. 构造事件组合(取出,放入并删除, 错误匹配) 对应路径 '''
for errpair in errpairs:
GetoutPath, InputPath, ErrorPath = get_contrast_paths(errpair, basepath)
relative_paths.append((GetoutPath, InputPath, ErrorPath))
'''3. 获取 (取出,放入并删除, 错误匹配) 对应路径,保存相应轨迹图像'''
if saveimgs:
save_tracking_imgpairs(errpair, basepath, savepath)
return relative_paths
def one2n_test():
fpath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\other'
fpath = r'\\192.168.1.28\share\测试_202406\1030\images'
savepath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\illustration'
if not os.path.exists(savepath):
os.mkdir(savepath)
if os.path.isdir(fpath):
filepaths = [os.path.join(fpath, f) for f in os.listdir(fpath)
if f.find('.txt')>0
and (f.find('deletedBarcode')>=0 or f.find('returnGoods')>=0)]
elif os.path.isfile(fpath):
filepaths = [fpath]
else:
return
FileFormat = {}
BarLists, blists = {}, []
for pth in filepaths:
file = str(Path(pth).stem)
if file.find('deletedBarcode')>=0:
FileFormat[file] = 'deletedBarcode'
blist = read_deletedBarcode_file(pth)
elif file.find('returnGoods')>=0:
FileFormat[file] = 'returnGoods'
blist = read_returnGoods_file(pth)
else:
return
BarLists.update({file: blist})
blists.extend(blist)
BarLists.update({file: blist})
BarLists.update({"Total": blists})
for file, blist in BarLists.items():
if FileFormat[file] == 'deletedBarcode':
_, _, err_similarity, correct_similarity = one2n_old(blist)
elif FileFormat[file] == 'returnGoods':
_, _, err_similarity, correct_similarity = one2n_new(blist)
else:
_, _, err_similarity, correct_similarity = one2n_old(blist)
recall, prec, ths = compute_recall_precision(err_similarity, correct_similarity)
plt1 = show_recall_prec(recall, prec, ths)
# plt1.show()
plt1.xlabel(f'threshold, Num: {len(blist)}')
plt1.savefig(os.path.join(savepath, file+'_pr.png'))
# plt1.close()
# plt2 = showHist(err_similarity, correct_similarity)
# plt2.show()
# plt2.savefig(os.path.join(savepath, file+'_hist.png'))
# plt.close()
def test_getreltpath():
'''
适用于deletedBarcode.txt不适用于returnGoods.txt
'''
del_barcode_file = r'\\192.168.1.28\share\测试_202406\709\deletedBarcode.txt'
basepath = r'\\192.168.1.28\share\测试_202406\709'
# del_barcode_file = r'\\192.168.1.28\share\测试_202406\1030\images\returnGoods.txt'
# basepath = r'\\192.168.1.28\share\测试_202406\1030\images'
savepath = r'D:\contrast\dataset\result'
saveimgs = True
try:
relative_path = get_relative_paths(del_barcode_file, basepath, savepath, saveimgs)
except Exception as e:
print(f'Error Type: {e}')
if __name__ == '__main__':
one2n_test()
# test_getreltpath()