add yolo v10 and modify pipeline

This commit is contained in:
王庆刚
2025-03-28 13:19:54 +08:00
parent 183299c06b
commit 798c596acc
471 changed files with 19109 additions and 7342 deletions

View File

@ -9,67 +9,121 @@ from ultralytics.utils import LOGGER
class GMC:
"""
Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
def __init__(self, method='sparseOptFlow', downscale=2):
"""Initialize a video tracker with specified parameters."""
This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,
SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.
Attributes:
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
downscale (int): Factor by which to downscale the frames for processing.
prevFrame (np.ndarray): Stores the previous frame for tracking.
prevKeyPoints (list): Stores the keypoints from the previous frame.
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
Methods:
__init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method
and downscale factor.
apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses
provided detections.
applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame.
applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame.
applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
"""
def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
"""
Initialize a video tracker with specified parameters.
Args:
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
downscale (int): Downscale factor for processing frames.
"""
super().__init__()
self.method = method
self.downscale = max(1, int(downscale))
if self.method == 'orb':
if self.method == "orb":
self.detector = cv2.FastFeatureDetector_create(20)
self.extractor = cv2.ORB_create()
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
elif self.method == 'sift':
elif self.method == "sift":
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
elif self.method == 'ecc':
elif self.method == "ecc":
number_of_iterations = 5000
termination_eps = 1e-6
self.warp_mode = cv2.MOTION_EUCLIDEAN
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
elif self.method == 'sparseOptFlow':
self.feature_params = dict(maxCorners=1000,
qualityLevel=0.01,
minDistance=1,
blockSize=3,
useHarrisDetector=False,
k=0.04)
elif self.method == "sparseOptFlow":
self.feature_params = dict(
maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04
)
elif self.method in ['none', 'None', None]:
elif self.method in {"none", "None", None}:
self.method = None
else:
raise ValueError(f'Error: Unknown GMC method:{method}')
raise ValueError(f"Error: Unknown GMC method:{method}")
self.prevFrame = None
self.prevKeyPoints = None
self.prevDescriptors = None
self.initializedFirstFrame = False
def apply(self, raw_frame, detections=None):
"""Apply object detection on a raw frame using specified method."""
if self.method in ['orb', 'sift']:
def apply(self, raw_frame: np.array, detections: list = None) -> np.array:
"""
Apply object detection on a raw frame using specified method.
Args:
raw_frame (np.ndarray): The raw frame to be processed.
detections (list): List of detections to be used in the processing.
Returns:
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
>>> gmc.apply(np.array([[1, 2, 3], [4, 5, 6]]))
array([[1, 2, 3],
[4, 5, 6]])
"""
if self.method in ["orb", "sift"]:
return self.applyFeatures(raw_frame, detections)
elif self.method == 'ecc':
return self.applyEcc(raw_frame, detections)
elif self.method == 'sparseOptFlow':
return self.applySparseOptFlow(raw_frame, detections)
elif self.method == "ecc":
return self.applyEcc(raw_frame)
elif self.method == "sparseOptFlow":
return self.applySparseOptFlow(raw_frame)
else:
return np.eye(2, 3)
def applyEcc(self, raw_frame, detections=None):
"""Initialize."""
def applyEcc(self, raw_frame: np.array) -> np.array:
"""
Apply ECC algorithm to a raw frame.
Args:
raw_frame (np.ndarray): The raw frame to be processed.
Returns:
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
>>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]]))
array([[1, 2, 3],
[4, 5, 6]])
"""
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3, dtype=np.float32)
# Downscale image (TODO: consider using pyramids)
# Downscale image
if self.downscale > 1.0:
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
@ -89,33 +143,46 @@ class GMC:
# Run the ECC algorithm. The results are stored in warp_matrix.
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
try:
(cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
(_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
except Exception as e:
LOGGER.warning(f'WARNING: find transform failed. Set warp as identity {e}')
LOGGER.warning(f"WARNING: find transform failed. Set warp as identity {e}")
return H
def applyFeatures(self, raw_frame, detections=None):
"""Initialize."""
def applyFeatures(self, raw_frame: np.array, detections: list = None) -> np.array:
"""
Apply feature-based methods like ORB or SIFT to a raw frame.
Args:
raw_frame (np.ndarray): The raw frame to be processed.
detections (list): List of detections to be used in the processing.
Returns:
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
>>> gmc.applyFeatures(np.array([[1, 2, 3], [4, 5, 6]]))
array([[1, 2, 3],
[4, 5, 6]])
"""
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image (TODO: consider using pyramids)
# Downscale image
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# Find the keypoints
mask = np.zeros_like(frame)
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(0.98 * width)] = 255
mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
if detections is not None:
for det in detections:
tlbr = (det[:4] / self.downscale).astype(np.int_)
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
keypoints = self.detector.detect(frame, mask)
@ -134,10 +201,10 @@ class GMC:
return H
# Match descriptors.
# Match descriptors
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
# Filtered matches based on smallest spatial distance
# Filter matches based on smallest spatial distance
matches = []
spatialDistances = []
@ -157,11 +224,14 @@ class GMC:
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
currKeyPointLocation = keypoints[m.trainIdx].pt
spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
prevKeyPointLocation[1] - currKeyPointLocation[1])
spatialDistance = (
prevKeyPointLocation[0] - currKeyPointLocation[0],
prevKeyPointLocation[1] - currKeyPointLocation[1],
)
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and (
np.abs(spatialDistance[1]) < maxSpatialDistance[1]
):
spatialDistances.append(spatialDistance)
matches.append(m)
@ -187,7 +257,7 @@ class GMC:
# import matplotlib.pyplot as plt
# matches_img = np.hstack((self.prevFrame, frame))
# matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
# W = np.size(self.prevFrame, 1)
# W = self.prevFrame.shape[1]
# for m in goodMatches:
# prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
# curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
@ -204,7 +274,7 @@ class GMC:
# plt.show()
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
if prevPoints.shape[0] > 4:
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
# Handle downscale
@ -212,7 +282,7 @@ class GMC:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
LOGGER.warning('WARNING: not enough matching points')
LOGGER.warning("WARNING: not enough matching points")
# Store to next iteration
self.prevFrame = frame.copy()
@ -221,15 +291,28 @@ class GMC:
return H
def applySparseOptFlow(self, raw_frame, detections=None):
"""Initialize."""
def applySparseOptFlow(self, raw_frame: np.array) -> np.array:
"""
Apply Sparse Optical Flow method to a raw frame.
Args:
raw_frame (np.ndarray): The raw frame to be processed.
Returns:
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
>>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]]))
array([[1, 2, 3],
[4, 5, 6]])
"""
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
# Find the keypoints
@ -237,17 +320,13 @@ class GMC:
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
# Initialization done
self.initializedFirstFrame = True
return H
# Find correspondences
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
# Leave good correspondences only
prevPoints = []
@ -262,18 +341,23 @@ class GMC:
currPoints = np.array(currPoints)
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == prevPoints.shape[0]):
H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
# Handle downscale
if self.downscale > 1.0:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
LOGGER.warning('WARNING: not enough matching points')
LOGGER.warning("WARNING: not enough matching points")
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
return H
def reset_params(self) -> None:
"""Reset parameters."""
self.prevFrame = None
self.prevKeyPoints = None
self.prevDescriptors = None
self.initializedFirstFrame = False

View File

@ -8,8 +8,8 @@ class KalmanFilterXYAH:
"""
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y),
aspect ratio a, height h, and their respective velocities.
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect
ratio a, height h, and their respective velocities.
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
observation of the state space (linear observation model).
@ -17,126 +17,126 @@ class KalmanFilterXYAH:
def __init__(self):
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
ndim, dt = 4, 1.
ndim, dt = 4, 1.0
# Create Kalman filter model matrices.
# Create Kalman filter model matrices
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim)
# Motion and observation uncertainty are chosen relative to the current state estimate. These weights control
# the amount of uncertainty in the model. This is a bit hacky.
self._std_weight_position = 1. / 20
self._std_weight_velocity = 1. / 160
# the amount of uncertainty in the model.
self._std_weight_position = 1.0 / 20
self._std_weight_velocity = 1.0 / 160
def initiate(self, measurement):
def initiate(self, measurement: np.ndarray) -> tuple:
"""
Create track from unassociated measurement.
Parameters
----------
measurement : ndarray
Bounding box coordinates (x, y, a, h) with center position (x, y),
aspect ratio a, and height h.
Args:
measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
and height h.
Returns
-------
(ndarray, ndarray)
Returns the mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are initialized
to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of
the new track. Unobserved velocities are initialized to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
std = [
2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2,
2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]]
2 * self._std_weight_position * measurement[3],
2 * self._std_weight_position * measurement[3],
1e-2,
2 * self._std_weight_position * measurement[3],
10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[3],
1e-5,
10 * self._std_weight_velocity * measurement[3],
]
covariance = np.diag(np.square(std))
return mean, covariance
def predict(self, mean, covariance):
def predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
"""
Run Kalman filter prediction step.
Parameters
----------
mean : ndarray
The 8 dimensional mean vector of the object state at the previous time step.
covariance : ndarray
The 8x8 dimensional covariance matrix of the object state at the previous time step.
Args:
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
self._std_weight_position * mean[3]]
self._std_weight_position * mean[3],
self._std_weight_position * mean[3],
1e-2,
self._std_weight_position * mean[3],
]
std_vel = [
self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5,
self._std_weight_velocity * mean[3]]
self._std_weight_velocity * mean[3],
self._std_weight_velocity * mean[3],
1e-5,
self._std_weight_velocity * mean[3],
]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
# mean = np.dot(self._motion_mat, mean)
mean = np.dot(mean, self._motion_mat.T)
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
return mean, covariance
def project(self, mean, covariance):
def project(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
"""
Project state distribution to measurement space.
Parameters
----------
mean : ndarray
The state's mean vector (8 dimensional array).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
Args:
mean (ndarray): The state's mean vector (8 dimensional array).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
Returns
-------
(ndarray, ndarray)
Returns the projected mean and covariance matrix of the given state estimate.
Returns:
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
self._std_weight_position * mean[3]]
self._std_weight_position * mean[3],
self._std_weight_position * mean[3],
1e-1,
self._std_weight_position * mean[3],
]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
"""
Run Kalman filter prediction step (Vectorized version).
Parameters
----------
mean : ndarray
The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance : ndarray
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]]
self._std_weight_position * mean[:, 3],
self._std_weight_position * mean[:, 3],
1e-2 * np.ones_like(mean[:, 3]),
self._std_weight_position * mean[:, 3],
]
std_vel = [
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3],
1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]]
self._std_weight_velocity * mean[:, 3],
self._std_weight_velocity * mean[:, 3],
1e-5 * np.ones_like(mean[:, 3]),
self._std_weight_velocity * mean[:, 3],
]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
@ -148,60 +148,57 @@ class KalmanFilterXYAH:
return mean, covariance
def update(self, mean, covariance, measurement):
def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray) -> tuple:
"""
Run Kalman filter correction step.
Parameters
----------
mean : ndarray
The predicted state's mean vector (8 dimensional).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
measurement : ndarray
The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center position, a the aspect
ratio, and h the height of the bounding box.
Args:
mean (ndarray): The predicted state's mean vector (8 dimensional).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
measurement (ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center
position, a the aspect ratio, and h the height of the bounding box.
Returns
-------
(ndarray, ndarray)
Returns the measurement-corrected state distribution.
Returns:
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
"""
projected_mean, projected_cov = self.project(mean, covariance)
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
np.dot(covariance, self._update_mat.T).T,
check_finite=False).T
kalman_gain = scipy.linalg.cho_solve(
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False
).T
innovation = measurement - projected_mean
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
return new_mean, new_covariance
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
def gating_distance(
self,
mean: np.ndarray,
covariance: np.ndarray,
measurements: np.ndarray,
only_position: bool = False,
metric: str = "maha",
) -> np.ndarray:
"""
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
otherwise 2.
Parameters
----------
mean : ndarray
Mean vector over the state distribution (8 dimensional).
covariance : ndarray
Covariance of the state distribution (8x8 dimensional).
measurements : ndarray
An Nx4 dimensional matrix of N measurements, each in format (x, y, a, h) where (x, y) is the bounding box
center position, a the aspect ratio, and h the height.
only_position : Optional[bool]
If True, distance computation is done with respect to the bounding box center position only.
Args:
mean (ndarray): Mean vector over the state distribution (8 dimensional).
covariance (ndarray): Covariance of the state distribution (8x8 dimensional).
measurements (ndarray): An Nx4 matrix of N measurements, each in format (x, y, a, h) where (x, y)
is the bounding box center position, a the aspect ratio, and h the height.
only_position (bool, optional): If True, distance computation is done with respect to the bounding box
center position only. Defaults to False.
metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the
squared Euclidean distance and 'maha' for the squared Mahalanobis distance. Defaults to 'maha'.
Returns
-------
ndarray
Returns an array of length N, where the i-th element contains the squared Mahalanobis distance between
(mean, covariance) and `measurements[i]`.
Returns:
(np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
(mean, covariance) and `measurements[i]`.
"""
mean, covariance = self.project(mean, covariance)
if only_position:
@ -209,77 +206,79 @@ class KalmanFilterXYAH:
measurements = measurements[:, :2]
d = measurements - mean
if metric == 'gaussian':
if metric == "gaussian":
return np.sum(d * d, axis=1)
elif metric == 'maha':
elif metric == "maha":
cholesky_factor = np.linalg.cholesky(covariance)
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
return np.sum(z * z, axis=0) # square maha
else:
raise ValueError('invalid distance metric')
raise ValueError("Invalid distance metric")
class KalmanFilterXYWH(KalmanFilterXYAH):
"""
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y),
width w, height h, and their respective velocities.
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
w, height h, and their respective velocities.
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
observation of the state space (linear observation model).
"""
def initiate(self, measurement):
def initiate(self, measurement: np.ndarray) -> tuple:
"""
Create track from unassociated measurement.
Parameters
----------
measurement : ndarray
Bounding box coordinates (x, y, w, h) with center position (x, y), width w, and height h.
Args:
measurement (ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
Returns
-------
(ndarray, ndarray)
Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track.
Unobserved velocities are initialized to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of
the new track. Unobserved velocities are initialized to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
std = [
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3]]
2 * self._std_weight_position * measurement[2],
2 * self._std_weight_position * measurement[3],
2 * self._std_weight_position * measurement[2],
2 * self._std_weight_position * measurement[3],
10 * self._std_weight_velocity * measurement[2],
10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[2],
10 * self._std_weight_velocity * measurement[3],
]
covariance = np.diag(np.square(std))
return mean, covariance
def predict(self, mean, covariance):
def predict(self, mean, covariance) -> tuple:
"""
Run Kalman filter prediction step.
Parameters
----------
mean : ndarray
The 8 dimensional mean vector of the object state at the previous time step.
covariance : ndarray
The 8x8 dimensional covariance matrix of the object state at the previous time step.
Args:
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
self._std_weight_position * mean[2],
self._std_weight_position * mean[3],
self._std_weight_position * mean[2],
self._std_weight_position * mean[3],
]
std_vel = [
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3],
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3]]
self._std_weight_velocity * mean[2],
self._std_weight_velocity * mean[3],
self._std_weight_velocity * mean[2],
self._std_weight_velocity * mean[3],
]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
mean = np.dot(mean, self._motion_mat.T)
@ -287,54 +286,53 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
return mean, covariance
def project(self, mean, covariance):
def project(self, mean, covariance) -> tuple:
"""
Project state distribution to measurement space.
Parameters
----------
mean : ndarray
The state's mean vector (8 dimensional array).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
Args:
mean (ndarray): The state's mean vector (8 dimensional array).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
Returns
-------
(ndarray, ndarray)
Returns the projected mean and covariance matrix of the given state estimate.
Returns:
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
self._std_weight_position * mean[2],
self._std_weight_position * mean[3],
self._std_weight_position * mean[2],
self._std_weight_position * mean[3],
]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
def multi_predict(self, mean, covariance) -> tuple:
"""
Run Kalman filter prediction step (Vectorized version).
Parameters
----------
mean : ndarray
The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance : ndarray
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3]]
self._std_weight_position * mean[:, 2],
self._std_weight_position * mean[:, 3],
self._std_weight_position * mean[:, 2],
self._std_weight_position * mean[:, 3],
]
std_vel = [
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3],
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3]]
self._std_weight_velocity * mean[:, 2],
self._std_weight_velocity * mean[:, 3],
self._std_weight_velocity * mean[:, 2],
self._std_weight_velocity * mean[:, 3],
]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
@ -346,23 +344,17 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
return mean, covariance
def update(self, mean, covariance, measurement):
def update(self, mean, covariance, measurement) -> tuple:
"""
Run Kalman filter correction step.
Parameters
----------
mean : ndarray
The predicted state's mean vector (8 dimensional).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
measurement : ndarray
The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center position, w the width,
and h the height of the bounding box.
Args:
mean (ndarray): The predicted state's mean vector (8 dimensional).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
measurement (ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center
position, w the width, and h the height of the bounding box.
Returns
-------
(ndarray, ndarray)
Returns the measurement-corrected state distribution.
Returns:
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
"""
return super().update(mean, covariance, measurement)

View File

@ -4,7 +4,7 @@ import numpy as np
import scipy
from scipy.spatial.distance import cdist
from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.metrics import bbox_ioa, batch_probiou
try:
import lap # for linear_assignment
@ -13,11 +13,11 @@ try:
except (ImportError, AssertionError, AttributeError):
from ultralytics.utils.checks import check_requirements
check_requirements('lapx>=0.5.2') # update to lap package from https://github.com/rathaROG/lapx
check_requirements("lapx>=0.5.2") # update to lap package from https://github.com/rathaROG/lapx
import lap
def linear_assignment(cost_matrix, thresh, use_lap=True):
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple:
"""
Perform linear assignment using scipy or lap.lapjv.
@ -27,19 +27,24 @@ def linear_assignment(cost_matrix, thresh, use_lap=True):
use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True.
Returns:
(tuple): Tuple containing matched indices, unmatched indices from 'a', and unmatched indices from 'b'.
Tuple with:
- matched indices
- unmatched indices from 'a'
- unmatched indices from 'b'
"""
if cost_matrix.size == 0:
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
if use_lap:
# Use lap.lapjv
# https://github.com/gatagat/lap
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
unmatched_a = np.where(x < 0)[0]
unmatched_b = np.where(y < 0)[0]
else:
# Use scipy.optimize.linear_sum_assignment
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y
matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])
@ -53,7 +58,7 @@ def linear_assignment(cost_matrix, thresh, use_lap=True):
return matches, unmatched_a, unmatched_b
def iou_distance(atracks, btracks):
def iou_distance(atracks: list, btracks: list) -> np.ndarray:
"""
Compute cost based on Intersection over Union (IoU) between tracks.
@ -65,23 +70,30 @@ def iou_distance(atracks, btracks):
(np.ndarray): Cost matrix computed based on IoU.
"""
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
atlbrs = atracks
btlbrs = btracks
else:
atlbrs = [track.tlbr for track in atracks]
btlbrs = [track.tlbr for track in btracks]
atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]
btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if len(atlbrs) and len(btlbrs):
ious = bbox_ioa(np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32),
iou=True)
if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:
ious = batch_probiou(
np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32),
).numpy()
else:
ious = bbox_ioa(
np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32),
iou=True,
)
return 1 - ious # cost matrix
def embedding_distance(tracks, detections, metric='cosine'):
def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
"""
Compute distance between tracks and detections based on embeddings.
@ -105,7 +117,7 @@ def embedding_distance(tracks, detections, metric='cosine'):
return cost_matrix
def fuse_score(cost_matrix, detections):
def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
"""
Fuses cost matrix with detection scores to produce a single similarity matrix.