Files
detecttracking/tracking/module_analysis.py
2024-07-22 20:16:45 +08:00

414 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Thu May 30 14:03:03 2024
现场测试性能分析
@author: ym
"""
import os
import cv2
import numpy as np
from pathlib import Path
import sys
sys.path.append(r"D:\DetectTracking")
from tracking.utils.plotting import Annotator, colors, draw_tracking_boxes
from tracking.utils import Boxes, IterableSimpleNamespace, yaml_load
from tracking.trackers import BOTSORT, BYTETracker
from tracking.dotrack.dotracks_back import doBackTracks
from tracking.dotrack.dotracks_front import doFrontTracks
from tracking.utils.drawtracks import plot_frameID_y2, draw_all_trajectories
from tracking.utils.read_data import extract_data, read_deletedBarcode_file, read_tracking_output
from contrast_analysis import contrast_analysis
from tracking.utils.annotator import TrackAnnotator
W, H = 1024, 1280
Mode = 'front' #'back'
ImgFormat = ['.jpg', '.jpeg', '.png', '.bmp']
def video2imgs(path):
vpath = os.path.join(path, "videos")
k = 0
have = False
for filename in os.listdir(vpath):
file, ext = os.path.splitext(filename)
imgdir = os.path.join(path, file)
if os.path.exists(imgdir):
continue
else:
os.mkdir(imgdir)
vfile = os.path.join(vpath, filename)
cap = cv2.VideoCapture(vfile)
i = 0
while True:
ret, frame = cap.read()
if not ret:
break
i += 1
imgp = os.path.join(imgdir, file+f"_{i}.png")
cv2.imwrite(imgp, frame)
print(filename+f": {i}")
cap.release()
k+=1
if k==1000:
break
def draw_boxes():
datapath = r'D:\datasets\ym\videos_test\20240530\1_tracker_inout(1).data'
VideosData = read_tracker_input(datapath)
bboxes = VideosData[0][0]
ffeats = VideosData[0][1]
videopath = r"D:\datasets\ym\videos_test\20240530\134458234-1cd970cf-f8b9-4e80-9c2e-7ca3eec83b81-1_seek0.10415589124891511.mp4"
cap = cv2.VideoCapture(videopath)
i = 0
while True:
ret, frame = cap.read()
if not ret:
break
annotator = Annotator(frame.copy(), line_width=3)
boxes = bboxes[i]
for *xyxy, conf, cls in reversed(boxes):
label = f'{int(cls)}: {conf:.2f}'
color = colors(int(cls), True)
annotator.box_label(xyxy, label, color=color)
img = annotator.result()
imgpath = r"D:\datasets\ym\videos_test\20240530\result\int8_front\{}.png".format(i+1)
cv2.imwrite(imgpath, img)
print(f"Output: {i}")
i += 1
cap.release()
def read_imgs(imgspath, CamerType):
imgs, frmIDs = [], []
for filename in os.listdir(imgspath):
file, ext = os.path.splitext(filename)
flist = file.split('_')
if len(flist)==4 and ext in ImgFormat:
camID, frmID = flist[0], int(flist[-1])
imgpath = os.path.join(imgspath, filename)
img = cv2.imread(imgpath)
if camID==CamerType:
imgs.append(img)
frmIDs.append(frmID)
if len(frmIDs):
indice = np.argsort(np.array(frmIDs))
imgs = [imgs[i] for i in indice]
return imgs
pass
def init_tracker(tracker_yaml = None, bs=1):
"""
Initialize tracker for object tracking during prediction.
"""
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
cfg = IterableSimpleNamespace(**yaml_load(tracker_yaml))
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
return tracker
def tracking(bboxes, ffeats):
tracker_yaml = r"./trackers/cfg/botsort.yaml"
tracker = init_tracker(tracker_yaml)
TrackBoxes = np.empty((0, 9), dtype = np.float32)
TracksDict = {}
'''========================== 执行跟踪处理 ============================='''
# dets 与 feats 应保持严格对应
for dets, feats in zip(bboxes, ffeats):
det_tracking = Boxes(dets).cpu().numpy()
tracks = tracker.update(det_tracking, features=feats)
'''tracks: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
0 1 2 3 4 5 6 7 8
这里frame_index 也可以用视频的 帧ID 代替, box_index 保持不变
'''
if len(tracks):
TrackBoxes = np.concatenate([TrackBoxes, tracks], axis=0)
FeatDict = {}
for track in tracks:
tid = int(track[8])
FeatDict.update({tid: feats[tid, :]})
frameID = tracks[0, 7]
# print(f"frameID: {int(frameID)}")
assert len(tracks) == len(FeatDict), f"Please check the func: tracker.update() at frameID({int(frameID)})"
TracksDict[f"frame_{int(frameID)}"] = {"feats":FeatDict}
return TrackBoxes, TracksDict
def do_tracker_tracking(fpath, save_dir):
bboxes, ffeats, trackerboxes, tracker_feat_dict, trackingboxes, tracking_feat_dict = extract_data(fpath)
tboxes, feats_dict = tracking(bboxes, ffeats)
CamerType = os.path.basename(fpath).split('_')[0]
dirname = os.path.split(os.path.split(fpath)[0])[1]
if CamerType == '1':
vts = doFrontTracks(tboxes, feats_dict)
vts.classify()
plt = plot_frameID_y2(vts)
plt.savefig('front_y2.png')
# plt.close()
elif CamerType == '0':
vts = doBackTracks(tboxes, feats_dict)
vts.classify()
filename = dirname+'_' + CamerType
edgeline = cv2.imread("./shopcart/cart_tempt/edgeline.png")
draw_all_trajectories(vts, edgeline, save_dir, filename)
else:
print("Please check data file!")
def do_tracking(fpath, savedir):
'''
fpath: 算法各模块输出的data文件地址匹配
savedir: 对 fpath 各模块输出的复现;
分析具体视频时,需指定 fpath 和 savedir
'''
# fpath = r'D:\contrast\dataset\1_to_n\709\20240709-102758_6971558612189\1_track.data'
# savedir = r'D:\contrast\dataset\result\20240709-102843_6958770005357_6971558612189\error_6971558612189'
imgpath, dfname = os.path.split(fpath)
CamerType = dfname.split('_')[0]
bboxes, ffeats, trackerboxes, tracker_feat_dict, trackingboxes, tracking_feat_dict = extract_data(fpath)
tracking_output_path = os.path.join(imgpath, CamerType + '_tracking_output.data')
if not os.path.isfile(tracking_output_path): return
tracking_output_boxes, _ = read_tracking_output(tracking_output_path)
'''存储画框后的 img'''
save_dir, basename = os.path.split(savedir)
if not os.path.exists(savedir):
os.makedirs(savedir)
'''存储轨迹对应的 boxes子图'''
subimg_dir = os.path.join(save_dir, basename.split('_')[0] + '_subimgs')
if not os.path.exists(subimg_dir):
os.makedirs(subimg_dir)
''' 读取 fpath 中 track.data 文件对应的图像 '''
imgs = read_imgs(imgpath, CamerType)
''' 在 imgs 上画框并保存,如果 trackerboxes 的帧数和 imgs 数不匹配,返回原图'''
imgs_dw = draw_tracking_boxes(imgs, trackerboxes)
if len(imgs_dw)==0:
imgs_dw = [img for img in imgs]
print(f"fpath: {imgpath}, savedir: {savedir}。Tracker输出的图像数和 imgs 中图像数不相等,无法一一匹配并画框")
for i in range(len(imgs_dw)):
img_savepath = os.path.join(savedir, CamerType + "_" + f"{i}.png")
# img = imgs_dw[i]
cv2.imwrite(img_savepath, imgs_dw[i])
if not isinstance(savedir, Path):
savedir = Path(savedir)
save_dir = savedir.parent
traj_graphic = basename + '_' + CamerType
if CamerType == '1':
vts = doFrontTracks(trackerboxes, tracker_feat_dict)
vts.classify()
plt = plot_frameID_y2(vts)
ftpath = save_dir.joinpath(f"{traj_graphic}_front_y2.png")
plt.savefig(str(ftpath))
plt.close()
elif CamerType == '0':
vts = doBackTracks(trackerboxes, tracker_feat_dict)
vts.classify()
edgeline = cv2.imread("./shopcart/cart_tempt/edgeline.png")
img = draw_all_trajectories(vts, edgeline, save_dir, traj_graphic)
imgpth = save_dir.joinpath(f"{traj_graphic}_show.png")
cv2.imwrite(str(imgpth), img)
else:
print("Please check data file!")
for track in vts.Residual:
for *xyxy, tid, conf, cls, fid, bid in track.boxes:
img = imgs[int(fid-1)]
x1, y1, x2, y2 = int(xyxy[0]/2), int(xyxy[1]/2), int(xyxy[2]/2), int(xyxy[3]/2)
subimg = img[y1:y2, x1:x2]
subimg_path = os.path.join(subimg_dir, f'{CamerType}_{int(tid)}_{int(fid-1)}_{int(bid)}.png' )
cv2.imwrite(subimg_path, subimg)
'''================== 现场测试的 tracking() 算法输出 =================='''
if CamerType == '1':
aline = cv2.imread("./shopcart/cart_tempt/board_ftmp_line.png")
elif CamerType == '0':
aline = cv2.imread("./shopcart/cart_tempt/edgeline.png")
else:
print("Please check data file!")
bline = aline.copy()
annotator = TrackAnnotator(aline, line_width=2)
for track in trackingboxes:
annotator.plotting_track(track)
aline = annotator.result()
annotator = TrackAnnotator(bline, line_width=2)
if not isinstance(tracking_output_boxes, list):
tracking_output_boxes = [tracking_output_boxes]
for track in tracking_output_boxes:
annotator.plotting_track(track)
bline = annotator.result()
abimg = np.concatenate((aline, bline), axis = 1)
abH, abW = abimg.shape[:2]
cv2.line(abimg, (int(abW/2), 0), (int(abW/2), abH), (128, 255, 128), 2)
algpath = save_dir.joinpath(f"{traj_graphic}_Alg.png")
cv2.imwrite(str(algpath), abimg)
return
def main_loop():
del_barcode_file = r'\\192.168.1.28\share\测试_202406\deletedBarcode\bad\deletedBarcode_0719_4.txt'
basepath = r'\\192.168.1.28\share\测试_202406\0719\719_4' # 测试数据文件夹地址
SavePath = r'D:\contrast\dataset\result' # 结果保存地址
prefix = ["getout_", "input_", "error_"]
'''获取性能测试数据相关路径'''
relative_paths = contrast_analysis(del_barcode_file, basepath, SavePath)
'''开始循环执行每次测试过任务'''
k = 0
for tuple_paths in relative_paths:
'''生成文件夹存储结果图像的文件夹'''
namedirs = []
for data_path in tuple_paths:
base_name = os.path.basename(data_path).strip().split('_')
if len(base_name[-1]):
name = base_name[-1]
else:
name = base_name[0]
namedirs.append(name)
sdir = "_".join(namedirs)
savepath = os.path.join(SavePath, sdir)
if not os.path.exists(savepath):
os.makedirs(savepath)
for path in tuple_paths:
for filename in os.listdir(path):
fpath = os.path.join(path, filename)
if os.path.isfile(fpath) and filename.find("track.data")>0:
enent_name = ''
'''构建结果保存文件名前缀'''
for i, name in enumerate(namedirs):
if fpath.find(name)>0:
enent_name = prefix[i] + name
break
spath = os.path.join(savepath, enent_name)
do_tracking(fpath, spath)
k +=1
if k==1:
break
def main():
'''
fpath: data文件包括 Pipeline 各模块输出
save_dir需包含二级目录其中一级目录为轨迹图像
二级目录为与data文件对应的序列图像存储地址。
'''
fpath = r'\\192.168.1.28\share\测试_202406\0719\719_4\20240719-164209_\0_track.data'
save_dir = r'D:\contrast\dataset\result\20240719-164209_6971284204320_6902890247777\getout'
do_tracking(fpath, save_dir)
if __name__ == "__main__":
try:
# main()
main_loop()
except Exception as e:
print(f'Error: {e}')