199 lines
5.0 KiB
Python
199 lines
5.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on Fri Feb 23 11:04:48 2024
|
|
|
|
@author: ym
|
|
"""
|
|
import numpy as np
|
|
import cv2
|
|
from scipy.spatial.distance import cdist
|
|
# from trackers.utils import matching
|
|
|
|
# TracksDict
|
|
def readDict(boxes, TracksDict):
|
|
feats = []
|
|
for i in range(boxes.shape[0]):
|
|
tid, fid, bid = int(boxes[i, 4]), int(boxes[i, 7]), int(boxes[i, 8])
|
|
|
|
feat = TracksDict[f"frame_{fid}"]["feats"][bid]
|
|
img = TracksDict[f"frame_{fid}"]["imgs"][bid]
|
|
|
|
box = TracksDict[f"frame_{fid}"]["boxes"][bid]
|
|
|
|
assert (box[:4].astype(int) == boxes[i, :4].astype(int)).all(), f"Please check: frame_{fid}"
|
|
|
|
feats.append(feat)
|
|
|
|
# img = TracksDict[fid][f'{bid}_img']
|
|
# cv2.imwrite(f'./data/imgs/{tid}_{fid}_{bid}.png', img)
|
|
|
|
return np.asarray(feats, dtype=np.float32)
|
|
|
|
|
|
|
|
def track_equal_track(atrack, btrack):
|
|
# boxes: [x, y, w, h, track_id, score, cls, frame_index, box_index]
|
|
# 0 1 2 3 4 5 6 7 8
|
|
aboxes = atrack.boxes
|
|
bboxes = btrack.boxes
|
|
|
|
afeat = atrack.features
|
|
bfeat = btrack.features
|
|
|
|
# afeat = readDict(aboxes, TracksDict)
|
|
# bfeat = readDict(bboxes, TracksDict)
|
|
|
|
|
|
''' 1. 判断轨迹在时序上是否有交集 '''
|
|
afids = aboxes[:, 7].astype(np.int_)
|
|
bfids = bboxes[:, 7].astype(np.int_)
|
|
# 帧索引交集
|
|
interfid = set(afids).intersection(set(bfids))
|
|
|
|
# 或者直接判断帧索引是否有交集,返回 Ture or False
|
|
# interfid = set(afids).isdisjoint(set(bfids))
|
|
|
|
if len(interfid):
|
|
return False
|
|
|
|
''' 2. 轨迹特征相似度判断'''
|
|
feat = np.concatenate((afeat, bfeat), axis=0)
|
|
|
|
emb_simil = 1-np.maximum(0.0, cdist(feat, feat, 'cosine'))
|
|
emb_ = 1-cdist(np.mean(afeat, axis=0)[None, :], np.mean(bfeat, axis=0)[None, :], 'cosine')
|
|
|
|
if emb_[0, 0]<0.66:
|
|
return False
|
|
|
|
|
|
''' 3. 轨迹空间iou'''
|
|
alabel = np.array([0] * afids.size, dtype=np.int_)
|
|
blabel = np.array([1] * bfids.size, dtype=np.int_)
|
|
|
|
label = np.concatenate((alabel, blabel), axis=0)
|
|
fids = np.concatenate((afids, bfids), axis=0)
|
|
indices = np.argsort(fids)
|
|
idx_pair = []
|
|
for i in range(len(indices)-1):
|
|
idx1, idx2 = indices[i], indices[i+1]
|
|
if label[idx1] != label[idx2] and fids[idx2] - fids[idx1] <= 3:
|
|
if label[idx1] == 0:
|
|
a_idx = idx1
|
|
b_idx = idx2-alabel.size
|
|
else:
|
|
a_idx = idx2
|
|
b_idx = idx1-alabel.size
|
|
|
|
idx_pair.append((a_idx, b_idx))
|
|
|
|
ious = []
|
|
embs = []
|
|
for a, b in idx_pair:
|
|
abox, bbox = aboxes[a, :], bboxes[b, :]
|
|
|
|
af, bf = afeat[a, :], bfeat[b, :]
|
|
|
|
emb_ab = 1-cdist(af[None, :], bf[None, :], 'cosine')
|
|
|
|
|
|
xa1, ya1 = abox[0] - abox[2]/2, abox[1] - abox[3]/2
|
|
xa2, ya2 = abox[0] + abox[2]/2, abox[1] + abox[3]/2
|
|
|
|
xb1, yb1 = bbox[0] - bbox[2]/2, bbox[1] - bbox[3]/2
|
|
xb2, yb2 = bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2
|
|
|
|
|
|
inter = (np.minimum(xb2, xa2) - np.maximum(xb1, xa1)).clip(0) * \
|
|
(np.minimum(yb2, ya2) - np.maximum(yb1, ya1)).clip(0)
|
|
|
|
# Union Area
|
|
box1_area = abox[2] * abox[3]
|
|
box2_area = bbox[2] * bbox[3]
|
|
union = box1_area + box2_area - inter + 1e-6
|
|
|
|
ious.append(inter/union)
|
|
embs.append(emb_ab[0, 0])
|
|
|
|
|
|
cont = False if len(interfid) else True # fid 无交集
|
|
cont1 = all(emb > 0.5 for emb in embs)
|
|
cont2 = all(iou > 0.5 for iou in ious)
|
|
# cont = cont and cont2 and cont3
|
|
|
|
cont = cont and cont1 and cont2
|
|
|
|
|
|
return cont
|
|
|
|
|
|
|
|
def track_equal_str(atrack, btrack):
|
|
if atrack == btrack:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def merge_track(Residual):
|
|
out_list = []
|
|
alist = [t for t in Residual]
|
|
while alist:
|
|
atrack = alist[0]
|
|
cur_list = []
|
|
cur_list.append(atrack)
|
|
alist.pop(0)
|
|
|
|
blist = [b for b in alist]
|
|
alist = []
|
|
for btrack in blist:
|
|
if track_equal_str(atrack, btrack):
|
|
cur_list.append(btrack)
|
|
else:
|
|
alist.append(btrack)
|
|
|
|
out_list.append(cur_list)
|
|
return out_list
|
|
|
|
def main():
|
|
Residual = ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'b', 'c', 'd']
|
|
out_list = merge_track(Residual)
|
|
|
|
print(Residual)
|
|
print(out_list)
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
# for i, atrack in enumerate(input_list):
|
|
# cur_list = []
|
|
# cur_list.append(atrack)
|
|
# del input_list[i]
|
|
#
|
|
# for j, btrack in enumerate(input_list):
|
|
# if track_equal(atrack, btrack):
|
|
# cur_list.append(btrack)
|
|
# del input_list[j]
|
|
#
|
|
# out_list.append(cur_list)
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|