回传数据解析,兼容v5和v10
This commit is contained in:
321
ultralytics/trackers/README.md
Normal file
321
ultralytics/trackers/README.md
Normal file
@ -0,0 +1,321 @@
|
||||
# Multi-Object Tracking with Ultralytics YOLO
|
||||
|
||||
<img width="1024" src="https://user-images.githubusercontent.com/26833433/243418637-1d6250fd-1515-4c10-a844-a32818ae6d46.png" alt="YOLOv8 trackers visualization">
|
||||
|
||||
Object tracking in the realm of video analytics is a critical task that not only identifies the location and class of objects within the frame but also maintains a unique ID for each detected object as the video progresses. The applications are limitless—ranging from surveillance and security to real-time sports analytics.
|
||||
|
||||
## Why Choose Ultralytics YOLO for Object Tracking?
|
||||
|
||||
The output from Ultralytics trackers is consistent with standard object detection but has the added value of object IDs. This makes it easy to track objects in video streams and perform subsequent analytics. Here's why you should consider using Ultralytics YOLO for your object tracking needs:
|
||||
|
||||
- **Efficiency:** Process video streams in real-time without compromising accuracy.
|
||||
- **Flexibility:** Supports multiple tracking algorithms and configurations.
|
||||
- **Ease of Use:** Simple Python API and CLI options for quick integration and deployment.
|
||||
- **Customizability:** Easy to use with custom trained YOLO models, allowing integration into domain-specific applications.
|
||||
|
||||
**Video Tutorial:** [Object Detection and Tracking with Ultralytics YOLOv8](https://www.youtube.com/embed/hHyHmOtmEgs?si=VNZtXmm45Nb9s-N-).
|
||||
|
||||
## Features at a Glance
|
||||
|
||||
Ultralytics YOLO extends its object detection features to provide robust and versatile object tracking:
|
||||
|
||||
- **Real-Time Tracking:** Seamlessly track objects in high-frame-rate videos.
|
||||
- **Multiple Tracker Support:** Choose from a variety of established tracking algorithms.
|
||||
- **Customizable Tracker Configurations:** Tailor the tracking algorithm to meet specific requirements by adjusting various parameters.
|
||||
|
||||
## Available Trackers
|
||||
|
||||
Ultralytics YOLO supports the following tracking algorithms. They can be enabled by passing the relevant YAML configuration file such as `tracker=tracker_type.yaml`:
|
||||
|
||||
- [BoT-SORT](https://github.com/NirAharon/BoT-SORT) - Use `botsort.yaml` to enable this tracker.
|
||||
- [ByteTrack](https://github.com/ifzhang/ByteTrack) - Use `bytetrack.yaml` to enable this tracker.
|
||||
|
||||
The default tracker is BoT-SORT.
|
||||
|
||||
## Tracking
|
||||
|
||||
To run the tracker on video streams, use a trained Detect, Segment or Pose model such as YOLOv8n, YOLOv8n-seg and YOLOv8n-pose.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load an official or custom model
|
||||
model = YOLO("yolov8n.pt") # Load an official Detect model
|
||||
model = YOLO("yolov8n-seg.pt") # Load an official Segment model
|
||||
model = YOLO("yolov8n-pose.pt") # Load an official Pose model
|
||||
model = YOLO("path/to/best.pt") # Load a custom trained model
|
||||
|
||||
# Perform tracking with the model
|
||||
results = model.track(
|
||||
source="https://youtu.be/LNwODJXcvt4", show=True
|
||||
) # Tracking with default tracker
|
||||
results = model.track(
|
||||
source="https://youtu.be/LNwODJXcvt4", show=True, tracker="bytetrack.yaml"
|
||||
) # Tracking with ByteTrack tracker
|
||||
```
|
||||
|
||||
#### CLI
|
||||
|
||||
```bash
|
||||
# Perform tracking with various models using the command line interface
|
||||
yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" # Official Detect model
|
||||
yolo track model=yolov8n-seg.pt source="https://youtu.be/LNwODJXcvt4" # Official Segment model
|
||||
yolo track model=yolov8n-pose.pt source="https://youtu.be/LNwODJXcvt4" # Official Pose model
|
||||
yolo track model=path/to/best.pt source="https://youtu.be/LNwODJXcvt4" # Custom trained model
|
||||
|
||||
# Track using ByteTrack tracker
|
||||
yolo track model=path/to/best.pt tracker="bytetrack.yaml"
|
||||
```
|
||||
|
||||
As can be seen in the above usage, tracking is available for all Detect, Segment and Pose models run on videos or streaming sources.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Tracking Arguments
|
||||
|
||||
Tracking configuration shares properties with Predict mode, such as `conf`, `iou`, and `show`. For further configurations, refer to the [Predict](https://docs.ultralytics.com/modes/predict/) model page.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Configure the tracking parameters and run the tracker
|
||||
model = YOLO("yolov8n.pt")
|
||||
results = model.track(
|
||||
source="https://youtu.be/LNwODJXcvt4", conf=0.3, iou=0.5, show=True
|
||||
)
|
||||
```
|
||||
|
||||
#### CLI
|
||||
|
||||
```bash
|
||||
# Configure tracking parameters and run the tracker using the command line interface
|
||||
yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" conf=0.3, iou=0.5 show
|
||||
```
|
||||
|
||||
### Tracker Selection
|
||||
|
||||
Ultralytics also allows you to use a modified tracker configuration file. To do this, simply make a copy of a tracker config file (for example, `custom_tracker.yaml`) from [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) and modify any configurations (except the `tracker_type`) as per your needs.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load the model and run the tracker with a custom configuration file
|
||||
model = YOLO("yolov8n.pt")
|
||||
results = model.track(
|
||||
source="https://youtu.be/LNwODJXcvt4", tracker="custom_tracker.yaml"
|
||||
)
|
||||
```
|
||||
|
||||
#### CLI
|
||||
|
||||
```bash
|
||||
# Load the model and run the tracker with a custom configuration file using the command line interface
|
||||
yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" tracker='custom_tracker.yaml'
|
||||
```
|
||||
|
||||
For a comprehensive list of tracking arguments, refer to the [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) page.
|
||||
|
||||
## Python Examples
|
||||
|
||||
### Persisting Tracks Loop
|
||||
|
||||
Here is a Python script using OpenCV (`cv2`) and YOLOv8 to run object tracking on video frames. This script still assumes you have already installed the necessary packages (`opencv-python` and `ultralytics`). The `persist=True` argument tells the tracker than the current image or frame is the next in a sequence and to expect tracks from the previous image in the current image.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load the YOLOv8 model
|
||||
model = YOLO("yolov8n.pt")
|
||||
|
||||
# Open the video file
|
||||
video_path = "path/to/video.mp4"
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
# Loop through the video frames
|
||||
while cap.isOpened():
|
||||
# Read a frame from the video
|
||||
success, frame = cap.read()
|
||||
|
||||
if success:
|
||||
# Run YOLOv8 tracking on the frame, persisting tracks between frames
|
||||
results = model.track(frame, persist=True)
|
||||
|
||||
# Visualize the results on the frame
|
||||
annotated_frame = results[0].plot()
|
||||
|
||||
# Display the annotated frame
|
||||
cv2.imshow("YOLOv8 Tracking", annotated_frame)
|
||||
|
||||
# Break the loop if 'q' is pressed
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
else:
|
||||
# Break the loop if the end of the video is reached
|
||||
break
|
||||
|
||||
# Release the video capture object and close the display window
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
```
|
||||
|
||||
Please note the change from `model(frame)` to `model.track(frame)`, which enables object tracking instead of simple detection. This modified script will run the tracker on each frame of the video, visualize the results, and display them in a window. The loop can be exited by pressing 'q'.
|
||||
|
||||
### Plotting Tracks Over Time
|
||||
|
||||
Visualizing object tracks over consecutive frames can provide valuable insights into the movement patterns and behavior of detected objects within a video. With Ultralytics YOLOv8, plotting these tracks is a seamless and efficient process.
|
||||
|
||||
In the following example, we demonstrate how to utilize YOLOv8's tracking capabilities to plot the movement of detected objects across multiple video frames. This script involves opening a video file, reading it frame by frame, and utilizing the YOLO model to identify and track various objects. By retaining the center points of the detected bounding boxes and connecting them, we can draw lines that represent the paths followed by the tracked objects.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load the YOLOv8 model
|
||||
model = YOLO("yolov8n.pt")
|
||||
|
||||
# Open the video file
|
||||
video_path = "path/to/video.mp4"
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
# Store the track history
|
||||
track_history = defaultdict(lambda: [])
|
||||
|
||||
# Loop through the video frames
|
||||
while cap.isOpened():
|
||||
# Read a frame from the video
|
||||
success, frame = cap.read()
|
||||
|
||||
if success:
|
||||
# Run YOLOv8 tracking on the frame, persisting tracks between frames
|
||||
results = model.track(frame, persist=True)
|
||||
|
||||
# Get the boxes and track IDs
|
||||
boxes = results[0].boxes.xywh.cpu()
|
||||
track_ids = results[0].boxes.id.int().cpu().tolist()
|
||||
|
||||
# Visualize the results on the frame
|
||||
annotated_frame = results[0].plot()
|
||||
|
||||
# Plot the tracks
|
||||
for box, track_id in zip(boxes, track_ids):
|
||||
x, y, w, h = box
|
||||
track = track_history[track_id]
|
||||
track.append((float(x), float(y))) # x, y center point
|
||||
if len(track) > 30: # retain 90 tracks for 90 frames
|
||||
track.pop(0)
|
||||
|
||||
# Draw the tracking lines
|
||||
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(
|
||||
annotated_frame,
|
||||
[points],
|
||||
isClosed=False,
|
||||
color=(230, 230, 230),
|
||||
thickness=10,
|
||||
)
|
||||
|
||||
# Display the annotated frame
|
||||
cv2.imshow("YOLOv8 Tracking", annotated_frame)
|
||||
|
||||
# Break the loop if 'q' is pressed
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
else:
|
||||
# Break the loop if the end of the video is reached
|
||||
break
|
||||
|
||||
# Release the video capture object and close the display window
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
```
|
||||
|
||||
### Multithreaded Tracking
|
||||
|
||||
Multithreaded tracking provides the capability to run object tracking on multiple video streams simultaneously. This is particularly useful when handling multiple video inputs, such as from multiple surveillance cameras, where concurrent processing can greatly enhance efficiency and performance.
|
||||
|
||||
In the provided Python script, we make use of Python's `threading` module to run multiple instances of the tracker concurrently. Each thread is responsible for running the tracker on one video file, and all the threads run simultaneously in the background.
|
||||
|
||||
To ensure that each thread receives the correct parameters (the video file and the model to use), we define a function `run_tracker_in_thread` that accepts these parameters and contains the main tracking loop. This function reads the video frame by frame, runs the tracker, and displays the results.
|
||||
|
||||
Two different models are used in this example: `yolov8n.pt` and `yolov8n-seg.pt`, each tracking objects in a different video file. The video files are specified in `video_file1` and `video_file2`.
|
||||
|
||||
The `daemon=True` parameter in `threading.Thread` means that these threads will be closed as soon as the main program finishes. We then start the threads with `start()` and use `join()` to make the main thread wait until both tracker threads have finished.
|
||||
|
||||
Finally, after all threads have completed their task, the windows displaying the results are closed using `cv2.destroyAllWindows()`.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
import threading
|
||||
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
def run_tracker_in_thread(filename, model):
|
||||
video = cv2.VideoCapture(filename)
|
||||
frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
for _ in range(frames):
|
||||
ret, frame = video.read()
|
||||
if ret:
|
||||
results = model.track(source=frame, persist=True)
|
||||
res_plotted = results[0].plot()
|
||||
cv2.imshow("p", res_plotted)
|
||||
if cv2.waitKey(1) == ord("q"):
|
||||
break
|
||||
|
||||
|
||||
# Load the models
|
||||
model1 = YOLO("yolov8n.pt")
|
||||
model2 = YOLO("yolov8n-seg.pt")
|
||||
|
||||
# Define the video files for the trackers
|
||||
video_file1 = "path/to/video1.mp4"
|
||||
video_file2 = "path/to/video2.mp4"
|
||||
|
||||
# Create the tracker threads
|
||||
tracker_thread1 = threading.Thread(
|
||||
target=run_tracker_in_thread, args=(video_file1, model1), daemon=True
|
||||
)
|
||||
tracker_thread2 = threading.Thread(
|
||||
target=run_tracker_in_thread, args=(video_file2, model2), daemon=True
|
||||
)
|
||||
|
||||
# Start the tracker threads
|
||||
tracker_thread1.start()
|
||||
tracker_thread2.start()
|
||||
|
||||
# Wait for the tracker threads to finish
|
||||
tracker_thread1.join()
|
||||
tracker_thread2.join()
|
||||
|
||||
# Clean up and close windows
|
||||
cv2.destroyAllWindows()
|
||||
```
|
||||
|
||||
This example can easily be extended to handle more video files and models by creating more threads and applying the same methodology.
|
||||
|
||||
## Contribute New Trackers
|
||||
|
||||
Are you proficient in multi-object tracking and have successfully implemented or adapted a tracking algorithm with Ultralytics YOLO? We invite you to contribute to our Trackers section in [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers)! Your real-world applications and solutions could be invaluable for users working on tracking tasks.
|
||||
|
||||
By contributing to this section, you help expand the scope of tracking solutions available within the Ultralytics YOLO framework, adding another layer of functionality and utility for the community.
|
||||
|
||||
To initiate your contribution, please refer to our [Contributing Guide](https://docs.ultralytics.com/help/contributing) for comprehensive instructions on submitting a Pull Request (PR) 🛠️. We are excited to see what you bring to the table!
|
||||
|
||||
Together, let's enhance the tracking capabilities of the Ultralytics YOLO ecosystem 🙏!
|
7
ultralytics/trackers/__init__.py
Normal file
7
ultralytics/trackers/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .bot_sort import BOTSORT
|
||||
from .byte_tracker import BYTETracker
|
||||
from .track import register_tracker
|
||||
|
||||
__all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import
|
105
ultralytics/trackers/basetrack.py
Normal file
105
ultralytics/trackers/basetrack.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""This module defines the base classes and structures for object tracking in YOLO."""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TrackState:
|
||||
"""
|
||||
Enumeration class representing the possible states of an object being tracked.
|
||||
|
||||
Attributes:
|
||||
New (int): State when the object is newly detected.
|
||||
Tracked (int): State when the object is successfully tracked in subsequent frames.
|
||||
Lost (int): State when the object is no longer tracked.
|
||||
Removed (int): State when the object is removed from tracking.
|
||||
"""
|
||||
|
||||
New = 0
|
||||
Tracked = 1
|
||||
Lost = 2
|
||||
Removed = 3
|
||||
|
||||
|
||||
class BaseTrack:
|
||||
"""
|
||||
Base class for object tracking, providing foundational attributes and methods.
|
||||
|
||||
Attributes:
|
||||
_count (int): Class-level counter for unique track IDs.
|
||||
track_id (int): Unique identifier for the track.
|
||||
is_activated (bool): Flag indicating whether the track is currently active.
|
||||
state (TrackState): Current state of the track.
|
||||
history (OrderedDict): Ordered history of the track's states.
|
||||
features (list): List of features extracted from the object for tracking.
|
||||
curr_feature (any): The current feature of the object being tracked.
|
||||
score (float): The confidence score of the tracking.
|
||||
start_frame (int): The frame number where tracking started.
|
||||
frame_id (int): The most recent frame ID processed by the track.
|
||||
time_since_update (int): Frames passed since the last update.
|
||||
location (tuple): The location of the object in the context of multi-camera tracking.
|
||||
|
||||
Methods:
|
||||
end_frame: Returns the ID of the last frame where the object was tracked.
|
||||
next_id: Increments and returns the next global track ID.
|
||||
activate: Abstract method to activate the track.
|
||||
predict: Abstract method to predict the next state of the track.
|
||||
update: Abstract method to update the track with new data.
|
||||
mark_lost: Marks the track as lost.
|
||||
mark_removed: Marks the track as removed.
|
||||
reset_id: Resets the global track ID counter.
|
||||
"""
|
||||
|
||||
_count = 0
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a new track with unique ID and foundational tracking attributes."""
|
||||
self.track_id = 0
|
||||
self.is_activated = False
|
||||
self.state = TrackState.New
|
||||
self.history = OrderedDict()
|
||||
self.features = []
|
||||
self.curr_feature = None
|
||||
self.score = 0
|
||||
self.start_frame = 0
|
||||
self.frame_id = 0
|
||||
self.time_since_update = 0
|
||||
self.location = (np.inf, np.inf)
|
||||
|
||||
@property
|
||||
def end_frame(self):
|
||||
"""Return the last frame ID of the track."""
|
||||
return self.frame_id
|
||||
|
||||
@staticmethod
|
||||
def next_id():
|
||||
"""Increment and return the global track ID counter."""
|
||||
BaseTrack._count += 1
|
||||
return BaseTrack._count
|
||||
|
||||
def activate(self, *args):
|
||||
"""Abstract method to activate the track with provided arguments."""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self):
|
||||
"""Abstract method to predict the next state of the track."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
"""Abstract method to update the track with new observations."""
|
||||
raise NotImplementedError
|
||||
|
||||
def mark_lost(self):
|
||||
"""Mark the track as lost."""
|
||||
self.state = TrackState.Lost
|
||||
|
||||
def mark_removed(self):
|
||||
"""Mark the track as removed."""
|
||||
self.state = TrackState.Removed
|
||||
|
||||
@staticmethod
|
||||
def reset_id():
|
||||
"""Reset the global track ID counter."""
|
||||
BaseTrack._count = 0
|
200
ultralytics/trackers/bot_sort.py
Normal file
200
ultralytics/trackers/bot_sort.py
Normal file
@ -0,0 +1,200 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .basetrack import TrackState
|
||||
from .byte_tracker import BYTETracker, STrack
|
||||
from .utils import matching
|
||||
from .utils.gmc import GMC
|
||||
from .utils.kalman_filter import KalmanFilterXYWH
|
||||
|
||||
|
||||
class BOTrack(STrack):
|
||||
"""
|
||||
An extended version of the STrack class for YOLOv8, adding object tracking features.
|
||||
|
||||
Attributes:
|
||||
shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
|
||||
smooth_feat (np.ndarray): Smoothed feature vector.
|
||||
curr_feat (np.ndarray): Current feature vector.
|
||||
features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.
|
||||
alpha (float): Smoothing factor for the exponential moving average of features.
|
||||
mean (np.ndarray): The mean state of the Kalman filter.
|
||||
covariance (np.ndarray): The covariance matrix of the Kalman filter.
|
||||
|
||||
Methods:
|
||||
update_features(feat): Update features vector and smooth it using exponential moving average.
|
||||
predict(): Predicts the mean and covariance using Kalman filter.
|
||||
re_activate(new_track, frame_id, new_id): Reactivates a track with updated features and optionally new ID.
|
||||
update(new_track, frame_id): Update the YOLOv8 instance with new track and frame ID.
|
||||
tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.
|
||||
multi_predict(stracks): Predicts the mean and covariance of multiple object tracks using shared Kalman filter.
|
||||
convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
|
||||
tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
|
||||
|
||||
Usage:
|
||||
bo_track = BOTrack(tlwh, score, cls, feat)
|
||||
bo_track.predict()
|
||||
bo_track.update(new_track, frame_id)
|
||||
"""
|
||||
|
||||
shared_kalman = KalmanFilterXYWH()
|
||||
|
||||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
||||
"""Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
|
||||
super().__init__(tlwh, score, cls)
|
||||
|
||||
self.smooth_feat = None
|
||||
self.curr_feat = None
|
||||
if feat is not None:
|
||||
self.update_features(feat)
|
||||
self.features = deque([], maxlen=feat_history)
|
||||
self.alpha = 0.9
|
||||
|
||||
def update_features(self, feat):
|
||||
"""Update features vector and smooth it using exponential moving average."""
|
||||
feat /= np.linalg.norm(feat)
|
||||
self.curr_feat = feat
|
||||
if self.smooth_feat is None:
|
||||
self.smooth_feat = feat
|
||||
else:
|
||||
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
|
||||
self.features.append(feat)
|
||||
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
||||
|
||||
def predict(self):
|
||||
"""Predicts the mean and covariance using Kalman filter."""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[6] = 0
|
||||
mean_state[7] = 0
|
||||
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
"""Reactivates a track with updated features and optionally assigns a new ID."""
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
super().re_activate(new_track, frame_id, new_id)
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""Update the YOLOv8 instance with new track and frame ID."""
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
super().update(new_track, frame_id)
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y, width, height)`."""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][6] = 0
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
"""Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
|
||||
return self.tlwh_to_xywh(tlwh)
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xywh(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, width, height)`."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
|
||||
class BOTSORT(BYTETracker):
|
||||
"""
|
||||
An extended version of the BYTETracker class for YOLOv8, designed for object tracking with ReID and GMC algorithm.
|
||||
|
||||
Attributes:
|
||||
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
|
||||
appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
|
||||
encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled.
|
||||
gmc (GMC): An instance of the GMC algorithm for data association.
|
||||
args (object): Parsed command-line arguments containing tracking parameters.
|
||||
|
||||
Methods:
|
||||
get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
|
||||
init_track(dets, scores, cls, img): Initialize track with detections, scores, and classes.
|
||||
get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
|
||||
multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
|
||||
|
||||
Usage:
|
||||
bot_sort = BOTSORT(args, frame_rate)
|
||||
bot_sort.init_track(dets, scores, cls, img)
|
||||
bot_sort.multi_predict(tracks)
|
||||
|
||||
Note:
|
||||
The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
|
||||
"""
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
|
||||
super().__init__(args, frame_rate)
|
||||
# ReID module
|
||||
self.proximity_thresh = args.proximity_thresh
|
||||
self.appearance_thresh = args.appearance_thresh
|
||||
|
||||
if args.with_reid:
|
||||
# Haven't supported BoT-SORT(reid) yet
|
||||
self.encoder = None
|
||||
self.gmc = GMC(method=args.gmc_method)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns an instance of KalmanFilterXYWH for object tracking."""
|
||||
return KalmanFilterXYWH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
"""Initialize track with detections, scores, and classes."""
|
||||
if len(dets) == 0:
|
||||
return []
|
||||
if self.args.with_reid and self.encoder is not None:
|
||||
features_keep = self.encoder.inference(img, dets)
|
||||
return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections
|
||||
else:
|
||||
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
dists_mask = dists > self.proximity_thresh
|
||||
|
||||
# TODO: mot20
|
||||
# if not self.args.mot20:
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
|
||||
if self.args.with_reid and self.encoder is not None:
|
||||
emb_dists = matching.embedding_distance(tracks, detections) / 2.0
|
||||
emb_dists[emb_dists > self.appearance_thresh] = 1.0
|
||||
emb_dists[dists_mask] = 1.0
|
||||
dists = np.minimum(dists, emb_dists)
|
||||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
"""Predict and track multiple objects with YOLOv8 model."""
|
||||
BOTrack.multi_predict(tracks)
|
||||
|
||||
def reset(self):
|
||||
"""Reset tracker."""
|
||||
super().reset()
|
||||
self.gmc.reset_params()
|
444
ultralytics/trackers/byte_tracker.py
Normal file
444
ultralytics/trackers/byte_tracker.py
Normal file
@ -0,0 +1,444 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .basetrack import BaseTrack, TrackState
|
||||
from .utils import matching
|
||||
from .utils.kalman_filter import KalmanFilterXYAH
|
||||
from ..utils.ops import xywh2ltwh
|
||||
from ..utils import LOGGER
|
||||
|
||||
|
||||
class STrack(BaseTrack):
|
||||
"""
|
||||
Single object tracking representation that uses Kalman filtering for state estimation.
|
||||
|
||||
This class is responsible for storing all the information regarding individual tracklets and performs state updates
|
||||
and predictions based on Kalman filter.
|
||||
|
||||
Attributes:
|
||||
shared_kalman (KalmanFilterXYAH): Shared Kalman filter that is used across all STrack instances for prediction.
|
||||
_tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.
|
||||
kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.
|
||||
mean (np.ndarray): Mean state estimate vector.
|
||||
covariance (np.ndarray): Covariance of state estimate.
|
||||
is_activated (bool): Boolean flag indicating if the track has been activated.
|
||||
score (float): Confidence score of the track.
|
||||
tracklet_len (int): Length of the tracklet.
|
||||
cls (any): Class label for the object.
|
||||
idx (int): Index or identifier for the object.
|
||||
frame_id (int): Current frame ID.
|
||||
start_frame (int): Frame where the object was first detected.
|
||||
|
||||
Methods:
|
||||
predict(): Predict the next state of the object using Kalman filter.
|
||||
multi_predict(stracks): Predict the next states for multiple tracks.
|
||||
multi_gmc(stracks, H): Update multiple track states using a homography matrix.
|
||||
activate(kalman_filter, frame_id): Activate a new tracklet.
|
||||
re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
|
||||
update(new_track, frame_id): Update the state of a matched track.
|
||||
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
|
||||
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
|
||||
"""
|
||||
|
||||
shared_kalman = KalmanFilterXYAH()
|
||||
|
||||
def __init__(self, xywh, score, cls):
|
||||
"""Initialize new STrack instance."""
|
||||
super().__init__()
|
||||
# xywh+idx or xywha+idx
|
||||
assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
|
||||
self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
self.is_activated = False
|
||||
|
||||
self.score = score
|
||||
self.tracklet_len = 0
|
||||
self.cls = cls
|
||||
self.idx = xywh[-1]
|
||||
self.angle = xywh[4] if len(xywh) == 6 else None
|
||||
|
||||
def predict(self):
|
||||
"""Predicts mean and covariance using Kalman filter."""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[7] = 0
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""Perform multi-object predictive tracking using Kalman filter for given stracks."""
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
@staticmethod
|
||||
def multi_gmc(stracks, H=np.eye(2, 3)):
|
||||
"""Update state tracks positions and covariances using a homography matrix."""
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
|
||||
R = H[:2, :2]
|
||||
R8x8 = np.kron(np.eye(4, dtype=float), R)
|
||||
t = H[:2, 2]
|
||||
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
mean = R8x8.dot(mean)
|
||||
mean[:2] += t
|
||||
cov = R8x8.dot(cov).dot(R8x8.transpose())
|
||||
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet."""
|
||||
self.kalman_filter = kalman_filter
|
||||
self.track_id = self.next_id()
|
||||
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
|
||||
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
if frame_id == 1:
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
self.start_frame = frame_id
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
"""Reactivates a previously lost track with a new detection."""
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
||||
)
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
if new_id:
|
||||
self.track_id = self.next_id()
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
self.angle = new_track.angle
|
||||
self.idx = new_track.idx
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""
|
||||
Update the state of a matched track.
|
||||
|
||||
Args:
|
||||
new_track (STrack): The new track containing updated information.
|
||||
frame_id (int): The ID of the current frame.
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
|
||||
new_tlwh = new_track.tlwh
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.convert_coords(new_tlwh)
|
||||
)
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
self.angle = new_track.angle
|
||||
self.idx = new_track.idx
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
"""Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
|
||||
return self.tlwh_to_xyah(tlwh)
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format (top left x, top left y, width, height)."""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
ret[2] *= ret[3]
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
|
||||
height.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
||||
|
||||
@property
|
||||
def xywh(self):
|
||||
"""Get current position in bounding box format (center x, center y, width, height)."""
|
||||
ret = np.asarray(self.tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@property
|
||||
def xywha(self):
|
||||
"""Get current position in bounding box format (center x, center y, width, height, angle)."""
|
||||
if self.angle is None:
|
||||
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
|
||||
return self.xywh
|
||||
return np.concatenate([self.xywh, self.angle[None]])
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
"""Get current tracking results."""
|
||||
coords = self.xyxy if self.angle is None else self.xywha
|
||||
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
|
||||
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
||||
|
||||
|
||||
class BYTETracker:
|
||||
"""
|
||||
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
||||
|
||||
The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
|
||||
sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
|
||||
predicting the new object locations, and performs data association.
|
||||
|
||||
Attributes:
|
||||
tracked_stracks (list[STrack]): List of successfully activated tracks.
|
||||
lost_stracks (list[STrack]): List of lost tracks.
|
||||
removed_stracks (list[STrack]): List of removed tracks.
|
||||
frame_id (int): The current frame ID.
|
||||
args (namespace): Command-line arguments.
|
||||
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
|
||||
kalman_filter (object): Kalman Filter object.
|
||||
|
||||
Methods:
|
||||
update(results, img=None): Updates object tracker with new detections.
|
||||
get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes.
|
||||
init_track(dets, scores, cls, img=None): Initialize object tracking with detections.
|
||||
get_dists(tracks, detections): Calculates the distance between tracks and detections.
|
||||
multi_predict(tracks): Predicts the location of tracks.
|
||||
reset_id(): Resets the ID counter of STrack.
|
||||
joint_stracks(tlista, tlistb): Combines two lists of stracks.
|
||||
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
|
||||
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
|
||||
"""
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
|
||||
self.frame_id = 0
|
||||
self.args = args
|
||||
self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer)
|
||||
self.kalman_filter = self.get_kalmanfilter()
|
||||
self.reset_id()
|
||||
|
||||
def update(self, results, img=None):
|
||||
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
|
||||
self.frame_id += 1
|
||||
activated_stracks = []
|
||||
refind_stracks = []
|
||||
lost_stracks = []
|
||||
removed_stracks = []
|
||||
|
||||
scores = results.conf
|
||||
bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
|
||||
# Add index
|
||||
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
|
||||
cls = results.cls
|
||||
|
||||
remain_inds = scores > self.args.track_high_thresh
|
||||
inds_low = scores > self.args.track_low_thresh
|
||||
inds_high = scores < self.args.track_high_thresh
|
||||
|
||||
inds_second = np.logical_and(inds_low, inds_high)
|
||||
dets_second = bboxes[inds_second]
|
||||
dets = bboxes[remain_inds]
|
||||
scores_keep = scores[remain_inds]
|
||||
scores_second = scores[inds_second]
|
||||
cls_keep = cls[remain_inds]
|
||||
cls_second = cls[inds_second]
|
||||
|
||||
detections = self.init_track(dets, scores_keep, cls_keep, img)
|
||||
# Add newly detected tracklets to tracked_stracks
|
||||
unconfirmed = []
|
||||
tracked_stracks = [] # type: list[STrack]
|
||||
for track in self.tracked_stracks:
|
||||
if not track.is_activated:
|
||||
unconfirmed.append(track)
|
||||
else:
|
||||
tracked_stracks.append(track)
|
||||
# Step 2: First association, with high score detection boxes
|
||||
strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
# Predict the current location with KF
|
||||
self.multi_predict(strack_pool)
|
||||
if hasattr(self, "gmc") and img is not None:
|
||||
warp = self.gmc.apply(img, dets)
|
||||
STrack.multi_gmc(strack_pool, warp)
|
||||
STrack.multi_gmc(unconfirmed, warp)
|
||||
|
||||
dists = self.get_dists(strack_pool, detections)
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)
|
||||
|
||||
for itracked, idet in matches:
|
||||
track = strack_pool[itracked]
|
||||
det = detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_stracks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
# Step 3: Second association, with low score detection boxes association the untrack to the low score detections
|
||||
detections_second = self.init_track(dets_second, scores_second, cls_second, img)
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
||||
# TODO
|
||||
dists = matching.iou_distance(r_tracked_stracks, detections_second)
|
||||
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
|
||||
for itracked, idet in matches:
|
||||
track = r_tracked_stracks[itracked]
|
||||
det = detections_second[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_stracks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
for it in u_track:
|
||||
track = r_tracked_stracks[it]
|
||||
if track.state != TrackState.Lost:
|
||||
track.mark_lost()
|
||||
lost_stracks.append(track)
|
||||
# Deal with unconfirmed tracks, usually tracks with only one beginning frame
|
||||
detections = [detections[i] for i in u_detection]
|
||||
dists = self.get_dists(unconfirmed, detections)
|
||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
activated_stracks.append(unconfirmed[itracked])
|
||||
for it in u_unconfirmed:
|
||||
track = unconfirmed[it]
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
# Step 4: Init new stracks
|
||||
for inew in u_detection:
|
||||
track = detections[inew]
|
||||
if track.score < self.args.new_track_thresh:
|
||||
continue
|
||||
track.activate(self.kalman_filter, self.frame_id)
|
||||
activated_stracks.append(track)
|
||||
# Step 5: Update state
|
||||
for track in self.lost_stracks:
|
||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks)
|
||||
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks)
|
||||
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||
self.lost_stracks.extend(lost_stracks)
|
||||
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||
self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
||||
self.removed_stracks.extend(removed_stracks)
|
||||
if len(self.removed_stracks) > 1000:
|
||||
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
|
||||
|
||||
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns a Kalman filter object for tracking bounding boxes."""
|
||||
return KalmanFilterXYAH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
"""Initialize object tracking with detections and scores using STrack algorithm."""
|
||||
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Calculates the distance between tracks and detections using IoU and fuses scores."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
# TODO: mot20
|
||||
# if not self.args.mot20:
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
"""Returns the predicted tracks using the YOLOv8 network."""
|
||||
STrack.multi_predict(tracks)
|
||||
|
||||
@staticmethod
|
||||
def reset_id():
|
||||
"""Resets the ID counter of STrack."""
|
||||
STrack.reset_id()
|
||||
|
||||
def reset(self):
|
||||
"""Reset tracker."""
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
self.frame_id = 0
|
||||
self.kalman_filter = self.get_kalmanfilter()
|
||||
self.reset_id()
|
||||
|
||||
@staticmethod
|
||||
def joint_stracks(tlista, tlistb):
|
||||
"""Combine two lists of stracks into a single one."""
|
||||
exists = {}
|
||||
res = []
|
||||
for t in tlista:
|
||||
exists[t.track_id] = 1
|
||||
res.append(t)
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if not exists.get(tid, 0):
|
||||
exists[tid] = 1
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def sub_stracks(tlista, tlistb):
|
||||
"""DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/
|
||||
stracks = {t.track_id: t for t in tlista}
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if stracks.get(tid, 0):
|
||||
del stracks[tid]
|
||||
return list(stracks.values())
|
||||
"""
|
||||
track_ids_b = {t.track_id for t in tlistb}
|
||||
return [t for t in tlista if t.track_id not in track_ids_b]
|
||||
|
||||
@staticmethod
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
"""Remove duplicate stracks with non-maximum IoU distance."""
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = [], []
|
||||
for p, q in zip(*pairs):
|
||||
timep = stracksa[p].frame_id - stracksa[p].start_frame
|
||||
timeq = stracksb[q].frame_id - stracksb[q].start_frame
|
||||
if timep > timeq:
|
||||
dupb.append(q)
|
||||
else:
|
||||
dupa.append(p)
|
||||
resa = [t for i, t in enumerate(stracksa) if i not in dupa]
|
||||
resb = [t for i, t in enumerate(stracksb) if i not in dupb]
|
||||
return resa, resb
|
89
ultralytics/trackers/track.py
Normal file
89
ultralytics/trackers/track.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import IterableSimpleNamespace, yaml_load
|
||||
from ultralytics.utils.checks import check_yaml
|
||||
from .bot_sort import BOTSORT
|
||||
from .byte_tracker import BYTETracker
|
||||
|
||||
# A mapping of tracker types to corresponding tracker classes
|
||||
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
|
||||
|
||||
|
||||
def on_predict_start(predictor: object, persist: bool = False) -> None:
|
||||
"""
|
||||
Initialize trackers for object tracking during prediction.
|
||||
|
||||
Args:
|
||||
predictor (object): The predictor object to initialize trackers for.
|
||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
||||
"""
|
||||
if hasattr(predictor, "trackers") and persist:
|
||||
return
|
||||
|
||||
tracker = check_yaml(predictor.args.tracker)
|
||||
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
||||
|
||||
if cfg.tracker_type not in ["bytetrack", "botsort"]:
|
||||
raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
|
||||
|
||||
trackers = []
|
||||
for _ in range(predictor.dataset.bs):
|
||||
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
|
||||
trackers.append(tracker)
|
||||
if predictor.dataset.mode != "stream": # only need one tracker for other modes.
|
||||
break
|
||||
predictor.trackers = trackers
|
||||
predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video
|
||||
|
||||
|
||||
def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
|
||||
"""
|
||||
Postprocess detected boxes and update with object tracking.
|
||||
|
||||
Args:
|
||||
predictor (object): The predictor object containing the predictions.
|
||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||
"""
|
||||
path, im0s = predictor.batch[:2]
|
||||
|
||||
is_obb = predictor.args.task == "obb"
|
||||
is_stream = predictor.dataset.mode == "stream"
|
||||
for i in range(len(im0s)):
|
||||
tracker = predictor.trackers[i if is_stream else 0]
|
||||
vid_path = predictor.save_dir / Path(path[i]).name
|
||||
if not persist and predictor.vid_path[i if is_stream else 0] != vid_path:
|
||||
tracker.reset()
|
||||
predictor.vid_path[i if is_stream else 0] = vid_path
|
||||
|
||||
det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
|
||||
if len(det) == 0:
|
||||
continue
|
||||
tracks = tracker.update(det, im0s[i])
|
||||
if len(tracks) == 0:
|
||||
continue
|
||||
idx = tracks[:, -1].astype(int)
|
||||
predictor.results[i] = predictor.results[i][idx]
|
||||
|
||||
update_args = dict()
|
||||
update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1])
|
||||
predictor.results[i].update(**update_args)
|
||||
|
||||
|
||||
def register_tracker(model: object, persist: bool) -> None:
|
||||
"""
|
||||
Register tracking callbacks to the model for object tracking during prediction.
|
||||
|
||||
Args:
|
||||
model (object): The model object to register tracking callbacks for.
|
||||
persist (bool): Whether to persist the trackers if they already exist.
|
||||
"""
|
||||
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
|
||||
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
|
1
ultralytics/trackers/utils/__init__.py
Normal file
1
ultralytics/trackers/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
363
ultralytics/trackers/utils/gmc.py
Normal file
363
ultralytics/trackers/utils/gmc.py
Normal file
@ -0,0 +1,363 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import copy
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
|
||||
class GMC:
|
||||
"""
|
||||
Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
|
||||
|
||||
This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,
|
||||
SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.
|
||||
|
||||
Attributes:
|
||||
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
||||
downscale (int): Factor by which to downscale the frames for processing.
|
||||
prevFrame (np.ndarray): Stores the previous frame for tracking.
|
||||
prevKeyPoints (list): Stores the keypoints from the previous frame.
|
||||
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
|
||||
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
|
||||
|
||||
Methods:
|
||||
__init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method
|
||||
and downscale factor.
|
||||
apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses
|
||||
provided detections.
|
||||
applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame.
|
||||
applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame.
|
||||
applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
|
||||
"""
|
||||
|
||||
def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
|
||||
"""
|
||||
Initialize a video tracker with specified parameters.
|
||||
|
||||
Args:
|
||||
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
||||
downscale (int): Downscale factor for processing frames.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.method = method
|
||||
self.downscale = max(1, int(downscale))
|
||||
|
||||
if self.method == "orb":
|
||||
self.detector = cv2.FastFeatureDetector_create(20)
|
||||
self.extractor = cv2.ORB_create()
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
|
||||
|
||||
elif self.method == "sift":
|
||||
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
|
||||
|
||||
elif self.method == "ecc":
|
||||
number_of_iterations = 5000
|
||||
termination_eps = 1e-6
|
||||
self.warp_mode = cv2.MOTION_EUCLIDEAN
|
||||
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
|
||||
|
||||
elif self.method == "sparseOptFlow":
|
||||
self.feature_params = dict(
|
||||
maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04
|
||||
)
|
||||
|
||||
elif self.method in {"none", "None", None}:
|
||||
self.method = None
|
||||
else:
|
||||
raise ValueError(f"Error: Unknown GMC method:{method}")
|
||||
|
||||
self.prevFrame = None
|
||||
self.prevKeyPoints = None
|
||||
self.prevDescriptors = None
|
||||
self.initializedFirstFrame = False
|
||||
|
||||
def apply(self, raw_frame: np.array, detections: list = None) -> np.array:
|
||||
"""
|
||||
Apply object detection on a raw frame using specified method.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
detections (list): List of detections to be used in the processing.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.apply(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
"""
|
||||
if self.method in ["orb", "sift"]:
|
||||
return self.applyFeatures(raw_frame, detections)
|
||||
elif self.method == "ecc":
|
||||
return self.applyEcc(raw_frame)
|
||||
elif self.method == "sparseOptFlow":
|
||||
return self.applySparseOptFlow(raw_frame)
|
||||
else:
|
||||
return np.eye(2, 3)
|
||||
|
||||
def applyEcc(self, raw_frame: np.array) -> np.array:
|
||||
"""
|
||||
Apply ECC algorithm to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
"""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3, dtype=np.float32)
|
||||
|
||||
# Downscale image
|
||||
if self.downscale > 1.0:
|
||||
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Run the ECC algorithm. The results are stored in warp_matrix.
|
||||
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
|
||||
try:
|
||||
(_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING: find transform failed. Set warp as identity {e}")
|
||||
|
||||
return H
|
||||
|
||||
def applyFeatures(self, raw_frame: np.array, detections: list = None) -> np.array:
|
||||
"""
|
||||
Apply feature-based methods like ORB or SIFT to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
detections (list): List of detections to be used in the processing.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.applyFeatures(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
"""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image
|
||||
if self.downscale > 1.0:
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# Find the keypoints
|
||||
mask = np.zeros_like(frame)
|
||||
mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
|
||||
if detections is not None:
|
||||
for det in detections:
|
||||
tlbr = (det[:4] / self.downscale).astype(np.int_)
|
||||
mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
|
||||
|
||||
keypoints = self.detector.detect(frame, mask)
|
||||
|
||||
# Compute the descriptors
|
||||
keypoints, descriptors = self.extractor.compute(frame, keypoints)
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Match descriptors
|
||||
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
|
||||
|
||||
# Filter matches based on smallest spatial distance
|
||||
matches = []
|
||||
spatialDistances = []
|
||||
|
||||
maxSpatialDistance = 0.25 * np.array([width, height])
|
||||
|
||||
# Handle empty matches case
|
||||
if len(knnMatches) == 0:
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
return H
|
||||
|
||||
for m, n in knnMatches:
|
||||
if m.distance < 0.9 * n.distance:
|
||||
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
|
||||
currKeyPointLocation = keypoints[m.trainIdx].pt
|
||||
|
||||
spatialDistance = (
|
||||
prevKeyPointLocation[0] - currKeyPointLocation[0],
|
||||
prevKeyPointLocation[1] - currKeyPointLocation[1],
|
||||
)
|
||||
|
||||
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and (
|
||||
np.abs(spatialDistance[1]) < maxSpatialDistance[1]
|
||||
):
|
||||
spatialDistances.append(spatialDistance)
|
||||
matches.append(m)
|
||||
|
||||
meanSpatialDistances = np.mean(spatialDistances, 0)
|
||||
stdSpatialDistances = np.std(spatialDistances, 0)
|
||||
|
||||
inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
|
||||
|
||||
goodMatches = []
|
||||
prevPoints = []
|
||||
currPoints = []
|
||||
for i in range(len(matches)):
|
||||
if inliers[i, 0] and inliers[i, 1]:
|
||||
goodMatches.append(matches[i])
|
||||
prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
|
||||
currPoints.append(keypoints[matches[i].trainIdx].pt)
|
||||
|
||||
prevPoints = np.array(prevPoints)
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Draw the keypoint matches on the output image
|
||||
# if False:
|
||||
# import matplotlib.pyplot as plt
|
||||
# matches_img = np.hstack((self.prevFrame, frame))
|
||||
# matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
|
||||
# W = self.prevFrame.shape[1]
|
||||
# for m in goodMatches:
|
||||
# prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
|
||||
# curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
|
||||
# curr_pt[0] += W
|
||||
# color = np.random.randint(0, 255, 3)
|
||||
# color = (int(color[0]), int(color[1]), int(color[2]))
|
||||
#
|
||||
# matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
|
||||
# matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
|
||||
# matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
|
||||
#
|
||||
# plt.figure()
|
||||
# plt.imshow(matches_img)
|
||||
# plt.show()
|
||||
|
||||
# Find rigid matrix
|
||||
if prevPoints.shape[0] > 4:
|
||||
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
LOGGER.warning("WARNING: not enough matching points")
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
return H
|
||||
|
||||
def applySparseOptFlow(self, raw_frame: np.array) -> np.array:
|
||||
"""
|
||||
Apply Sparse Optical Flow method to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
"""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image
|
||||
if self.downscale > 1.0:
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
|
||||
# Find the keypoints
|
||||
keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.initializedFirstFrame = True
|
||||
return H
|
||||
|
||||
# Find correspondences
|
||||
matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
|
||||
|
||||
# Leave good correspondences only
|
||||
prevPoints = []
|
||||
currPoints = []
|
||||
|
||||
for i in range(len(status)):
|
||||
if status[i]:
|
||||
prevPoints.append(self.prevKeyPoints[i])
|
||||
currPoints.append(matchedKeypoints[i])
|
||||
|
||||
prevPoints = np.array(prevPoints)
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Find rigid matrix
|
||||
if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == prevPoints.shape[0]):
|
||||
H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
LOGGER.warning("WARNING: not enough matching points")
|
||||
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
return H
|
||||
|
||||
def reset_params(self) -> None:
|
||||
"""Reset parameters."""
|
||||
self.prevFrame = None
|
||||
self.prevKeyPoints = None
|
||||
self.prevDescriptors = None
|
||||
self.initializedFirstFrame = False
|
360
ultralytics/trackers/utils/kalman_filter.py
Normal file
360
ultralytics/trackers/utils/kalman_filter.py
Normal file
@ -0,0 +1,360 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
|
||||
class KalmanFilterXYAH:
|
||||
"""
|
||||
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect
|
||||
ratio a, height h, and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
|
||||
observation of the state space (linear observation model).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
|
||||
ndim, dt = 4, 1.0
|
||||
|
||||
# Create Kalman filter model matrices
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current state estimate. These weights control
|
||||
# the amount of uncertainty in the model.
|
||||
self._std_weight_position = 1.0 / 20
|
||||
self._std_weight_velocity = 1.0 / 160
|
||||
|
||||
def initiate(self, measurement: np.ndarray) -> tuple:
|
||||
"""
|
||||
Create track from unassociated measurement.
|
||||
|
||||
Args:
|
||||
measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
|
||||
and height h.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of
|
||||
the new track. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
1e-2,
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
1e-5,
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
|
||||
"""
|
||||
Run Kalman filter prediction step.
|
||||
|
||||
Args:
|
||||
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
|
||||
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-2,
|
||||
self._std_weight_position * mean[3],
|
||||
]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[3],
|
||||
1e-5,
|
||||
self._std_weight_velocity * mean[3],
|
||||
]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
|
||||
"""
|
||||
Project state distribution to measurement space.
|
||||
|
||||
Args:
|
||||
mean (ndarray): The state's mean vector (8 dimensional array).
|
||||
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-1,
|
||||
self._std_weight_position * mean[3],
|
||||
]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
|
||||
"""
|
||||
Run Kalman filter prediction step (Vectorized version).
|
||||
|
||||
Args:
|
||||
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 3],
|
||||
1e-2 * np.ones_like(mean[:, 3]),
|
||||
self._std_weight_position * mean[:, 3],
|
||||
]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
1e-5 * np.ones_like(mean[:, 3]),
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray) -> tuple:
|
||||
"""
|
||||
Run Kalman filter correction step.
|
||||
|
||||
Args:
|
||||
mean (ndarray): The predicted state's mean vector (8 dimensional).
|
||||
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
|
||||
measurement (ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center
|
||||
position, a the aspect ratio, and h the height of the bounding box.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve(
|
||||
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False
|
||||
).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(
|
||||
self,
|
||||
mean: np.ndarray,
|
||||
covariance: np.ndarray,
|
||||
measurements: np.ndarray,
|
||||
only_position: bool = False,
|
||||
metric: str = "maha",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
|
||||
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
|
||||
otherwise 2.
|
||||
|
||||
Args:
|
||||
mean (ndarray): Mean vector over the state distribution (8 dimensional).
|
||||
covariance (ndarray): Covariance of the state distribution (8x8 dimensional).
|
||||
measurements (ndarray): An Nx4 matrix of N measurements, each in format (x, y, a, h) where (x, y)
|
||||
is the bounding box center position, a the aspect ratio, and h the height.
|
||||
only_position (bool, optional): If True, distance computation is done with respect to the bounding box
|
||||
center position only. Defaults to False.
|
||||
metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the
|
||||
squared Euclidean distance and 'maha' for the squared Mahalanobis distance. Defaults to 'maha'.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
|
||||
(mean, covariance) and `measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == "gaussian":
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == "maha":
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
|
||||
return np.sum(z * z, axis=0) # square maha
|
||||
else:
|
||||
raise ValueError("Invalid distance metric")
|
||||
|
||||
|
||||
class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||
"""
|
||||
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
|
||||
w, height h, and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
|
||||
observation of the state space (linear observation model).
|
||||
"""
|
||||
|
||||
def initiate(self, measurement: np.ndarray) -> tuple:
|
||||
"""
|
||||
Create track from unassociated measurement.
|
||||
|
||||
Args:
|
||||
measurement (ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of
|
||||
the new track. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[2],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[2],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance) -> tuple:
|
||||
"""
|
||||
Run Kalman filter prediction step.
|
||||
|
||||
Args:
|
||||
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
|
||||
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3],
|
||||
]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[2],
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[2],
|
||||
self._std_weight_velocity * mean[3],
|
||||
]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance) -> tuple:
|
||||
"""
|
||||
Project state distribution to measurement space.
|
||||
|
||||
Args:
|
||||
mean (ndarray): The state's mean vector (8 dimensional array).
|
||||
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3],
|
||||
]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance) -> tuple:
|
||||
"""
|
||||
Run Kalman filter prediction step (Vectorized version).
|
||||
|
||||
Args:
|
||||
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 2],
|
||||
self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 2],
|
||||
self._std_weight_position * mean[:, 3],
|
||||
]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 2],
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 2],
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement) -> tuple:
|
||||
"""
|
||||
Run Kalman filter correction step.
|
||||
|
||||
Args:
|
||||
mean (ndarray): The predicted state's mean vector (8 dimensional).
|
||||
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
|
||||
measurement (ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center
|
||||
position, w the width, and h the height of the bounding box.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
|
||||
"""
|
||||
return super().update(mean, covariance, measurement)
|
138
ultralytics/trackers/utils/matching.py
Normal file
138
ultralytics/trackers/utils/matching.py
Normal file
@ -0,0 +1,138 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from ultralytics.utils.metrics import bbox_ioa, batch_probiou
|
||||
|
||||
try:
|
||||
import lap # for linear_assignment
|
||||
|
||||
assert lap.__version__ # verify package is not directory
|
||||
except (ImportError, AssertionError, AttributeError):
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements("lapx>=0.5.2") # update to lap package from https://github.com/rathaROG/lapx
|
||||
import lap
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple:
|
||||
"""
|
||||
Perform linear assignment using scipy or lap.lapjv.
|
||||
|
||||
Args:
|
||||
cost_matrix (np.ndarray): The matrix containing cost values for assignments.
|
||||
thresh (float): Threshold for considering an assignment valid.
|
||||
use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tuple with:
|
||||
- matched indices
|
||||
- unmatched indices from 'a'
|
||||
- unmatched indices from 'b'
|
||||
"""
|
||||
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
|
||||
if use_lap:
|
||||
# Use lap.lapjv
|
||||
# https://github.com/gatagat/lap
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
else:
|
||||
# Use scipy.optimize.linear_sum_assignment
|
||||
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
|
||||
x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y
|
||||
matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])
|
||||
if len(matches) == 0:
|
||||
unmatched_a = list(np.arange(cost_matrix.shape[0]))
|
||||
unmatched_b = list(np.arange(cost_matrix.shape[1]))
|
||||
else:
|
||||
unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0]))
|
||||
unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1]))
|
||||
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def iou_distance(atracks: list, btracks: list) -> np.ndarray:
|
||||
"""
|
||||
Compute cost based on Intersection over Union (IoU) between tracks.
|
||||
|
||||
Args:
|
||||
atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes.
|
||||
btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Cost matrix computed based on IoU.
|
||||
"""
|
||||
|
||||
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]
|
||||
btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]
|
||||
|
||||
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
|
||||
if len(atlbrs) and len(btlbrs):
|
||||
if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:
|
||||
ious = batch_probiou(
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float32),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float32),
|
||||
).numpy()
|
||||
else:
|
||||
ious = bbox_ioa(
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float32),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float32),
|
||||
iou=True,
|
||||
)
|
||||
return 1 - ious # cost matrix
|
||||
|
||||
|
||||
def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
|
||||
"""
|
||||
Compute distance between tracks and detections based on embeddings.
|
||||
|
||||
Args:
|
||||
tracks (list[STrack]): List of tracks.
|
||||
detections (list[BaseTrack]): List of detections.
|
||||
metric (str, optional): Metric for distance computation. Defaults to 'cosine'.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Cost matrix computed based on embeddings.
|
||||
"""
|
||||
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
|
||||
# for i, track in enumerate(tracks):
|
||||
# cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
|
||||
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
|
||||
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
|
||||
"""
|
||||
Fuses cost matrix with detection scores to produce a single similarity matrix.
|
||||
|
||||
Args:
|
||||
cost_matrix (np.ndarray): The matrix containing cost values for assignments.
|
||||
detections (list[BaseTrack]): List of detections with scores.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Fused similarity matrix.
|
||||
"""
|
||||
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
iou_sim = 1 - cost_matrix
|
||||
det_scores = np.array([det.score for det in detections])
|
||||
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
fuse_sim = iou_sim * det_scores
|
||||
return 1 - fuse_sim # fuse_cost
|
Reference in New Issue
Block a user