Files
detecttracking/tracking/utils/drawtracks.py
2024-07-18 17:52:12 +08:00

379 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Mon Jan 15 15:26:38 2024
@author: ym
"""
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from utils.annotator import TrackAnnotator
from utils.plotting import colors
from pathlib import Path
def plot_frameID_y2(vts):
# boxes: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
# 0, 1, 2, 3, 4, 5, 6, 7, 8
markers = ['o', 'v', '^', '<', '>', 's', 'p', 'P','*', '+', 'x', 'X', 'd', 'D', 'H']
colors = ['b', 'g', 'c', 'm', 'y', ]
bboxes = vts.bboxes
maxfid = max(vts.bboxes[:, 7])
CART_HIGH_THRESH1 = 430
TRACK_STATIC_THRESH = 8
fig = plt.figure(figsize=(16, 12))
gs = fig.add_gridspec(2, 1, left=0.1, right=0.9, bottom=0.1, top=0.9,
wspace=0.05, hspace=0.15)
# ax1, ax2 = axs
ax1 = fig.add_subplot(gs[0,0])
ax2 = fig.add_subplot(gs[1,0])
ax1.plot((0, maxfid+5), (1280-CART_HIGH_THRESH1, 1280-CART_HIGH_THRESH1), 'b--', linewidth=2 )
ax2.plot((0, maxfid+5), (1280-CART_HIGH_THRESH1, 1280-CART_HIGH_THRESH1), 'b--', linewidth=2 )
hands = [t for t in vts.Hands if not t.isHandStatic]
tracks = vts.join_tracks(vts.Residual, hands)
for i, track in enumerate(vts.tracks):
boxes = track.boxes
cls, tid = track.cls, track.tid
y2, fids = boxes[:, 3], boxes[:, 7]
if cls==0:
ax1.scatter(fids, 1280-y2, marker='4', s=50, color=colors[tid%len(colors)], label = f"ID_{tid}")
else:
ax1.scatter(fids, 1280-y2, marker=markers[tid%len(markers)], color=colors[tid%len(colors)],
s=50, label = f"ID_{tid}")
# hist, bins = np.histogram(1280-y2, bins='auto')
ax1.set_ylim([-50, 1350])
for i, track in enumerate(tracks):
boxes = track.boxes
cls, tid = track.cls, track.tid
y2, fids = boxes[:, 3], boxes[:, 7]
if cls==0:
ax2.scatter(fids, 1280-y2, marker='4', s=50, color=colors[tid%len(colors)], label = f"ID_{tid}")
else:
ax2.scatter(fids, 1280-y2, marker=markers[tid%len(markers)], color=colors[tid%len(colors)],
s=50, label = f"ID_{tid}")
# hist, bins = np.histogram(1280-y2, bins='auto')
ax2.set_ylim([-50, 1350])
ax1.grid(True), ax1.set_xlim(0, maxfid+5), ax1.set_title('y2')
ax1.legend()
ax2.grid(True), ax2.set_xlim(0, maxfid+5), ax2.set_title('y2')
ax2.legend()
# plt.show()
return plt
def draw_all_trajectories(vts, edgeline, save_dir, file, draw5p=False):
'''显示四种类型结果'''
# file, ext = os.path.splitext(filename)
# edgeline = cv2.imread("./shopcart/cart_tempt/edgeline.png")
# edgeline2 = edgeline1.copy()
# edgeline = np.concatenate((edgeline1, edgeline2), exis = 1)
if not isinstance(save_dir, Path): save_dir = Path(save_dir)
''' all tracks 中心轨迹'''
img1, img2 = edgeline.copy(), edgeline.copy()
img1 = drawTrack(vts.tracks, img1)
img2 = drawTrack(vts.Residual, img2)
img = np.concatenate((img1, img2), axis = 1)
H, W = img.shape[:2]
cv2.line(img, (int(W/2), 0), (int(W/2), H), (128, 255, 128), 2)
imgpth = save_dir.joinpath(f"{file}_show.png")
cv2.imwrite(str(imgpth), img)
if not draw5p:
return
''' tracks 5点轨迹'''
trackpth = save_dir / Path("trajectory") / Path(f"{file}")
if not trackpth.exists():
trackpth.mkdir(parents=True, exist_ok=True)
for track in vts.tracks:
# if track.cls != 0:
img = edgeline.copy()
img = draw5points(track, img)
pth = trackpth.joinpath(f"{track.tid}.png")
cv2.imwrite(str(pth), img)
for track in vts.merged_tracks:
# if track.cls != 0:
img = edgeline.copy()
img = draw5points(track, img)
pth = trackpth.joinpath(f"{track.tid}_.png")
cv2.imwrite(str(pth), img)
# =============================================================================
# '''3. moving tracks 中心轨迹'''
# filename2 = f"{file}_show_r.png"
# img = edgeline.copy()
# img = drawTrack(vts.Residual, img)
# pth = save_dir.joinpath(filename2)
# cv2.imwrite(pth, img)
# =============================================================================
'''5. tracks 时序trajmin、trajmax、arearate、incartrate'''
# plt = drawtracefeat(vts)
# pth = save_dir.joinpath(f"{file}_x.png")
# plt.savefig(pth)
# plt.close('all')
def drawFeatures(allvts, save_dir):
# [trajlen_min, trajdist_max, trajlen_rate, trajist_rate]]
feats = [track.TrajFeat for vts in allvts for track in vts.tracks]
feats = np.array(feats)
fig, ax = plt.subplots()
ax.scatter(feats[:,3], feats[:, 1], s=10)
# ax.set_xlim(0, 2)
# ax.set_ylim(0, 100)
ax.grid(True)
plt.show()
pth = save_dir.joinpath("scatter.png")
plt.savefig(pth)
plt.close('all')
def drawtracefeat(vts):
'''
需要对曲线进行特征提取和分类
boxes: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
0 1 2 3 4 5 6 7 8
'''
# matplotlib.use('Agg')
fid = vts.frameid
fid1, fid2 = min(fid), max(fid)
fig, axs = plt.subplots(2, 2,figsize=(18, 8))
kernel = [0.15, 0.7, 0.15]
for i, track in enumerate(vts.tracks):
boxes = track.boxes
tid = int(track.tid)
cls = int(track.cls)
posState = track.posState
if track.frnum>=5:
x1 = boxes[1:, 7]
y1 = track.trajmin
x11 = [i for i in range(int(min(x1)), int(max(x1)+1))]
y11 = np.interp(x11, x1, y1)
y11[1:-1] = np.convolve(y11, kernel, 'valid')
x3 = boxes[1:, 7]
y3 = track.trajmax
x33 = [i for i in range(int(min(x3)), int(max(x3)+1))]
y33 = np.interp(x33, x3, y3)
y33[1:-1] = np.convolve(y33, kernel, 'valid')
x2 = boxes[:, 7]
# y2 = track.Area/max(track.Area) - min(track.Area/max(track.Area))
y2 = track.Area/max(track.Area)
x22 = [i for i in range(int(min(x2)), int(max(x2)+1))]
y22 = np.interp(x22, x2, y2)
y22[1:-1] = np.convolve(y22, kernel, 'valid')
x4 = boxes[:, 7]
y4 = track.incartrates
x44 = [i for i in range(int(min(x4)), int(max(x4)+1))]
y44 = np.interp(x44, x4, y4)
y44[1:-1] = np.convolve(y44, kernel, 'valid')
elif track.frnum>=2:
x11 = boxes[1:, 7]
y11 = track.trajmin
x33 = boxes[1:, 7]
y33 = track.trajmax
x22 = boxes[:, 7]
# y22 = track.Area/max(track.Area) - min(track.Area/max(track.Area))
y22 = track.Area/max(track.Area)
x44 = boxes[:, 7]
y44 = track.incartrates
else:
continue
# cls!=0, max(y)>20
if cls!=0 and cls!=9 and posState>=2 and max(y11)>10 and max(y33)>10 and max(y22>0.1):
axs[0, 0].plot(x11, y11, label=f"ID_{tid}")
axs[0, 0].legend()
# axs[0].set_ylim(0, 100)
axs[0, 1].plot(x22, y22, label=f"ID_{tid}")
axs[0, 1].legend()
axs[1, 0].plot(x33, y33, label=f"ID_{tid}")
axs[1, 0].legend()
axs[1, 1].plot(x44, y44, label=f"ID_{tid}")
axs[1, 1].legend()
axs[0, 0].grid(True), axs[0, 0].set_xlim(fid1, fid2+10), axs[0, 0].set_title('trajmin')
axs[0, 1].grid(True), axs[0, 1].set_xlim(fid1, fid2+10), axs[0, 1].set_title('arearate')
axs[1, 0].grid(True), axs[1, 0].set_xlim(fid1, fid2+10), axs[1, 0].set_title('trajmax')
axs[1, 1].grid(True), axs[1, 1].set_xlim(fid1, fid2+10), axs[1, 1].set_ylim(-0.1, 1.1)
axs[1, 1].set_title('incartrate')
# pth = save_dir.joinpath(f"{file}_show_x.png")
# plt.savefig(pth)
# plt.savefig(f"./result/cls11_80212_time/{file}_show_x.png")
# plt.show()
return plt
def draw5points(track, img):
"""
显示中心点、4角点的轨迹以及轨迹 features
"""
colorx = np.array([[255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255],
[0, 0, 255], [0, 255, 0], [255, 51, 255], [102, 178, 255], [51, 153, 255],[255, 153, 153],
[255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102], [51, 255, 51],
[255, 102, 255], [153, 204, 255], [255, 0, 0], [255, 255, 255]], dtype=np.uint8)
color = ((0, 0, 255), (255, 128, 0))
# img = cv2.imread("./shopcart/cart_tempt/edgeline.png")
boxes = track.boxes
cornpoints = track.cornpoints
trajlens = [int(t) for t in track.trajlens]
trajdist = [int(t) for t in track.trajdist]
if len(track.trajmin):
trajstd = np.std(track.trajmin)
else:
trajstd = 0
trajlen_min, trajlen_max, trajdist_min, trajdist_max, trajlen_rate, trajdist_rate = track.TrajFeat
for i in range(boxes.shape[0]):
cv2.circle(img, (int(cornpoints[i, 0]), int(cornpoints[i, 1])), 6, (255, 255, 255), 2)
cv2.circle(img, (int(cornpoints[i, 2]), int(cornpoints[i, 3])), 6, (255, 0, 255), 2)
cv2.circle(img, (int(cornpoints[i, 4]), int(cornpoints[i, 5])), 6, (0, 255, 0), 2)
cv2.circle(img, (int(cornpoints[i, 6]), int(cornpoints[i, 7])), 6, (64, 128, 255), 2)
cv2.circle(img, (int(cornpoints[i, 8]), int(cornpoints[i, 9])), 6, (255, 128, 64), 2)
label_0 = f"ID: {track.tid}, Class: {track.cls}"
label_1 = f"trajlens: {trajlens}, trajlen_min: {int(trajlen_min)}"
label_2 = f"trajdist: {trajdist}: trajdist_max: {int(trajdist_max)}"
label_3 = "trajlen_min/trajlen_max: {:.2f}/{:.2f} = {:.2f}".format(trajlen_min, trajlen_max, trajlen_rate)
label_4 = "trajdist_min/mwh : {:.2f}/{:.2f} = {:.2f}".format(trajdist_min, track.mwh, trajdist_rate)
label_5 = "std(trajmin) : {:.2f}".format(trajstd)
label_6 = "PCA(variance_ratio) : "
label_7 = "Rect W&H&Ratio : "
label_8 = ""
# label_8 = "IOU of incart/maxbox/minbox: {:.2f}, {:.2f}, {:.2f}".format(
# track.feature_ious[0], track.feature_ious[3], track.feature_ious[4])
'''=============== 最小轨迹长度索引 ===================='''
if track.imgBorder:
idx = 0
else:
idx = trajlens.index(min(trajlens))
'''=============== PCA ===================='''
if trajlens[idx] > 12:
X = cornpoints[:, 2*idx:2*(idx+1)]
pca = PCA()
pca.fit(X)
label_6 = "PCA(variance_ratio): {:.2f}".format(pca.explained_variance_ratio_[0])
# if sum(np.isnan(pca.explained_variance_ratio_)) == 0:
for i, (comp, var) in enumerate(zip(pca.components_, pca.explained_variance_ratio_)):
pt1 = (pca.mean_ - comp*var*200).astype(np.int64)
pt2 = (pca.mean_ + comp*var*200).astype(np.int64)
cv2.line(img, pt1, pt2, color=color[i], thickness=2)
'''=============== RECT ===================='''
rect = track.trajrects[idx]
box = cv2.boxPoints(rect)
box = np.int0(box)
cv2.drawContours(img, [box], 0, (0, 255, 0), 2)
label_7 = "Rect W&H&Ratio: {}, {}, {:.2f}".format(int(rect[1][0]), int(rect[1][1]), min(rect[1])/(max(rect[1])+0.001))
'''=============== 显示文字 ===================='''
# label = [label_0, label_1, label_2, label_3, label_4, label_5, label_6, label_7, label_8]
# w, h = cv2.getTextSize('abc', 0, fontScale=2, thickness=1)[0]
# for i in range(len(label)):
# cv2.putText(img, label[i], (20, int((i+1)*1.1*h)), 0, 1,
# [int(x) for x in colorx[i]], 2, lineType=cv2.LINE_AA)
# pth = save_dir.joinpath(f"{file}_{track.tid}.png")
# cv2.imwrite(pth, img)
'''撰写专利需要,生成黑白图像'''
# imgbt = cv2.bitwise_not(img)
# for i in range(box.shape[0]):
# cv2.circle(imgbt, (int(cornpoints[i, 0]), int(cornpoints[i, 1])), 14, (0, 0, 0), 2)
# cv2.drawMarker(imgbt, (int(cornpoints[i, 2]), int(cornpoints[i, 3])), color= (0, 0, 0), markerType=3, markerSize = 30, thickness=2)
# cv2.drawMarker(imgbt, (int(cornpoints[i, 4]), int(cornpoints[i, 5])), color= (0, 0, 0), markerType=4, markerSize = 30, thickness=2)
# cv2.drawMarker(imgbt, (int(cornpoints[i, 6]), int(cornpoints[i, 7])), color= (0, 0, 0), markerType=5, markerSize = 30, thickness=2)
# cv2.drawMarker(imgbt, (int(cornpoints[i, 8]), int(cornpoints[i, 9])), color= (0, 0, 0), markerType=6, markerSize = 30, thickness=2)
# cv2.imwrite(pth + f"/zhuanli/{file}_{track.tid}.png", imgbt)
return img
def drawTrack(tracks, img):
# img = cv2.imread("./shopcart/cart_tempt/edgeline.png")
annotator = TrackAnnotator(img, line_width=2)
for track in tracks:
annotator.plotting_track(track.boxes)
img = annotator.result()
# pth = save_dir.joinpath(f"{filename}")
# cv2.imwrite(pth, img)
return img
if __name__ == "__main__":
y = np.array([5.0, 20, 40, 41, 42, 55, 56])