add yolo v10 and modify pipeline

This commit is contained in:
王庆刚
2025-03-28 13:19:54 +08:00
parent 183299c06b
commit 798c596acc
471 changed files with 19109 additions and 7342 deletions

View File

@ -1,91 +1,318 @@
# Tracker
# Multi-Object Tracking with Ultralytics YOLO
## Supported Trackers
<img width="1024" src="https://user-images.githubusercontent.com/26833433/243418637-1d6250fd-1515-4c10-a844-a32818ae6d46.png" alt="YOLOv8 trackers visualization">
- [x] ByteTracker
- [x] BoT-SORT
Object tracking in the realm of video analytics is a critical task that not only identifies the location and class of objects within the frame but also maintains a unique ID for each detected object as the video progresses. The applications are limitless—ranging from surveillance and security to real-time sports analytics.
## Usage
## Why Choose Ultralytics YOLO for Object Tracking?
### python interface:
The output from Ultralytics trackers is consistent with standard object detection but has the added value of object IDs. This makes it easy to track objects in video streams and perform subsequent analytics. Here's why you should consider using Ultralytics YOLO for your object tracking needs:
You can use the Python interface to track objects using the YOLO model.
- **Efficiency:** Process video streams in real-time without compromising accuracy.
- **Flexibility:** Supports multiple tracking algorithms and configurations.
- **Ease of Use:** Simple Python API and CLI options for quick integration and deployment.
- **Customizability:** Easy to use with custom trained YOLO models, allowing integration into domain-specific applications.
**Video Tutorial:** [Object Detection and Tracking with Ultralytics YOLOv8](https://www.youtube.com/embed/hHyHmOtmEgs?si=VNZtXmm45Nb9s-N-).
## Features at a Glance
Ultralytics YOLO extends its object detection features to provide robust and versatile object tracking:
- **Real-Time Tracking:** Seamlessly track objects in high-frame-rate videos.
- **Multiple Tracker Support:** Choose from a variety of established tracking algorithms.
- **Customizable Tracker Configurations:** Tailor the tracking algorithm to meet specific requirements by adjusting various parameters.
## Available Trackers
Ultralytics YOLO supports the following tracking algorithms. They can be enabled by passing the relevant YAML configuration file such as `tracker=tracker_type.yaml`:
- [BoT-SORT](https://github.com/NirAharon/BoT-SORT) - Use `botsort.yaml` to enable this tracker.
- [ByteTrack](https://github.com/ifzhang/ByteTrack) - Use `bytetrack.yaml` to enable this tracker.
The default tracker is BoT-SORT.
## Tracking
To run the tracker on video streams, use a trained Detect, Segment or Pose model such as YOLOv8n, YOLOv8n-seg and YOLOv8n-pose.
#### Python
```python
from ultralytics import YOLO
model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt
model.track(
source="video/streams",
stream=True,
tracker="botsort.yaml", # or 'bytetrack.yaml'
show=True,
# Load an official or custom model
model = YOLO("yolov8n.pt") # Load an official Detect model
model = YOLO("yolov8n-seg.pt") # Load an official Segment model
model = YOLO("yolov8n-pose.pt") # Load an official Pose model
model = YOLO("path/to/best.pt") # Load a custom trained model
# Perform tracking with the model
results = model.track(
source="https://youtu.be/LNwODJXcvt4", show=True
) # Tracking with default tracker
results = model.track(
source="https://youtu.be/LNwODJXcvt4", show=True, tracker="bytetrack.yaml"
) # Tracking with ByteTrack tracker
```
#### CLI
```bash
# Perform tracking with various models using the command line interface
yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" # Official Detect model
yolo track model=yolov8n-seg.pt source="https://youtu.be/LNwODJXcvt4" # Official Segment model
yolo track model=yolov8n-pose.pt source="https://youtu.be/LNwODJXcvt4" # Official Pose model
yolo track model=path/to/best.pt source="https://youtu.be/LNwODJXcvt4" # Custom trained model
# Track using ByteTrack tracker
yolo track model=path/to/best.pt tracker="bytetrack.yaml"
```
As can be seen in the above usage, tracking is available for all Detect, Segment and Pose models run on videos or streaming sources.
## Configuration
### Tracking Arguments
Tracking configuration shares properties with Predict mode, such as `conf`, `iou`, and `show`. For further configurations, refer to the [Predict](https://docs.ultralytics.com/modes/predict/) model page.
#### Python
```python
from ultralytics import YOLO
# Configure the tracking parameters and run the tracker
model = YOLO("yolov8n.pt")
results = model.track(
source="https://youtu.be/LNwODJXcvt4", conf=0.3, iou=0.5, show=True
)
```
You can get the IDs of the tracked objects using the following code:
#### CLI
```bash
# Configure tracking parameters and run the tracker using the command line interface
yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" conf=0.3, iou=0.5 show
```
### Tracker Selection
Ultralytics also allows you to use a modified tracker configuration file. To do this, simply make a copy of a tracker config file (for example, `custom_tracker.yaml`) from [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) and modify any configurations (except the `tracker_type`) as per your needs.
#### Python
```python
from ultralytics import YOLO
# Load the model and run the tracker with a custom configuration file
model = YOLO("yolov8n.pt")
for result in model.track(source="video.mp4"):
print(
result.boxes.id.cpu().numpy().astype(int)
) # this will print the IDs of the tracked objects in the frame
results = model.track(
source="https://youtu.be/LNwODJXcvt4", tracker="custom_tracker.yaml"
)
```
If you want to use the tracker with a folder of images or when you loop on the video frames, you should use the `persist` parameter to tell the model that these frames are related to each other so the IDs will be fixed for the same objects. Otherwise, the IDs will be different in each frame because in each loop, the model creates a new object for tracking, but the `persist` parameter makes it use the same object for tracking.
#### CLI
```bash
# Load the model and run the tracker with a custom configuration file using the command line interface
yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" tracker='custom_tracker.yaml'
```
For a comprehensive list of tracking arguments, refer to the [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) page.
## Python Examples
### Persisting Tracks Loop
Here is a Python script using OpenCV (`cv2`) and YOLOv8 to run object tracking on video frames. This script still assumes you have already installed the necessary packages (`opencv-python` and `ultralytics`). The `persist=True` argument tells the tracker than the current image or frame is the next in a sequence and to expect tracks from the previous image in the current image.
#### Python
```python
import cv2
from ultralytics import YOLO
cap = cv2.VideoCapture("video.mp4")
# Load the YOLOv8 model
model = YOLO("yolov8n.pt")
while True:
ret, frame = cap.read()
if not ret:
break
results = model.track(frame, persist=True)
boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
ids = results[0].boxes.id.cpu().numpy().astype(int)
for box, id in zip(boxes, ids):
cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
cv2.putText(
frame,
f"Id {id}",
(box[0], box[1]),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 255),
2,
)
cv2.imshow("frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
# Open the video file
video_path = "path/to/video.mp4"
cap = cv2.VideoCapture(video_path)
# Loop through the video frames
while cap.isOpened():
# Read a frame from the video
success, frame = cap.read()
if success:
# Run YOLOv8 tracking on the frame, persisting tracks between frames
results = model.track(frame, persist=True)
# Visualize the results on the frame
annotated_frame = results[0].plot()
# Display the annotated frame
cv2.imshow("YOLOv8 Tracking", annotated_frame)
# Break the loop if 'q' is pressed
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
# Break the loop if the end of the video is reached
break
# Release the video capture object and close the display window
cap.release()
cv2.destroyAllWindows()
```
## Change tracker parameters
Please note the change from `model(frame)` to `model.track(frame)`, which enables object tracking instead of simple detection. This modified script will run the tracker on each frame of the video, visualize the results, and display them in a window. The loop can be exited by pressing 'q'.
You can change the tracker parameters by editing the `tracker.yaml` file which is located in the ultralytics/cfg/trackers folder.
### Plotting Tracks Over Time
## Command Line Interface (CLI)
Visualizing object tracks over consecutive frames can provide valuable insights into the movement patterns and behavior of detected objects within a video. With Ultralytics YOLOv8, plotting these tracks is a seamless and efficient process.
You can also use the command line interface to track objects using the YOLO model.
In the following example, we demonstrate how to utilize YOLOv8's tracking capabilities to plot the movement of detected objects across multiple video frames. This script involves opening a video file, reading it frame by frame, and utilizing the YOLO model to identify and track various objects. By retaining the center points of the detected bounding boxes and connecting them, we can draw lines that represent the paths followed by the tracked objects.
```bash
yolo detect track source=... tracker=...
yolo segment track source=... tracker=...
yolo pose track source=... tracker=...
#### Python
```python
from collections import defaultdict
import cv2
import numpy as np
from ultralytics import YOLO
# Load the YOLOv8 model
model = YOLO("yolov8n.pt")
# Open the video file
video_path = "path/to/video.mp4"
cap = cv2.VideoCapture(video_path)
# Store the track history
track_history = defaultdict(lambda: [])
# Loop through the video frames
while cap.isOpened():
# Read a frame from the video
success, frame = cap.read()
if success:
# Run YOLOv8 tracking on the frame, persisting tracks between frames
results = model.track(frame, persist=True)
# Get the boxes and track IDs
boxes = results[0].boxes.xywh.cpu()
track_ids = results[0].boxes.id.int().cpu().tolist()
# Visualize the results on the frame
annotated_frame = results[0].plot()
# Plot the tracks
for box, track_id in zip(boxes, track_ids):
x, y, w, h = box
track = track_history[track_id]
track.append((float(x), float(y))) # x, y center point
if len(track) > 30: # retain 90 tracks for 90 frames
track.pop(0)
# Draw the tracking lines
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(
annotated_frame,
[points],
isClosed=False,
color=(230, 230, 230),
thickness=10,
)
# Display the annotated frame
cv2.imshow("YOLOv8 Tracking", annotated_frame)
# Break the loop if 'q' is pressed
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
# Break the loop if the end of the video is reached
break
# Release the video capture object and close the display window
cap.release()
cv2.destroyAllWindows()
```
By default, trackers will use the configuration in `ultralytics/cfg/trackers`. We also support using a modified tracker config file. Please refer to the tracker config files in `ultralytics/cfg/trackers`.
### Multithreaded Tracking
## Contribute to Our Trackers Section
Multithreaded tracking provides the capability to run object tracking on multiple video streams simultaneously. This is particularly useful when handling multiple video inputs, such as from multiple surveillance cameras, where concurrent processing can greatly enhance efficiency and performance.
Are you proficient in multi-object tracking and have successfully implemented or adapted a tracking algorithm with Ultralytics YOLO? We invite you to contribute to our Trackers section! Your real-world applications and solutions could be invaluable for users working on tracking tasks.
In the provided Python script, we make use of Python's `threading` module to run multiple instances of the tracker concurrently. Each thread is responsible for running the tracker on one video file, and all the threads run simultaneously in the background.
To ensure that each thread receives the correct parameters (the video file and the model to use), we define a function `run_tracker_in_thread` that accepts these parameters and contains the main tracking loop. This function reads the video frame by frame, runs the tracker, and displays the results.
Two different models are used in this example: `yolov8n.pt` and `yolov8n-seg.pt`, each tracking objects in a different video file. The video files are specified in `video_file1` and `video_file2`.
The `daemon=True` parameter in `threading.Thread` means that these threads will be closed as soon as the main program finishes. We then start the threads with `start()` and use `join()` to make the main thread wait until both tracker threads have finished.
Finally, after all threads have completed their task, the windows displaying the results are closed using `cv2.destroyAllWindows()`.
#### Python
```python
import threading
import cv2
from ultralytics import YOLO
def run_tracker_in_thread(filename, model):
video = cv2.VideoCapture(filename)
frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
for _ in range(frames):
ret, frame = video.read()
if ret:
results = model.track(source=frame, persist=True)
res_plotted = results[0].plot()
cv2.imshow("p", res_plotted)
if cv2.waitKey(1) == ord("q"):
break
# Load the models
model1 = YOLO("yolov8n.pt")
model2 = YOLO("yolov8n-seg.pt")
# Define the video files for the trackers
video_file1 = "path/to/video1.mp4"
video_file2 = "path/to/video2.mp4"
# Create the tracker threads
tracker_thread1 = threading.Thread(
target=run_tracker_in_thread, args=(video_file1, model1), daemon=True
)
tracker_thread2 = threading.Thread(
target=run_tracker_in_thread, args=(video_file2, model2), daemon=True
)
# Start the tracker threads
tracker_thread1.start()
tracker_thread2.start()
# Wait for the tracker threads to finish
tracker_thread1.join()
tracker_thread2.join()
# Clean up and close windows
cv2.destroyAllWindows()
```
This example can easily be extended to handle more video files and models by creating more threads and applying the same methodology.
## Contribute New Trackers
Are you proficient in multi-object tracking and have successfully implemented or adapted a tracking algorithm with Ultralytics YOLO? We invite you to contribute to our Trackers section in [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers)! Your real-world applications and solutions could be invaluable for users working on tracking tasks.
By contributing to this section, you help expand the scope of tracking solutions available within the Ultralytics YOLO framework, adding another layer of functionality and utility for the community.

View File

@ -4,4 +4,4 @@ from .bot_sort import BOTSORT
from .byte_tracker import BYTETracker
from .track import register_tracker
__all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import
__all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import

View File

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""This module defines the base classes and structures for object tracking in YOLO."""
from collections import OrderedDict
@ -6,7 +7,15 @@ import numpy as np
class TrackState:
"""Enumeration of possible object tracking states."""
"""
Enumeration class representing the possible states of an object being tracked.
Attributes:
New (int): State when the object is newly detected.
Tracked (int): State when the object is successfully tracked in subsequent frames.
Lost (int): State when the object is no longer tracked.
Removed (int): State when the object is removed from tracking.
"""
New = 0
Tracked = 1
@ -15,24 +24,49 @@ class TrackState:
class BaseTrack:
"""Base class for object tracking, handling basic track attributes and operations."""
"""
Base class for object tracking, providing foundational attributes and methods.
Attributes:
_count (int): Class-level counter for unique track IDs.
track_id (int): Unique identifier for the track.
is_activated (bool): Flag indicating whether the track is currently active.
state (TrackState): Current state of the track.
history (OrderedDict): Ordered history of the track's states.
features (list): List of features extracted from the object for tracking.
curr_feature (any): The current feature of the object being tracked.
score (float): The confidence score of the tracking.
start_frame (int): The frame number where tracking started.
frame_id (int): The most recent frame ID processed by the track.
time_since_update (int): Frames passed since the last update.
location (tuple): The location of the object in the context of multi-camera tracking.
Methods:
end_frame: Returns the ID of the last frame where the object was tracked.
next_id: Increments and returns the next global track ID.
activate: Abstract method to activate the track.
predict: Abstract method to predict the next state of the track.
update: Abstract method to update the track with new data.
mark_lost: Marks the track as lost.
mark_removed: Marks the track as removed.
reset_id: Resets the global track ID counter.
"""
_count = 0
track_id = 0
is_activated = False
state = TrackState.New
history = OrderedDict()
features = []
curr_feature = None
score = 0
start_frame = 0
frame_id = 0
time_since_update = 0
# Multi-camera
location = (np.inf, np.inf)
def __init__(self):
"""Initializes a new track with unique ID and foundational tracking attributes."""
self.track_id = 0
self.is_activated = False
self.state = TrackState.New
self.history = OrderedDict()
self.features = []
self.curr_feature = None
self.score = 0
self.start_frame = 0
self.frame_id = 0
self.time_since_update = 0
self.location = (np.inf, np.inf)
@property
def end_frame(self):
@ -46,15 +80,15 @@ class BaseTrack:
return BaseTrack._count
def activate(self, *args):
"""Activate the track with the provided arguments."""
"""Abstract method to activate the track with provided arguments."""
raise NotImplementedError
def predict(self):
"""Predict the next state of the track."""
"""Abstract method to predict the next state of the track."""
raise NotImplementedError
def update(self, *args, **kwargs):
"""Update the track with new observations."""
"""Abstract method to update the track with new observations."""
raise NotImplementedError
def mark_lost(self):

View File

@ -12,6 +12,34 @@ from .utils.kalman_filter import KalmanFilterXYWH
class BOTrack(STrack):
"""
An extended version of the STrack class for YOLOv8, adding object tracking features.
Attributes:
shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
smooth_feat (np.ndarray): Smoothed feature vector.
curr_feat (np.ndarray): Current feature vector.
features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.
alpha (float): Smoothing factor for the exponential moving average of features.
mean (np.ndarray): The mean state of the Kalman filter.
covariance (np.ndarray): The covariance matrix of the Kalman filter.
Methods:
update_features(feat): Update features vector and smooth it using exponential moving average.
predict(): Predicts the mean and covariance using Kalman filter.
re_activate(new_track, frame_id, new_id): Reactivates a track with updated features and optionally new ID.
update(new_track, frame_id): Update the YOLOv8 instance with new track and frame ID.
tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.
multi_predict(stracks): Predicts the mean and covariance of multiple object tracks using shared Kalman filter.
convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
Usage:
bo_track = BOTrack(tlwh, score, cls, feat)
bo_track.predict()
bo_track.update(new_track, frame_id)
"""
shared_kalman = KalmanFilterXYWH()
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
@ -59,9 +87,7 @@ class BOTrack(STrack):
@property
def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
"""Get current position in bounding box format `(top left x, top left y, width, height)`."""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
@ -90,15 +116,37 @@ class BOTrack(STrack):
@staticmethod
def tlwh_to_xywh(tlwh):
"""Convert bounding box to format `(center x, center y, width,
height)`.
"""
"""Convert bounding box to format `(center x, center y, width, height)`."""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
return ret
class BOTSORT(BYTETracker):
"""
An extended version of the BYTETracker class for YOLOv8, designed for object tracking with ReID and GMC algorithm.
Attributes:
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled.
gmc (GMC): An instance of the GMC algorithm for data association.
args (object): Parsed command-line arguments containing tracking parameters.
Methods:
get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
init_track(dets, scores, cls, img): Initialize track with detections, scores, and classes.
get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
Usage:
bot_sort = BOTSORT(args, frame_rate)
bot_sort.init_track(dets, scores, cls, img)
bot_sort.multi_predict(tracks)
Note:
The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
"""
def __init__(self, args, frame_rate=30):
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
@ -110,8 +158,7 @@ class BOTSORT(BYTETracker):
if args.with_reid:
# Haven't supported BoT-SORT(reid) yet
self.encoder = None
# self.gmc = GMC(method=args.gmc_method) # commented by WQG
self.gmc = GMC(method=args.gmc_method)
def get_kalmanfilter(self):
"""Returns an instance of KalmanFilterXYWH for object tracking."""
@ -130,7 +177,7 @@ class BOTSORT(BYTETracker):
def get_dists(self, tracks, detections):
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
dists = matching.iou_distance(tracks, detections)
dists_mask = (dists > self.proximity_thresh)
dists_mask = dists > self.proximity_thresh
# TODO: mot20
# if not self.args.mot20:
@ -146,3 +193,8 @@ class BOTSORT(BYTETracker):
def multi_predict(self, tracks):
"""Predict and track multiple objects with YOLOv8 model."""
BOTrack.multi_predict(tracks)
def reset(self):
"""Reset tracker."""
super().reset()
self.gmc.reset_params()

View File

@ -1,29 +1,54 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# Ultralytics YOLO 🚀, AGPL-3.0 license
import numpy as np
from .basetrack import BaseTrack, TrackState
from .utils import matching
from .utils.kalman_filter import KalmanFilterXYAH
def dists_update(dists, strack_pool, detections):
if len(strack_pool) and len(detections):
alabel = np.array([int(stack.cls) for stack in strack_pool])
blabel = np.array([int(stack.cls) for stack in detections])
amlabel = np.expand_dims(alabel, axis=1).repeat(len(detections),axis=1)
bmlabel = np.expand_dims(blabel, axis=0).repeat(len(strack_pool),axis=0)
dist_label = 1 - (bmlabel == amlabel)
dists = np.where(dists > dist_label, dists, dist_label)
return dists
from ..utils.ops import xywh2ltwh
from ..utils import LOGGER
class STrack(BaseTrack):
"""
Single object tracking representation that uses Kalman filtering for state estimation.
This class is responsible for storing all the information regarding individual tracklets and performs state updates
and predictions based on Kalman filter.
Attributes:
shared_kalman (KalmanFilterXYAH): Shared Kalman filter that is used across all STrack instances for prediction.
_tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.
kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.
mean (np.ndarray): Mean state estimate vector.
covariance (np.ndarray): Covariance of state estimate.
is_activated (bool): Boolean flag indicating if the track has been activated.
score (float): Confidence score of the track.
tracklet_len (int): Length of the tracklet.
cls (any): Class label for the object.
idx (int): Index or identifier for the object.
frame_id (int): Current frame ID.
start_frame (int): Frame where the object was first detected.
Methods:
predict(): Predict the next state of the object using Kalman filter.
multi_predict(stracks): Predict the next states for multiple tracks.
multi_gmc(stracks, H): Update multiple track states using a homography matrix.
activate(kalman_filter, frame_id): Activate a new tracklet.
re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
update(new_track, frame_id): Update the state of a matched track.
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
"""
shared_kalman = KalmanFilterXYAH()
def __init__(self, tlwh, score, cls):
"""wait activate."""
self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32)
def __init__(self, xywh, score, cls):
"""Initialize new STrack instance."""
super().__init__()
# xywh+idx or xywha+idx
assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False
@ -31,7 +56,8 @@ class STrack(BaseTrack):
self.score = score
self.tracklet_len = 0
self.cls = cls
self.idx = tlwh[-1]
self.idx = xywh[-1]
self.angle = xywh[4] if len(xywh) == 6 else None
def predict(self):
"""Predicts mean and covariance using Kalman filter."""
@ -89,8 +115,9 @@ class STrack(BaseTrack):
def re_activate(self, new_track, frame_id, new_id=False):
"""Reactivates a previously lost track with a new detection."""
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
self.convert_coords(new_track.tlwh))
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
)
self.tracklet_len = 0
self.state = TrackState.Tracked
self.is_activated = True
@ -99,37 +126,39 @@ class STrack(BaseTrack):
self.track_id = self.next_id()
self.score = new_track.score
self.cls = new_track.cls
self.angle = new_track.angle
self.idx = new_track.idx
def update(self, new_track, frame_id):
"""
Update a matched track
:type new_track: STrack
:type frame_id: int
:return:
Update the state of a matched track.
Args:
new_track (STrack): The new track containing updated information.
frame_id (int): The ID of the current frame.
"""
self.frame_id = frame_id
self.tracklet_len += 1
new_tlwh = new_track.tlwh
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
self.convert_coords(new_tlwh))
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.convert_coords(new_tlwh)
)
self.state = TrackState.Tracked
self.is_activated = True
self.score = new_track.score
self.cls = new_track.cls
self.angle = new_track.angle
self.idx = new_track.idx
def convert_coords(self, tlwh):
"""Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent."""
"""Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
return self.tlwh_to_xyah(tlwh)
@property
def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
"""Get current position in bounding box format (top left x, top left y, width, height)."""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
@ -138,44 +167,76 @@ class STrack(BaseTrack):
return ret
@property
def tlbr(self):
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
def xyxy(self):
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
@staticmethod
def tlwh_to_xyah(tlwh):
"""Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
height.
"""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
return ret
@staticmethod
def tlbr_to_tlwh(tlbr):
"""Converts top-left bottom-right format to top-left width height format."""
ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2]
@property
def xywh(self):
"""Get current position in bounding box format (center x, center y, width, height)."""
ret = np.asarray(self.tlwh).copy()
ret[:2] += ret[2:] / 2
return ret
@staticmethod
def tlwh_to_tlbr(tlwh):
"""Converts tlwh bounding box format to tlbr format."""
ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2]
return ret
@property
def xywha(self):
"""Get current position in bounding box format (center x, center y, width, height, angle)."""
if self.angle is None:
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
return self.xywh
return np.concatenate([self.xywh, self.angle[None]])
@property
def result(self):
"""Get current tracking results."""
coords = self.xyxy if self.angle is None else self.xywha
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
def __repr__(self):
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})'
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
class BYTETracker:
"""
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
predicting the new object locations, and performs data association.
Attributes:
tracked_stracks (list[STrack]): List of successfully activated tracks.
lost_stracks (list[STrack]): List of lost tracks.
removed_stracks (list[STrack]): List of removed tracks.
frame_id (int): The current frame ID.
args (namespace): Command-line arguments.
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
kalman_filter (object): Kalman Filter object.
Methods:
update(results, img=None): Updates object tracker with new detections.
get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes.
init_track(dets, scores, cls, img=None): Initialize object tracking with detections.
get_dists(tracks, detections): Calculates the distance between tracks and detections.
multi_predict(tracks): Predicts the location of tracks.
reset_id(): Resets the ID counter of STrack.
joint_stracks(tlista, tlistb): Combines two lists of stracks.
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
"""
def __init__(self, args, frame_rate=30):
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
@ -198,7 +259,7 @@ class BYTETracker:
removed_stracks = []
scores = results.conf
bboxes = results.xyxy
bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
# Add index
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
cls = results.cls
@ -216,7 +277,6 @@ class BYTETracker:
cls_second = cls[inds_second]
detections = self.init_track(dets, scores_keep, cls_keep, img)
# Add newly detected tracklets to tracked_stracks
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
@ -225,24 +285,18 @@ class BYTETracker:
unconfirmed.append(track)
else:
tracked_stracks.append(track)
# Step 2: First association, with high score detection boxes
strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
self.multi_predict(strack_pool)
# ============================================================= 没必要gmcWQG
# if hasattr(self, 'gmc') and img is not None:
# warp = self.gmc.apply(img, dets)
# STrack.multi_gmc(strack_pool, warp)
# STrack.multi_gmc(unconfirmed, warp)
# =============================================================================
if hasattr(self, "gmc") and img is not None:
warp = self.gmc.apply(img, dets)
STrack.multi_gmc(strack_pool, warp)
STrack.multi_gmc(unconfirmed, warp)
dists = self.get_dists(strack_pool, detections)
dists = dists_update(dists, strack_pool, detections)
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)
for itracked, idet in matches:
track = strack_pool[itracked]
det = detections[idet]
@ -252,17 +306,11 @@ class BYTETracker:
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
# Step 3: Second association, with low score detection boxes
# association the untrack to the low score detections
# Step 3: Second association, with low score detection boxes association the untrack to the low score detections
detections_second = self.init_track(dets_second, scores_second, cls_second, img)
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
# TODO
dists = matching.iou_distance(r_tracked_stracks, detections_second)
dists = dists_update(dists, r_tracked_stracks, detections_second)
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
for itracked, idet in matches:
track = r_tracked_stracks[itracked]
@ -279,13 +327,9 @@ class BYTETracker:
if track.state != TrackState.Lost:
track.mark_lost()
lost_stracks.append(track)
# Deal with unconfirmed tracks, usually tracks with only one beginning frame
detections = [detections[i] for i in u_detection]
dists = self.get_dists(unconfirmed, detections)
dists = dists_update(dists, unconfirmed, detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
@ -317,9 +361,8 @@ class BYTETracker:
self.removed_stracks.extend(removed_stracks)
if len(self.removed_stracks) > 1000:
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
return np.asarray(
[x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated],
dtype=np.float32)
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
def get_kalmanfilter(self):
"""Returns a Kalman filter object for tracking bounding boxes."""
@ -330,7 +373,7 @@ class BYTETracker:
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
def get_dists(self, tracks, detections):
"""Calculates the distance between tracks and detections using IOU and fuses scores."""
"""Calculates the distance between tracks and detections using IoU and fuses scores."""
dists = matching.iou_distance(tracks, detections)
# TODO: mot20
# if not self.args.mot20:
@ -341,10 +384,20 @@ class BYTETracker:
"""Returns the predicted tracks using the YOLOv8 network."""
STrack.multi_predict(tracks)
def reset_id(self):
@staticmethod
def reset_id():
"""Resets the ID counter of STrack."""
STrack.reset_id()
def reset(self):
"""Reset tracker."""
self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
self.frame_id = 0
self.kalman_filter = self.get_kalmanfilter()
self.reset_id()
@staticmethod
def joint_stracks(tlista, tlistb):
"""Combine two lists of stracks into a single one."""
@ -375,7 +428,7 @@ class BYTETracker:
@staticmethod
def remove_duplicate_stracks(stracksa, stracksb):
"""Remove duplicate stracks with non-maximum IOU distance."""
"""Remove duplicate stracks with non-maximum IoU distance."""
pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15)
dupa, dupb = [], []

View File

@ -1,19 +1,20 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from functools import partial
from pathlib import Path
import torch
from ultralytics.utils import IterableSimpleNamespace, yaml_load
from ultralytics.utils.checks import check_yaml
from .bot_sort import BOTSORT
from .byte_tracker import BYTETracker
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
# A mapping of tracker types to corresponding tracker classes
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
def on_predict_start(predictor, persist=False):
def on_predict_start(predictor: object, persist: bool = False) -> None:
"""
Initialize trackers for object tracking during prediction.
@ -24,43 +25,65 @@ def on_predict_start(predictor, persist=False):
Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
"""
if hasattr(predictor, 'trackers') and persist:
if hasattr(predictor, "trackers") and persist:
return
tracker = check_yaml(predictor.args.tracker)
cfg = IterableSimpleNamespace(**yaml_load(tracker))
assert cfg.tracker_type in ['bytetrack', 'botsort'], \
f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
if cfg.tracker_type not in ["bytetrack", "botsort"]:
raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
trackers = []
for _ in range(predictor.dataset.bs):
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
trackers.append(tracker)
if predictor.dataset.mode != "stream": # only need one tracker for other modes.
break
predictor.trackers = trackers
predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video
def on_predict_postprocess_end(predictor):
"""Postprocess detected boxes and update with object tracking."""
bs = predictor.dataset.bs
im0s = predictor.batch[1]
for i in range(bs):
det = predictor.results[i].boxes.cpu().numpy()
def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
"""
Postprocess detected boxes and update with object tracking.
Args:
predictor (object): The predictor object containing the predictions.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
"""
path, im0s = predictor.batch[:2]
is_obb = predictor.args.task == "obb"
is_stream = predictor.dataset.mode == "stream"
for i in range(len(im0s)):
tracker = predictor.trackers[i if is_stream else 0]
vid_path = predictor.save_dir / Path(path[i]).name
if not persist and predictor.vid_path[i if is_stream else 0] != vid_path:
tracker.reset()
predictor.vid_path[i if is_stream else 0] = vid_path
det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
if len(det) == 0:
continue
tracks = predictor.trackers[i].update(det, im0s[i])
tracks = tracker.update(det, im0s[i])
if len(tracks) == 0:
continue
idx = tracks[:, -1].astype(int)
predictor.results[i] = predictor.results[i][idx]
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
update_args = dict()
update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1])
predictor.results[i].update(**update_args)
def register_tracker(model, persist):
def register_tracker(model: object, persist: bool) -> None:
"""
Register tracking callbacks to the model for object tracking during prediction.
Args:
model (object): The model object to register tracking callbacks for.
persist (bool): Whether to persist the trackers if they already exist.
"""
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))

View File

@ -9,67 +9,121 @@ from ultralytics.utils import LOGGER
class GMC:
"""
Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
def __init__(self, method='sparseOptFlow', downscale=2):
"""Initialize a video tracker with specified parameters."""
This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,
SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.
Attributes:
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
downscale (int): Factor by which to downscale the frames for processing.
prevFrame (np.ndarray): Stores the previous frame for tracking.
prevKeyPoints (list): Stores the keypoints from the previous frame.
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
Methods:
__init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method
and downscale factor.
apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses
provided detections.
applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame.
applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame.
applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
"""
def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
"""
Initialize a video tracker with specified parameters.
Args:
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
downscale (int): Downscale factor for processing frames.
"""
super().__init__()
self.method = method
self.downscale = max(1, int(downscale))
if self.method == 'orb':
if self.method == "orb":
self.detector = cv2.FastFeatureDetector_create(20)
self.extractor = cv2.ORB_create()
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
elif self.method == 'sift':
elif self.method == "sift":
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
elif self.method == 'ecc':
elif self.method == "ecc":
number_of_iterations = 5000
termination_eps = 1e-6
self.warp_mode = cv2.MOTION_EUCLIDEAN
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
elif self.method == 'sparseOptFlow':
self.feature_params = dict(maxCorners=1000,
qualityLevel=0.01,
minDistance=1,
blockSize=3,
useHarrisDetector=False,
k=0.04)
elif self.method == "sparseOptFlow":
self.feature_params = dict(
maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04
)
elif self.method in ['none', 'None', None]:
elif self.method in {"none", "None", None}:
self.method = None
else:
raise ValueError(f'Error: Unknown GMC method:{method}')
raise ValueError(f"Error: Unknown GMC method:{method}")
self.prevFrame = None
self.prevKeyPoints = None
self.prevDescriptors = None
self.initializedFirstFrame = False
def apply(self, raw_frame, detections=None):
"""Apply object detection on a raw frame using specified method."""
if self.method in ['orb', 'sift']:
def apply(self, raw_frame: np.array, detections: list = None) -> np.array:
"""
Apply object detection on a raw frame using specified method.
Args:
raw_frame (np.ndarray): The raw frame to be processed.
detections (list): List of detections to be used in the processing.
Returns:
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
>>> gmc.apply(np.array([[1, 2, 3], [4, 5, 6]]))
array([[1, 2, 3],
[4, 5, 6]])
"""
if self.method in ["orb", "sift"]:
return self.applyFeatures(raw_frame, detections)
elif self.method == 'ecc':
return self.applyEcc(raw_frame, detections)
elif self.method == 'sparseOptFlow':
return self.applySparseOptFlow(raw_frame, detections)
elif self.method == "ecc":
return self.applyEcc(raw_frame)
elif self.method == "sparseOptFlow":
return self.applySparseOptFlow(raw_frame)
else:
return np.eye(2, 3)
def applyEcc(self, raw_frame, detections=None):
"""Initialize."""
def applyEcc(self, raw_frame: np.array) -> np.array:
"""
Apply ECC algorithm to a raw frame.
Args:
raw_frame (np.ndarray): The raw frame to be processed.
Returns:
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
>>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]]))
array([[1, 2, 3],
[4, 5, 6]])
"""
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3, dtype=np.float32)
# Downscale image (TODO: consider using pyramids)
# Downscale image
if self.downscale > 1.0:
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
@ -89,33 +143,46 @@ class GMC:
# Run the ECC algorithm. The results are stored in warp_matrix.
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
try:
(cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
(_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
except Exception as e:
LOGGER.warning(f'WARNING: find transform failed. Set warp as identity {e}')
LOGGER.warning(f"WARNING: find transform failed. Set warp as identity {e}")
return H
def applyFeatures(self, raw_frame, detections=None):
"""Initialize."""
def applyFeatures(self, raw_frame: np.array, detections: list = None) -> np.array:
"""
Apply feature-based methods like ORB or SIFT to a raw frame.
Args:
raw_frame (np.ndarray): The raw frame to be processed.
detections (list): List of detections to be used in the processing.
Returns:
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
>>> gmc.applyFeatures(np.array([[1, 2, 3], [4, 5, 6]]))
array([[1, 2, 3],
[4, 5, 6]])
"""
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image (TODO: consider using pyramids)
# Downscale image
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# Find the keypoints
mask = np.zeros_like(frame)
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(0.98 * width)] = 255
mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
if detections is not None:
for det in detections:
tlbr = (det[:4] / self.downscale).astype(np.int_)
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
keypoints = self.detector.detect(frame, mask)
@ -134,10 +201,10 @@ class GMC:
return H
# Match descriptors.
# Match descriptors
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
# Filtered matches based on smallest spatial distance
# Filter matches based on smallest spatial distance
matches = []
spatialDistances = []
@ -157,11 +224,14 @@ class GMC:
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
currKeyPointLocation = keypoints[m.trainIdx].pt
spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
prevKeyPointLocation[1] - currKeyPointLocation[1])
spatialDistance = (
prevKeyPointLocation[0] - currKeyPointLocation[0],
prevKeyPointLocation[1] - currKeyPointLocation[1],
)
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and (
np.abs(spatialDistance[1]) < maxSpatialDistance[1]
):
spatialDistances.append(spatialDistance)
matches.append(m)
@ -187,7 +257,7 @@ class GMC:
# import matplotlib.pyplot as plt
# matches_img = np.hstack((self.prevFrame, frame))
# matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
# W = np.size(self.prevFrame, 1)
# W = self.prevFrame.shape[1]
# for m in goodMatches:
# prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
# curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
@ -204,7 +274,7 @@ class GMC:
# plt.show()
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
if prevPoints.shape[0] > 4:
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
# Handle downscale
@ -212,7 +282,7 @@ class GMC:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
LOGGER.warning('WARNING: not enough matching points')
LOGGER.warning("WARNING: not enough matching points")
# Store to next iteration
self.prevFrame = frame.copy()
@ -221,15 +291,28 @@ class GMC:
return H
def applySparseOptFlow(self, raw_frame, detections=None):
"""Initialize."""
def applySparseOptFlow(self, raw_frame: np.array) -> np.array:
"""
Apply Sparse Optical Flow method to a raw frame.
Args:
raw_frame (np.ndarray): The raw frame to be processed.
Returns:
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
>>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]]))
array([[1, 2, 3],
[4, 5, 6]])
"""
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
# Find the keypoints
@ -237,17 +320,13 @@ class GMC:
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
# Initialization done
self.initializedFirstFrame = True
return H
# Find correspondences
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
# Leave good correspondences only
prevPoints = []
@ -262,18 +341,23 @@ class GMC:
currPoints = np.array(currPoints)
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == prevPoints.shape[0]):
H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
# Handle downscale
if self.downscale > 1.0:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
LOGGER.warning('WARNING: not enough matching points')
LOGGER.warning("WARNING: not enough matching points")
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
return H
def reset_params(self) -> None:
"""Reset parameters."""
self.prevFrame = None
self.prevKeyPoints = None
self.prevDescriptors = None
self.initializedFirstFrame = False

View File

@ -8,8 +8,8 @@ class KalmanFilterXYAH:
"""
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y),
aspect ratio a, height h, and their respective velocities.
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect
ratio a, height h, and their respective velocities.
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
observation of the state space (linear observation model).
@ -17,126 +17,126 @@ class KalmanFilterXYAH:
def __init__(self):
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
ndim, dt = 4, 1.
ndim, dt = 4, 1.0
# Create Kalman filter model matrices.
# Create Kalman filter model matrices
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim)
# Motion and observation uncertainty are chosen relative to the current state estimate. These weights control
# the amount of uncertainty in the model. This is a bit hacky.
self._std_weight_position = 1. / 20
self._std_weight_velocity = 1. / 160
# the amount of uncertainty in the model.
self._std_weight_position = 1.0 / 20
self._std_weight_velocity = 1.0 / 160
def initiate(self, measurement):
def initiate(self, measurement: np.ndarray) -> tuple:
"""
Create track from unassociated measurement.
Parameters
----------
measurement : ndarray
Bounding box coordinates (x, y, a, h) with center position (x, y),
aspect ratio a, and height h.
Args:
measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
and height h.
Returns
-------
(ndarray, ndarray)
Returns the mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are initialized
to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of
the new track. Unobserved velocities are initialized to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
std = [
2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2,
2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]]
2 * self._std_weight_position * measurement[3],
2 * self._std_weight_position * measurement[3],
1e-2,
2 * self._std_weight_position * measurement[3],
10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[3],
1e-5,
10 * self._std_weight_velocity * measurement[3],
]
covariance = np.diag(np.square(std))
return mean, covariance
def predict(self, mean, covariance):
def predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
"""
Run Kalman filter prediction step.
Parameters
----------
mean : ndarray
The 8 dimensional mean vector of the object state at the previous time step.
covariance : ndarray
The 8x8 dimensional covariance matrix of the object state at the previous time step.
Args:
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
self._std_weight_position * mean[3]]
self._std_weight_position * mean[3],
self._std_weight_position * mean[3],
1e-2,
self._std_weight_position * mean[3],
]
std_vel = [
self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5,
self._std_weight_velocity * mean[3]]
self._std_weight_velocity * mean[3],
self._std_weight_velocity * mean[3],
1e-5,
self._std_weight_velocity * mean[3],
]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
# mean = np.dot(self._motion_mat, mean)
mean = np.dot(mean, self._motion_mat.T)
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
return mean, covariance
def project(self, mean, covariance):
def project(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
"""
Project state distribution to measurement space.
Parameters
----------
mean : ndarray
The state's mean vector (8 dimensional array).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
Args:
mean (ndarray): The state's mean vector (8 dimensional array).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
Returns
-------
(ndarray, ndarray)
Returns the projected mean and covariance matrix of the given state estimate.
Returns:
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
self._std_weight_position * mean[3]]
self._std_weight_position * mean[3],
self._std_weight_position * mean[3],
1e-1,
self._std_weight_position * mean[3],
]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
"""
Run Kalman filter prediction step (Vectorized version).
Parameters
----------
mean : ndarray
The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance : ndarray
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]]
self._std_weight_position * mean[:, 3],
self._std_weight_position * mean[:, 3],
1e-2 * np.ones_like(mean[:, 3]),
self._std_weight_position * mean[:, 3],
]
std_vel = [
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3],
1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]]
self._std_weight_velocity * mean[:, 3],
self._std_weight_velocity * mean[:, 3],
1e-5 * np.ones_like(mean[:, 3]),
self._std_weight_velocity * mean[:, 3],
]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
@ -148,60 +148,57 @@ class KalmanFilterXYAH:
return mean, covariance
def update(self, mean, covariance, measurement):
def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray) -> tuple:
"""
Run Kalman filter correction step.
Parameters
----------
mean : ndarray
The predicted state's mean vector (8 dimensional).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
measurement : ndarray
The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center position, a the aspect
ratio, and h the height of the bounding box.
Args:
mean (ndarray): The predicted state's mean vector (8 dimensional).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
measurement (ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center
position, a the aspect ratio, and h the height of the bounding box.
Returns
-------
(ndarray, ndarray)
Returns the measurement-corrected state distribution.
Returns:
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
"""
projected_mean, projected_cov = self.project(mean, covariance)
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
np.dot(covariance, self._update_mat.T).T,
check_finite=False).T
kalman_gain = scipy.linalg.cho_solve(
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False
).T
innovation = measurement - projected_mean
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
return new_mean, new_covariance
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
def gating_distance(
self,
mean: np.ndarray,
covariance: np.ndarray,
measurements: np.ndarray,
only_position: bool = False,
metric: str = "maha",
) -> np.ndarray:
"""
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
otherwise 2.
Parameters
----------
mean : ndarray
Mean vector over the state distribution (8 dimensional).
covariance : ndarray
Covariance of the state distribution (8x8 dimensional).
measurements : ndarray
An Nx4 dimensional matrix of N measurements, each in format (x, y, a, h) where (x, y) is the bounding box
center position, a the aspect ratio, and h the height.
only_position : Optional[bool]
If True, distance computation is done with respect to the bounding box center position only.
Args:
mean (ndarray): Mean vector over the state distribution (8 dimensional).
covariance (ndarray): Covariance of the state distribution (8x8 dimensional).
measurements (ndarray): An Nx4 matrix of N measurements, each in format (x, y, a, h) where (x, y)
is the bounding box center position, a the aspect ratio, and h the height.
only_position (bool, optional): If True, distance computation is done with respect to the bounding box
center position only. Defaults to False.
metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the
squared Euclidean distance and 'maha' for the squared Mahalanobis distance. Defaults to 'maha'.
Returns
-------
ndarray
Returns an array of length N, where the i-th element contains the squared Mahalanobis distance between
(mean, covariance) and `measurements[i]`.
Returns:
(np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
(mean, covariance) and `measurements[i]`.
"""
mean, covariance = self.project(mean, covariance)
if only_position:
@ -209,77 +206,79 @@ class KalmanFilterXYAH:
measurements = measurements[:, :2]
d = measurements - mean
if metric == 'gaussian':
if metric == "gaussian":
return np.sum(d * d, axis=1)
elif metric == 'maha':
elif metric == "maha":
cholesky_factor = np.linalg.cholesky(covariance)
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
return np.sum(z * z, axis=0) # square maha
else:
raise ValueError('invalid distance metric')
raise ValueError("Invalid distance metric")
class KalmanFilterXYWH(KalmanFilterXYAH):
"""
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y),
width w, height h, and their respective velocities.
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
w, height h, and their respective velocities.
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
observation of the state space (linear observation model).
"""
def initiate(self, measurement):
def initiate(self, measurement: np.ndarray) -> tuple:
"""
Create track from unassociated measurement.
Parameters
----------
measurement : ndarray
Bounding box coordinates (x, y, w, h) with center position (x, y), width w, and height h.
Args:
measurement (ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
Returns
-------
(ndarray, ndarray)
Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track.
Unobserved velocities are initialized to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of
the new track. Unobserved velocities are initialized to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
std = [
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3]]
2 * self._std_weight_position * measurement[2],
2 * self._std_weight_position * measurement[3],
2 * self._std_weight_position * measurement[2],
2 * self._std_weight_position * measurement[3],
10 * self._std_weight_velocity * measurement[2],
10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[2],
10 * self._std_weight_velocity * measurement[3],
]
covariance = np.diag(np.square(std))
return mean, covariance
def predict(self, mean, covariance):
def predict(self, mean, covariance) -> tuple:
"""
Run Kalman filter prediction step.
Parameters
----------
mean : ndarray
The 8 dimensional mean vector of the object state at the previous time step.
covariance : ndarray
The 8x8 dimensional covariance matrix of the object state at the previous time step.
Args:
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
self._std_weight_position * mean[2],
self._std_weight_position * mean[3],
self._std_weight_position * mean[2],
self._std_weight_position * mean[3],
]
std_vel = [
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3],
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3]]
self._std_weight_velocity * mean[2],
self._std_weight_velocity * mean[3],
self._std_weight_velocity * mean[2],
self._std_weight_velocity * mean[3],
]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
mean = np.dot(mean, self._motion_mat.T)
@ -287,54 +286,53 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
return mean, covariance
def project(self, mean, covariance):
def project(self, mean, covariance) -> tuple:
"""
Project state distribution to measurement space.
Parameters
----------
mean : ndarray
The state's mean vector (8 dimensional array).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
Args:
mean (ndarray): The state's mean vector (8 dimensional array).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
Returns
-------
(ndarray, ndarray)
Returns the projected mean and covariance matrix of the given state estimate.
Returns:
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
self._std_weight_position * mean[2],
self._std_weight_position * mean[3],
self._std_weight_position * mean[2],
self._std_weight_position * mean[3],
]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
def multi_predict(self, mean, covariance) -> tuple:
"""
Run Kalman filter prediction step (Vectorized version).
Parameters
----------
mean : ndarray
The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance : ndarray
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3]]
self._std_weight_position * mean[:, 2],
self._std_weight_position * mean[:, 3],
self._std_weight_position * mean[:, 2],
self._std_weight_position * mean[:, 3],
]
std_vel = [
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3],
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3]]
self._std_weight_velocity * mean[:, 2],
self._std_weight_velocity * mean[:, 3],
self._std_weight_velocity * mean[:, 2],
self._std_weight_velocity * mean[:, 3],
]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
@ -346,23 +344,17 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
return mean, covariance
def update(self, mean, covariance, measurement):
def update(self, mean, covariance, measurement) -> tuple:
"""
Run Kalman filter correction step.
Parameters
----------
mean : ndarray
The predicted state's mean vector (8 dimensional).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
measurement : ndarray
The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center position, w the width,
and h the height of the bounding box.
Args:
mean (ndarray): The predicted state's mean vector (8 dimensional).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
measurement (ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center
position, w the width, and h the height of the bounding box.
Returns
-------
(ndarray, ndarray)
Returns the measurement-corrected state distribution.
Returns:
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
"""
return super().update(mean, covariance, measurement)

View File

@ -4,7 +4,7 @@ import numpy as np
import scipy
from scipy.spatial.distance import cdist
from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.metrics import bbox_ioa, batch_probiou
try:
import lap # for linear_assignment
@ -13,11 +13,11 @@ try:
except (ImportError, AssertionError, AttributeError):
from ultralytics.utils.checks import check_requirements
check_requirements('lapx>=0.5.2') # update to lap package from https://github.com/rathaROG/lapx
check_requirements("lapx>=0.5.2") # update to lap package from https://github.com/rathaROG/lapx
import lap
def linear_assignment(cost_matrix, thresh, use_lap=True):
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple:
"""
Perform linear assignment using scipy or lap.lapjv.
@ -27,19 +27,24 @@ def linear_assignment(cost_matrix, thresh, use_lap=True):
use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True.
Returns:
(tuple): Tuple containing matched indices, unmatched indices from 'a', and unmatched indices from 'b'.
Tuple with:
- matched indices
- unmatched indices from 'a'
- unmatched indices from 'b'
"""
if cost_matrix.size == 0:
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
if use_lap:
# Use lap.lapjv
# https://github.com/gatagat/lap
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
unmatched_a = np.where(x < 0)[0]
unmatched_b = np.where(y < 0)[0]
else:
# Use scipy.optimize.linear_sum_assignment
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y
matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])
@ -53,7 +58,7 @@ def linear_assignment(cost_matrix, thresh, use_lap=True):
return matches, unmatched_a, unmatched_b
def iou_distance(atracks, btracks):
def iou_distance(atracks: list, btracks: list) -> np.ndarray:
"""
Compute cost based on Intersection over Union (IoU) between tracks.
@ -65,23 +70,30 @@ def iou_distance(atracks, btracks):
(np.ndarray): Cost matrix computed based on IoU.
"""
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
atlbrs = atracks
btlbrs = btracks
else:
atlbrs = [track.tlbr for track in atracks]
btlbrs = [track.tlbr for track in btracks]
atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]
btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if len(atlbrs) and len(btlbrs):
ious = bbox_ioa(np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32),
iou=True)
if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:
ious = batch_probiou(
np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32),
).numpy()
else:
ious = bbox_ioa(
np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32),
iou=True,
)
return 1 - ious # cost matrix
def embedding_distance(tracks, detections, metric='cosine'):
def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
"""
Compute distance between tracks and detections based on embeddings.
@ -105,7 +117,7 @@ def embedding_distance(tracks, detections, metric='cosine'):
return cost_matrix
def fuse_score(cost_matrix, detections):
def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
"""
Fuses cost matrix with detection scores to produce a single similarity matrix.