更新 detacttracking
This commit is contained in:
250
detecttracking/tracking/utils/read_pipeline_data.py
Normal file
250
detecttracking/tracking/utils/read_pipeline_data.py
Normal file
@ -0,0 +1,250 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue May 21 15:25:23 2024
|
||||
读取 Pipeline 各模块的数据,主代码由 马晓慧 完成
|
||||
|
||||
@author: ieemoo-zl003
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
# 替换为你的目录路径
|
||||
files_path = 'D:/contrast/dataset/1_to_n/709/20240709-112658_6903148351833/'
|
||||
|
||||
def str_to_float_arr(s):
|
||||
# 移除字符串末尾的逗号(如果存在)
|
||||
if s.endswith(','):
|
||||
s = s[:-1]
|
||||
|
||||
# 使用split()方法分割字符串,然后将每个元素转化为float
|
||||
float_array = np.array([float(x) for x in s.split(",")])
|
||||
return float_array
|
||||
|
||||
def extract_tracker_input_boxes_feats(file_name):
|
||||
boxes = []
|
||||
feats = []
|
||||
with open(file_name, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
line = line.strip() # 去除行尾的换行符和可能的空白字符
|
||||
|
||||
# 跳过空行
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 检查是否以'box:'或'feat:'开始
|
||||
if line.find("box:") >= 0 and line.find("output_box:") < 0:
|
||||
box = line[line.find("box:") + 4:].strip()
|
||||
boxes.append(str_to_float_arr(box)) # 去掉'box:'并去除可能的空白字符
|
||||
|
||||
if line.find("feat:") >= 0:
|
||||
feat = line[line.find("feat:") + 5:].strip()
|
||||
feats.append(str_to_float_arr(feat)) # 去掉'box:'并去除可能的空白字符
|
||||
|
||||
return np.array(boxes), np.array(feats)
|
||||
|
||||
def find_string_in_array(arr, target):
|
||||
"""
|
||||
在字符串数组中找到目标字符串对应的行(索引)。
|
||||
|
||||
参数:
|
||||
arr -- 字符串数组
|
||||
target -- 要查找的目标字符串
|
||||
|
||||
返回:
|
||||
目标字符串在数组中的索引。如果未找到,则返回-1。
|
||||
"""
|
||||
tg = [float(t) for k, t in enumerate(target.split(',')) if k<4][:4]
|
||||
for i, st in enumerate(arr):
|
||||
st = [float(s) for k, s in enumerate(target.split(',')) if k<4][:4]
|
||||
|
||||
if st == tg:
|
||||
return i
|
||||
|
||||
# if st[:20] == target[:20]:
|
||||
# return i
|
||||
return -1
|
||||
|
||||
def find_samebox_in_array(arr, target):
|
||||
|
||||
for i, st in enumerate(arr):
|
||||
if all(st[:4] == target[:4]):
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def extract_tracker_output_boxes_feats(read_file_name):
|
||||
|
||||
input_boxes, input_feats = extract_tracker_input_boxes_feats(read_file_name)
|
||||
|
||||
boxes = []
|
||||
feats = []
|
||||
with open(read_file_name, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
line = line.strip() # 去除行尾的换行符和可能的空白字符
|
||||
|
||||
# 跳过空行
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 检查是否以'output_box:'开始
|
||||
if line.find("output_box:") >= 0:
|
||||
box = str_to_float_arr(line[line.find("output_box:") + 11:].strip())
|
||||
boxes.append(box) # 去掉'output_box:'并去除可能的空白字符
|
||||
index = find_samebox_in_array(input_boxes, box)
|
||||
if index >= 0:
|
||||
# feat_f = str_to_float_arr(input_feats[index])
|
||||
feat_f = input_feats[index]
|
||||
norm_f = np.linalg.norm(feat_f)
|
||||
feat_f = feat_f / norm_f
|
||||
feats.append(feat_f)
|
||||
return input_boxes, input_feats, np.array(boxes), np.array(feats)
|
||||
|
||||
def extract_tracking_output_boxes_feats(read_file_name):
|
||||
tracker_boxes, tracker_feats, input_boxes, input_feats = extract_tracker_output_boxes_feats(read_file_name)
|
||||
boxes = []
|
||||
feats = []
|
||||
|
||||
tracking_flag = False
|
||||
with open(read_file_name, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
line = line.strip() # 去除行尾的换行符和可能的空白字符
|
||||
|
||||
# 跳过空行
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if tracking_flag:
|
||||
if line.find("tracking_") >= 0:
|
||||
tracking_flag = False
|
||||
else:
|
||||
box = str_to_float_arr(line)
|
||||
boxes.append(box)
|
||||
index = find_samebox_in_array(input_boxes, box)
|
||||
if index >= 0:
|
||||
feats.append(input_feats[index])
|
||||
# 检查是否以tracking_'开始
|
||||
if line.find("tracking_") >= 0:
|
||||
tracking_flag = True
|
||||
|
||||
assert(len(tracker_boxes)==len(tracker_feats)), "Error at Yolo output"
|
||||
assert(len(input_boxes)==len(input_feats)), "Error at tracker output"
|
||||
assert(len(boxes)==len(feats)), "Error at tracking output"
|
||||
|
||||
return tracker_boxes, tracker_feats, input_boxes, input_feats, np.array(boxes), np.array(feats)
|
||||
|
||||
def read_tracking_input(datapath):
|
||||
with open(datapath, 'r') as file:
|
||||
lines = file.readlines()
|
||||
|
||||
data = []
|
||||
for line in lines:
|
||||
data.append([s for s in line.split(',') if len(s)>=3])
|
||||
# data.append([float(s) for s in line.split(',') if len(s)>=3])
|
||||
|
||||
# data = np.array(data, dtype = np.float32)
|
||||
try:
|
||||
data = np.array(data, dtype = np.float32)
|
||||
except Exception as e:
|
||||
data = np.array([], dtype = np.float32)
|
||||
print('DataError for func: read_tracking_input()')
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
||||
def read_tracker_input(datapath):
|
||||
with open(datapath, 'r') as file:
|
||||
lines = file.readlines()
|
||||
Videos = []
|
||||
FrameBoxes, FrameFeats = [], []
|
||||
boxes, feats = [], []
|
||||
|
||||
timestamp = []
|
||||
t1 = None
|
||||
for line in lines:
|
||||
if line.find('CameraId') >= 0:
|
||||
t = int(line.split(',')[1].split(':')[1])
|
||||
timestamp.append(t)
|
||||
|
||||
if len(boxes) and len(feats):
|
||||
FrameBoxes.append(np.array(boxes, dtype = np.float32))
|
||||
FrameFeats.append(np.array(feats, dtype = np.float32))
|
||||
boxes, feats = [], []
|
||||
|
||||
if t1 and t - t1 > 1e3:
|
||||
Videos.append((FrameBoxes, FrameFeats))
|
||||
FrameBoxes, FrameFeats = [], []
|
||||
t1 = int(line.split(',')[1].split(':')[1])
|
||||
|
||||
if line.find('box') >= 0:
|
||||
box = line.split(':', )[1].split(',')[:-1]
|
||||
boxes.append(box)
|
||||
|
||||
|
||||
if line.find('feat') >= 0:
|
||||
feat = line.split(':', )[1].split(',')[:-1]
|
||||
feats.append(feat)
|
||||
|
||||
FrameBoxes.append(np.array(boxes, dtype = np.float32))
|
||||
FrameFeats.append(np.array(feats, dtype = np.float32))
|
||||
Videos.append((FrameBoxes, FrameFeats))
|
||||
|
||||
# TimeStamp = np.array(timestamp, dtype = np.int64)
|
||||
# DimesDiff = np.diff((TimeStamp))
|
||||
# sorted_indices = np.argsort(TimeStamp)
|
||||
# TimeStamp_sorted = TimeStamp[sorted_indices]
|
||||
# DimesDiff_sorted = np.diff((TimeStamp_sorted))
|
||||
|
||||
return Videos
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
files_path = 'D:/contrast/dataset/1_to_n/709/20240709-112658_6903148351833/'
|
||||
|
||||
# 遍历目录下的所有文件和目录
|
||||
for filename in os.listdir(files_path):
|
||||
# 构造完整的文件路径
|
||||
file_path = os.path.join(files_path, filename)
|
||||
if os.path.isfile(file_path) and filename.find("track.data")>0:
|
||||
tracker_boxes, tracker_feats, tracking_boxes, tracking_feats, output_boxes, output_feats = extract_tracking_output_boxes_feats(file_path)
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user