add yolo v10 and modify pipeline
This commit is contained in:
@ -1,91 +1,318 @@
|
||||
# Tracker
|
||||
# Multi-Object Tracking with Ultralytics YOLO
|
||||
|
||||
## Supported Trackers
|
||||
<img width="1024" src="https://user-images.githubusercontent.com/26833433/243418637-1d6250fd-1515-4c10-a844-a32818ae6d46.png" alt="YOLOv8 trackers visualization">
|
||||
|
||||
- [x] ByteTracker
|
||||
- [x] BoT-SORT
|
||||
Object tracking in the realm of video analytics is a critical task that not only identifies the location and class of objects within the frame but also maintains a unique ID for each detected object as the video progresses. The applications are limitless—ranging from surveillance and security to real-time sports analytics.
|
||||
|
||||
## Usage
|
||||
## Why Choose Ultralytics YOLO for Object Tracking?
|
||||
|
||||
### python interface:
|
||||
The output from Ultralytics trackers is consistent with standard object detection but has the added value of object IDs. This makes it easy to track objects in video streams and perform subsequent analytics. Here's why you should consider using Ultralytics YOLO for your object tracking needs:
|
||||
|
||||
You can use the Python interface to track objects using the YOLO model.
|
||||
- **Efficiency:** Process video streams in real-time without compromising accuracy.
|
||||
- **Flexibility:** Supports multiple tracking algorithms and configurations.
|
||||
- **Ease of Use:** Simple Python API and CLI options for quick integration and deployment.
|
||||
- **Customizability:** Easy to use with custom trained YOLO models, allowing integration into domain-specific applications.
|
||||
|
||||
**Video Tutorial:** [Object Detection and Tracking with Ultralytics YOLOv8](https://www.youtube.com/embed/hHyHmOtmEgs?si=VNZtXmm45Nb9s-N-).
|
||||
|
||||
## Features at a Glance
|
||||
|
||||
Ultralytics YOLO extends its object detection features to provide robust and versatile object tracking:
|
||||
|
||||
- **Real-Time Tracking:** Seamlessly track objects in high-frame-rate videos.
|
||||
- **Multiple Tracker Support:** Choose from a variety of established tracking algorithms.
|
||||
- **Customizable Tracker Configurations:** Tailor the tracking algorithm to meet specific requirements by adjusting various parameters.
|
||||
|
||||
## Available Trackers
|
||||
|
||||
Ultralytics YOLO supports the following tracking algorithms. They can be enabled by passing the relevant YAML configuration file such as `tracker=tracker_type.yaml`:
|
||||
|
||||
- [BoT-SORT](https://github.com/NirAharon/BoT-SORT) - Use `botsort.yaml` to enable this tracker.
|
||||
- [ByteTrack](https://github.com/ifzhang/ByteTrack) - Use `bytetrack.yaml` to enable this tracker.
|
||||
|
||||
The default tracker is BoT-SORT.
|
||||
|
||||
## Tracking
|
||||
|
||||
To run the tracker on video streams, use a trained Detect, Segment or Pose model such as YOLOv8n, YOLOv8n-seg and YOLOv8n-pose.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt
|
||||
model.track(
|
||||
source="video/streams",
|
||||
stream=True,
|
||||
tracker="botsort.yaml", # or 'bytetrack.yaml'
|
||||
show=True,
|
||||
# Load an official or custom model
|
||||
model = YOLO("yolov8n.pt") # Load an official Detect model
|
||||
model = YOLO("yolov8n-seg.pt") # Load an official Segment model
|
||||
model = YOLO("yolov8n-pose.pt") # Load an official Pose model
|
||||
model = YOLO("path/to/best.pt") # Load a custom trained model
|
||||
|
||||
# Perform tracking with the model
|
||||
results = model.track(
|
||||
source="https://youtu.be/LNwODJXcvt4", show=True
|
||||
) # Tracking with default tracker
|
||||
results = model.track(
|
||||
source="https://youtu.be/LNwODJXcvt4", show=True, tracker="bytetrack.yaml"
|
||||
) # Tracking with ByteTrack tracker
|
||||
```
|
||||
|
||||
#### CLI
|
||||
|
||||
```bash
|
||||
# Perform tracking with various models using the command line interface
|
||||
yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" # Official Detect model
|
||||
yolo track model=yolov8n-seg.pt source="https://youtu.be/LNwODJXcvt4" # Official Segment model
|
||||
yolo track model=yolov8n-pose.pt source="https://youtu.be/LNwODJXcvt4" # Official Pose model
|
||||
yolo track model=path/to/best.pt source="https://youtu.be/LNwODJXcvt4" # Custom trained model
|
||||
|
||||
# Track using ByteTrack tracker
|
||||
yolo track model=path/to/best.pt tracker="bytetrack.yaml"
|
||||
```
|
||||
|
||||
As can be seen in the above usage, tracking is available for all Detect, Segment and Pose models run on videos or streaming sources.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Tracking Arguments
|
||||
|
||||
Tracking configuration shares properties with Predict mode, such as `conf`, `iou`, and `show`. For further configurations, refer to the [Predict](https://docs.ultralytics.com/modes/predict/) model page.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Configure the tracking parameters and run the tracker
|
||||
model = YOLO("yolov8n.pt")
|
||||
results = model.track(
|
||||
source="https://youtu.be/LNwODJXcvt4", conf=0.3, iou=0.5, show=True
|
||||
)
|
||||
```
|
||||
|
||||
You can get the IDs of the tracked objects using the following code:
|
||||
#### CLI
|
||||
|
||||
```bash
|
||||
# Configure tracking parameters and run the tracker using the command line interface
|
||||
yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" conf=0.3, iou=0.5 show
|
||||
```
|
||||
|
||||
### Tracker Selection
|
||||
|
||||
Ultralytics also allows you to use a modified tracker configuration file. To do this, simply make a copy of a tracker config file (for example, `custom_tracker.yaml`) from [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) and modify any configurations (except the `tracker_type`) as per your needs.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load the model and run the tracker with a custom configuration file
|
||||
model = YOLO("yolov8n.pt")
|
||||
|
||||
for result in model.track(source="video.mp4"):
|
||||
print(
|
||||
result.boxes.id.cpu().numpy().astype(int)
|
||||
) # this will print the IDs of the tracked objects in the frame
|
||||
results = model.track(
|
||||
source="https://youtu.be/LNwODJXcvt4", tracker="custom_tracker.yaml"
|
||||
)
|
||||
```
|
||||
|
||||
If you want to use the tracker with a folder of images or when you loop on the video frames, you should use the `persist` parameter to tell the model that these frames are related to each other so the IDs will be fixed for the same objects. Otherwise, the IDs will be different in each frame because in each loop, the model creates a new object for tracking, but the `persist` parameter makes it use the same object for tracking.
|
||||
#### CLI
|
||||
|
||||
```bash
|
||||
# Load the model and run the tracker with a custom configuration file using the command line interface
|
||||
yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" tracker='custom_tracker.yaml'
|
||||
```
|
||||
|
||||
For a comprehensive list of tracking arguments, refer to the [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) page.
|
||||
|
||||
## Python Examples
|
||||
|
||||
### Persisting Tracks Loop
|
||||
|
||||
Here is a Python script using OpenCV (`cv2`) and YOLOv8 to run object tracking on video frames. This script still assumes you have already installed the necessary packages (`opencv-python` and `ultralytics`). The `persist=True` argument tells the tracker than the current image or frame is the next in a sequence and to expect tracks from the previous image in the current image.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
cap = cv2.VideoCapture("video.mp4")
|
||||
# Load the YOLOv8 model
|
||||
model = YOLO("yolov8n.pt")
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
results = model.track(frame, persist=True)
|
||||
boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
|
||||
ids = results[0].boxes.id.cpu().numpy().astype(int)
|
||||
for box, id in zip(boxes, ids):
|
||||
cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
|
||||
cv2.putText(
|
||||
frame,
|
||||
f"Id {id}",
|
||||
(box[0], box[1]),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1,
|
||||
(0, 0, 255),
|
||||
2,
|
||||
)
|
||||
cv2.imshow("frame", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
|
||||
# Open the video file
|
||||
video_path = "path/to/video.mp4"
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
# Loop through the video frames
|
||||
while cap.isOpened():
|
||||
# Read a frame from the video
|
||||
success, frame = cap.read()
|
||||
|
||||
if success:
|
||||
# Run YOLOv8 tracking on the frame, persisting tracks between frames
|
||||
results = model.track(frame, persist=True)
|
||||
|
||||
# Visualize the results on the frame
|
||||
annotated_frame = results[0].plot()
|
||||
|
||||
# Display the annotated frame
|
||||
cv2.imshow("YOLOv8 Tracking", annotated_frame)
|
||||
|
||||
# Break the loop if 'q' is pressed
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
else:
|
||||
# Break the loop if the end of the video is reached
|
||||
break
|
||||
|
||||
# Release the video capture object and close the display window
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
```
|
||||
|
||||
## Change tracker parameters
|
||||
Please note the change from `model(frame)` to `model.track(frame)`, which enables object tracking instead of simple detection. This modified script will run the tracker on each frame of the video, visualize the results, and display them in a window. The loop can be exited by pressing 'q'.
|
||||
|
||||
You can change the tracker parameters by editing the `tracker.yaml` file which is located in the ultralytics/cfg/trackers folder.
|
||||
### Plotting Tracks Over Time
|
||||
|
||||
## Command Line Interface (CLI)
|
||||
Visualizing object tracks over consecutive frames can provide valuable insights into the movement patterns and behavior of detected objects within a video. With Ultralytics YOLOv8, plotting these tracks is a seamless and efficient process.
|
||||
|
||||
You can also use the command line interface to track objects using the YOLO model.
|
||||
In the following example, we demonstrate how to utilize YOLOv8's tracking capabilities to plot the movement of detected objects across multiple video frames. This script involves opening a video file, reading it frame by frame, and utilizing the YOLO model to identify and track various objects. By retaining the center points of the detected bounding boxes and connecting them, we can draw lines that represent the paths followed by the tracked objects.
|
||||
|
||||
```bash
|
||||
yolo detect track source=... tracker=...
|
||||
yolo segment track source=... tracker=...
|
||||
yolo pose track source=... tracker=...
|
||||
#### Python
|
||||
|
||||
```python
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load the YOLOv8 model
|
||||
model = YOLO("yolov8n.pt")
|
||||
|
||||
# Open the video file
|
||||
video_path = "path/to/video.mp4"
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
# Store the track history
|
||||
track_history = defaultdict(lambda: [])
|
||||
|
||||
# Loop through the video frames
|
||||
while cap.isOpened():
|
||||
# Read a frame from the video
|
||||
success, frame = cap.read()
|
||||
|
||||
if success:
|
||||
# Run YOLOv8 tracking on the frame, persisting tracks between frames
|
||||
results = model.track(frame, persist=True)
|
||||
|
||||
# Get the boxes and track IDs
|
||||
boxes = results[0].boxes.xywh.cpu()
|
||||
track_ids = results[0].boxes.id.int().cpu().tolist()
|
||||
|
||||
# Visualize the results on the frame
|
||||
annotated_frame = results[0].plot()
|
||||
|
||||
# Plot the tracks
|
||||
for box, track_id in zip(boxes, track_ids):
|
||||
x, y, w, h = box
|
||||
track = track_history[track_id]
|
||||
track.append((float(x), float(y))) # x, y center point
|
||||
if len(track) > 30: # retain 90 tracks for 90 frames
|
||||
track.pop(0)
|
||||
|
||||
# Draw the tracking lines
|
||||
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(
|
||||
annotated_frame,
|
||||
[points],
|
||||
isClosed=False,
|
||||
color=(230, 230, 230),
|
||||
thickness=10,
|
||||
)
|
||||
|
||||
# Display the annotated frame
|
||||
cv2.imshow("YOLOv8 Tracking", annotated_frame)
|
||||
|
||||
# Break the loop if 'q' is pressed
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
else:
|
||||
# Break the loop if the end of the video is reached
|
||||
break
|
||||
|
||||
# Release the video capture object and close the display window
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
```
|
||||
|
||||
By default, trackers will use the configuration in `ultralytics/cfg/trackers`. We also support using a modified tracker config file. Please refer to the tracker config files in `ultralytics/cfg/trackers`.
|
||||
### Multithreaded Tracking
|
||||
|
||||
## Contribute to Our Trackers Section
|
||||
Multithreaded tracking provides the capability to run object tracking on multiple video streams simultaneously. This is particularly useful when handling multiple video inputs, such as from multiple surveillance cameras, where concurrent processing can greatly enhance efficiency and performance.
|
||||
|
||||
Are you proficient in multi-object tracking and have successfully implemented or adapted a tracking algorithm with Ultralytics YOLO? We invite you to contribute to our Trackers section! Your real-world applications and solutions could be invaluable for users working on tracking tasks.
|
||||
In the provided Python script, we make use of Python's `threading` module to run multiple instances of the tracker concurrently. Each thread is responsible for running the tracker on one video file, and all the threads run simultaneously in the background.
|
||||
|
||||
To ensure that each thread receives the correct parameters (the video file and the model to use), we define a function `run_tracker_in_thread` that accepts these parameters and contains the main tracking loop. This function reads the video frame by frame, runs the tracker, and displays the results.
|
||||
|
||||
Two different models are used in this example: `yolov8n.pt` and `yolov8n-seg.pt`, each tracking objects in a different video file. The video files are specified in `video_file1` and `video_file2`.
|
||||
|
||||
The `daemon=True` parameter in `threading.Thread` means that these threads will be closed as soon as the main program finishes. We then start the threads with `start()` and use `join()` to make the main thread wait until both tracker threads have finished.
|
||||
|
||||
Finally, after all threads have completed their task, the windows displaying the results are closed using `cv2.destroyAllWindows()`.
|
||||
|
||||
#### Python
|
||||
|
||||
```python
|
||||
import threading
|
||||
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
def run_tracker_in_thread(filename, model):
|
||||
video = cv2.VideoCapture(filename)
|
||||
frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
for _ in range(frames):
|
||||
ret, frame = video.read()
|
||||
if ret:
|
||||
results = model.track(source=frame, persist=True)
|
||||
res_plotted = results[0].plot()
|
||||
cv2.imshow("p", res_plotted)
|
||||
if cv2.waitKey(1) == ord("q"):
|
||||
break
|
||||
|
||||
|
||||
# Load the models
|
||||
model1 = YOLO("yolov8n.pt")
|
||||
model2 = YOLO("yolov8n-seg.pt")
|
||||
|
||||
# Define the video files for the trackers
|
||||
video_file1 = "path/to/video1.mp4"
|
||||
video_file2 = "path/to/video2.mp4"
|
||||
|
||||
# Create the tracker threads
|
||||
tracker_thread1 = threading.Thread(
|
||||
target=run_tracker_in_thread, args=(video_file1, model1), daemon=True
|
||||
)
|
||||
tracker_thread2 = threading.Thread(
|
||||
target=run_tracker_in_thread, args=(video_file2, model2), daemon=True
|
||||
)
|
||||
|
||||
# Start the tracker threads
|
||||
tracker_thread1.start()
|
||||
tracker_thread2.start()
|
||||
|
||||
# Wait for the tracker threads to finish
|
||||
tracker_thread1.join()
|
||||
tracker_thread2.join()
|
||||
|
||||
# Clean up and close windows
|
||||
cv2.destroyAllWindows()
|
||||
```
|
||||
|
||||
This example can easily be extended to handle more video files and models by creating more threads and applying the same methodology.
|
||||
|
||||
## Contribute New Trackers
|
||||
|
||||
Are you proficient in multi-object tracking and have successfully implemented or adapted a tracking algorithm with Ultralytics YOLO? We invite you to contribute to our Trackers section in [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers)! Your real-world applications and solutions could be invaluable for users working on tracking tasks.
|
||||
|
||||
By contributing to this section, you help expand the scope of tracking solutions available within the Ultralytics YOLO framework, adding another layer of functionality and utility for the community.
|
||||
|
||||
|
@ -4,4 +4,4 @@ from .bot_sort import BOTSORT
|
||||
from .byte_tracker import BYTETracker
|
||||
from .track import register_tracker
|
||||
|
||||
__all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import
|
||||
__all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,4 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""This module defines the base classes and structures for object tracking in YOLO."""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
@ -6,7 +7,15 @@ import numpy as np
|
||||
|
||||
|
||||
class TrackState:
|
||||
"""Enumeration of possible object tracking states."""
|
||||
"""
|
||||
Enumeration class representing the possible states of an object being tracked.
|
||||
|
||||
Attributes:
|
||||
New (int): State when the object is newly detected.
|
||||
Tracked (int): State when the object is successfully tracked in subsequent frames.
|
||||
Lost (int): State when the object is no longer tracked.
|
||||
Removed (int): State when the object is removed from tracking.
|
||||
"""
|
||||
|
||||
New = 0
|
||||
Tracked = 1
|
||||
@ -15,24 +24,49 @@ class TrackState:
|
||||
|
||||
|
||||
class BaseTrack:
|
||||
"""Base class for object tracking, handling basic track attributes and operations."""
|
||||
"""
|
||||
Base class for object tracking, providing foundational attributes and methods.
|
||||
|
||||
Attributes:
|
||||
_count (int): Class-level counter for unique track IDs.
|
||||
track_id (int): Unique identifier for the track.
|
||||
is_activated (bool): Flag indicating whether the track is currently active.
|
||||
state (TrackState): Current state of the track.
|
||||
history (OrderedDict): Ordered history of the track's states.
|
||||
features (list): List of features extracted from the object for tracking.
|
||||
curr_feature (any): The current feature of the object being tracked.
|
||||
score (float): The confidence score of the tracking.
|
||||
start_frame (int): The frame number where tracking started.
|
||||
frame_id (int): The most recent frame ID processed by the track.
|
||||
time_since_update (int): Frames passed since the last update.
|
||||
location (tuple): The location of the object in the context of multi-camera tracking.
|
||||
|
||||
Methods:
|
||||
end_frame: Returns the ID of the last frame where the object was tracked.
|
||||
next_id: Increments and returns the next global track ID.
|
||||
activate: Abstract method to activate the track.
|
||||
predict: Abstract method to predict the next state of the track.
|
||||
update: Abstract method to update the track with new data.
|
||||
mark_lost: Marks the track as lost.
|
||||
mark_removed: Marks the track as removed.
|
||||
reset_id: Resets the global track ID counter.
|
||||
"""
|
||||
|
||||
_count = 0
|
||||
|
||||
track_id = 0
|
||||
is_activated = False
|
||||
state = TrackState.New
|
||||
|
||||
history = OrderedDict()
|
||||
features = []
|
||||
curr_feature = None
|
||||
score = 0
|
||||
start_frame = 0
|
||||
frame_id = 0
|
||||
time_since_update = 0
|
||||
|
||||
# Multi-camera
|
||||
location = (np.inf, np.inf)
|
||||
def __init__(self):
|
||||
"""Initializes a new track with unique ID and foundational tracking attributes."""
|
||||
self.track_id = 0
|
||||
self.is_activated = False
|
||||
self.state = TrackState.New
|
||||
self.history = OrderedDict()
|
||||
self.features = []
|
||||
self.curr_feature = None
|
||||
self.score = 0
|
||||
self.start_frame = 0
|
||||
self.frame_id = 0
|
||||
self.time_since_update = 0
|
||||
self.location = (np.inf, np.inf)
|
||||
|
||||
@property
|
||||
def end_frame(self):
|
||||
@ -46,15 +80,15 @@ class BaseTrack:
|
||||
return BaseTrack._count
|
||||
|
||||
def activate(self, *args):
|
||||
"""Activate the track with the provided arguments."""
|
||||
"""Abstract method to activate the track with provided arguments."""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self):
|
||||
"""Predict the next state of the track."""
|
||||
"""Abstract method to predict the next state of the track."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
"""Update the track with new observations."""
|
||||
"""Abstract method to update the track with new observations."""
|
||||
raise NotImplementedError
|
||||
|
||||
def mark_lost(self):
|
||||
|
@ -12,6 +12,34 @@ from .utils.kalman_filter import KalmanFilterXYWH
|
||||
|
||||
|
||||
class BOTrack(STrack):
|
||||
"""
|
||||
An extended version of the STrack class for YOLOv8, adding object tracking features.
|
||||
|
||||
Attributes:
|
||||
shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
|
||||
smooth_feat (np.ndarray): Smoothed feature vector.
|
||||
curr_feat (np.ndarray): Current feature vector.
|
||||
features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.
|
||||
alpha (float): Smoothing factor for the exponential moving average of features.
|
||||
mean (np.ndarray): The mean state of the Kalman filter.
|
||||
covariance (np.ndarray): The covariance matrix of the Kalman filter.
|
||||
|
||||
Methods:
|
||||
update_features(feat): Update features vector and smooth it using exponential moving average.
|
||||
predict(): Predicts the mean and covariance using Kalman filter.
|
||||
re_activate(new_track, frame_id, new_id): Reactivates a track with updated features and optionally new ID.
|
||||
update(new_track, frame_id): Update the YOLOv8 instance with new track and frame ID.
|
||||
tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.
|
||||
multi_predict(stracks): Predicts the mean and covariance of multiple object tracks using shared Kalman filter.
|
||||
convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
|
||||
tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
|
||||
|
||||
Usage:
|
||||
bo_track = BOTrack(tlwh, score, cls, feat)
|
||||
bo_track.predict()
|
||||
bo_track.update(new_track, frame_id)
|
||||
"""
|
||||
|
||||
shared_kalman = KalmanFilterXYWH()
|
||||
|
||||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
||||
@ -59,9 +87,7 @@ class BOTrack(STrack):
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
"""Get current position in bounding box format `(top left x, top left y, width, height)`."""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
@ -90,15 +116,37 @@ class BOTrack(STrack):
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xywh(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, width,
|
||||
height)`.
|
||||
"""
|
||||
"""Convert bounding box to format `(center x, center y, width, height)`."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
|
||||
class BOTSORT(BYTETracker):
|
||||
"""
|
||||
An extended version of the BYTETracker class for YOLOv8, designed for object tracking with ReID and GMC algorithm.
|
||||
|
||||
Attributes:
|
||||
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
|
||||
appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
|
||||
encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled.
|
||||
gmc (GMC): An instance of the GMC algorithm for data association.
|
||||
args (object): Parsed command-line arguments containing tracking parameters.
|
||||
|
||||
Methods:
|
||||
get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
|
||||
init_track(dets, scores, cls, img): Initialize track with detections, scores, and classes.
|
||||
get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
|
||||
multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
|
||||
|
||||
Usage:
|
||||
bot_sort = BOTSORT(args, frame_rate)
|
||||
bot_sort.init_track(dets, scores, cls, img)
|
||||
bot_sort.multi_predict(tracks)
|
||||
|
||||
Note:
|
||||
The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
|
||||
"""
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
|
||||
@ -110,8 +158,7 @@ class BOTSORT(BYTETracker):
|
||||
if args.with_reid:
|
||||
# Haven't supported BoT-SORT(reid) yet
|
||||
self.encoder = None
|
||||
|
||||
# self.gmc = GMC(method=args.gmc_method) # commented by WQG
|
||||
self.gmc = GMC(method=args.gmc_method)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns an instance of KalmanFilterXYWH for object tracking."""
|
||||
@ -130,7 +177,7 @@ class BOTSORT(BYTETracker):
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
dists_mask = (dists > self.proximity_thresh)
|
||||
dists_mask = dists > self.proximity_thresh
|
||||
|
||||
# TODO: mot20
|
||||
# if not self.args.mot20:
|
||||
@ -146,3 +193,8 @@ class BOTSORT(BYTETracker):
|
||||
def multi_predict(self, tracks):
|
||||
"""Predict and track multiple objects with YOLOv8 model."""
|
||||
BOTrack.multi_predict(tracks)
|
||||
|
||||
def reset(self):
|
||||
"""Reset tracker."""
|
||||
super().reset()
|
||||
self.gmc.reset_params()
|
||||
|
@ -1,29 +1,54 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .basetrack import BaseTrack, TrackState
|
||||
from .utils import matching
|
||||
from .utils.kalman_filter import KalmanFilterXYAH
|
||||
|
||||
|
||||
def dists_update(dists, strack_pool, detections):
|
||||
if len(strack_pool) and len(detections):
|
||||
alabel = np.array([int(stack.cls) for stack in strack_pool])
|
||||
blabel = np.array([int(stack.cls) for stack in detections])
|
||||
amlabel = np.expand_dims(alabel, axis=1).repeat(len(detections),axis=1)
|
||||
bmlabel = np.expand_dims(blabel, axis=0).repeat(len(strack_pool),axis=0)
|
||||
dist_label = 1 - (bmlabel == amlabel)
|
||||
dists = np.where(dists > dist_label, dists, dist_label)
|
||||
return dists
|
||||
from ..utils.ops import xywh2ltwh
|
||||
from ..utils import LOGGER
|
||||
|
||||
|
||||
class STrack(BaseTrack):
|
||||
"""
|
||||
Single object tracking representation that uses Kalman filtering for state estimation.
|
||||
|
||||
This class is responsible for storing all the information regarding individual tracklets and performs state updates
|
||||
and predictions based on Kalman filter.
|
||||
|
||||
Attributes:
|
||||
shared_kalman (KalmanFilterXYAH): Shared Kalman filter that is used across all STrack instances for prediction.
|
||||
_tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.
|
||||
kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.
|
||||
mean (np.ndarray): Mean state estimate vector.
|
||||
covariance (np.ndarray): Covariance of state estimate.
|
||||
is_activated (bool): Boolean flag indicating if the track has been activated.
|
||||
score (float): Confidence score of the track.
|
||||
tracklet_len (int): Length of the tracklet.
|
||||
cls (any): Class label for the object.
|
||||
idx (int): Index or identifier for the object.
|
||||
frame_id (int): Current frame ID.
|
||||
start_frame (int): Frame where the object was first detected.
|
||||
|
||||
Methods:
|
||||
predict(): Predict the next state of the object using Kalman filter.
|
||||
multi_predict(stracks): Predict the next states for multiple tracks.
|
||||
multi_gmc(stracks, H): Update multiple track states using a homography matrix.
|
||||
activate(kalman_filter, frame_id): Activate a new tracklet.
|
||||
re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
|
||||
update(new_track, frame_id): Update the state of a matched track.
|
||||
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
|
||||
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
|
||||
"""
|
||||
|
||||
shared_kalman = KalmanFilterXYAH()
|
||||
|
||||
def __init__(self, tlwh, score, cls):
|
||||
"""wait activate."""
|
||||
self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32)
|
||||
def __init__(self, xywh, score, cls):
|
||||
"""Initialize new STrack instance."""
|
||||
super().__init__()
|
||||
# xywh+idx or xywha+idx
|
||||
assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
|
||||
self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
self.is_activated = False
|
||||
@ -31,7 +56,8 @@ class STrack(BaseTrack):
|
||||
self.score = score
|
||||
self.tracklet_len = 0
|
||||
self.cls = cls
|
||||
self.idx = tlwh[-1]
|
||||
self.idx = xywh[-1]
|
||||
self.angle = xywh[4] if len(xywh) == 6 else None
|
||||
|
||||
def predict(self):
|
||||
"""Predicts mean and covariance using Kalman filter."""
|
||||
@ -89,8 +115,9 @@ class STrack(BaseTrack):
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
"""Reactivates a previously lost track with a new detection."""
|
||||
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
|
||||
self.convert_coords(new_track.tlwh))
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
||||
)
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
@ -99,37 +126,39 @@ class STrack(BaseTrack):
|
||||
self.track_id = self.next_id()
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
self.angle = new_track.angle
|
||||
self.idx = new_track.idx
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""
|
||||
Update a matched track
|
||||
:type new_track: STrack
|
||||
:type frame_id: int
|
||||
:return:
|
||||
Update the state of a matched track.
|
||||
|
||||
Args:
|
||||
new_track (STrack): The new track containing updated information.
|
||||
frame_id (int): The ID of the current frame.
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
|
||||
new_tlwh = new_track.tlwh
|
||||
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
|
||||
self.convert_coords(new_tlwh))
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.convert_coords(new_tlwh)
|
||||
)
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
self.angle = new_track.angle
|
||||
self.idx = new_track.idx
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
"""Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent."""
|
||||
"""Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
|
||||
return self.tlwh_to_xyah(tlwh)
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
"""Get current position in bounding box format (top left x, top left y, width, height)."""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
@ -138,44 +167,76 @@ class STrack(BaseTrack):
|
||||
return ret
|
||||
|
||||
@property
|
||||
def tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
def xyxy(self):
|
||||
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
|
||||
height.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
"""Converts top-left bottom-right format to top-left width height format."""
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
@property
|
||||
def xywh(self):
|
||||
"""Get current position in bounding box format (center x, center y, width, height)."""
|
||||
ret = np.asarray(self.tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
"""Converts tlwh bounding box format to tlbr format."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
@property
|
||||
def xywha(self):
|
||||
"""Get current position in bounding box format (center x, center y, width, height, angle)."""
|
||||
if self.angle is None:
|
||||
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
|
||||
return self.xywh
|
||||
return np.concatenate([self.xywh, self.angle[None]])
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
"""Get current tracking results."""
|
||||
coords = self.xyxy if self.angle is None else self.xywha
|
||||
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
|
||||
return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})'
|
||||
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
||||
|
||||
|
||||
class BYTETracker:
|
||||
"""
|
||||
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
||||
|
||||
The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
|
||||
sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
|
||||
predicting the new object locations, and performs data association.
|
||||
|
||||
Attributes:
|
||||
tracked_stracks (list[STrack]): List of successfully activated tracks.
|
||||
lost_stracks (list[STrack]): List of lost tracks.
|
||||
removed_stracks (list[STrack]): List of removed tracks.
|
||||
frame_id (int): The current frame ID.
|
||||
args (namespace): Command-line arguments.
|
||||
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
|
||||
kalman_filter (object): Kalman Filter object.
|
||||
|
||||
Methods:
|
||||
update(results, img=None): Updates object tracker with new detections.
|
||||
get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes.
|
||||
init_track(dets, scores, cls, img=None): Initialize object tracking with detections.
|
||||
get_dists(tracks, detections): Calculates the distance between tracks and detections.
|
||||
multi_predict(tracks): Predicts the location of tracks.
|
||||
reset_id(): Resets the ID counter of STrack.
|
||||
joint_stracks(tlista, tlistb): Combines two lists of stracks.
|
||||
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
|
||||
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
|
||||
"""
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
|
||||
@ -198,7 +259,7 @@ class BYTETracker:
|
||||
removed_stracks = []
|
||||
|
||||
scores = results.conf
|
||||
bboxes = results.xyxy
|
||||
bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
|
||||
# Add index
|
||||
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
|
||||
cls = results.cls
|
||||
@ -216,7 +277,6 @@ class BYTETracker:
|
||||
cls_second = cls[inds_second]
|
||||
|
||||
detections = self.init_track(dets, scores_keep, cls_keep, img)
|
||||
|
||||
# Add newly detected tracklets to tracked_stracks
|
||||
unconfirmed = []
|
||||
tracked_stracks = [] # type: list[STrack]
|
||||
@ -225,24 +285,18 @@ class BYTETracker:
|
||||
unconfirmed.append(track)
|
||||
else:
|
||||
tracked_stracks.append(track)
|
||||
|
||||
|
||||
# Step 2: First association, with high score detection boxes
|
||||
strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
# Predict the current location with KF
|
||||
self.multi_predict(strack_pool)
|
||||
|
||||
# ============================================================= 没必要gmc,WQG
|
||||
# if hasattr(self, 'gmc') and img is not None:
|
||||
# warp = self.gmc.apply(img, dets)
|
||||
# STrack.multi_gmc(strack_pool, warp)
|
||||
# STrack.multi_gmc(unconfirmed, warp)
|
||||
# =============================================================================
|
||||
if hasattr(self, "gmc") and img is not None:
|
||||
warp = self.gmc.apply(img, dets)
|
||||
STrack.multi_gmc(strack_pool, warp)
|
||||
STrack.multi_gmc(unconfirmed, warp)
|
||||
|
||||
dists = self.get_dists(strack_pool, detections)
|
||||
dists = dists_update(dists, strack_pool, detections)
|
||||
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)
|
||||
|
||||
for itracked, idet in matches:
|
||||
track = strack_pool[itracked]
|
||||
det = detections[idet]
|
||||
@ -252,17 +306,11 @@ class BYTETracker:
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
|
||||
# Step 3: Second association, with low score detection boxes
|
||||
# association the untrack to the low score detections
|
||||
# Step 3: Second association, with low score detection boxes association the untrack to the low score detections
|
||||
detections_second = self.init_track(dets_second, scores_second, cls_second, img)
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
||||
|
||||
# TODO
|
||||
dists = matching.iou_distance(r_tracked_stracks, detections_second)
|
||||
dists = dists_update(dists, r_tracked_stracks, detections_second)
|
||||
|
||||
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
|
||||
for itracked, idet in matches:
|
||||
track = r_tracked_stracks[itracked]
|
||||
@ -279,13 +327,9 @@ class BYTETracker:
|
||||
if track.state != TrackState.Lost:
|
||||
track.mark_lost()
|
||||
lost_stracks.append(track)
|
||||
|
||||
# Deal with unconfirmed tracks, usually tracks with only one beginning frame
|
||||
detections = [detections[i] for i in u_detection]
|
||||
dists = self.get_dists(unconfirmed, detections)
|
||||
|
||||
dists = dists_update(dists, unconfirmed, detections)
|
||||
|
||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
@ -317,9 +361,8 @@ class BYTETracker:
|
||||
self.removed_stracks.extend(removed_stracks)
|
||||
if len(self.removed_stracks) > 1000:
|
||||
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
|
||||
return np.asarray(
|
||||
[x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated],
|
||||
dtype=np.float32)
|
||||
|
||||
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns a Kalman filter object for tracking bounding boxes."""
|
||||
@ -330,7 +373,7 @@ class BYTETracker:
|
||||
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Calculates the distance between tracks and detections using IOU and fuses scores."""
|
||||
"""Calculates the distance between tracks and detections using IoU and fuses scores."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
# TODO: mot20
|
||||
# if not self.args.mot20:
|
||||
@ -341,10 +384,20 @@ class BYTETracker:
|
||||
"""Returns the predicted tracks using the YOLOv8 network."""
|
||||
STrack.multi_predict(tracks)
|
||||
|
||||
def reset_id(self):
|
||||
@staticmethod
|
||||
def reset_id():
|
||||
"""Resets the ID counter of STrack."""
|
||||
STrack.reset_id()
|
||||
|
||||
def reset(self):
|
||||
"""Reset tracker."""
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
self.frame_id = 0
|
||||
self.kalman_filter = self.get_kalmanfilter()
|
||||
self.reset_id()
|
||||
|
||||
@staticmethod
|
||||
def joint_stracks(tlista, tlistb):
|
||||
"""Combine two lists of stracks into a single one."""
|
||||
@ -375,7 +428,7 @@ class BYTETracker:
|
||||
|
||||
@staticmethod
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
"""Remove duplicate stracks with non-maximum IOU distance."""
|
||||
"""Remove duplicate stracks with non-maximum IoU distance."""
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = [], []
|
||||
|
@ -1,19 +1,20 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import IterableSimpleNamespace, yaml_load
|
||||
from ultralytics.utils.checks import check_yaml
|
||||
|
||||
from .bot_sort import BOTSORT
|
||||
from .byte_tracker import BYTETracker
|
||||
|
||||
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
|
||||
# A mapping of tracker types to corresponding tracker classes
|
||||
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
|
||||
|
||||
|
||||
def on_predict_start(predictor, persist=False):
|
||||
def on_predict_start(predictor: object, persist: bool = False) -> None:
|
||||
"""
|
||||
Initialize trackers for object tracking during prediction.
|
||||
|
||||
@ -24,43 +25,65 @@ def on_predict_start(predictor, persist=False):
|
||||
Raises:
|
||||
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
||||
"""
|
||||
if hasattr(predictor, 'trackers') and persist:
|
||||
if hasattr(predictor, "trackers") and persist:
|
||||
return
|
||||
|
||||
tracker = check_yaml(predictor.args.tracker)
|
||||
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
||||
assert cfg.tracker_type in ['bytetrack', 'botsort'], \
|
||||
f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
|
||||
|
||||
if cfg.tracker_type not in ["bytetrack", "botsort"]:
|
||||
raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
|
||||
|
||||
trackers = []
|
||||
for _ in range(predictor.dataset.bs):
|
||||
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
|
||||
trackers.append(tracker)
|
||||
if predictor.dataset.mode != "stream": # only need one tracker for other modes.
|
||||
break
|
||||
predictor.trackers = trackers
|
||||
predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video
|
||||
|
||||
|
||||
def on_predict_postprocess_end(predictor):
|
||||
"""Postprocess detected boxes and update with object tracking."""
|
||||
bs = predictor.dataset.bs
|
||||
im0s = predictor.batch[1]
|
||||
for i in range(bs):
|
||||
det = predictor.results[i].boxes.cpu().numpy()
|
||||
def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
|
||||
"""
|
||||
Postprocess detected boxes and update with object tracking.
|
||||
|
||||
Args:
|
||||
predictor (object): The predictor object containing the predictions.
|
||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||
"""
|
||||
path, im0s = predictor.batch[:2]
|
||||
|
||||
is_obb = predictor.args.task == "obb"
|
||||
is_stream = predictor.dataset.mode == "stream"
|
||||
for i in range(len(im0s)):
|
||||
tracker = predictor.trackers[i if is_stream else 0]
|
||||
vid_path = predictor.save_dir / Path(path[i]).name
|
||||
if not persist and predictor.vid_path[i if is_stream else 0] != vid_path:
|
||||
tracker.reset()
|
||||
predictor.vid_path[i if is_stream else 0] = vid_path
|
||||
|
||||
det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
|
||||
if len(det) == 0:
|
||||
continue
|
||||
tracks = predictor.trackers[i].update(det, im0s[i])
|
||||
tracks = tracker.update(det, im0s[i])
|
||||
if len(tracks) == 0:
|
||||
continue
|
||||
idx = tracks[:, -1].astype(int)
|
||||
predictor.results[i] = predictor.results[i][idx]
|
||||
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
|
||||
|
||||
update_args = dict()
|
||||
update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1])
|
||||
predictor.results[i].update(**update_args)
|
||||
|
||||
|
||||
def register_tracker(model, persist):
|
||||
def register_tracker(model: object, persist: bool) -> None:
|
||||
"""
|
||||
Register tracking callbacks to the model for object tracking during prediction.
|
||||
|
||||
Args:
|
||||
model (object): The model object to register tracking callbacks for.
|
||||
persist (bool): Whether to persist the trackers if they already exist.
|
||||
|
||||
"""
|
||||
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
|
||||
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
|
||||
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
|
||||
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -9,67 +9,121 @@ from ultralytics.utils import LOGGER
|
||||
|
||||
|
||||
class GMC:
|
||||
"""
|
||||
Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
|
||||
|
||||
def __init__(self, method='sparseOptFlow', downscale=2):
|
||||
"""Initialize a video tracker with specified parameters."""
|
||||
This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,
|
||||
SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.
|
||||
|
||||
Attributes:
|
||||
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
||||
downscale (int): Factor by which to downscale the frames for processing.
|
||||
prevFrame (np.ndarray): Stores the previous frame for tracking.
|
||||
prevKeyPoints (list): Stores the keypoints from the previous frame.
|
||||
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
|
||||
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
|
||||
|
||||
Methods:
|
||||
__init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method
|
||||
and downscale factor.
|
||||
apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses
|
||||
provided detections.
|
||||
applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame.
|
||||
applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame.
|
||||
applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
|
||||
"""
|
||||
|
||||
def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
|
||||
"""
|
||||
Initialize a video tracker with specified parameters.
|
||||
|
||||
Args:
|
||||
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
||||
downscale (int): Downscale factor for processing frames.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.method = method
|
||||
self.downscale = max(1, int(downscale))
|
||||
|
||||
if self.method == 'orb':
|
||||
if self.method == "orb":
|
||||
self.detector = cv2.FastFeatureDetector_create(20)
|
||||
self.extractor = cv2.ORB_create()
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
|
||||
|
||||
elif self.method == 'sift':
|
||||
elif self.method == "sift":
|
||||
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
|
||||
|
||||
elif self.method == 'ecc':
|
||||
elif self.method == "ecc":
|
||||
number_of_iterations = 5000
|
||||
termination_eps = 1e-6
|
||||
self.warp_mode = cv2.MOTION_EUCLIDEAN
|
||||
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
|
||||
|
||||
elif self.method == 'sparseOptFlow':
|
||||
self.feature_params = dict(maxCorners=1000,
|
||||
qualityLevel=0.01,
|
||||
minDistance=1,
|
||||
blockSize=3,
|
||||
useHarrisDetector=False,
|
||||
k=0.04)
|
||||
elif self.method == "sparseOptFlow":
|
||||
self.feature_params = dict(
|
||||
maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04
|
||||
)
|
||||
|
||||
elif self.method in ['none', 'None', None]:
|
||||
elif self.method in {"none", "None", None}:
|
||||
self.method = None
|
||||
else:
|
||||
raise ValueError(f'Error: Unknown GMC method:{method}')
|
||||
raise ValueError(f"Error: Unknown GMC method:{method}")
|
||||
|
||||
self.prevFrame = None
|
||||
self.prevKeyPoints = None
|
||||
self.prevDescriptors = None
|
||||
|
||||
self.initializedFirstFrame = False
|
||||
|
||||
def apply(self, raw_frame, detections=None):
|
||||
"""Apply object detection on a raw frame using specified method."""
|
||||
if self.method in ['orb', 'sift']:
|
||||
def apply(self, raw_frame: np.array, detections: list = None) -> np.array:
|
||||
"""
|
||||
Apply object detection on a raw frame using specified method.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
detections (list): List of detections to be used in the processing.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.apply(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
"""
|
||||
if self.method in ["orb", "sift"]:
|
||||
return self.applyFeatures(raw_frame, detections)
|
||||
elif self.method == 'ecc':
|
||||
return self.applyEcc(raw_frame, detections)
|
||||
elif self.method == 'sparseOptFlow':
|
||||
return self.applySparseOptFlow(raw_frame, detections)
|
||||
elif self.method == "ecc":
|
||||
return self.applyEcc(raw_frame)
|
||||
elif self.method == "sparseOptFlow":
|
||||
return self.applySparseOptFlow(raw_frame)
|
||||
else:
|
||||
return np.eye(2, 3)
|
||||
|
||||
def applyEcc(self, raw_frame, detections=None):
|
||||
"""Initialize."""
|
||||
def applyEcc(self, raw_frame: np.array) -> np.array:
|
||||
"""
|
||||
Apply ECC algorithm to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
"""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3, dtype=np.float32)
|
||||
|
||||
# Downscale image (TODO: consider using pyramids)
|
||||
# Downscale image
|
||||
if self.downscale > 1.0:
|
||||
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
@ -89,33 +143,46 @@ class GMC:
|
||||
# Run the ECC algorithm. The results are stored in warp_matrix.
|
||||
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
|
||||
try:
|
||||
(cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
|
||||
(_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING: find transform failed. Set warp as identity {e}')
|
||||
LOGGER.warning(f"WARNING: find transform failed. Set warp as identity {e}")
|
||||
|
||||
return H
|
||||
|
||||
def applyFeatures(self, raw_frame, detections=None):
|
||||
"""Initialize."""
|
||||
def applyFeatures(self, raw_frame: np.array, detections: list = None) -> np.array:
|
||||
"""
|
||||
Apply feature-based methods like ORB or SIFT to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
detections (list): List of detections to be used in the processing.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.applyFeatures(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
"""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image (TODO: consider using pyramids)
|
||||
# Downscale image
|
||||
if self.downscale > 1.0:
|
||||
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# Find the keypoints
|
||||
mask = np.zeros_like(frame)
|
||||
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
|
||||
mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(0.98 * width)] = 255
|
||||
mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
|
||||
if detections is not None:
|
||||
for det in detections:
|
||||
tlbr = (det[:4] / self.downscale).astype(np.int_)
|
||||
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
|
||||
mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
|
||||
|
||||
keypoints = self.detector.detect(frame, mask)
|
||||
|
||||
@ -134,10 +201,10 @@ class GMC:
|
||||
|
||||
return H
|
||||
|
||||
# Match descriptors.
|
||||
# Match descriptors
|
||||
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
|
||||
|
||||
# Filtered matches based on smallest spatial distance
|
||||
# Filter matches based on smallest spatial distance
|
||||
matches = []
|
||||
spatialDistances = []
|
||||
|
||||
@ -157,11 +224,14 @@ class GMC:
|
||||
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
|
||||
currKeyPointLocation = keypoints[m.trainIdx].pt
|
||||
|
||||
spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
|
||||
prevKeyPointLocation[1] - currKeyPointLocation[1])
|
||||
spatialDistance = (
|
||||
prevKeyPointLocation[0] - currKeyPointLocation[0],
|
||||
prevKeyPointLocation[1] - currKeyPointLocation[1],
|
||||
)
|
||||
|
||||
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
|
||||
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
|
||||
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and (
|
||||
np.abs(spatialDistance[1]) < maxSpatialDistance[1]
|
||||
):
|
||||
spatialDistances.append(spatialDistance)
|
||||
matches.append(m)
|
||||
|
||||
@ -187,7 +257,7 @@ class GMC:
|
||||
# import matplotlib.pyplot as plt
|
||||
# matches_img = np.hstack((self.prevFrame, frame))
|
||||
# matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
|
||||
# W = np.size(self.prevFrame, 1)
|
||||
# W = self.prevFrame.shape[1]
|
||||
# for m in goodMatches:
|
||||
# prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
|
||||
# curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
|
||||
@ -204,7 +274,7 @@ class GMC:
|
||||
# plt.show()
|
||||
|
||||
# Find rigid matrix
|
||||
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
|
||||
if prevPoints.shape[0] > 4:
|
||||
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
@ -212,7 +282,7 @@ class GMC:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
LOGGER.warning('WARNING: not enough matching points')
|
||||
LOGGER.warning("WARNING: not enough matching points")
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
@ -221,15 +291,28 @@ class GMC:
|
||||
|
||||
return H
|
||||
|
||||
def applySparseOptFlow(self, raw_frame, detections=None):
|
||||
"""Initialize."""
|
||||
def applySparseOptFlow(self, raw_frame: np.array) -> np.array:
|
||||
"""
|
||||
Apply Sparse Optical Flow method to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
"""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image
|
||||
if self.downscale > 1.0:
|
||||
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
|
||||
# Find the keypoints
|
||||
@ -237,17 +320,13 @@ class GMC:
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Find correspondences
|
||||
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
|
||||
matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
|
||||
|
||||
# Leave good correspondences only
|
||||
prevPoints = []
|
||||
@ -262,18 +341,23 @@ class GMC:
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Find rigid matrix
|
||||
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
|
||||
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == prevPoints.shape[0]):
|
||||
H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
LOGGER.warning('WARNING: not enough matching points')
|
||||
LOGGER.warning("WARNING: not enough matching points")
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
return H
|
||||
|
||||
def reset_params(self) -> None:
|
||||
"""Reset parameters."""
|
||||
self.prevFrame = None
|
||||
self.prevKeyPoints = None
|
||||
self.prevDescriptors = None
|
||||
self.initializedFirstFrame = False
|
||||
|
@ -8,8 +8,8 @@ class KalmanFilterXYAH:
|
||||
"""
|
||||
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y),
|
||||
aspect ratio a, height h, and their respective velocities.
|
||||
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect
|
||||
ratio a, height h, and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
|
||||
observation of the state space (linear observation model).
|
||||
@ -17,126 +17,126 @@ class KalmanFilterXYAH:
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
|
||||
ndim, dt = 4, 1.
|
||||
ndim, dt = 4, 1.0
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
# Create Kalman filter model matrices
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current state estimate. These weights control
|
||||
# the amount of uncertainty in the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
# the amount of uncertainty in the model.
|
||||
self._std_weight_position = 1.0 / 20
|
||||
self._std_weight_velocity = 1.0 / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
def initiate(self, measurement: np.ndarray) -> tuple:
|
||||
"""
|
||||
Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
Args:
|
||||
measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
|
||||
and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of
|
||||
the new track. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2,
|
||||
2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]]
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
1e-2,
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
1e-5,
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
def predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
|
||||
"""
|
||||
Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||
Args:
|
||||
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
|
||||
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||
initialized to 0 mean.
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-2,
|
||||
self._std_weight_position * mean[3],
|
||||
]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[3],
|
||||
1e-5,
|
||||
self._std_weight_velocity * mean[3],
|
||||
]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
# mean = np.dot(self._motion_mat, mean)
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
def project(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
|
||||
"""
|
||||
Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
Args:
|
||||
mean (ndarray): The state's mean vector (8 dimensional array).
|
||||
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state estimate.
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-1,
|
||||
self._std_weight_position * mean[3],
|
||||
]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
|
||||
"""
|
||||
Run Kalman filter prediction step (Vectorized version).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
|
||||
Args:
|
||||
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||
initialized to 0 mean.
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
|
||||
1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]]
|
||||
self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 3],
|
||||
1e-2 * np.ones_like(mean[:, 3]),
|
||||
self._std_weight_position * mean[:, 3],
|
||||
]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3],
|
||||
1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]]
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
1e-5 * np.ones_like(mean[:, 3]),
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
|
||||
@ -148,60 +148,57 @@ class KalmanFilterXYAH:
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray) -> tuple:
|
||||
"""
|
||||
Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center position, a the aspect
|
||||
ratio, and h the height of the bounding box.
|
||||
Args:
|
||||
mean (ndarray): The predicted state's mean vector (8 dimensional).
|
||||
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
|
||||
measurement (ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center
|
||||
position, a the aspect ratio, and h the height of the bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
|
||||
np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
kalman_gain = scipy.linalg.cho_solve(
|
||||
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False
|
||||
).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
||||
def gating_distance(
|
||||
self,
|
||||
mean: np.ndarray,
|
||||
covariance: np.ndarray,
|
||||
measurements: np.ndarray,
|
||||
only_position: bool = False,
|
||||
metric: str = "maha",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
|
||||
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
|
||||
otherwise 2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in format (x, y, a, h) where (x, y) is the bounding box
|
||||
center position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding box center position only.
|
||||
Args:
|
||||
mean (ndarray): Mean vector over the state distribution (8 dimensional).
|
||||
covariance (ndarray): Covariance of the state distribution (8x8 dimensional).
|
||||
measurements (ndarray): An Nx4 matrix of N measurements, each in format (x, y, a, h) where (x, y)
|
||||
is the bounding box center position, a the aspect ratio, and h the height.
|
||||
only_position (bool, optional): If True, distance computation is done with respect to the bounding box
|
||||
center position only. Defaults to False.
|
||||
metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the
|
||||
squared Euclidean distance and 'maha' for the squared Mahalanobis distance. Defaults to 'maha'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the squared Mahalanobis distance between
|
||||
(mean, covariance) and `measurements[i]`.
|
||||
Returns:
|
||||
(np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
|
||||
(mean, covariance) and `measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
@ -209,77 +206,79 @@ class KalmanFilterXYAH:
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
if metric == "gaussian":
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
elif metric == "maha":
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
|
||||
return np.sum(z * z, axis=0) # square maha
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
||||
raise ValueError("Invalid distance metric")
|
||||
|
||||
|
||||
class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||
"""
|
||||
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y),
|
||||
width w, height h, and their respective velocities.
|
||||
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
|
||||
w, height h, and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
|
||||
observation of the state space (linear observation model).
|
||||
"""
|
||||
|
||||
def initiate(self, measurement):
|
||||
def initiate(self, measurement: np.ndarray) -> tuple:
|
||||
"""
|
||||
Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, w, h) with center position (x, y), width w, and height h.
|
||||
Args:
|
||||
measurement (ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track.
|
||||
Unobserved velocities are initialized to 0 mean.
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of
|
||||
the new track. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3]]
|
||||
2 * self._std_weight_position * measurement[2],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[2],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
def predict(self, mean, covariance) -> tuple:
|
||||
"""
|
||||
Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||
Args:
|
||||
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
|
||||
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||
initialized to 0 mean.
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3],
|
||||
]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3]]
|
||||
self._std_weight_velocity * mean[2],
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[2],
|
||||
self._std_weight_velocity * mean[3],
|
||||
]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
@ -287,54 +286,53 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
def project(self, mean, covariance) -> tuple:
|
||||
"""
|
||||
Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
Args:
|
||||
mean (ndarray): The state's mean vector (8 dimensional array).
|
||||
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state estimate.
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3],
|
||||
]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
def multi_predict(self, mean, covariance) -> tuple:
|
||||
"""
|
||||
Run Kalman filter prediction step (Vectorized version).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
|
||||
Args:
|
||||
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||
initialized to 0 mean.
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3]]
|
||||
self._std_weight_position * mean[:, 2],
|
||||
self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 2],
|
||||
self._std_weight_position * mean[:, 3],
|
||||
]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3]]
|
||||
self._std_weight_velocity * mean[:, 2],
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 2],
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
|
||||
@ -346,23 +344,17 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
def update(self, mean, covariance, measurement) -> tuple:
|
||||
"""
|
||||
Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center position, w the width,
|
||||
and h the height of the bounding box.
|
||||
Args:
|
||||
mean (ndarray): The predicted state's mean vector (8 dimensional).
|
||||
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
|
||||
measurement (ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center
|
||||
position, w the width, and h the height of the bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
|
||||
"""
|
||||
return super().update(mean, covariance, measurement)
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
import scipy
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from ultralytics.utils.metrics import bbox_ioa
|
||||
from ultralytics.utils.metrics import bbox_ioa, batch_probiou
|
||||
|
||||
try:
|
||||
import lap # for linear_assignment
|
||||
@ -13,11 +13,11 @@ try:
|
||||
except (ImportError, AssertionError, AttributeError):
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements('lapx>=0.5.2') # update to lap package from https://github.com/rathaROG/lapx
|
||||
check_requirements("lapx>=0.5.2") # update to lap package from https://github.com/rathaROG/lapx
|
||||
import lap
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh, use_lap=True):
|
||||
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple:
|
||||
"""
|
||||
Perform linear assignment using scipy or lap.lapjv.
|
||||
|
||||
@ -27,19 +27,24 @@ def linear_assignment(cost_matrix, thresh, use_lap=True):
|
||||
use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing matched indices, unmatched indices from 'a', and unmatched indices from 'b'.
|
||||
Tuple with:
|
||||
- matched indices
|
||||
- unmatched indices from 'a'
|
||||
- unmatched indices from 'b'
|
||||
"""
|
||||
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
|
||||
if use_lap:
|
||||
# Use lap.lapjv
|
||||
# https://github.com/gatagat/lap
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
else:
|
||||
# Use scipy.optimize.linear_sum_assignment
|
||||
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
|
||||
x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y
|
||||
matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])
|
||||
@ -53,7 +58,7 @@ def linear_assignment(cost_matrix, thresh, use_lap=True):
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def iou_distance(atracks, btracks):
|
||||
def iou_distance(atracks: list, btracks: list) -> np.ndarray:
|
||||
"""
|
||||
Compute cost based on Intersection over Union (IoU) between tracks.
|
||||
|
||||
@ -65,23 +70,30 @@ def iou_distance(atracks, btracks):
|
||||
(np.ndarray): Cost matrix computed based on IoU.
|
||||
"""
|
||||
|
||||
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
|
||||
or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlbr for track in atracks]
|
||||
btlbrs = [track.tlbr for track in btracks]
|
||||
atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]
|
||||
btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]
|
||||
|
||||
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
|
||||
if len(atlbrs) and len(btlbrs):
|
||||
ious = bbox_ioa(np.ascontiguousarray(atlbrs, dtype=np.float32),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float32),
|
||||
iou=True)
|
||||
if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:
|
||||
ious = batch_probiou(
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float32),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float32),
|
||||
).numpy()
|
||||
else:
|
||||
ious = bbox_ioa(
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float32),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float32),
|
||||
iou=True,
|
||||
)
|
||||
return 1 - ious # cost matrix
|
||||
|
||||
|
||||
def embedding_distance(tracks, detections, metric='cosine'):
|
||||
def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
|
||||
"""
|
||||
Compute distance between tracks and detections based on embeddings.
|
||||
|
||||
@ -105,7 +117,7 @@ def embedding_distance(tracks, detections, metric='cosine'):
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_score(cost_matrix, detections):
|
||||
def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
|
||||
"""
|
||||
Fuses cost matrix with detection scores to produce a single similarity matrix.
|
||||
|
||||
|
Reference in New Issue
Block a user