Files
2024-11-27 15:37:10 +08:00

255 lines
7.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Wed Sep 20 17:33:00 2023
@author: ym
"""
import sys
import cv2
import os
import numpy as np
import time
import pickle
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
from pathlib import Path
from datetime import datetime
# ================= using for import ultralytics
# sys.path.append(r"D:\DeepLearning\yolov5_track")
# from utils.proBoxes import boxes_add_fid
from ytracking.tracking.utils.plotting import boxing_img # , Annotator, colors,
from ytracking.tracking.utils.gen import Profile
from ytracking.tracking.utils.drawtracks import draw5points, drawTrack, drawtracefeat, plot_frameID_y2, drawFeatures, \
draw_all_trajectories
from ytracking.tracking.utils import Boxes, IterableSimpleNamespace, yaml_load
from ytracking.tracking.trackers import BOTSORT, BYTETracker
sys.path.append("ytracking/tracking/")
from dotrack.dotracks_back import doBackTracks
from dotrack.dotracks_front import doFrontTracks
# from utils.mergetrack import track_equal_track
# from utils.basetrack import MoveState, ShoppingCart, doTracks
def init_tracker(tracker_yaml=None, bs=1):
"""
Initialize tracker for object tracking during prediction.
"""
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
cfg = IterableSimpleNamespace(**yaml_load(tracker_yaml))
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
return tracker
def have_tracked_front():
'''前摄轨迹处理。已执行跟踪处理,只对跟踪结果进行分析'''
featdir = r"./data/trackfeats"
npydir = r"./data/tracks"
k = 0
gt = Profile()
for filename in os.listdir(npydir):
# filename = "084501222314_20240108-143651_front.npy"
if not filename.find("front") >= 0: continue
file, ext = os.path.splitext(filename)
fpath = os.path.join(npydir, filename)
featpath = os.path.join(featdir, file + '.pkl')
bboxes = np.load(fpath)
features_dict = np.load(featpath, allow_pickle=True)
with gt:
vts = doFrontTracks(bboxes, features_dict)
vts.classify()
plt = plot_frameID_y2(vts)
plt.savefig(f'./result/{file}_y2.png')
plt.close()
print(file + f" need time: {gt.dt:.2f}s")
# edgeline = cv2.imread("./shopcart/cart_tempt/board_ftmp_line.png")
# draw_all_trajectories(vts, edgeline, save_dir, filename)
# k += 1
# if k == 1:
# break
def have_tracked_back():
'''后摄轨迹处理。已执行跟踪处理,只对跟踪结果进行分析'''
featdir = r"./data/trackfeats"
npydir = r"./data/tracks"
k = 0
alltracks = []
gt = Profile()
for filename in os.listdir(npydir):
# filename = "084501222314_20240108-143656_back.npy" # "加购_55.npy"
if not filename.find("back") >= 0: continue
t1 = time.time()
file, ext = os.path.splitext(filename)
fpath = os.path.join(npydir, filename)
featpath = os.path.join(featdir, file + '.pkl')
# try:
bboxes = np.load(fpath)
features_dict = np.load(featpath, allow_pickle=True)
with gt:
vts = doBackTracks(bboxes, features_dict)
vts.classify()
# vts.merge_tracks()
print(file + f" need time: {gt.dt:.2f}s")
edgeline = cv2.imread("./shopcart/cart_tempt/edgeline.png")
draw_all_trajectories(vts, edgeline, save_dir, filename)
alltracks.append(vts)
# except Exception as e:
# # print(str(e))
# pass
# print(file+" need time: {:.2f}s".format(time.time()-t1))
k += 1
if k == 1:
break
if len(alltracks):
drawFeatures(alltracks, save_dir)
def tracking(vboxes):
tracker_yaml = r"./trackers/cfg/botsort.yaml"
tracker = init_tracker(tracker_yaml)
tboxes = []
images = []
track_boxes = np.empty((0, 9), dtype=np.float32)
features_dict = {}
'''==================== 执行跟踪处理 ======================='''
for det, img, frame in vboxes:
# 需要根据frame_id重排序
det_tracking = Boxes(det).cpu().numpy()
H, W = img.shape[:2]
imgs = []
for d in range(np.size(det, 0)):
tlbr = det[d, :4].astype(np.int_)
tlbr[0] = max(0, tlbr[0])
tlbr[1] = max(0, tlbr[1])
tlbr[2] = min(W - 1, tlbr[2])
tlbr[3] = min(H - 1, tlbr[3])
patch = img[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :]
patch = patch[:, :, ::-1] # 原程序用PIL.Image读取是RGBOpencv是BGR
imgs.append(patch)
tracks = tracker.update(det_tracking, imgs)
if len(tracks):
track_boxes = np.concatenate([track_boxes, tracks], axis=0)
feat_dict = {int(x.idx): x.curr_feat for x in tracker.tracked_stracks if x.is_activated}
frame_id = tracks[0, 7]
features_dict.update({int(frame_id): feat_dict})
# det = tracks[:, :-1]
# tboxes.append((det, frame))
imgx = boxing_img(tracks, img)
images.append((imgx, frame))
# bboxes = boxes_add_fid(tboxes)
vts = doBackTracks(track_boxes, features_dict)
vts.classify()
return vts, images
def do_tracking():
pkldir = r"./data/boxes_imgs"
k = 0
save_result = True
alltracks = []
gt = Profile()
for filename in os.listdir(pkldir):
filename = "加购_18.pkl"
file, _ = os.path.splitext(filename)
vboxes = []
##================================ load the detection data
with open(pkldir + f'/{filename}', 'rb') as f:
vboxes = pickle.load(f)
assert len(vboxes) > 0
with gt:
vts, images = tracking(vboxes)
alltracks.append(vts)
print(file + f" need time: {gt.dt * 1E3:.1f}ms")
##================================ save images, video, track-trajectory
if save_result == True:
curdir = imgdir.joinpath(file)
if not curdir.exists():
curdir.mkdir(parents=True, exist_ok=True)
vidpath = str(curdir.joinpath(file).with_suffix('.mp4'))
fps, w, h = 30, images[0][0].shape[1], images[0][0].shape[0]
vidwriter = cv2.VideoWriter(vidpath, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for img, frame in images:
imgpath = curdir.joinpath(file + f"_{frame}.png")
cv2.imwrite(imgpath, img)
vidwriter.write(img)
vidwriter.release()
edgeline = cv2.imread("./shopcart/cart_tempt/edgeline.png")
draw_all_trajectories(vts, edgeline, save_dir, filename)
k += 1
if k == 1:
break
drawFeatures(alltracks, save_dir)
def have_tracked(bboxes, features_dict, camera_id):
if camera_id == '0':
vts = doBackTracks(bboxes, features_dict)
vts.classify()
elif camera_id == '1':
vts = doFrontTracks(bboxes, features_dict)
vts.classify()
else:
raise ValueError("have no camera_id")
return vts
if __name__ == "__main__":
now = datetime.now()
time_string = now.strftime("%Y%m%d%H%M%S")[:8]
# save_dir = Path(f'./result/{time_string}_traj/')
# if not save_dir.exists():
# save_dir.mkdir(parents=True, exist_ok=True)
save_dir = Path(f'./result/')
mode = "merge" ## "merge": 已完成跟踪处理, "other": 未执行跟踪处理
if mode == "merge":
# have_tracked_back()
have_tracked_front()
else:
'''执行do_tracking()函数时视频和图像存储位置'''
imgdir = Path(f'./result/{time_string}_imgs/')
if not imgdir.exists():
imgdir.mkdir(parents=True, exist_ok=True)
do_tracking()