Files
detecttracking/tracking/utils/read_pipeline_data.py
2024-07-18 17:52:12 +08:00

250 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Tue May 21 15:25:23 2024
读取 Pipeline 各模块的数据,主代码由 马晓慧 完成
@author: ieemoo-zl003
"""
import os
import numpy as np
# 替换为你的目录路径
files_path = 'D:/contrast/dataset/1_to_n/709/20240709-112658_6903148351833/'
def str_to_float_arr(s):
# 移除字符串末尾的逗号(如果存在)
if s.endswith(','):
s = s[:-1]
# 使用split()方法分割字符串然后将每个元素转化为float
float_array = np.array([float(x) for x in s.split(",")])
return float_array
def extract_tracker_input_boxes_feats(file_name):
boxes = []
feats = []
with open(file_name, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip() # 去除行尾的换行符和可能的空白字符
# 跳过空行
if not line:
continue
# 检查是否以'box:'或'feat:'开始
if line.find("box:") >= 0 and line.find("output_box:") < 0:
box = line[line.find("box:") + 4:].strip()
boxes.append(str_to_float_arr(box)) # 去掉'box:'并去除可能的空白字符
if line.find("feat:") >= 0:
feat = line[line.find("feat:") + 5:].strip()
feats.append(str_to_float_arr(feat)) # 去掉'box:'并去除可能的空白字符
return np.array(boxes), np.array(feats)
def find_string_in_array(arr, target):
"""
在字符串数组中找到目标字符串对应的行(索引)。
参数:
arr -- 字符串数组
target -- 要查找的目标字符串
返回:
目标字符串在数组中的索引。如果未找到,则返回-1。
"""
tg = [float(t) for k, t in enumerate(target.split(',')) if k<4][:4]
for i, st in enumerate(arr):
st = [float(s) for k, s in enumerate(target.split(',')) if k<4][:4]
if st == tg:
return i
# if st[:20] == target[:20]:
# return i
return -1
def find_samebox_in_array(arr, target):
for i, st in enumerate(arr):
if all(st[:4] == target[:4]):
return i
return -1
def extract_tracker_output_boxes_feats(read_file_name):
input_boxes, input_feats = extract_tracker_input_boxes_feats(read_file_name)
boxes = []
feats = []
with open(read_file_name, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip() # 去除行尾的换行符和可能的空白字符
# 跳过空行
if not line:
continue
# 检查是否以'output_box:'开始
if line.find("output_box:") >= 0:
box = str_to_float_arr(line[line.find("output_box:") + 11:].strip())
boxes.append(box) # 去掉'output_box:'并去除可能的空白字符
index = find_samebox_in_array(input_boxes, box)
if index >= 0:
# feat_f = str_to_float_arr(input_feats[index])
feat_f = input_feats[index]
norm_f = np.linalg.norm(feat_f)
feat_f = feat_f / norm_f
feats.append(feat_f)
return input_boxes, input_feats, np.array(boxes), np.array(feats)
def extract_tracking_output_boxes_feats(read_file_name):
tracker_boxes, tracker_feats, input_boxes, input_feats = extract_tracker_output_boxes_feats(read_file_name)
boxes = []
feats = []
tracking_flag = False
with open(read_file_name, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip() # 去除行尾的换行符和可能的空白字符
# 跳过空行
if not line:
continue
if tracking_flag:
if line.find("tracking_") >= 0:
tracking_flag = False
else:
box = str_to_float_arr(line)
boxes.append(box)
index = find_samebox_in_array(input_boxes, box)
if index >= 0:
feats.append(input_feats[index])
# 检查是否以tracking_'开始
if line.find("tracking_") >= 0:
tracking_flag = True
assert(len(tracker_boxes)==len(tracker_feats)), "Error at Yolo output"
assert(len(input_boxes)==len(input_feats)), "Error at tracker output"
assert(len(boxes)==len(feats)), "Error at tracking output"
return tracker_boxes, tracker_feats, input_boxes, input_feats, np.array(boxes), np.array(feats)
def read_tracking_input(datapath):
with open(datapath, 'r') as file:
lines = file.readlines()
data = []
for line in lines:
data.append([s for s in line.split(',') if len(s)>=3])
# data.append([float(s) for s in line.split(',') if len(s)>=3])
# data = np.array(data, dtype = np.float32)
try:
data = np.array(data, dtype = np.float32)
except Exception as e:
data = np.array([], dtype = np.float32)
print('DataError for func: read_tracking_input()')
return data
def read_tracker_input(datapath):
with open(datapath, 'r') as file:
lines = file.readlines()
Videos = []
FrameBoxes, FrameFeats = [], []
boxes, feats = [], []
timestamp = []
t1 = None
for line in lines:
if line.find('CameraId') >= 0:
t = int(line.split(',')[1].split(':')[1])
timestamp.append(t)
if len(boxes) and len(feats):
FrameBoxes.append(np.array(boxes, dtype = np.float32))
FrameFeats.append(np.array(feats, dtype = np.float32))
boxes, feats = [], []
if t1 and t - t1 > 1e3:
Videos.append((FrameBoxes, FrameFeats))
FrameBoxes, FrameFeats = [], []
t1 = int(line.split(',')[1].split(':')[1])
if line.find('box') >= 0:
box = line.split(':', )[1].split(',')[:-1]
boxes.append(box)
if line.find('feat') >= 0:
feat = line.split(':', )[1].split(',')[:-1]
feats.append(feat)
FrameBoxes.append(np.array(boxes, dtype = np.float32))
FrameFeats.append(np.array(feats, dtype = np.float32))
Videos.append((FrameBoxes, FrameFeats))
# TimeStamp = np.array(timestamp, dtype = np.int64)
# DimesDiff = np.diff((TimeStamp))
# sorted_indices = np.argsort(TimeStamp)
# TimeStamp_sorted = TimeStamp[sorted_indices]
# DimesDiff_sorted = np.diff((TimeStamp_sorted))
return Videos
def main():
files_path = 'D:/contrast/dataset/1_to_n/709/20240709-112658_6903148351833/'
# 遍历目录下的所有文件和目录
for filename in os.listdir(files_path):
# 构造完整的文件路径
file_path = os.path.join(files_path, filename)
if os.path.isfile(file_path) and filename.find("track.data")>0:
tracker_boxes, tracker_feats, tracking_boxes, tracking_feats, output_boxes, output_feats = extract_tracking_output_boxes_feats(file_path)
print("Done")
if __name__ == "__main__":
main()