This commit is contained in:
lee
2025-06-18 14:35:43 +08:00
commit e474ab5f9f
529 changed files with 80523 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,88 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
from copy import deepcopy
import numpy as np
import torch
from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
from ultralytics.utils.torch_utils import profile
def check_train_batch_size(model, imgsz=640, amp=True):
"""
Check YOLO training batch size using the autobatch() function.
Args:
model (torch.nn.Module): YOLO model to check batch size for.
imgsz (int): Image size used for training.
amp (bool): If True, use automatic mixed precision (AMP) for training.
Returns:
(int): Optimal batch size computed using the autobatch() function.
"""
with torch.cuda.amp.autocast(amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
"""
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
Args:
model (torch.nn.module): YOLO model to compute batch size for.
imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.60.
batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
Returns:
(int): The optimal batch size.
"""
# Check device
prefix = colorstr("AutoBatch: ")
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}")
device = next(model.parameters()).device # get model device
if device.type == "cpu":
LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")
return batch_size
if torch.backends.cudnn.benchmark:
LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
return batch_size
# Inspect CUDA memory
gb = 1 << 30 # bytes to GiB (1024 ** 3)
d = str(device).upper() # 'CUDA:0'
properties = torch.cuda.get_device_properties(device) # device properties
t = properties.total_memory / gb # GiB total
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
f = t - (r + a) # GiB free
LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
# Profile batch sizes
batch_sizes = [1, 2, 4, 8, 16]
try:
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
results = profile(img, model, n=3, device=device)
# Fit a solution
y = [x[2] for x in results if x] # memory [2]
p = np.polyfit(batch_sizes[: len(y)], y, deg=1) # first degree polynomial fit
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
if None in results: # some sizes failed
i = results.index(None) # first fail index
if b >= batch_sizes[i]: # y intercept above failure point
b = batch_sizes[max(i - 1, 0)] # select prior safe point
if b < 1 or b > 1024: # b outside of safe range
b = batch_size
LOGGER.info(f"{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.")
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
return b
except Exception as e:
LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.")
return batch_size

View File

@ -0,0 +1,402 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Benchmark a YOLO model formats for speed and accuracy.
Usage:
from ultralytics.utils.benchmarks import ProfileModels, benchmark
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml']).profile()
benchmark(model='yolov8n.pt', imgsz=160)
Format | `format=argument` | Model
--- | --- | ---
PyTorch | - | yolov8n.pt
TorchScript | `torchscript` | yolov8n.torchscript
ONNX | `onnx` | yolov8n.onnx
OpenVINO | `openvino` | yolov8n_openvino_model/
TensorRT | `engine` | yolov8n.engine
CoreML | `coreml` | yolov8n.mlpackage
TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
TensorFlow GraphDef | `pb` | yolov8n.pb
TensorFlow Lite | `tflite` | yolov8n.tflite
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
TensorFlow.js | `tfjs` | yolov8n_web_model/
PaddlePaddle | `paddle` | yolov8n_paddle_model/
NCNN | `ncnn` | yolov8n_ncnn_model/
"""
import glob
import platform
import time
from pathlib import Path
import numpy as np
import torch.cuda
from ultralytics import YOLO, YOLOWorld
from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.engine.exporter import export_formats
from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo
from ultralytics.utils.files import file_size
from ultralytics.utils.torch_utils import select_device
def benchmark(
model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False
):
"""
Benchmark a YOLO model across different formats for speed and accuracy.
Args:
model (str | Path | optional): Path to the model file or directory. Default is
Path(SETTINGS['weights_dir']) / 'yolov8n.pt'.
data (str, optional): Dataset to evaluate on, inherited from TASK2DATA if not passed. Default is None.
imgsz (int, optional): Image size for the benchmark. Default is 160.
half (bool, optional): Use half-precision for the model if True. Default is False.
int8 (bool, optional): Use int8-precision for the model if True. Default is False.
device (str, optional): Device to run the benchmark on, either 'cpu' or 'cuda'. Default is 'cpu'.
verbose (bool | float | optional): If True or a float, assert benchmarks pass with given metric.
Default is False.
Returns:
df (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size,
metric, and inference time.
Example:
```python
from ultralytics.utils.benchmarks import benchmark
benchmark(model='yolov8n.pt', imgsz=640)
```
"""
import pandas as pd
pd.options.display.max_columns = 10
pd.options.display.width = 120
device = select_device(device, verbose=False)
if isinstance(model, (str, Path)):
model = YOLO(model)
y = []
t0 = time.time()
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
emoji, filename = "", None # export defaults
try:
# Checks
if i == 9: # Edge TPU
assert LINUX, "Edge TPU export only supported on Linux"
elif i == 7: # TF GraphDef
assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
elif i in {5, 10}: # CoreML and TF.js
assert MACOS or LINUX, "export only supported on macOS and Linux"
if i in {3, 5}: # CoreML and OpenVINO
assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
if i in {6, 7, 8, 9, 10}: # All TF formats
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
if i in {11}: # Paddle
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
if i in {12}: # NCNN
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
if "cpu" in device.type:
assert cpu, "inference not supported on CPU"
if "cuda" in device.type:
assert gpu, "inference not supported on GPU"
# Export
if format == "-":
filename = model.ckpt_path or model.cfg
exported_model = model # PyTorch format
else:
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
exported_model = YOLO(filename, task=model.task)
assert suffix in str(filename), "export failed"
emoji = "" # indicates export succeeded
# Predict
assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported
assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
# Validate
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
results = exported_model.val(
data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False
)
metric, speed = results.results_dict[key], results.speed["inference"]
y.append([name, "", round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
except Exception as e:
if verbose:
assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}")
y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
# Print results
check_yolo(device=device) # print system info
df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)"])
name = Path(model.ckpt_path).name
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n"
LOGGER.info(s)
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
f.write(s)
if verbose and isinstance(verbose, float):
metrics = df[key].array # values to compare to floor
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
return df
class ProfileModels:
"""
ProfileModels class for profiling different models on ONNX and TensorRT.
This class profiles the performance of different models, returning results such as model speed and FLOPs.
Attributes:
paths (list): Paths of the models to profile.
num_timed_runs (int): Number of timed runs for the profiling. Default is 100.
num_warmup_runs (int): Number of warmup runs before profiling. Default is 10.
min_time (float): Minimum number of seconds to profile for. Default is 60.
imgsz (int): Image size used in the models. Default is 640.
Methods:
profile(): Profiles the models and prints the result.
Example:
```python
from ultralytics.utils.benchmarks import ProfileModels
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml'], imgsz=640).profile()
```
"""
def __init__(
self,
paths: list,
num_timed_runs=100,
num_warmup_runs=10,
min_time=60,
imgsz=640,
half=True,
trt=True,
device=None,
):
"""
Initialize the ProfileModels class for profiling models.
Args:
paths (list): List of paths of the models to be profiled.
num_timed_runs (int, optional): Number of timed runs for the profiling. Default is 100.
num_warmup_runs (int, optional): Number of warmup runs before the actual profiling starts. Default is 10.
min_time (float, optional): Minimum time in seconds for profiling a model. Default is 60.
imgsz (int, optional): Size of the image used during profiling. Default is 640.
half (bool, optional): Flag to indicate whether to use half-precision floating point for profiling.
trt (bool, optional): Flag to indicate whether to profile using TensorRT. Default is True.
device (torch.device, optional): Device used for profiling. If None, it is determined automatically.
"""
self.paths = paths
self.num_timed_runs = num_timed_runs
self.num_warmup_runs = num_warmup_runs
self.min_time = min_time
self.imgsz = imgsz
self.half = half
self.trt = trt # run TensorRT profiling
self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu")
def profile(self):
"""Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
files = self.get_files()
if not files:
print("No matching *.pt or *.onnx files found.")
return
table_rows = []
output = []
for file in files:
engine_file = file.with_suffix(".engine")
if file.suffix in (".pt", ".yaml", ".yml"):
model = YOLO(str(file))
model.fuse() # to report correct params and GFLOPs in model.info()
model_info = model.info()
if self.trt and self.device.type != "cpu" and not engine_file.is_file():
engine_file = model.export(
format="engine", half=self.half, imgsz=self.imgsz, device=self.device, verbose=False
)
onnx_file = model.export(
format="onnx", half=self.half, imgsz=self.imgsz, simplify=True, device=self.device, verbose=False
)
elif file.suffix == ".onnx":
model_info = self.get_onnx_model_info(file)
onnx_file = file
else:
continue
t_engine = self.profile_tensorrt_model(str(engine_file))
t_onnx = self.profile_onnx_model(str(onnx_file))
table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info))
output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info))
self.print_table(table_rows)
return output
def get_files(self):
"""Returns a list of paths for all relevant model files given by the user."""
files = []
for path in self.paths:
path = Path(path)
if path.is_dir():
extensions = ["*.pt", "*.onnx", "*.yaml"]
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing
files.append(str(path))
else:
files.extend(glob.glob(str(path)))
print(f"Profiling: {sorted(files)}")
return [Path(file) for file in sorted(files)]
def get_onnx_model_info(self, onnx_file: str):
"""Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
file.
"""
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
@staticmethod
def iterative_sigma_clipping(data, sigma=2, max_iters=3):
"""Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
data = np.array(data)
for _ in range(max_iters):
mean, std = np.mean(data), np.std(data)
clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)]
if len(clipped_data) == len(data):
break
data = clipped_data
return data
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
"""Profiles the TensorRT model, measuring average run time and standard deviation among runs."""
if not self.trt or not Path(engine_file).is_file():
return 0.0, 0.0
# Model and input
model = YOLO(engine_file)
input_data = np.random.rand(self.imgsz, self.imgsz, 3).astype(np.float32) # must be FP32
# Warmup runs
elapsed = 0.0
for _ in range(3):
start_time = time.time()
for _ in range(self.num_warmup_runs):
model(input_data, imgsz=self.imgsz, verbose=False)
elapsed = time.time() - start_time
# Compute number of runs as higher of min_time or num_timed_runs
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)
# Timed runs
run_times = []
for _ in TQDM(range(num_runs), desc=engine_file):
results = model(input_data, imgsz=self.imgsz, verbose=False)
run_times.append(results[0].speed["inference"]) # Convert to milliseconds
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
return np.mean(run_times), np.std(run_times)
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
times.
"""
check_requirements("onnxruntime")
import onnxruntime as ort
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 8 # Limit the number of threads
sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
input_tensor = sess.get_inputs()[0]
input_type = input_tensor.type
# Mapping ONNX datatype to numpy datatype
if "float16" in input_type:
input_dtype = np.float16
elif "float" in input_type:
input_dtype = np.float32
elif "double" in input_type:
input_dtype = np.float64
elif "int64" in input_type:
input_dtype = np.int64
elif "int32" in input_type:
input_dtype = np.int32
else:
raise ValueError(f"Unsupported ONNX datatype {input_type}")
input_data = np.random.rand(*input_tensor.shape).astype(input_dtype)
input_name = input_tensor.name
output_name = sess.get_outputs()[0].name
# Warmup runs
elapsed = 0.0
for _ in range(3):
start_time = time.time()
for _ in range(self.num_warmup_runs):
sess.run([output_name], {input_name: input_data})
elapsed = time.time() - start_time
# Compute number of runs as higher of min_time or num_timed_runs
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)
# Timed runs
run_times = []
for _ in TQDM(range(num_runs), desc=onnx_file):
start_time = time.time()
sess.run([output_name], {input_name: input_data})
run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping
return np.mean(run_times), np.std(run_times)
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
"""Generates a formatted string for a table row that includes model performance and metric details."""
layers, params, gradients, flops = model_info
return (
f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± "
f"{t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
)
@staticmethod
def generate_results_dict(model_name, t_onnx, t_engine, model_info):
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
layers, params, gradients, flops = model_info
return {
"model/name": model_name,
"model/parameters": params,
"model/GFLOPs": round(flops, 3),
"model/speed_ONNX(ms)": round(t_onnx[0], 3),
"model/speed_TensorRT(ms)": round(t_engine[0], 3),
}
@staticmethod
def print_table(table_rows):
"""Formats and prints a comparison table for different models with given statistics and performance data."""
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
header = (
f"| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | "
f"Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |"
)
separator = (
"|-------------|---------------------|--------------------|------------------------------|"
"-----------------------------------|------------------|-----------------|"
)
print(f"\n\n{header}")
print(separator)
for row in table_rows:
print(row)

View File

@ -0,0 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"

View File

@ -0,0 +1,219 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Base callbacks."""
from collections import defaultdict
from copy import deepcopy
# Trainer callbacks ----------------------------------------------------------------------------------------------------
def on_pretrain_routine_start(trainer):
"""Called before the pretraining routine starts."""
pass
def on_pretrain_routine_end(trainer):
"""Called after the pretraining routine ends."""
pass
def on_train_start(trainer):
"""Called when the training starts."""
pass
def on_train_epoch_start(trainer):
"""Called at the start of each training epoch."""
pass
def on_train_batch_start(trainer):
"""Called at the start of each training batch."""
pass
def optimizer_step(trainer):
"""Called when the optimizer takes a step."""
pass
def on_before_zero_grad(trainer):
"""Called before the gradients are set to zero."""
pass
def on_train_batch_end(trainer):
"""Called at the end of each training batch."""
pass
def on_train_epoch_end(trainer):
"""Called at the end of each training epoch."""
pass
def on_fit_epoch_end(trainer):
"""Called at the end of each fit epoch (train + val)."""
pass
def on_model_save(trainer):
"""Called when the model is saved."""
pass
def on_train_end(trainer):
"""Called when the training ends."""
pass
def on_params_update(trainer):
"""Called when the model parameters are updated."""
pass
def teardown(trainer):
"""Called during the teardown of the training process."""
pass
# Validator callbacks --------------------------------------------------------------------------------------------------
def on_val_start(validator):
"""Called when the validation starts."""
pass
def on_val_batch_start(validator):
"""Called at the start of each validation batch."""
pass
def on_val_batch_end(validator):
"""Called at the end of each validation batch."""
pass
def on_val_end(validator):
"""Called when the validation ends."""
pass
# Predictor callbacks --------------------------------------------------------------------------------------------------
def on_predict_start(predictor):
"""Called when the prediction starts."""
pass
def on_predict_batch_start(predictor):
"""Called at the start of each prediction batch."""
pass
def on_predict_batch_end(predictor):
"""Called at the end of each prediction batch."""
pass
def on_predict_postprocess_end(predictor):
"""Called after the post-processing of the prediction ends."""
pass
def on_predict_end(predictor):
"""Called when the prediction ends."""
pass
# Exporter callbacks ---------------------------------------------------------------------------------------------------
def on_export_start(exporter):
"""Called when the model export starts."""
pass
def on_export_end(exporter):
"""Called when the model export ends."""
pass
default_callbacks = {
# Run in trainer
"on_pretrain_routine_start": [on_pretrain_routine_start],
"on_pretrain_routine_end": [on_pretrain_routine_end],
"on_train_start": [on_train_start],
"on_train_epoch_start": [on_train_epoch_start],
"on_train_batch_start": [on_train_batch_start],
"optimizer_step": [optimizer_step],
"on_before_zero_grad": [on_before_zero_grad],
"on_train_batch_end": [on_train_batch_end],
"on_train_epoch_end": [on_train_epoch_end],
"on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
"on_model_save": [on_model_save],
"on_train_end": [on_train_end],
"on_params_update": [on_params_update],
"teardown": [teardown],
# Run in validator
"on_val_start": [on_val_start],
"on_val_batch_start": [on_val_batch_start],
"on_val_batch_end": [on_val_batch_end],
"on_val_end": [on_val_end],
# Run in predictor
"on_predict_start": [on_predict_start],
"on_predict_batch_start": [on_predict_batch_start],
"on_predict_postprocess_end": [on_predict_postprocess_end],
"on_predict_batch_end": [on_predict_batch_end],
"on_predict_end": [on_predict_end],
# Run in exporter
"on_export_start": [on_export_start],
"on_export_end": [on_export_end],
}
def get_default_callbacks():
"""
Return a copy of the default_callbacks dictionary with lists as default values.
Returns:
(defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values.
"""
return defaultdict(list, deepcopy(default_callbacks))
def add_integration_callbacks(instance):
"""
Add integration callbacks from various sources to the instance's callbacks.
Args:
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
of callback lists.
"""
# Load HUB callbacks
from .hub import callbacks as hub_cb
callbacks_list = [hub_cb]
# Load training callbacks
if "Trainer" in instance.__class__.__name__:
from .clearml import callbacks as clear_cb
from .comet import callbacks as comet_cb
from .dvc import callbacks as dvc_cb
from .mlflow import callbacks as mlflow_cb
from .neptune import callbacks as neptune_cb
from .raytune import callbacks as tune_cb
from .tensorboard import callbacks as tb_cb
from .wb import callbacks as wb_cb
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
# Add the callbacks to the callbacks dictionary
for callbacks in callbacks_list:
for k, v in callbacks.items():
if v not in instance.callbacks[k]:
instance.callbacks[k].append(v)

View File

@ -0,0 +1,152 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["clearml"] is True # verify integration is enabled
import clearml
from clearml import Task
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from clearml.binding.matplotlib_bind import PatchedMatplotlib
assert hasattr(clearml, "__version__") # verify package is not directory
except (ImportError, AssertionError):
clearml = None
def _log_debug_samples(files, title="Debug Samples") -> None:
"""
Log files (images) as debug samples in the ClearML task.
Args:
files (list): A list of file paths in PosixPath format.
title (str): A title that groups together images with the same values.
"""
import re
if task := Task.current_task():
for f in files:
if f.exists():
it = re.search(r"_batch(\d+)", f.name)
iteration = int(it.groups()[0]) if it else 0
task.get_logger().report_image(
title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
)
def _log_plot(title, plot_path) -> None:
"""
Log an image as a plot in the plot section of ClearML.
Args:
title (str): The title of the plot.
plot_path (str): The path to the saved image file.
"""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
Task.current_task().get_logger().report_matplotlib_figure(
title=title, series="", figure=fig, report_interactive=False
)
def on_pretrain_routine_start(trainer):
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
try:
if task := Task.current_task():
# Make sure the automatic pytorch and matplotlib bindings are disabled!
# We are logging these plots and model files manually in the integration
PatchPyTorchModelIO.update_current_task(None)
PatchedMatplotlib.update_current_task(None)
else:
task = Task.init(
project_name=trainer.args.project or "YOLOv8",
task_name=trainer.args.name,
tags=["YOLOv8"],
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={"pytorch": False, "matplotlib": False},
)
LOGGER.warning(
"ClearML Initialized a new task. If you want to run remotely, "
"please add clearml-init and connect your arguments before initializing YOLO."
)
task.connect(vars(trainer.args), name="General")
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
"""Logs debug samples for the first epoch of YOLO training and report current training progress."""
if task := Task.current_task():
# Log debug samples
if trainer.epoch == 1:
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
# Report the current training progress
for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
for k, v in trainer.lr.items():
task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
def on_fit_epoch_end(trainer):
"""Reports model information to logger at the end of an epoch."""
if task := Task.current_task():
# You should have access to the validation bboxes under jdict
task.get_logger().report_scalar(
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
)
for k, v in trainer.metrics.items():
task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for k, v in model_info_for_loggers(trainer).items():
task.get_logger().report_single_value(k, v)
def on_val_end(validator):
"""Logs validation results including labels and predictions."""
if Task.current_task():
# Log val_labels and val_pred
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
def on_train_end(trainer):
"""Logs final model and its name on training completion."""
if task := Task.current_task():
# Log final results, CM matrix + PR plots
files = [
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Report final metrics
for k, v in trainer.validator.metrics.results_dict.items():
task.get_logger().report_single_value(k, v)
# Log the final model
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if clearml
else {}
)

View File

@ -0,0 +1,375 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["comet"] is True # verify integration is enabled
import comet_ml
assert hasattr(comet_ml, "__version__") # verify package is not directory
import os
from pathlib import Path
# Ensures certain logging functions only run for supported tasks
COMET_SUPPORTED_TASKS = ["detect"]
# Names of plots created by YOLOv8 that are logged to Comet
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
LABEL_PLOT_NAMES = "labels", "labels_correlogram"
_comet_image_prediction_count = 0
except (ImportError, AssertionError):
comet_ml = None
def _get_comet_mode():
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
return os.getenv("COMET_MODE", "online")
def _get_comet_model_name():
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
return os.getenv("COMET_MODEL_NAME", "YOLOv8")
def _get_eval_batch_logging_interval():
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
def _get_max_image_predictions_to_log():
"""Get the maximum number of image predictions to log from the environment variables."""
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
def _scale_confidence_score(score):
"""Scales the given confidence score by a factor specified in an environment variable."""
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
return score * scale
def _should_log_confusion_matrix():
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
def _should_log_image_predictions():
"""Determines whether to log image predictions based on a specified environment variable."""
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
def _get_experiment_type(mode, project_name):
"""Return an experiment based on mode and project name."""
if mode == "offline":
return comet_ml.OfflineExperiment(project_name=project_name)
return comet_ml.Experiment(project_name=project_name)
def _create_experiment(args):
"""Ensures that the experiment object is only created in a single process during distributed training."""
if RANK not in (-1, 0):
return
try:
comet_mode = _get_comet_mode()
_project_name = os.getenv("COMET_PROJECT_NAME", args.project)
experiment = _get_experiment_type(comet_mode, _project_name)
experiment.log_parameters(vars(args))
experiment.log_others(
{
"eval_batch_logging_interval": _get_eval_batch_logging_interval(),
"log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
"log_image_predictions": _should_log_image_predictions(),
"max_image_predictions": _get_max_image_predictions_to_log(),
}
)
experiment.log_other("Created from", "yolov8")
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
def _fetch_trainer_metadata(trainer):
"""Returns metadata for YOLO training including epoch and asset saving status."""
curr_epoch = trainer.epoch + 1
train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
curr_step = curr_epoch * train_num_steps_per_epoch
final_epoch = curr_epoch == trainer.epochs
save = trainer.args.save
save_period = trainer.args.save_period
save_interval = curr_epoch % save_period == 0
save_assets = save and save_period > 0 and save_interval and not final_epoch
return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
"""
YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
This function rescales the bounding box labels to the original image shape.
"""
resized_image_height, resized_image_width = resized_image_shape
# Convert normalized xywh format predictions to xyxy in resized scale format
box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
# Scale box predictions from resized image scale back to original image scale
box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
# Convert bounding box format from xyxy to xywh for Comet logging
box = ops.xyxy2xywh(box)
# Adjust xy center to correspond top-left corner
box[:2] -= box[2:] / 2
box = box.tolist()
return box
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
"""Format ground truth annotations for detection."""
indices = batch["batch_idx"] == img_idx
bboxes = batch["bboxes"][indices]
if len(bboxes) == 0:
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
return None
cls_labels = batch["cls"][indices].squeeze(1).tolist()
if class_name_map:
cls_labels = [str(class_name_map[label]) for label in cls_labels]
original_image_shape = batch["ori_shape"][img_idx]
resized_image_shape = batch["resized_shape"][img_idx]
ratio_pad = batch["ratio_pad"][img_idx]
data = []
for box, label in zip(bboxes, cls_labels):
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
data.append(
{
"boxes": [box],
"label": f"gt_{label}",
"score": _scale_confidence_score(1.0),
}
)
return {"name": "ground_truth", "data": data}
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
"""Format YOLO predictions for object detection visualization."""
stem = image_path.stem
image_id = int(stem) if stem.isnumeric() else stem
predictions = metadata.get(image_id)
if not predictions:
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
return None
data = []
for prediction in predictions:
boxes = prediction["bbox"]
score = _scale_confidence_score(prediction["score"])
cls_label = prediction["category_id"]
if class_label_map:
cls_label = str(class_label_map[cls_label])
data.append({"boxes": [boxes], "label": cls_label, "score": score})
return {"name": "prediction", "data": data}
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
"""Join the ground truth and prediction annotations if they exist."""
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
img_idx, image_path, batch, class_label_map
)
prediction_annotations = _format_prediction_annotations_for_detection(
image_path, prediction_metadata_map, class_label_map
)
annotations = [
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
]
return [annotations] if annotations else None
def _create_prediction_metadata_map(model_predictions):
"""Create metadata map for model predictions by groupings them based on image ID."""
pred_metadata_map = {}
for prediction in model_predictions:
pred_metadata_map.setdefault(prediction["image_id"], [])
pred_metadata_map[prediction["image_id"]].append(prediction)
return pred_metadata_map
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
"""Log the confusion matrix to Comet experiment."""
conf_mat = trainer.validator.confusion_matrix.matrix
names = list(trainer.data["names"].values()) + ["background"]
experiment.log_confusion_matrix(
matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
)
def _log_images(experiment, image_paths, curr_step, annotations=None):
"""Logs images to the experiment with optional annotations."""
if annotations:
for image_path, annotation in zip(image_paths, annotations):
experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
else:
for image_path in image_paths:
experiment.log_image(image_path, name=image_path.stem, step=curr_step)
def _log_image_predictions(experiment, validator, curr_step):
"""Logs predicted boxes for a single image during training."""
global _comet_image_prediction_count
task = validator.args.task
if task not in COMET_SUPPORTED_TASKS:
return
jdict = validator.jdict
if not jdict:
return
predictions_metadata_map = _create_prediction_metadata_map(jdict)
dataloader = validator.dataloader
class_label_map = validator.names
batch_logging_interval = _get_eval_batch_logging_interval()
max_image_predictions = _get_max_image_predictions_to_log()
for batch_idx, batch in enumerate(dataloader):
if (batch_idx + 1) % batch_logging_interval != 0:
continue
image_paths = batch["im_file"]
for img_idx, image_path in enumerate(image_paths):
if _comet_image_prediction_count >= max_image_predictions:
return
image_path = Path(image_path)
annotations = _fetch_annotations(
img_idx,
image_path,
batch,
predictions_metadata_map,
class_label_map,
)
_log_images(
experiment,
[image_path],
curr_step,
annotations=annotations,
)
_comet_image_prediction_count += 1
def _log_plots(experiment, trainer):
"""Logs evaluation plots and label plots for the experiment."""
plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
_log_images(experiment, plot_filenames, None)
label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
_log_images(experiment, label_plot_filenames, None)
def _log_model(experiment, trainer):
"""Log the best-trained model to Comet.ml."""
model_name = _get_comet_model_name()
experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best_gift_v10n.pt", overwrite=True)
def on_pretrain_routine_start(trainer):
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
experiment = comet_ml.get_global_experiment()
is_alive = getattr(experiment, "alive", False)
if not experiment or not is_alive:
_create_experiment(trainer.args)
def on_train_epoch_end(trainer):
"""Log metrics and save batch images at the end of training epochs."""
experiment = comet_ml.get_global_experiment()
if not experiment:
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
def on_fit_epoch_end(trainer):
"""Logs model assets at the end of each epoch."""
experiment = comet_ml.get_global_experiment()
if not experiment:
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
save_assets = metadata["save_assets"]
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
from ultralytics.utils.torch_utils import model_info_for_loggers
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
if not save_assets:
return
_log_model(experiment, trainer)
if _should_log_confusion_matrix():
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
if _should_log_image_predictions():
_log_image_predictions(experiment, trainer.validator, curr_step)
def on_train_end(trainer):
"""Perform operations at the end of training."""
experiment = comet_ml.get_global_experiment()
if not experiment:
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
plots = trainer.args.plots
_log_model(experiment, trainer)
if plots:
_log_plots(experiment, trainer)
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
_log_image_predictions(experiment, trainer.validator, curr_step)
experiment.end()
global _comet_image_prediction_count
_comet_image_prediction_count = 0
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if comet_ml
else {}
)

View File

@ -0,0 +1,145 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["dvc"] is True # verify integration is enabled
import dvclive
assert checks.check_version("dvclive", "2.11.0", verbose=True)
import os
import re
from pathlib import Path
# DVCLive logger instance
live = None
_processed_plots = {}
# `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
# distinguish final evaluation of the best model vs last epoch validation
_training_epoch = False
except (ImportError, AssertionError, TypeError):
dvclive = None
def _log_images(path, prefix=""):
"""Logs images at specified path with an optional prefix using DVCLive."""
if live:
name = path.name
# Group images by batch to enable sliders in UI
if m := re.search(r"_batch(\d+)", name):
ni = m[1]
new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
name = (Path(new_stem) / ni).with_suffix(path.suffix)
live.log_image(os.path.join(prefix, name), path)
def _log_plots(plots, prefix=""):
"""Logs plot images for training progress if they have not been previously processed."""
for name, params in plots.items():
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
_log_images(name, prefix)
_processed_plots[name] = timestamp
def _log_confusion_matrix(validator):
"""Logs the confusion matrix for the given validator using DVCLive."""
targets = []
preds = []
matrix = validator.confusion_matrix.matrix
names = list(validator.names.values())
if validator.confusion_matrix.task == "detect":
names += ["background"]
for ti, pred in enumerate(matrix.T.astype(int)):
for pi, num in enumerate(pred):
targets.extend([names[ti]] * num)
preds.extend([names[pi]] * num)
live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
def on_pretrain_routine_start(trainer):
"""Initializes DVCLive logger for training metadata during pre-training routine."""
try:
global live
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
def on_pretrain_routine_end(trainer):
"""Logs plots related to the training process at the end of the pretraining routine."""
_log_plots(trainer.plots, "train")
def on_train_start(trainer):
"""Logs the training parameters if DVCLive logging is active."""
if live:
live.log_params(trainer.args)
def on_train_epoch_start(trainer):
"""Sets the global variable _training_epoch value to True at the start of training each epoch."""
global _training_epoch
_training_epoch = True
def on_fit_epoch_end(trainer):
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
global _training_epoch
if live and _training_epoch:
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for metric, value in model_info_for_loggers(trainer).items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, "train")
_log_plots(trainer.validator.plots, "val")
live.next_step()
_training_epoch = False
def on_train_end(trainer):
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
if live:
# At the end log the best metrics. It runs validator on the best model internally.
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, "val")
_log_plots(trainer.validator.plots, "val")
_log_confusion_matrix(trainer.validator)
if trainer.best.exists():
live.log_artifact(trainer.best, copy=True, type="model")
live.end()
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_start": on_train_start,
"on_train_epoch_start": on_train_epoch_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if dvclive
else {}
)

View File

@ -0,0 +1,108 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import json
from time import time
from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
from ultralytics.utils import LOGGER, SETTINGS
def on_pretrain_routine_end(trainer):
"""Logs info before starting timer for upload rate limit."""
session = getattr(trainer, "hub_session", None)
if session:
# Start timer for upload rate limit
session.timers = {
"metrics": time(),
"ckpt": time(),
} # start timer on session.rate_limit
def on_fit_epoch_end(trainer):
"""Uploads training progress metrics at the end of each epoch."""
session = getattr(trainer, "hub_session", None)
if session:
# Upload metrics after val end
all_plots = {
**trainer.label_loss_items(trainer.tloss, prefix="train"),
**trainer.metrics,
}
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
# If any metrics fail to upload, add them to the queue to attempt uploading again.
if session.metrics_upload_failed_queue:
session.metrics_queue.update(session.metrics_upload_failed_queue)
if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
session.upload_metrics()
session.timers["metrics"] = time() # reset timer
session.metrics_queue = {} # reset queue
def on_model_save(trainer):
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
session = getattr(trainer, "hub_session", None)
if session:
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}")
session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers["ckpt"] = time() # reset timer
def on_train_end(trainer):
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
session = getattr(trainer, "hub_session", None)
if session:
# Upload final model and metrics with exponential standoff
LOGGER.info(f"{PREFIX}Syncing final model...")
session.upload_model(
trainer.epoch,
trainer.best,
map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
final=True,
)
session.alive = False # stop heartbeats
LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀")
def on_train_start(trainer):
"""Run events on train start."""
events(trainer.args)
def on_val_start(validator):
"""Runs events on validation start."""
events(validator.args)
def on_predict_start(predictor):
"""Run events on predict start."""
events(predictor.args)
def on_export_start(exporter):
"""Run events on export start."""
events(exporter.args)
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_model_save": on_model_save,
"on_train_end": on_train_end,
"on_train_start": on_train_start,
"on_val_start": on_val_start,
"on_predict_start": on_predict_start,
"on_export_start": on_export_start,
}
if SETTINGS["hub"] is True
else {}
) # verify enabled

View File

@ -0,0 +1,133 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
MLflow Logging for Ultralytics YOLO.
This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts.
For setting up, a tracking URI should be specified. The logging can be customized using environment variables.
Commands:
1. To set a project name:
`export MLFLOW_EXPERIMENT_NAME=<your_experiment_name>` or use the project=<project> argument
2. To set a run name:
`export MLFLOW_RUN=<your_run_name>` or use the name=<name> argument
3. To start a local MLflow server:
mlflow server --backend-store-uri runs/mlflow
It will by default start a local server at http://127.0.0.1:5000.
To specify a different URI, set the MLFLOW_TRACKING_URI environment variable.
4. To kill all running MLflow server instances:
ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9
"""
from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr
try:
import os
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
assert SETTINGS["mlflow"] is True # verify integration is enabled
import mlflow
assert hasattr(mlflow, "__version__") # verify package is not directory
from pathlib import Path
PREFIX = colorstr("MLflow: ")
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
except (ImportError, AssertionError):
mlflow = None
def on_pretrain_routine_end(trainer):
"""
Log training parameters to MLflow at the end of the pretraining routine.
This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
from the trainer.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
Global:
mlflow: The imported mlflow module to use for logging.
Environment Variables:
MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of the training phase.
"""
global mlflow
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
mlflow.set_tracking_uri(uri)
# Set experiment and run names
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
mlflow.set_experiment(experiment_name)
mlflow.autolog()
try:
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
if Path(uri).is_dir():
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
mlflow.log_params(dict(trainer.args))
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
def on_train_epoch_end(trainer):
"""Log training metrics at the end of each train epoch to MLflow."""
if mlflow:
mlflow.log_metrics(
metrics={
**SANITIZE(trainer.lr),
**SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")),
},
step=trainer.epoch,
)
def on_fit_epoch_end(trainer):
"""Log training metrics at the end of each fit epoch to MLflow."""
if mlflow:
mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch)
def on_train_end(trainer):
"""Log model artifacts at the end of the training."""
if mlflow:
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best_gift_v10n.pt and last.pt
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
mlflow.log_artifact(str(f))
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true")
if keep_run_active:
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
else:
mlflow.end_run()
LOGGER.debug(f"{PREFIX}mlflow run ended")
LOGGER.info(
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
f"{PREFIX}disable with 'yolo settings mlflow=False'"
)
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if mlflow
else {}
)

View File

@ -0,0 +1,112 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["neptune"] is True # verify integration is enabled
import neptune
from neptune.types import File
assert hasattr(neptune, "__version__")
run = None # NeptuneAI experiment logger instance
except (ImportError, AssertionError):
neptune = None
def _log_scalars(scalars, step=0):
"""Log scalars to the NeptuneAI experiment logger."""
if run:
for k, v in scalars.items():
run[k].append(value=v, step=step)
def _log_images(imgs_dict, group=""):
"""Log scalars to the NeptuneAI experiment logger."""
if run:
for k, v in imgs_dict.items():
run[f"{group}/{k}"].upload(File(v))
def _log_plot(title, plot_path):
"""
Log plots to the NeptuneAI experiment logger.
Args:
title (str): Title of the plot.
plot_path (PosixPath | str): Path to the saved image file.
"""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
run[f"Plots/{title}"].upload(fig)
def on_pretrain_routine_start(trainer):
"""Callback function called before the training routine starts."""
try:
global run
run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
"""Callback function called at end of each training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
if trainer.epoch == 1:
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
def on_fit_epoch_end(trainer):
"""Callback function called at end of each fit (train+val) epoch."""
if run and trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
run["Configuration/Model"] = model_info_for_loggers(trainer)
_log_scalars(trainer.metrics, trainer.epoch + 1)
def on_val_end(validator):
"""Callback function called at end of each validation."""
if run:
# Log val_labels and val_pred
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
def on_train_end(trainer):
"""Callback function called at end of training."""
if run:
# Log final results, CM matrix + PR plots
files = [
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Log the final model
run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if neptune
else {}
)

View File

@ -0,0 +1,29 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import SETTINGS
try:
assert SETTINGS["raytune"] is True # verify integration is enabled
import ray
from ray import tune
from ray.air import session
except (ImportError, AssertionError):
tune = None
def on_fit_epoch_end(trainer):
"""Sends training metrics to Ray Tune at end of each epoch."""
if ray.tune.is_session_enabled():
metrics = trainer.metrics
metrics["epoch"] = trainer.epoch
session.report(metrics)
callbacks = (
{
"on_fit_epoch_end": on_fit_epoch_end,
}
if tune
else {}
)

View File

@ -0,0 +1,106 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
try:
# WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
from torch.utils.tensorboard import SummaryWriter
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["tensorboard"] is True # verify integration is enabled
WRITER = None # TensorBoard SummaryWriter instance
PREFIX = colorstr("TensorBoard: ")
# Imports below only required if TensorBoard enabled
import warnings
from copy import deepcopy
from ultralytics.utils.torch_utils import de_parallel, torch
except (ImportError, AssertionError, TypeError, AttributeError):
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
# AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
SummaryWriter = None
def _log_scalars(scalars, step=0):
"""Logs scalar values to TensorBoard."""
if WRITER:
for k, v in scalars.items():
WRITER.add_scalar(k, v, step)
def _log_tensorboard_graph(trainer):
"""Log model graph to TensorBoard."""
# Input image
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
# Try simple method first (YOLO)
with contextlib.suppress(Exception):
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(de_parallel(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
def on_pretrain_routine_start(trainer):
"""Initialize TensorBoard logging with SummaryWriter."""
if SummaryWriter:
try:
global WRITER
WRITER = SummaryWriter(str(trainer.save_dir))
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
def on_train_start(trainer):
"""Log TensorBoard graph."""
if WRITER:
_log_tensorboard_graph(trainer)
def on_train_epoch_end(trainer):
"""Logs scalar statistics at the end of a training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
def on_fit_epoch_end(trainer):
"""Logs epoch metrics at end of training epoch."""
_log_scalars(trainer.metrics, trainer.epoch + 1)
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_start": on_train_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_epoch_end": on_train_epoch_end,
}
if SummaryWriter
else {}
)

View File

@ -0,0 +1,163 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import SETTINGS, TESTS_RUNNING
from ultralytics.utils.torch_utils import model_info_for_loggers
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["wandb"] is True # verify integration is enabled
import wandb as wb
assert hasattr(wb, "__version__") # verify package is not directory
import numpy as np
import pandas as pd
_processed_plots = {}
except (ImportError, AssertionError):
wb = None
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
"""
Create and log a custom metric visualization to wandb.plot.pr_curve.
This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
different classes.
Args:
x (List): Values for the x-axis; expected to have length N.
y (List): Corresponding values for the y-axis; also expected to have length N.
classes (List): Labels identifying the class of each point; length N.
title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
y_title (str, optional): Label for the y-axis; defaults to 'Precision'.
Returns:
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
"""
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
fields = {"x": "x", "y": "y", "class": "class"}
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
return wb.plot_table(
"wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
)
def _plot_curve(
x,
y,
names=None,
id="precision-recall",
title="Precision Recall Curve",
x_title="Recall",
y_title="Precision",
num_x=100,
only_mean=False,
):
"""
Log a metric curve visualization.
This function generates a metric curve based on input data and logs the visualization to wandb.
The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
Args:
x (np.ndarray): Data points for the x-axis with length N.
y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.
Note:
The function leverages the '_custom_table' function to generate the actual visualization.
"""
# Create new x
if names is None:
names = []
x_new = np.linspace(x[0], x[-1], num_x).round(5)
# Create arrays for logging
x_log = x_new.tolist()
y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
if only_mean:
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
else:
classes = ["mean"] * len(x_log)
for i, yi in enumerate(y):
x_log.extend(x_new) # add new x
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
classes.extend([names[i]] * len(x_new)) # add class names
wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
def _log_plots(plots, step):
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
for name, params in plots.items():
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
_processed_plots[name] = timestamp
def on_pretrain_routine_start(trainer):
"""Initiate and start project if module is present."""
wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
def on_fit_epoch_end(trainer):
"""Logs training metrics and model information at the end of an epoch."""
wb.run.log(trainer.metrics, step=trainer.epoch + 1)
_log_plots(trainer.plots, step=trainer.epoch + 1)
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
if trainer.epoch == 0:
wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
def on_train_epoch_end(trainer):
"""Log metrics and save images at the end of each training epoch."""
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
wb.run.log(trainer.lr, step=trainer.epoch + 1)
if trainer.epoch == 1:
_log_plots(trainer.plots, step=trainer.epoch + 1)
def on_train_end(trainer):
"""Save the best model as an artifact at end of training."""
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
_log_plots(trainer.plots, step=trainer.epoch + 1)
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
if trainer.best.exists():
art.add_file(trainer.best)
wb.run.log_artifact(art, aliases=["best"])
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
x, y, x_title, y_title = curve_values
_plot_curve(
x,
y,
names=list(trainer.validator.metrics.names.values()),
id=f"curves/{curve_name}",
title=curve_name,
x_title=x_title,
y_title=y_title,
)
wb.run.finish() # required or run continues on dashboard
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if wb
else {}
)

731
ultralytics/utils/checks.py Normal file
View File

@ -0,0 +1,731 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import glob
import inspect
import math
import os
import platform
import re
import shutil
import subprocess
import time
from importlib import metadata
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
import requests
import torch
from matplotlib import font_manager
from ultralytics.utils import (
ASSETS,
AUTOINSTALL,
LINUX,
LOGGER,
ONLINE,
ROOT,
USER_CONFIG_DIR,
SimpleNamespace,
ThreadingLocked,
TryExcept,
clean_url,
colorstr,
downloads,
emojis,
is_colab,
is_docker,
is_github_action_running,
is_jupyter,
is_kaggle,
is_online,
is_pip_package,
url2file,
)
PYTHON_VERSION = platform.python_version()
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
"""
Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
Args:
file_path (Path): Path to the requirements.txt file.
package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'.
Returns:
(List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys.
Example:
```python
from ultralytics.utils.checks import parse_requirements
parse_requirements(package='ultralytics')
```
"""
if package:
requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
else:
requires = Path(file_path).read_text().splitlines()
requirements = []
for line in requires:
line = line.strip()
if line and not line.startswith("#"):
line = line.split("#")[0].strip() # ignore inline comments
match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line)
if match:
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
return requirements
def parse_version(version="0.0.0") -> tuple:
"""
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
function replaces deprecated 'pkg_resources.parse_version(v)'.
Args:
version (str): Version string, i.e. '2.0.1+cpu'
Returns:
(tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
"""
try:
return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
return 0, 0, 0
def is_ascii(s) -> bool:
"""
Check if a string is composed of only ASCII characters.
Args:
s (str): String to be checked.
Returns:
(bool): True if the string is composed only of ASCII characters, False otherwise.
"""
# Convert list, tuple, None, etc. to string
s = str(s)
# Check if the string is composed of only ASCII characters
return all(ord(c) < 128 for c in s)
def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
"""
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
Args:
imgsz (int | cList[int]): Image size.
stride (int): Stride value.
min_dim (int): Minimum number of dimensions.
max_dim (int): Maximum number of dimensions.
floor (int): Minimum allowed value for image size.
Returns:
(List[int]): Updated image size.
"""
# Convert stride to integer if it is a tensor
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
# Convert image size to list if it is an integer
if isinstance(imgsz, int):
imgsz = [imgsz]
elif isinstance(imgsz, (list, tuple)):
imgsz = list(imgsz)
elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'
imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
else:
raise TypeError(
f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
)
# Apply max_dim
if len(imgsz) > max_dim:
msg = (
"'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
)
if max_dim != 1:
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
imgsz = [max(imgsz)]
# Make image size a multiple of the stride
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
# Print warning message if image size was updated
if sz != imgsz:
LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
# Add missing dimensions if necessary
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
return sz
def check_version(
current: str = "0.0.0",
required: str = "0.0.0",
name: str = "version",
hard: bool = False,
verbose: bool = False,
msg: str = "",
) -> bool:
"""
Check current version against the required version or range.
Args:
current (str): Current version or package name to get version from.
required (str): Required version or range (in pip-style format).
name (str, optional): Name to be used in warning message.
hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
verbose (bool, optional): If True, print warning message if requirement is not met.
msg (str, optional): Extra message to display if verbose.
Returns:
(bool): True if requirement is met, False otherwise.
Example:
```python
# Check if current version is exactly 22.04
check_version(current='22.04', required='==22.04')
# Check if current version is greater than or equal to 22.04
check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed
# Check if current version is less than or equal to 22.04
check_version(current='22.04', required='<=22.04')
# Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
check_version(current='21.10', required='>20.04,<22.04')
```
"""
if not current: # if current is '' or None
LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
return True
elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
try:
name = current # assigned package name to 'name' arg
current = metadata.version(current) # get version string from package name
except metadata.PackageNotFoundError as e:
if hard:
raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e
else:
return False
if not required: # if required is '' or None
return True
op = ""
version = ""
result = True
c = parse_version(current) # '1.2.3' -> (1, 2, 3)
for r in required.strip(",").split(","):
op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
v = parse_version(version) # '1.2.3' -> (1, 2, 3)
if op == "==" and c != v:
result = False
elif op == "!=" and c == v:
result = False
elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
result = False
elif op == "<=" and not (c <= v):
result = False
elif op == ">" and not (c > v):
result = False
elif op == "<" and not (c < v):
result = False
if not result:
warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
if hard:
raise ModuleNotFoundError(emojis(warning)) # assert version requirements met
if verbose:
LOGGER.warning(warning)
return result
def check_latest_pypi_version(package_name="ultralytics"):
"""
Returns the latest version of a PyPI package without downloading or installing it.
Parameters:
package_name (str): The name of the package to find the latest version for.
Returns:
(str): The latest version of the package.
"""
with contextlib.suppress(Exception):
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
if response.status_code == 200:
return response.json()["info"]["version"]
def check_pip_update_available():
"""
Checks if a new version of the ultralytics package is available on PyPI.
Returns:
(bool): True if an update is available, False otherwise.
"""
if ONLINE and is_pip_package():
with contextlib.suppress(Exception):
from ultralytics import __version__
latest = check_latest_pypi_version()
if check_version(__version__, f"<{latest}"): # check if current version is < latest version
LOGGER.info(
f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
f"Update with 'pip install -U ultralytics'"
)
return True
return False
@ThreadingLocked()
def check_font(font="Arial.ttf"):
"""
Find font locally or download to user's configuration directory if it does not already exist.
Args:
font (str): Path or name of font.
Returns:
file (Path): Resolved font file path.
"""
name = Path(font).name
# Check USER_CONFIG_DIR
file = USER_CONFIG_DIR / name
if file.exists():
return file
# Check system fonts
matches = [s for s in font_manager.findSystemFonts() if font in s]
if any(matches):
return matches[0]
# Download to USER_CONFIG_DIR if missing
url = f"https://ultralytics.com/assets/{name}"
if downloads.is_url(url, check=True):
downloads.safe_download(url=url, file=file)
return file
def check_python(minimum: str = "3.8.0") -> bool:
"""
Check current python version against the required minimum version.
Args:
minimum (str): Required minimum version of python.
Returns:
(bool): Whether the installed Python version meets the minimum constraints.
"""
return check_version(PYTHON_VERSION, minimum, name="Python ", hard=True)
@TryExcept()
def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
"""
Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
Args:
requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a
string, or a list of package requirements as strings.
exclude (Tuple[str]): Tuple of package names to exclude from checking.
install (bool): If True, attempt to auto-update packages that don't meet requirements.
cmds (str): Additional commands to pass to the pip install command when auto-updating.
Example:
```python
from ultralytics.utils.checks import check_requirements
# Check a requirements.txt file
check_requirements('path/to/requirements.txt')
# Check a single package
check_requirements('ultralytics>=8.0.0')
# Check multiple packages
check_requirements(['numpy', 'ultralytics>=8.0.0'])
```
"""
prefix = colorstr("red", "bold", "requirements:")
check_python() # check python version
check_torchvision() # check torch-torchvision compatibility
if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve()
assert file.exists(), f"{prefix} {file} not found, check failed."
requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
elif isinstance(requirements, str):
requirements = [requirements]
pkgs = []
for r in requirements:
r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
name, required = match[1], match[2].strip() if match[2] else ""
try:
assert check_version(metadata.version(name), required) # exception if requirements not met
except (AssertionError, metadata.PackageNotFoundError):
pkgs.append(r)
s = " ".join(f'"{x}"' for x in pkgs) # console string
if s:
if install and AUTOINSTALL: # check environment variable
n = len(pkgs) # number of packages updates
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
try:
t = time.time()
assert is_online(), "AutoUpdate skipped (offline)"
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
dt = time.time() - t
LOGGER.info(
f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
)
except Exception as e:
LOGGER.warning(f"{prefix}{e}")
return False
else:
return False
return True
def check_torchvision():
"""
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
to the provided compatibility table based on:
https://github.com/pytorch/vision#installation.
The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
Torchvision versions.
"""
import torchvision
# Compatibility table
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
# Extract only the major and minor versions
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch]
if all(v_torchvision != v for v in compatible_versions):
print(
f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
"'pip install -U torch torchvision' to update both.\n"
"For a full compatibility table see https://github.com/pytorch/vision#installation"
)
def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):
"""Check file(s) for acceptable suffix."""
if file and suffix:
if isinstance(suffix, str):
suffix = (suffix,)
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower().strip() # file suffix
if len(s):
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"
def check_yolov5u_filename(file: str, verbose: bool = True):
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
if "yolov3" in file or "yolov5" in file:
if "u.yaml" in file:
file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
elif ".pt" in file and "u" not in file:
original_file = file
file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file and verbose:
LOGGER.info(
f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
)
return file
def check_model_file_from_stem(model="yolov8n"):
"""Return a model filename from a valid model stem."""
if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt
else:
return model
def check_file(file, suffix="", download=True, hard=True):
"""Search/download file (if necessary) and return path."""
check_suffix(file, suffix) # optional
file = str(file).strip() # convert to string and strip spaces
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
if (
not file
or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
or file.lower().startswith("grpc://")
): # file exists or gRPC Triton images
return file
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
url = file # warning: Pathlib turns :// -> :/
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
if Path(file).exists():
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
downloads.safe_download(url=url, file=file, unzip=False)
return file
else: # search
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
if not files and hard:
raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard:
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
return files[0] if len(files) else [] if hard else file # return file
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
"""Search/download YAML file (if necessary) and return path, checking suffix."""
return check_file(file, suffix, hard=hard)
def check_is_path_safe(basedir, path):
"""
Check if the resolved path is under the intended directory to prevent path traversal.
Args:
basedir (Path | str): The intended directory.
path (Path | str): The path to check.
Returns:
(bool): True if the path is safe, False otherwise.
"""
base_dir_resolved = Path(basedir).resolve()
path_resolved = Path(path).resolve()
return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
def check_imshow(warn=False):
"""Check if environment supports image displays."""
try:
if LINUX:
assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle()
cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
return True
except Exception as e:
if warn:
LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
return False
def check_yolo(verbose=True, device=""):
"""Return a human-readable YOLO software and hardware summary."""
import psutil
from ultralytics.utils.torch_utils import select_device
if is_jupyter():
if check_requirements("wandb", install=False):
os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang
if is_colab():
shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
if verbose:
# System info
gib = 1 << 30 # bytes per GiB
ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage("/")
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
with contextlib.suppress(Exception): # clear display if ipython is installed
from IPython import display
display.clear_output()
else:
s = ""
select_device(device=device, newline=False)
LOGGER.info(f"Setup complete ✅ {s}")
def collect_system_info():
"""Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
import psutil
from ultralytics.utils import ENVIRONMENT, is_git_dir
from ultralytics.utils.torch_utils import get_cpu_info
ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
check_yolo()
LOGGER.info(
f"\n{'OS':<20}{platform.platform()}\n"
f"{'Environment':<20}{ENVIRONMENT}\n"
f"{'Python':<20}{PYTHON_VERSION}\n"
f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
f"{'RAM':<20}{ram_info:.2f} GB\n"
f"{'CPU':<20}{get_cpu_info()}\n"
f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
)
for r in parse_requirements(package="ultralytics"):
try:
current = metadata.version(r.name)
is_met = "" if check_version(current, str(r.specifier), hard=True) else ""
except metadata.PackageNotFoundError:
current = "(not installed)"
is_met = ""
LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")
if is_github_action_running():
LOGGER.info(
f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
)
def check_amp(model):
"""
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
be disabled during training.
Args:
model (nn.Module): A YOLOv8 model instance.
Example:
```python
from ultralytics import YOLO
from ultralytics.utils.checks import check_amp
model = YOLO('yolov8n.pt').model.cuda()
check_amp(model)
```
Returns:
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
"""
device = next(model.parameters()).device # get model device
if device.type in ("cpu", "mps"):
return False # AMP only used on CUDA devices
def amp_allclose(m, im):
"""All close FP32 vs AMP results."""
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
with torch.cuda.amp.autocast(True):
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
im = ASSETS / "bus.jpg" # image to check
prefix = colorstr("AMP: ")
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
try:
from ultralytics import YOLO
assert amp_allclose(YOLO("yolov8n.pt"), im)
LOGGER.info(f"{prefix}checks passed ✅")
except ConnectionError:
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")
except (AttributeError, ModuleNotFoundError):
LOGGER.warning(
f"{prefix}checks skipped ⚠️. "
f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
)
except AssertionError:
LOGGER.warning(
f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
)
return False
return True
def git_describe(path=ROOT): # path must be a directory
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
with contextlib.suppress(Exception):
return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
return ""
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
"""Print function arguments (optional args dict)."""
def strip_auth(v):
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
except ValueError:
file = Path(file).stem
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))
def cuda_device_count() -> int:
"""
Get the number of NVIDIA GPUs available in the environment.
Returns:
(int): The number of NVIDIA GPUs available.
"""
try:
# Run the nvidia-smi command and capture its output
output = subprocess.check_output(
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
)
# Take the first line and strip any leading/trailing white space
first_line = output.strip().split("\n")[0]
return int(first_line)
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
# If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available
return 0
def cuda_is_available() -> bool:
"""
Check if CUDA is available in the environment.
Returns:
(bool): True if one or more NVIDIA GPUs are available, False otherwise.
"""
return cuda_device_count() > 0
# Define constants
IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")

71
ultralytics/utils/dist.py Normal file
View File

@ -0,0 +1,71 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
import shutil
import socket
import sys
import tempfile
from . import USER_CONFIG_DIR
from .torch_utils import TORCH_1_9
def find_free_network_port() -> int:
"""
Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1] # port
def generate_ddp_file(trainer):
"""Generates a DDP file and returns its file name."""
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
content = f"""
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
overrides = {vars(trainer.args)}
if __name__ == "__main__":
from {module} import {name}
from ultralytics.utils import DEFAULT_CFG_DICT
cfg = DEFAULT_CFG_DICT.copy()
cfg.update(save_dir='') # handle the extra key 'save_dir'
trainer = {name}(cfg=cfg, overrides=overrides)
results = trainer.train()
"""
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(
prefix="_temp_",
suffix=f"{id(trainer)}.py",
mode="w+",
encoding="utf-8",
dir=USER_CONFIG_DIR / "DDP",
delete=False,
) as file:
file.write(content)
return file.name
def generate_ddp_command(world_size, trainer):
"""Generates and returns command for distributed training."""
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir
file = generate_ddp_file(trainer)
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
port = find_free_network_port()
cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
return cmd, file
def ddp_cleanup(trainer, file):
"""Delete temp file if created."""
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
os.remove(file)

View File

@ -0,0 +1,500 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import re
import shutil
import subprocess
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
from urllib import parse, request
import requests
import torch
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
GITHUB_ASSETS_REPO = "ultralytics/assets"
GITHUB_ASSETS_NAMES = (
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+ [f"yolov8{k}-world.pt" for k in "smlx"]
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
+ [f"yolov9{k}.pt" for k in "ce"]
+ [f"yolo_nas_{k}.pt" for k in "sml"]
+ [f"sam_{k}.pt" for k in "bl"]
+ [f"FastSAM-{k}.pt" for k in "sx"]
+ [f"rtdetr-{k}.pt" for k in "lx"]
+ ["mobile_sam.pt"]
+ ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"]
)
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
def is_url(url, check=False):
"""
Validates if the given string is a URL and optionally checks if the URL exists online.
Args:
url (str): The string to be validated as a URL.
check (bool, optional): If True, performs an additional check to see if the URL exists online.
Defaults to True.
Returns:
(bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online.
Returns False otherwise.
Example:
```python
valid = is_url("https://www.example.com")
```
"""
with contextlib.suppress(Exception):
url = str(url)
result = parse.urlparse(url)
assert all([result.scheme, result.netloc]) # check if is url
if check:
with request.urlopen(url) as response:
return response.getcode() == 200 # check if exists online
return True
return False
def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
"""
Deletes all ".DS_store" files under a specified directory.
Args:
path (str, optional): The directory path where the ".DS_store" files should be deleted.
files_to_delete (tuple): The files to be deleted.
Example:
```python
from ultralytics.utils.downloads import delete_dsstore
delete_dsstore('path/to/dir')
```
Note:
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
are hidden system files and can cause issues when transferring files between different operating systems.
"""
for file in files_to_delete:
matches = list(Path(path).rglob(file))
LOGGER.info(f"Deleting {file} files: {matches}")
for f in matches:
f.unlink()
def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
"""
Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
named after the directory and placed alongside it.
Args:
directory (str | Path): The path to the directory to be zipped.
compress (bool): Whether to compress the files while zipping. Default is True.
exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
progress (bool, optional): Whether to display a progress bar. Defaults to True.
Returns:
(Path): The path to the resulting zip file.
Example:
```python
from ultralytics.utils.downloads import zip_directory
file = zip_directory('path/to/dir')
```
"""
from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile
delete_dsstore(directory)
directory = Path(directory)
if not directory.is_dir():
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
# Unzip with progress bar
files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
zip_file = directory.with_suffix(".zip")
compression = ZIP_DEFLATED if compress else ZIP_STORED
with ZipFile(zip_file, "w", compression) as f:
for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress):
f.write(file, file.relative_to(directory))
return zip_file # return path to zip file
def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
"""
Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
If the zipfile does not contain a single top-level directory, the function will create a new
directory with the same name as the zipfile (without the extension) to extract its contents.
If a path is not provided, the function will use the parent directory of the zipfile as the default path.
Args:
file (str): The path to the zipfile to be extracted.
path (str, optional): The path to extract the zipfile to. Defaults to None.
exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False.
progress (bool, optional): Whether to display a progress bar. Defaults to True.
Raises:
BadZipFile: If the provided file does not exist or is not a valid zipfile.
Returns:
(Path): The path to the directory where the zipfile was extracted.
Example:
```python
from ultralytics.utils.downloads import unzip_file
dir = unzip_file('path/to/file.zip')
```
"""
from zipfile import BadZipFile, ZipFile, is_zipfile
if not (Path(file).exists() and is_zipfile(file)):
raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.")
if path is None:
path = Path(file).parent # default path
# Unzip the file contents
with ZipFile(file) as zipObj:
files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
top_level_dirs = {Path(f).parts[0] for f in files}
if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")):
# Zip has multiple files at top level
path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8
else:
# Zip has 1 top-level directory
extract_path = path # i.e. ../datasets
path = Path(path) / list(top_level_dirs)[0] # i.e. ../datasets/coco8
# Check if destination directory already exists and contains files
if path.exists() and any(path.iterdir()) and not exist_ok:
# If it exists and is not empty, return the path without unzipping
LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.")
return path
for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress):
# Ensure the file is within the extract_path to avoid path traversal security vulnerability
if ".." in Path(f).parts:
LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
continue
zipObj.extract(f, extract_path)
return path # return unzip dir
def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", path=Path.cwd(), sf=1.5, hard=True):
"""
Check if there is sufficient disk space to download and store a file.
Args:
url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco128.zip'.
path (str | Path, optional): The path or drive to check the available free space on.
sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0.
hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True.
Returns:
(bool): True if there is sufficient disk space, False otherwise.
"""
try:
r = requests.head(url) # response
assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response
except Exception:
return True # requests issue, default to True
# Check file size
gib = 1 << 30 # bytes per GiB
data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB)
total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes
if data * sf < free:
return True # sufficient space
# Insufficient space
text = (
f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
f"Please free {data * sf - free:.1f} GB additional disk space and try again."
)
if hard:
raise MemoryError(text)
LOGGER.warning(text)
return False
def get_google_drive_file_info(link):
"""
Retrieves the direct download link and filename for a shareable Google Drive file link.
Args:
link (str): The shareable link of the Google Drive file.
Returns:
(str): Direct download URL for the Google Drive file.
(str): Original filename of the Google Drive file. If filename extraction fails, returns None.
Example:
```python
from ultralytics.utils.downloads import get_google_drive_file_info
link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link"
url, filename = get_google_drive_file_info(link)
```
"""
file_id = link.split("/d/")[1].split("/view")[0]
drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
filename = None
# Start session
with requests.Session() as session:
response = session.get(drive_url, stream=True)
if "quota exceeded" in str(response.content.lower()):
raise ConnectionError(
emojis(
f"❌ Google Drive file download quota exceeded. "
f"Please try again later or download this file manually at {link}."
)
)
for k, v in response.cookies.items():
if k.startswith("download_warning"):
drive_url += f"&confirm={v}" # v is token
cd = response.headers.get("content-disposition")
if cd:
filename = re.findall('filename="(.+)"', cd)[0]
return drive_url, filename
def safe_download(
url,
file=None,
dir=None,
unzip=True,
delete=False,
curl=False,
retry=3,
min_bytes=1e0,
exist_ok=False,
progress=True,
):
"""
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
Args:
url (str): The URL of the file to be downloaded.
file (str, optional): The filename of the downloaded file.
If not provided, the file will be saved with the same name as the URL.
dir (str, optional): The directory to save the downloaded file.
If not provided, the file will be saved in the current working directory.
unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
curl (bool, optional): Whether to use curl command line tool for downloading. Default: False.
retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
a successful download. Default: 1E0.
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
progress (bool, optional): Whether to display a progress bar during the download. Default: True.
Example:
```python
from ultralytics.utils.downloads import safe_download
link = "https://ultralytics.com/assets/bus.jpg"
path = safe_download(link)
```
"""
gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
if gdrive:
url, file = get_google_drive_file_info(url)
f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename
elif not f.is_file(): # URL and file do not exist
desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'"
LOGGER.info(f"{desc}...")
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
check_disk_space(url, path=f.parent)
for i in range(retry + 1):
try:
if curl or i > 0: # curl download with retry, continue
s = "sS" * (not progress) # silent
r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
assert r == 0, f"Curl return value {r}"
else: # urllib download
method = "torch"
if method == "torch":
torch.hub.download_url_to_file(url, f, progress=progress)
else:
with request.urlopen(url) as response, TQDM(
total=int(response.getheader("Content-Length", 0)),
desc=desc,
disable=not progress,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
with open(f, "wb") as f_opened:
for data in response:
f_opened.write(data)
pbar.update(len(data))
if f.exists():
if f.stat().st_size > min_bytes:
break # success
f.unlink() # remove partial downloads
except Exception as e:
if i == 0 and not is_online():
raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e
elif i >= retry:
raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
from zipfile import is_zipfile
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
if is_zipfile(f):
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
elif f.suffix in (".tar", ".gz"):
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
if delete:
f.unlink() # remove zip
return unzip_dir
def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
"""
Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
function fetches the latest release assets.
Args:
repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
version (str, optional): The release version to fetch assets from. Defaults to 'latest'.
retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False.
Returns:
(tuple): A tuple containing the release tag and a list of asset names.
Example:
```python
tag, assets = get_github_assets(repo='ultralytics/assets', version='latest')
```
"""
if version != "latest":
version = f"tags/{version}" # i.e. tags/v6.2
url = f"https://api.github.com/repos/{repo}/releases/{version}"
r = requests.get(url) # github api
if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
r = requests.get(url) # try again
if r.status_code != 200:
LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}")
return "", []
data = r.json()
return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
def attempt_download_asset(file, repo="ultralytics/assets", release="v8.1.0", **kwargs):
"""
Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
locally first, then tries to download it from the specified GitHub repository release.
Args:
file (str | Path): The filename or file path to be downloaded.
repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
release (str, optional): The specific release version to be downloaded. Defaults to 'v8.1.0'.
**kwargs (any): Additional keyword arguments for the download process.
Returns:
(str): The path to the downloaded file.
Example:
```python
file_path = attempt_download_asset('yolov5s.pt', repo='ultralytics/assets', release='latest')
```
"""
from ultralytics.utils import SETTINGS # scoped for circular import
# YOLOv3/5u updates
file = str(file)
file = checks.check_yolov5u_filename(file)
file = Path(file.strip().replace("'", ""))
if file.exists():
return str(file)
elif (SETTINGS["weights_dir"] / file).exists():
return str(SETTINGS["weights_dir"] / file)
else:
# URL specified
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
download_url = f"https://github.com/{repo}/releases/download"
if str(file).startswith(("http:/", "https:/")): # download
url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
if Path(file).is_file():
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
else:
tag, assets = get_github_assets(repo, release)
if not assets:
tag, assets = get_github_assets(repo) # latest release
if name in assets:
safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
return str(file)
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):
"""
Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are
specified.
Args:
url (str | list): The URL or list of URLs of the files to be downloaded.
dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory.
unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True.
delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False.
curl (bool, optional): Flag to use curl for downloading. Defaults to False.
threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1.
retry (int, optional): Number of retries in case of download failure. Defaults to 3.
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
Example:
```python
download('https://ultralytics.com/assets/example.zip', dir='path/to/dir', unzip=True)
```
"""
dir = Path(dir)
dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1:
with ThreadPool(threads) as pool:
pool.map(
lambda x: safe_download(
url=x[0],
dir=x[1],
unzip=unzip,
delete=delete,
curl=curl,
retry=retry,
exist_ok=exist_ok,
progress=threads <= 1,
),
zip(url, repeat(dir)),
)
pool.close()
pool.join()
else:
for u in [url] if isinstance(url, (str, Path)) else url:
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)

View File

@ -0,0 +1,22 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import emojis
class HUBModelError(Exception):
"""
Custom exception class for handling errors related to model fetching in Ultralytics YOLO.
This exception is raised when a requested model is not found or cannot be retrieved.
The message is also processed to include emojis for better user experience.
Attributes:
message (str): The error message displayed when the exception is raised.
Note:
The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package.
"""
def __init__(self, message="Model not found. Please check model URL and try again."):
"""Create an exception for when a model is not found."""
super().__init__(emojis(message))

188
ultralytics/utils/files.py Normal file
View File

@ -0,0 +1,188 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import glob
import os
import shutil
import tempfile
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
class WorkingDirectory(contextlib.ContextDecorator):
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
def __init__(self, new_dir):
"""Sets the working directory to 'new_dir' upon instantiation."""
self.dir = new_dir # new dir
self.cwd = Path.cwd().resolve() # current dir
def __enter__(self):
"""Changes the current directory to the specified directory."""
os.chdir(self.dir)
def __exit__(self, exc_type, exc_val, exc_tb): # noqa
"""Restore the current working directory on context exit."""
os.chdir(self.cwd)
@contextmanager
def spaces_in_path(path):
"""
Context manager to handle paths with spaces in their names. If a path contains spaces, it replaces them with
underscores, copies the file/directory to the new path, executes the context code block, then copies the
file/directory back to its original location.
Args:
path (str | Path): The original path.
Yields:
(Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.
Example:
```python
with ultralytics.utils.files import spaces_in_path
with spaces_in_path('/path/with spaces') as new_path:
# Your code here
```
"""
# If path has spaces, replace them with underscores
if " " in str(path):
string = isinstance(path, str) # input type
path = Path(path)
# Create a temporary directory and construct the new path
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
# Copy file/directory
if path.is_dir():
# tmp_path.mkdir(parents=True, exist_ok=True)
shutil.copytree(path, tmp_path)
elif path.is_file():
tmp_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(path, tmp_path)
try:
# Yield the temporary path
yield str(tmp_path) if string else tmp_path
finally:
# Copy file/directory back
if tmp_path.is_dir():
shutil.copytree(tmp_path, path, dirs_exist_ok=True)
elif tmp_path.is_file():
shutil.copy2(tmp_path, path) # Copy back the file
else:
# If there are no spaces, just yield the original path
yield path
def increment_path(path, exist_ok=False, sep="", mkdir=False):
"""
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
directory if it does not already exist.
Args:
path (str, pathlib.Path): Path to increment.
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
Returns:
(pathlib.Path): Incremented path.
"""
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
# Method 1
for n in range(2, 9999):
p = f"{path}{sep}{n}{suffix}" # increment path
if not os.path.exists(p):
break
path = Path(p)
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
def file_age(path=__file__):
"""Return days since last file update."""
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
return dt.days # + dt.seconds / 86400 # fractional days
def file_date(path=__file__):
"""Return human-readable file modification date, i.e. '2021-3-26'."""
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
return f"{t.year}-{t.month}-{t.day}"
def file_size(path):
"""Return file/dir size (MB)."""
if isinstance(path, (str, Path)):
mb = 1 << 20 # bytes to MiB (1024 ** 2)
path = Path(path)
if path.is_file():
return path.stat().st_size / mb
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
return 0.0
def get_latest_run(search_dir="."):
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ""
def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False):
"""
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
Args:
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
update_names (bool, optional): Update model names from a data YAML.
Example:
```python
from ultralytics.utils.files import update_models
model_names = (f"rtdetr-{size}.pt" for size in "lx")
update_models(model_names)
```
"""
from ultralytics import YOLO
from ultralytics.nn.autobackend import default_class_names
target_dir = source_dir / "updated_models"
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
for model_name in model_names:
model_path = source_dir / model_name
print(f"Loading model from {model_path}")
# Load model
model = YOLO(model_path)
model.half()
if update_names: # update model names from a dataset YAML
model.model.names = default_class_names("coco8.yaml")
# Define new save path
save_path = target_dir / model_name
# Save model using model.save()
print(f"Re-saving {model_name} model to {save_path}")
model.save(save_path, use_dill=False)

View File

@ -0,0 +1,407 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from collections import abc
from itertools import repeat
from numbers import Number
from typing import List
import numpy as np
from .ops import ltwh2xywh, ltwh2xyxy, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
def _ntuple(n):
"""From PyTorch internals."""
def parse(x):
"""Parse bounding boxes format between XYWH and LTWH."""
return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
to_4tuple = _ntuple(4)
# `xyxy` means left top and right bottom
# `xywh` means center x, center y and width, height(YOLO format)
# `ltwh` means left top and width, height(COCO format)
_formats = ["xyxy", "xywh", "ltwh"]
__all__ = ("Bboxes",) # tuple or list
class Bboxes:
"""
A class for handling bounding boxes.
The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'.
Bounding box data should be provided in numpy arrays.
Attributes:
bboxes (numpy.ndarray): The bounding boxes stored in a 2D numpy array.
format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').
Note:
This class does not handle normalization or denormalization of bounding boxes.
"""
def __init__(self, bboxes, format="xyxy") -> None:
"""Initializes the Bboxes class with bounding box data in a specified format."""
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
assert bboxes.ndim == 2
assert bboxes.shape[1] == 4
self.bboxes = bboxes
self.format = format
# self.normalized = normalized
def convert(self, format):
"""Converts bounding box format from one type to another."""
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
if self.format == format:
return
elif self.format == "xyxy":
func = xyxy2xywh if format == "xywh" else xyxy2ltwh
elif self.format == "xywh":
func = xywh2xyxy if format == "xyxy" else xywh2ltwh
else:
func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
self.bboxes = func(self.bboxes)
self.format = format
def areas(self):
"""Return box areas."""
self.convert("xyxy")
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
# def denormalize(self, w, h):
# if not self.normalized:
# return
# assert (self.bboxes <= 1.0).all()
# self.bboxes[:, 0::2] *= w
# self.bboxes[:, 1::2] *= h
# self.normalized = False
#
# def normalize(self, w, h):
# if self.normalized:
# return
# assert (self.bboxes > 1.0).any()
# self.bboxes[:, 0::2] /= w
# self.bboxes[:, 1::2] /= h
# self.normalized = True
def mul(self, scale):
"""
Args:
scale (tuple | list | int): the scale for four coords.
"""
if isinstance(scale, Number):
scale = to_4tuple(scale)
assert isinstance(scale, (tuple, list))
assert len(scale) == 4
self.bboxes[:, 0] *= scale[0]
self.bboxes[:, 1] *= scale[1]
self.bboxes[:, 2] *= scale[2]
self.bboxes[:, 3] *= scale[3]
def add(self, offset):
"""
Args:
offset (tuple | list | int): the offset for four coords.
"""
if isinstance(offset, Number):
offset = to_4tuple(offset)
assert isinstance(offset, (tuple, list))
assert len(offset) == 4
self.bboxes[:, 0] += offset[0]
self.bboxes[:, 1] += offset[1]
self.bboxes[:, 2] += offset[2]
self.bboxes[:, 3] += offset[3]
def __len__(self):
"""Return the number of boxes."""
return len(self.bboxes)
@classmethod
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
"""
Concatenate a list of Bboxes objects into a single Bboxes object.
Args:
boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate.
axis (int, optional): The axis along which to concatenate the bounding boxes.
Defaults to 0.
Returns:
Bboxes: A new Bboxes object containing the concatenated bounding boxes.
Note:
The input should be a list or tuple of Bboxes objects.
"""
assert isinstance(boxes_list, (list, tuple))
if not boxes_list:
return cls(np.empty(0))
assert all(isinstance(box, Bboxes) for box in boxes_list)
if len(boxes_list) == 1:
return boxes_list[0]
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
def __getitem__(self, index) -> "Bboxes":
"""
Retrieve a specific bounding box or a set of bounding boxes using indexing.
Args:
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
the desired bounding boxes.
Returns:
Bboxes: A new Bboxes object containing the selected bounding boxes.
Raises:
AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix.
Note:
When using boolean indexing, make sure to provide a boolean array with the same
length as the number of bounding boxes.
"""
if isinstance(index, int):
return Bboxes(self.bboxes[index].view(1, -1))
b = self.bboxes[index]
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
return Bboxes(b)
class Instances:
"""
Container for bounding boxes, segments, and keypoints of detected objects in an image.
Attributes:
_bboxes (Bboxes): Internal object for handling bounding box operations.
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3]. Default is None.
normalized (bool): Flag indicating whether the bounding box coordinates are normalized.
segments (ndarray): Segments array with shape [N, 1000, 2] after resampling.
Args:
bboxes (ndarray): An array of bounding boxes with shape [N, 4].
segments (list | ndarray, optional): A list or array of object segments. Default is None.
keypoints (ndarray, optional): An array of keypoints with shape [N, 17, 3]. Default is None.
bbox_format (str, optional): The format of bounding boxes ('xywh' or 'xyxy'). Default is 'xywh'.
normalized (bool, optional): Whether the bounding box coordinates are normalized. Default is True.
Examples:
```python
# Create an Instances object
instances = Instances(
bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]])
)
```
Note:
The bounding box format is either 'xywh' or 'xyxy', and is determined by the `bbox_format` argument.
This class does not perform input validation, and it assumes the inputs are well-formed.
"""
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
"""
Args:
bboxes (ndarray): bboxes with shape [N, 4].
segments (list | ndarray): segments.
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
"""
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
self.keypoints = keypoints
self.normalized = normalized
self.segments = segments
def convert_bbox(self, format):
"""Convert bounding box format."""
self._bboxes.convert(format=format)
@property
def bbox_areas(self):
"""Calculate the area of bounding boxes."""
return self._bboxes.areas()
def scale(self, scale_w, scale_h, bbox_only=False):
"""This might be similar with denormalize func but without normalized sign."""
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
if bbox_only:
return
self.segments[..., 0] *= scale_w
self.segments[..., 1] *= scale_h
if self.keypoints is not None:
self.keypoints[..., 0] *= scale_w
self.keypoints[..., 1] *= scale_h
def denormalize(self, w, h):
"""Denormalizes boxes, segments, and keypoints from normalized coordinates."""
if not self.normalized:
return
self._bboxes.mul(scale=(w, h, w, h))
self.segments[..., 0] *= w
self.segments[..., 1] *= h
if self.keypoints is not None:
self.keypoints[..., 0] *= w
self.keypoints[..., 1] *= h
self.normalized = False
def normalize(self, w, h):
"""Normalize bounding boxes, segments, and keypoints to image dimensions."""
if self.normalized:
return
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
self.segments[..., 0] /= w
self.segments[..., 1] /= h
if self.keypoints is not None:
self.keypoints[..., 0] /= w
self.keypoints[..., 1] /= h
self.normalized = True
def add_padding(self, padw, padh):
"""Handle rect and mosaic situation."""
assert not self.normalized, "you should add padding with absolute coordinates."
self._bboxes.add(offset=(padw, padh, padw, padh))
self.segments[..., 0] += padw
self.segments[..., 1] += padh
if self.keypoints is not None:
self.keypoints[..., 0] += padw
self.keypoints[..., 1] += padh
def __getitem__(self, index) -> "Instances":
"""
Retrieve a specific instance or a set of instances using indexing.
Args:
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
the desired instances.
Returns:
Instances: A new Instances object containing the selected bounding boxes,
segments, and keypoints if present.
Note:
When using boolean indexing, make sure to provide a boolean array with the same
length as the number of instances.
"""
segments = self.segments[index] if len(self.segments) else self.segments
keypoints = self.keypoints[index] if self.keypoints is not None else None
bboxes = self.bboxes[index]
bbox_format = self._bboxes.format
return Instances(
bboxes=bboxes,
segments=segments,
keypoints=keypoints,
bbox_format=bbox_format,
normalized=self.normalized,
)
def flipud(self, h):
"""Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
if self._bboxes.format == "xyxy":
y1 = self.bboxes[:, 1].copy()
y2 = self.bboxes[:, 3].copy()
self.bboxes[:, 1] = h - y2
self.bboxes[:, 3] = h - y1
else:
self.bboxes[:, 1] = h - self.bboxes[:, 1]
self.segments[..., 1] = h - self.segments[..., 1]
if self.keypoints is not None:
self.keypoints[..., 1] = h - self.keypoints[..., 1]
def fliplr(self, w):
"""Reverses the order of the bounding boxes and segments horizontally."""
if self._bboxes.format == "xyxy":
x1 = self.bboxes[:, 0].copy()
x2 = self.bboxes[:, 2].copy()
self.bboxes[:, 0] = w - x2
self.bboxes[:, 2] = w - x1
else:
self.bboxes[:, 0] = w - self.bboxes[:, 0]
self.segments[..., 0] = w - self.segments[..., 0]
if self.keypoints is not None:
self.keypoints[..., 0] = w - self.keypoints[..., 0]
def clip(self, w, h):
"""Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
ori_format = self._bboxes.format
self.convert_bbox(format="xyxy")
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
if ori_format != "xyxy":
self.convert_bbox(format=ori_format)
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
if self.keypoints is not None:
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
def remove_zero_area_boxes(self):
"""
Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
This removes them.
"""
good = self.bbox_areas > 0
if not all(good):
self._bboxes = self._bboxes[good]
if len(self.segments):
self.segments = self.segments[good]
if self.keypoints is not None:
self.keypoints = self.keypoints[good]
return good
def update(self, bboxes, segments=None, keypoints=None):
"""Updates instance variables."""
self._bboxes = Bboxes(bboxes, format=self._bboxes.format)
if segments is not None:
self.segments = segments
if keypoints is not None:
self.keypoints = keypoints
def __len__(self):
"""Return the length of the instance list."""
return len(self.bboxes)
@classmethod
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
"""
Concatenates a list of Instances objects into a single Instances object.
Args:
instances_list (List[Instances]): A list of Instances objects to concatenate.
axis (int, optional): The axis along which the arrays will be concatenated. Defaults to 0.
Returns:
Instances: A new Instances object containing the concatenated bounding boxes,
segments, and keypoints if present.
Note:
The `Instances` objects in the list should have the same properties, such as
the format of the bounding boxes, whether keypoints are present, and if the
coordinates are normalized.
"""
assert isinstance(instances_list, (list, tuple))
if not instances_list:
return cls(np.empty(0))
assert all(isinstance(instance, Instances) for instance in instances_list)
if len(instances_list) == 1:
return instances_list[0]
use_keypoint = instances_list[0].keypoints is not None
bbox_format = instances_list[0]._bboxes.format
normalized = instances_list[0].normalized
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
@property
def bboxes(self):
"""Return bounding boxes."""
return self._bboxes.bboxes

727
ultralytics/utils/loss.py Normal file
View File

@ -0,0 +1,727 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.utils.metrics import OKS_SIGMA
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
from .metrics import bbox_iou, probiou
from .tal import bbox2dist
class VarifocalLoss(nn.Module):
"""
Varifocal loss by Zhang et al.
https://arxiv.org/abs/2008.13367.
"""
def __init__(self):
"""Initialize the VarifocalLoss class."""
super().__init__()
@staticmethod
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
"""Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
loss = (
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
.mean(1)
.sum()
)
return loss
class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self):
"""Initializer for FocalLoss class with no parameters."""
super().__init__()
@staticmethod
def forward(pred, label, gamma=1.5, alpha=0.25):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
pred_prob = pred.sigmoid() # prob from logits
p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
modulating_factor = (1.0 - p_t) ** gamma
loss *= modulating_factor
if alpha > 0:
alpha_factor = label * alpha + (1 - label) * (1 - alpha)
loss *= alpha_factor
return loss.mean(1).sum()
class BboxLoss(nn.Module):
"""Criterion class for computing training losses during training."""
def __init__(self, reg_max, use_dfl=False):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
super().__init__()
self.reg_max = reg_max
self.use_dfl = use_dfl
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.use_dfl:
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
@staticmethod
def _df_loss(pred_dist, target):
"""
Return sum of left and right DFL losses.
Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
https://ieeexplore.ieee.org/document/9792391
"""
tl = target.long() # target left
tr = tl + 1 # target right
wl = tr - target # weight left
wr = 1 - wl # weight right
return (
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
).mean(-1, keepdim=True)
class RotatedBboxLoss(BboxLoss):
"""Criterion class for computing training losses during training."""
def __init__(self, reg_max, use_dfl=False):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
super().__init__(reg_max, use_dfl)
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.use_dfl:
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
class KeypointLoss(nn.Module):
"""Criterion class for computing training losses."""
def __init__(self, sigmas) -> None:
"""Initialize the KeypointLoss class."""
super().__init__()
self.sigmas = sigmas
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
class v8DetectionLoss:
"""Criterion class for computing training losses."""
def __init__(self, model, tal_topk=10): # model must be de-paralleled
"""Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
device = next(model.parameters()).device # get model device
h = model.args # hyperparameters
m = model.model[-1] # Detect() module
self.bce = nn.BCEWithLogitsLoss(reduction="none")
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
self.no = m.no
self.reg_max = m.reg_max
self.device = device
self.use_dfl = m.reg_max > 1
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 5, device=self.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
for j in range(batch_size):
matches = i == j
n = matches.sum()
if n:
out[j, :n] = targets[matches, 1:]
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
return out
def bbox_decode(self, anchor_points, pred_dist):
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
return dist2bbox(pred_dist, anchor_points, xywh=False)
def __call__(self, preds, batch):
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats = preds[1] if isinstance(preds, tuple) else preds
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
batch_size = pred_scores.shape[0]
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[2] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.cls # cls gain
loss[2] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
class v8SegmentationLoss(v8DetectionLoss):
"""Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled
"""Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument."""
super().__init__(model)
self.overlap = model.args.overlap_mask
def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
loss = torch.zeros(4, device=self.device) # box, cls, dfl
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
try:
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
raise TypeError(
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
) from e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
if fg_mask.sum():
# Bbox loss
loss[0], loss[3] = self.bbox_loss(
pred_distri,
pred_bboxes,
anchor_points,
target_bboxes / stride_tensor,
target_scores,
target_scores_sum,
fg_mask,
)
# Masks loss
masks = batch["masks"].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
loss[1] = self.calculate_segmentation_loss(
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
)
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.box # seg gain
loss[2] *= self.hyp.cls # cls gain
loss[3] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
@staticmethod
def single_mask_loss(
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
) -> torch.Tensor:
"""
Compute the instance segmentation loss for a single image.
Args:
gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
proto (torch.Tensor): Prototype masks of shape (32, H, W).
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
Returns:
(torch.Tensor): The calculated mask loss for a single image.
Notes:
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
predicted masks from the prototype masks and predicted mask coefficients.
"""
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
def calculate_segmentation_loss(
self,
fg_mask: torch.Tensor,
masks: torch.Tensor,
target_gt_idx: torch.Tensor,
target_bboxes: torch.Tensor,
batch_idx: torch.Tensor,
proto: torch.Tensor,
pred_masks: torch.Tensor,
imgsz: torch.Tensor,
overlap: bool,
) -> torch.Tensor:
"""
Calculate the loss for instance segmentation.
Args:
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
overlap (bool): Whether the masks in `masks` tensor overlap.
Returns:
(torch.Tensor): The calculated loss for instance segmentation.
Notes:
The batch loss can be computed for improved speed at higher memory usage.
For example, pred_mask can be computed as follows:
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
"""
_, _, mask_h, mask_w = proto.shape
loss = 0
# Normalize to 0-1
target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
# Areas of target bboxes
marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
# Normalize to mask size
mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
if fg_mask_i.any():
mask_idx = target_gt_idx_i[fg_mask_i]
if overlap:
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
gt_mask = gt_mask.float()
else:
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
loss += self.single_mask_loss(
gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
)
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
return loss / fg_mask.sum()
class v8PoseLoss(v8DetectionLoss):
"""Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled
"""Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance."""
super().__init__(model)
self.kpt_shape = model.model[-1].kpt_shape
self.bce_pose = nn.BCEWithLogitsLoss()
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0] # number of keypoints
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
def __call__(self, preds, batch):
"""Calculate the total loss and detach it."""
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
batch_size = pred_scores.shape[0]
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[4] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
keypoints = batch["keypoints"].to(self.device).float().clone()
keypoints[..., 0] *= imgsz[1]
keypoints[..., 1] *= imgsz[0]
loss[1], loss[2] = self.calculate_keypoints_loss(
fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.pose # pose gain
loss[2] *= self.hyp.kobj # kobj gain
loss[3] *= self.hyp.cls # cls gain
loss[4] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
@staticmethod
def kpts_decode(anchor_points, pred_kpts):
"""Decodes predicted keypoints to image coordinates."""
y = pred_kpts.clone()
y[..., :2] *= 2.0
y[..., 0] += anchor_points[:, [0]] - 0.5
y[..., 1] += anchor_points[:, [1]] - 0.5
return y
def calculate_keypoints_loss(
self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
):
"""
Calculate the keypoints loss for the model.
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
a binary classification loss that classifies whether a keypoint is present or not.
Args:
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
Returns:
(tuple): Returns a tuple containing:
- kpts_loss (torch.Tensor): The keypoints loss.
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
"""
batch_idx = batch_idx.flatten()
batch_size = len(masks)
# Find the maximum number of keypoints in a single image
max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
# Create a tensor to hold batched keypoints
batched_keypoints = torch.zeros(
(batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
)
# TODO: any idea how to vectorize this?
# Fill batched_keypoints with keypoints based on batch_idx
for i in range(batch_size):
keypoints_i = keypoints[batch_idx == i]
batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
# Expand dimensions of target_gt_idx to match the shape of batched_keypoints
target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
# Use target_gt_idx_expanded to select keypoints from batched_keypoints
selected_keypoints = batched_keypoints.gather(
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
)
# Divide coordinates by stride
selected_keypoints /= stride_tensor.view(1, -1, 1, 1)
kpts_loss = 0
kpts_obj_loss = 0
if masks.any():
gt_kpt = selected_keypoints[masks]
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
pred_kpt = pred_kpts[masks]
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
if pred_kpt.shape[-1] == 3:
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
return kpts_loss, kpts_obj_loss
class v8ClassificationLoss:
"""Criterion class for computing training losses."""
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
loss_items = loss.detach()
return loss, loss_items
class v8OBBLoss(v8DetectionLoss):
def __init__(self, model):
"""
Initializes v8OBBLoss with model, assigner, and rotated bbox loss.
Note model must be de-paralleled.
"""
super().__init__(model)
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 6, device=self.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
for j in range(batch_size):
matches = i == j
n = matches.sum()
if n:
bboxes = targets[matches, 2:]
bboxes[..., :4].mul_(scale_tensor)
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
return out
def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
# b, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
try:
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
raise TypeError(
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
) from e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
bboxes_for_assigner = pred_bboxes.clone().detach()
# Only the first four elements need to be scaled
bboxes_for_assigner[..., :4] *= stride_tensor
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(),
bboxes_for_assigner.type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# Bbox loss
if fg_mask.sum():
target_bboxes[..., :4] /= stride_tensor
loss[0], loss[2] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
else:
loss[0] += (pred_angle * 0).sum()
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.cls # cls gain
loss[2] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def bbox_decode(self, anchor_points, pred_dist, pred_angle):
"""
Decode predicted object bounding box coordinates from anchor points and distribution.
Args:
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
Returns:
(torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
"""
if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
class v10DetectLoss:
def __init__(self, model):
self.one2many = v8DetectionLoss(model, tal_topk=10)
self.one2one = v8DetectionLoss(model, tal_topk=1)
def __call__(self, preds, batch):
one2many = preds["one2many"]
loss_one2many = self.one2many(one2many, batch)
one2one = preds["one2one"]
loss_one2one = self.one2one(one2one, batch)
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))

1292
ultralytics/utils/metrics.py Normal file

File diff suppressed because it is too large Load Diff

865
ultralytics/utils/ops.py Normal file
View File

@ -0,0 +1,865 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import math
import re
import time
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import batch_probiou
class Profile(contextlib.ContextDecorator):
"""
YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
Example:
```python
from ultralytics.utils.ops import Profile
with Profile(device=device) as dt:
pass # slow operation here
print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
```
"""
def __init__(self, t=0.0, device: torch.device = None):
"""
Initialize the Profile class.
Args:
t (float): Initial time. Defaults to 0.0.
device (torch.device): Devices used for model inference. Defaults to None (cpu).
"""
self.t = t
self.device = device
self.cuda = bool(device and str(device).startswith("cuda"))
def __enter__(self):
"""Start timing."""
self.start = self.time()
return self
def __exit__(self, type, value, traceback): # noqa
"""Stop timing."""
self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def __str__(self):
"""Returns a human-readable string representing the accumulated elapsed time in the profiler."""
return f"Elapsed time is {self.t} s"
def time(self):
"""Get current time."""
if self.cuda:
torch.cuda.synchronize(self.device)
return time.time()
def segment2box(segment, width=640, height=640):
"""
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
Args:
segment (torch.Tensor): the segment label
width (int): the width of the image. Defaults to 640
height (int): The height of the image. Defaults to 640
Returns:
(np.ndarray): the minimum and maximum x and y values of the segment.
"""
x, y = segment.T # segment xy
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
x = x[inside]
y = y[inside]
return (
np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
if any(x)
else np.zeros(4, dtype=segment.dtype)
) # xyxy
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
"""
Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
specified in (img1_shape) to the shape of a different image (img0_shape).
Args:
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
img0_shape (tuple): the shape of the target image, in the format of (height, width).
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
calculated based on the size difference between the two images.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
xywh (bool): The box format is xywh or not, default=False.
Returns:
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (
round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
) # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
if padding:
boxes[..., 0] -= pad[0] # x padding
boxes[..., 1] -= pad[1] # y padding
if not xywh:
boxes[..., 2] -= pad[0] # x padding
boxes[..., 3] -= pad[1] # y padding
boxes[..., :4] /= gain
return clip_boxes(boxes, img0_shape)
def make_divisible(x, divisor):
"""
Returns the nearest number that is divisible by the given divisor.
Args:
x (int): The number to make divisible.
divisor (int | torch.Tensor): The divisor.
Returns:
(int): The nearest number divisible by the divisor.
"""
if isinstance(divisor, torch.Tensor):
divisor = int(divisor.max()) # to int
return math.ceil(x / divisor) * divisor
def nms_rotated(boxes, scores, threshold=0.45):
"""
NMS for obbs, powered by probiou and fast-nms.
Args:
boxes (torch.Tensor): (N, 5), xywhr.
scores (torch.Tensor): (N, ).
threshold (float): IoU threshold.
Returns:
"""
if len(boxes) == 0:
return np.empty((0,), dtype=np.int8)
sorted_idx = torch.argsort(scores, descending=True)
boxes = boxes[sorted_idx]
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
return sorted_idx[pick]
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nc=0, # number of classes (optional)
max_time_img=0.05,
max_nms=30000,
max_wh=7680,
in_place=True,
rotated=False,
):
"""
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
Args:
prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
containing the predicted boxes, classes, and masks. The tensor should be in the format
output by a model, such as YOLO.
conf_thres (float): The confidence threshold below which boxes will be filtered out.
Valid values are between 0.0 and 1.0.
iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
Valid values are between 0.0 and 1.0.
classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
agnostic (bool): If True, the model is agnostic to the number of classes, and all
classes will be considered as one.
multi_label (bool): If True, each box may have multiple labels.
labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
list contains the apriori labels for a given image. The list should be in the format
output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
max_det (int): The maximum number of boxes to keep after NMS.
nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
max_time_img (float): The maximum time (seconds) for processing one image.
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
max_wh (int): The maximum box width and height in pixels.
in_place (bool): If True, the input prediction tensor will be modified in place.
Returns:
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
"""
# Checks
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
bs = prediction.shape[0] # batch size
nc = nc or (prediction.shape[1] - 4) # number of classes
nm = prediction.shape[1] - nc - 4
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
# Settings
# min_wh = 2 # (pixels) minimum box width and height
time_limit = 2.0 + max_time_img * bs # seconds to quit after
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
if not rotated:
if in_place:
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
else:
prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy
t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]) and not rotated:
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, nm), 1)
if multi_label:
i, j = torch.where(cls > conf_thres)
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
if n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
scores = x[:, 4] # scores
if rotated:
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
i = nms_rotated(boxes, scores, iou_thres)
else:
boxes = x[:, :4] + c # boxes (offset by class)
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
# # Experimental
# merge = False # use merge-NMS
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# from .metrics import box_iou
# iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix
# weights = iou * scores[None] # box weights
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# redundant = True # require redundant detections
# if redundant:
# i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if (time.time() - t) > time_limit:
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
break # time limit exceeded
return output
def clip_boxes(boxes, shape):
"""
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
Args:
boxes (torch.Tensor): the bounding boxes to clip
shape (tuple): the shape of the image
Returns:
(torch.Tensor | numpy.ndarray): Clipped boxes
"""
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1
boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2
boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2
else: # np.array (faster grouped)
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
return boxes
def clip_coords(coords, shape):
"""
Clip line coordinates to the image boundaries.
Args:
coords (torch.Tensor | numpy.ndarray): A list of line coordinates.
shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
Returns:
(torch.Tensor | numpy.ndarray): Clipped coordinates
"""
if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y
else: # np.array (faster grouped)
coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
return coords
def scale_image(masks, im0_shape, ratio_pad=None):
"""
Takes a mask, and resizes it to the original image size.
Args:
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
im0_shape (tuple): the original image shape
ratio_pad (tuple): the ratio of the padding to the original image.
Returns:
masks (torch.Tensor): The masks that are being returned.
"""
# Rescale coordinates (xyxy) from im1_shape to im0_shape
im1_shape = masks.shape
if im1_shape[:2] == im0_shape[:2]:
return masks
if ratio_pad is None: # calculate from im0_shape
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
else:
# gain = ratio_pad[0][0]
pad = ratio_pad[1]
top, left = int(pad[1]), int(pad[0]) # y, x
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
if len(masks.shape) < 2:
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
masks = masks[top:bottom, left:right]
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
if len(masks.shape) == 2:
masks = masks[:, :, None]
return masks
def xyxy2xywh(x):
"""
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
top-left corner and (x2, y2) is the bottom-right corner.
Args:
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
y[..., 2] = x[..., 2] - x[..., 0] # width
y[..., 3] = x[..., 3] - x[..., 1] # height
return y
def xywh2xyxy(x):
"""
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
top-left corner and (x2, y2) is the bottom-right corner.
Args:
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
dw = x[..., 2] / 2 # half-width
dh = x[..., 3] / 2 # half-height
y[..., 0] = x[..., 0] - dw # top left x
y[..., 1] = x[..., 1] - dh # top left y
y[..., 2] = x[..., 0] + dw # bottom right x
y[..., 3] = x[..., 1] + dh # bottom right y
return y
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
"""
Convert normalized bounding box coordinates to pixel coordinates.
Args:
x (np.ndarray | torch.Tensor): The bounding box coordinates.
w (int): Width of the image. Defaults to 640
h (int): Height of the image. Defaults to 640
padw (int): Padding width. Defaults to 0
padh (int): Padding height. Defaults to 0
Returns:
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
return y
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
"""
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
width and height are normalized to image dimensions.
Args:
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
w (int): The width of the image. Defaults to 640
h (int): The height of the image. Defaults to 640
clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
eps (float): The minimum value of the box's width and height. Defaults to 0.0
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
"""
if clip:
x = clip_boxes(x, (h - eps, w - eps))
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
return y
def xywh2ltwh(x):
"""
Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
Args:
x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
return y
def xyxy2ltwh(x):
"""
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
Args:
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 2] = x[..., 2] - x[..., 0] # width
y[..., 3] = x[..., 3] - x[..., 1] # height
return y
def ltwh2xywh(x):
"""
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
Args:
x (torch.Tensor): the input tensor
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y
return y
def xyxyxyxy2xywhr(corners):
"""
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
expected in degrees from 0 to 90.
Args:
corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8).
Returns:
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
"""
is_torch = isinstance(corners, torch.Tensor)
points = corners.cpu().numpy() if is_torch else corners
points = points.reshape(len(corners), -1, 2)
rboxes = []
for pts in points:
# NOTE: Use cv2.minAreaRect to get accurate xywhr,
# especially some objects are cut off by augmentations in dataloader.
(x, y), (w, h), angle = cv2.minAreaRect(pts)
rboxes.append([x, y, w, h, angle / 180 * np.pi])
return (
torch.tensor(rboxes, device=corners.device, dtype=corners.dtype)
if is_torch
else np.asarray(rboxes, dtype=points.dtype)
) # rboxes
def xywhr2xyxyxyxy(rboxes):
"""
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
be in degrees from 0 to 90.
Args:
rboxes (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
Returns:
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
"""
is_numpy = isinstance(rboxes, np.ndarray)
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
ctr = rboxes[..., :2]
w, h, angle = (rboxes[..., i : i + 1] for i in range(2, 5))
cos_value, sin_value = cos(angle), sin(angle)
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
vec1 = np.concatenate(vec1, axis=-1) if is_numpy else torch.cat(vec1, dim=-1)
vec2 = np.concatenate(vec2, axis=-1) if is_numpy else torch.cat(vec2, dim=-1)
pt1 = ctr + vec1 + vec2
pt2 = ctr + vec1 - vec2
pt3 = ctr - vec1 - vec2
pt4 = ctr - vec1 + vec2
return np.stack([pt1, pt2, pt3, pt4], axis=-2) if is_numpy else torch.stack([pt1, pt2, pt3, pt4], dim=-2)
def ltwh2xyxy(x):
"""
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
Args:
x (np.ndarray | torch.Tensor): the input image
Returns:
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 2] = x[..., 2] + x[..., 0] # width
y[..., 3] = x[..., 3] + x[..., 1] # height
return y
def segments2boxes(segments):
"""
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
Args:
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
Returns:
(np.ndarray): the xywh coordinates of the bounding boxes.
"""
boxes = []
for s in segments:
x, y = s.T # segment xy
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
return xyxy2xywh(np.array(boxes)) # cls, xywh
def resample_segments(segments, n=1000):
"""
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
Args:
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
n (int): number of points to resample the segment to. Defaults to 1000
Returns:
segments (list): the resampled segments.
"""
for i, s in enumerate(segments):
s = np.concatenate((s, s[0:1, :]), axis=0)
x = np.linspace(0, len(s) - 1, n)
xp = np.arange(len(s))
segments[i] = (
np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
) # segment xy
return segments
def crop_mask(masks, boxes):
"""
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.
Args:
masks (torch.Tensor): [n, h, w] tensor of masks
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
Returns:
(torch.Tensor): The masks are being cropped to the bounding box.
"""
_, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
def process_mask_upsample(protos, masks_in, bboxes, shape):
"""
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
but is slower.
Args:
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
shape (tuple): the size of the input image (h,w)
Returns:
(torch.Tensor): The upsampled masks.
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
"""
Apply masks to bounding boxes using the output of the mask head.
Args:
protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
Returns:
(torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
are the height and width of the input image. The mask is applied to the bounding boxes.
"""
c, mh, mw = protos.shape # CHW
ih, iw = shape
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
width_ratio = mw / iw
height_ratio = mh / ih
downsampled_bboxes = bboxes.clone()
downsampled_bboxes[:, 0] *= width_ratio
downsampled_bboxes[:, 2] *= width_ratio
downsampled_bboxes[:, 3] *= height_ratio
downsampled_bboxes[:, 1] *= height_ratio
masks = crop_mask(masks, downsampled_bboxes) # CHW
if upsample:
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
return masks.gt_(0.5)
def process_mask_native(protos, masks_in, bboxes, shape):
"""
It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
Args:
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
shape (tuple): the size of the input image (h,w)
Returns:
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
masks = scale_masks(masks[None], shape)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)
def scale_masks(masks, shape, padding=True):
"""
Rescale segment masks to shape.
Args:
masks (torch.Tensor): (N, C, H, W).
shape (tuple): Height and width.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
"""
mh, mw = masks.shape[2:]
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding
if padding:
pad[0] /= 2
pad[1] /= 2
top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x
bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
masks = masks[..., top:bottom, left:right]
masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
return masks
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
"""
Rescale segment coordinates (xy) from img1_shape to img0_shape.
Args:
img1_shape (tuple): The shape of the image that the coords are from.
coords (torch.Tensor): the coords to be scaled of shape n,2.
img0_shape (tuple): the shape of the image that the segmentation is being applied to.
ratio_pad (tuple): the ratio of the image size to the padded image size.
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
Returns:
coords (torch.Tensor): The scaled coordinates.
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
if padding:
coords[..., 0] -= pad[0] # x padding
coords[..., 1] -= pad[1] # y padding
coords[..., 0] /= gain
coords[..., 1] /= gain
coords = clip_coords(coords, img0_shape)
if normalize:
coords[..., 0] /= img0_shape[1] # width
coords[..., 1] /= img0_shape[0] # height
return coords
def regularize_rboxes(rboxes):
"""
Regularize rotated boxes in range [0, pi/2].
Args:
rboxes (torch.Tensor): (N, 5), xywhr.
Returns:
(torch.Tensor): The regularized boxes.
"""
x, y, w, h, t = rboxes.unbind(dim=-1)
# Swap edge and angle if h >= w
w_ = torch.where(w > h, w, h)
h_ = torch.where(w > h, h, w)
t = torch.where(w > h, t, t + math.pi / 2) % math.pi
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
def masks2segments(masks, strategy="largest"):
"""
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
Args:
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
strategy (str): 'concat' or 'largest'. Defaults to largest
Returns:
segments (List): list of segment masks
"""
segments = []
for x in masks.int().cpu().numpy().astype("uint8"):
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
if c:
if strategy == "concat": # concatenate all segments
c = np.concatenate([x.reshape(-1, 2) for x in c])
elif strategy == "largest": # select largest segment
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
else:
c = np.zeros((0, 2)) # no segments found
segments.append(c.astype("float32"))
return segments
def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
"""
Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout.
Args:
batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32.
Returns:
(np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8.
"""
return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
def clean_str(s):
"""
Cleans a string by replacing special characters with underscore _
Args:
s (str): a string needing special characters replaced
Returns:
(str): a string with special characters replaced by an underscore _
"""
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
def v10postprocess(preds, max_det, nc=80):
assert(4 + nc == preds.shape[-1])
boxes, scores = preds.split([4, nc], dim=-1)
max_scores = scores.amax(dim=-1)
# print(max_scores.shape, max_det)
max_scores, index = torch.topk(max_scores, max_det, dim=-1)
index = index.unsqueeze(-1)
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
scores, index = torch.topk(scores.flatten(1), max_det, dim=-1)
labels = index % nc
index = index // nc
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
return boxes, scores, labels

View File

@ -0,0 +1,88 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Monkey patches to update/extend functionality of existing functions."""
import time
from pathlib import Path
import cv2
import numpy as np
import torch
# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
_imshow = cv2.imshow # copy to avoid recursion errors
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
"""
Read an image from a file.
Args:
filename (str): Path to the file to read.
flags (int, optional): Flag that can take values of cv2.IMREAD_*. Defaults to cv2.IMREAD_COLOR.
Returns:
(np.ndarray): The read image.
"""
return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
def imwrite(filename: str, img: np.ndarray, params=None):
"""
Write an image to a file.
Args:
filename (str): Path to the file to write.
img (np.ndarray): Image to write.
params (list of ints, optional): Additional parameters. See OpenCV documentation.
Returns:
(bool): True if the file was written, False otherwise.
"""
try:
cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
return True
except Exception:
return False
def imshow(winname: str, mat: np.ndarray):
"""
Displays an image in the specified window.
Args:
winname (str): Name of the window.
mat (np.ndarray): Image to be shown.
"""
_imshow(winname.encode("unicode_escape").decode(), mat)
# PyTorch functions ----------------------------------------------------------------------------------------------------
_torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, use_dill=True, **kwargs):
"""
Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and
exponential standoff in case of save failure.
Args:
*args (tuple): Positional arguments to pass to torch.save.
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
**kwargs (any): Keyword arguments to pass to torch.save.
"""
try:
assert use_dill
import dill as pickle
except (AssertionError, ImportError):
import pickle
if "pickle_module" not in kwargs:
kwargs["pickle_module"] = pickle
for i in range(4): # 3 retries
try:
return _torch_save(*args, **kwargs)
except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan
if i == 3:
raise e
time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,116 @@
import os.path
import os
import matplotlib.pyplot as plt
import numpy as np
# [tag, bandage, null]
# [0, 1, 2]
class ShowPR:
def __init__(self, tags, prec_value, title_name=None):
self.tags = tags
self.prec_value = prec_value
self.thres = [i * 0.01 for i in range(101)]
self.ratios = [i * 0.05 for i in range(21)]
self.title_name = title_name
def change_precValue(self, thre=0.5):
values = []
for i in range(len(self.prec_value)):
value = []
for j in range(len(self.prec_value[i])):
if self.prec_value[i][j] > thre:
value.append(1)
else:
value.append(0)
values.append(value)
return values
def _calculate_pr(self, prec_value):
FN, FP, TN, TP = 0, 0, 0, 0
for output, target in zip(prec_value, self.tags):
# print("output >> {} , target >> {}".format(output, target))
if output != target:
if target == 0:
FP += 1
elif target == 1:
FN += 1
else:
if target == 0:
TN += 1
elif target == 1:
TP += 1
if TP == 0:
prec, recall = 0, 0
else:
prec = TP / (TP + FP)
recall = TP / (TP + FN)
if TN == 0:
tn_prec, tn_recall = 0, 0
else:
tn_prec = TN / (TN + FN)
tn_recall = TN / (TN + FP)
# print("TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(TP, FP, TN, FN))
return prec, recall, tn_prec, tn_recall
def calculate_multiple(self, ratio=0.2):
recall, recall_TN, PrecisePos, PreciseNeg = [], [], [], []
for thre in self.thres:
prec_value = []
value = self.change_precValue(thre)
for num in range(len(value)):
proportion = value[num].count(1) / len(value[num])
if proportion >= ratio:
prec_value.append(1)
else:
prec_value.append(0)
prec, recall_pos, tn_prec, tn_recall = self._calculate_pr(prec_value)
print(
f"thre>>{ratio:.2f}, recall>>{recall_pos:.4f}, precise_pos>>{prec:.4f}, recall_tn>>{tn_recall:.4f}, precise_neg>>{tn_prec:4f}")
PrecisePos.append(prec)
recall.append(recall_pos)
PreciseNeg.append(tn_prec)
recall_TN.append(tn_recall)
return recall, recall_TN, PrecisePos, PreciseNeg
def write_results_to_file(self, recall, recall_TN, PrecisePos, PreciseNeg, ratio):
file_path = os.sep.join(['./ckpts/tracePR', self.title_name + f"_{ratio:.2f}" + '.txt'])
with open(file_path, 'w') as file:
file.write("threshold, recall, recall_TN, PrecisePos, PreciseNeg\n")
for thre, rec, rec_tn, prec_pos, prec_neg in zip(self.thres, recall, recall_TN, PrecisePos, PreciseNeg):
file.write(
f"thre>>{thre:.2f}, recall>>{rec:.4f}, precise_pos>>{prec_pos:.4f}, recall_tn>>{rec_tn:.4f}, precise_neg>>{prec_neg:4f}\n")
def show_pr(self, recall, recall_TN, PrecisePos, PreciseNeg, ratio):
# self.calculate_multiple()
x = self.thres
plt.figure(figsize=(10, 6))
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
plt.legend()
plt.xlabel('threshold')
if self.title_name is not None:
plt.title(f"PrecisePos & Recall ratio:{ratio:.2f}", fontdict={'fontsize': 12, 'fontweight': 'black'})
# plt.grid(True, linestyle='--', alpha=0.5)
# 启用次刻度
plt.minorticks_on()
# 设置主刻度的网格线
plt.grid(which='major', linestyle='-', alpha=0.5, color='gray')
# 设置次刻度的网格线
plt.grid(which='minor', linestyle=':', alpha=0.3, color='gray')
plt.savefig(os.sep.join(['./ckpts/tracePR', self.title_name + f"_{ratio:.2f}" + '.png']))
plt.show()
plt.close()
self.write_results_to_file(recall, recall_TN, PrecisePos, PreciseNeg, ratio)
def get_pic(self):
for ratio in self.ratios:
# ratio = 0.5
if ratio < 0.2 or ratio > 0.95:
continue
recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple(ratio)
self.show_pr(recall, recall_TN, PrecisePos, PreciseNeg, ratio)

345
ultralytics/utils/tal.py Normal file
View File

@ -0,0 +1,345 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
import torch.nn as nn
from .checks import check_version
from .metrics import bbox_iou, probiou
from .ops import xywhr2xyxyxyxy
TORCH_1_10 = check_version(torch.__version__, "1.10.0")
class TaskAlignedAssigner(nn.Module):
"""
A task-aligned assigner for object detection.
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
classification and localization information.
Attributes:
topk (int): The number of top candidates to consider.
num_classes (int): The number of object classes.
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
beta (float): The beta parameter for the localization component of the task-aligned metric.
eps (float): A small value to prevent division by zero.
"""
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
super().__init__()
self.topk = topk
self.num_classes = num_classes
self.bg_idx = num_classes
self.alpha = alpha
self.beta = beta
self.eps = eps
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
"""
Compute the task-aligned assignment. Reference code is available at
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
Args:
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
anc_points (Tensor): shape(num_total_anchors, 2)
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
Returns:
target_labels (Tensor): shape(bs, num_total_anchors)
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
fg_mask (Tensor): shape(bs, num_total_anchors)
target_gt_idx (Tensor): shape(bs, num_total_anchors)
"""
self.bs = pd_scores.shape[0]
self.n_max_boxes = gt_bboxes.shape[1]
if self.n_max_boxes == 0:
device = gt_bboxes.device
return (
torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
)
mask_pos, align_metric, overlaps = self.get_pos_mask(
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
)
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
# Assigned target
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
# Normalize
align_metric *= mask_pos
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
target_scores = target_scores * norm_align_metric
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
"""Get in_gts mask, (b, max_num_obj, h*w)."""
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
# Get anchor_align metric, (b, max_num_obj, h*w)
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
# Get topk_metric mask, (b, max_num_obj, h*w)
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
# Merge all mask to a final mask, (b, max_num_obj, h*w)
mask_pos = mask_topk * mask_in_gts * mask_gt
return mask_pos, align_metric, overlaps
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
"""Compute alignment metric given predicted and ground truth bounding boxes."""
na = pd_bboxes.shape[-2]
mask_gt = mask_gt.bool() # b, max_num_obj, h*w
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
# Get the scores of each grid for each gt cls
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
return align_metric, overlaps
def iou_calculation(self, gt_bboxes, pd_bboxes):
"""IoU calculation for horizontal bounding boxes."""
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
"""
Select the top-k candidates based on the given metrics.
Args:
metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
max_num_obj is the maximum number of objects, and h*w represents the
total number of anchor points.
largest (bool): If True, select the largest values; otherwise, select the smallest values.
topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
topk is the number of top candidates to consider. If not provided,
the top-k values are automatically computed based on the given metrics.
Returns:
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
"""
# (b, max_num_obj, topk)
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
if topk_mask is None:
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
# (b, max_num_obj, topk)
topk_idxs.masked_fill_(~topk_mask, 0)
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
for k in range(self.topk):
# Expand topk_idxs for each value of k and add 1 at the specified positions
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
# Filter invalid bboxes
count_tensor.masked_fill_(count_tensor > 1, 0)
return count_tensor.to(metrics.dtype)
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
"""
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
Args:
gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
batch size and max_num_obj is the maximum number of objects.
gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
anchor points, with shape (b, h*w), where h*w is the total
number of anchor points.
fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
(foreground) anchor points.
Returns:
(Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
- target_labels (Tensor): Shape (b, h*w), containing the target labels for
positive anchor points.
- target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
for positive anchor points.
- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
for positive anchor points, where num_classes is the number
of object classes.
"""
# Assigned target labels, (b, 1)
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
# Assigned target scores
target_labels.clamp_(0)
# 10x faster than F.one_hot()
target_scores = torch.zeros(
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
dtype=torch.int64,
device=target_labels.device,
) # (b, h*w, 80)
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
return target_labels, target_bboxes, target_scores
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
"""
Select the positive anchor center in gt.
Args:
xy_centers (Tensor): shape(h*w, 2)
gt_bboxes (Tensor): shape(b, n_boxes, 4)
Returns:
(Tensor): shape(b, n_boxes, h*w)
"""
n_anchors = xy_centers.shape[0]
bs, n_boxes, _ = gt_bboxes.shape
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
return bbox_deltas.amin(3).gt_(eps)
@staticmethod
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
"""
If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
Args:
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
overlaps (Tensor): shape(b, n_max_boxes, h*w)
Returns:
target_gt_idx (Tensor): shape(b, h*w)
fg_mask (Tensor): shape(b, h*w)
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
"""
# (b, n_max_boxes, h*w) -> (b, h*w)
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
fg_mask = mask_pos.sum(-2)
# Find each grid serve which gt(index)
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
return target_gt_idx, fg_mask, mask_pos
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
def iou_calculation(self, gt_bboxes, pd_bboxes):
"""IoU calculation for rotated bounding boxes."""
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes):
"""
Select the positive anchor center in gt for rotated bounding boxes.
Args:
xy_centers (Tensor): shape(h*w, 2)
gt_bboxes (Tensor): shape(b, n_boxes, 5)
Returns:
(Tensor): shape(b, n_boxes, h*w)
"""
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
corners = xywhr2xyxyxyxy(gt_bboxes)
# (b, n_boxes, 1, 2)
a, b, _, d = corners.split(1, dim=-2)
ab = b - a
ad = d - a
# (b, n_boxes, h*w, 2)
ap = xy_centers - a
norm_ab = (ab * ab).sum(dim=-1)
norm_ad = (ad * ad).sum(dim=-1)
ap_dot_ab = (ap * ab).sum(dim=-1)
ap_dot_ad = (ap * ad).sum(dim=-1)
return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
def make_anchors(feats, strides, grid_cell_offset=0.5):
"""Generate anchors from features."""
anchor_points, stride_tensor = [], []
assert feats is not None
dtype, device = feats[0].dtype, feats[0].device
for i, stride in enumerate(strides):
_, _, h, w = feats[i].shape
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_points), torch.cat(stride_tensor)
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
"""Transform distance(ltrb) to box(xywh or xyxy)."""
assert(distance.shape[dim] == 4)
lt, rb = distance.split([2, 2], dim)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
if xywh:
c_xy = (x1y1 + x2y2) / 2
wh = x2y2 - x1y1
return torch.cat((c_xy, wh), dim) # xywh bbox
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
def bbox2dist(anchor_points, bbox, reg_max):
"""Transform bbox(xyxy) to dist(ltrb)."""
x1y1, x2y2 = bbox.chunk(2, -1)
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
"""
Decode predicted object bounding box coordinates from anchor points and distribution.
Args:
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
Returns:
(torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
"""
lt, rb = pred_dist.split(2, dim=dim)
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
# (bs, h*w, 1)
xf, yf = ((rb - lt) / 2).split(1, dim=dim)
x, y = xf * cos - yf * sin, xf * sin + yf * cos
xy = torch.cat([x, y], dim=dim) + anchor_points
return torch.cat([xy, lt + rb], dim=dim)

View File

@ -0,0 +1,610 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import math
import os
import random
import time
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
from typing import Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
from ultralytics.utils.checks import PYTHON_VERSION, check_version
try:
import thop
except ImportError:
thop = None
# Version checks (all default to version>=min_version)
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""Decorator to make all processes in distributed training wait for each local_master to do something."""
initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
if initialized and local_rank not in (-1, 0):
dist.barrier(device_ids=[local_rank])
yield
if initialized and local_rank == 0:
dist.barrier(device_ids=[0])
def smart_inference_mode():
"""Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
def decorate(fn):
"""Applies appropriate torch decorator for inference mode based on torch version."""
if TORCH_1_9 and torch.is_inference_mode_enabled():
return fn # already in inference_mode, act as a pass-through
else:
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
return decorate
def get_cpu_info():
"""Return a string with system CPU information, i.e. 'Apple M2'."""
import cpuinfo # pip install py-cpuinfo
k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available)
info = cpuinfo.get_cpu_info() # info dict
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
def select_device(device="", batch=0, newline=False, verbose=True):
"""
Selects the appropriate PyTorch device based on the provided arguments.
The function takes a string specifying the device or a torch.device object and returns a torch.device object
representing the selected device. The function also validates the number of available devices and raises an
exception if the requested device(s) are not available.
Args:
device (str | torch.device, optional): Device string or torch.device object.
Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
the first available GPU, or CPU if no GPU is available.
batch (int, optional): Batch size being used in your model. Defaults to 0.
newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
verbose (bool, optional): If True, logs the device information. Defaults to True.
Returns:
(torch.device): Selected device.
Raises:
ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
devices when using multiple GPUs.
Examples:
>>> select_device('cuda:0')
device(type='cuda', index=0)
>>> select_device('cpu')
device(type='cpu')
Note:
Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
"""
if isinstance(device, torch.device):
return device
s = f"Ultralytics YOLOv{__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
device = str(device).lower()
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
cpu = device == "cpu"
mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
if cpu or mps:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
if device == "cuda":
device = "0"
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
LOGGER.info(s)
install = (
"See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
"CUDA devices are seen by torch.\n"
if torch.cuda.device_count() == 0
else ""
)
raise ValueError(
f"Invalid CUDA 'device={device}' requested."
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
f"{install}"
)
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
raise ValueError(
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
)
space = " " * (len(s) + 1)
for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
arg = "cuda:0"
elif mps and TORCH_2_0 and torch.backends.mps.is_available():
# Prefer MPS if available
s += f"MPS ({get_cpu_info()})\n"
arg = "mps"
else: # revert to CPU
s += f"CPU ({get_cpu_info()})\n"
arg = "cpu"
if verbose:
LOGGER.info(s if newline else s.rstrip())
return torch.device(arg)
def time_sync():
"""PyTorch-accurate time."""
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
def fuse_conv_and_bn(conv, bn):
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
fusedconv = (
nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)
# Prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
# Prepare spatial bias
b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
def fuse_deconv_and_bn(deconv, bn):
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
fuseddconv = (
nn.ConvTranspose2d(
deconv.in_channels,
deconv.out_channels,
kernel_size=deconv.kernel_size,
stride=deconv.stride,
padding=deconv.padding,
output_padding=deconv.output_padding,
dilation=deconv.dilation,
groups=deconv.groups,
bias=True,
)
.requires_grad_(False)
.to(deconv.weight.device)
)
# Prepare filters
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
# Prepare spatial bias
b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fuseddconv
def model_info(model, detailed=False, verbose=True, imgsz=640):
"""
Model information.
imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
"""
if not verbose:
return
n_p = get_num_params(model) # number of parameters
n_g = get_num_gradients(model) # number of gradients
n_l = len(list(model.modules())) # number of layers
if detailed:
LOGGER.info(
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
)
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace("module_list.", "")
LOGGER.info(
"%5g %40s %9s %12g %20s %10.3g %10.3g %10s"
% (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)
)
flops = get_flops(model, imgsz)
fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
fs = f", {flops:.1f} GFLOPs" if flops else ""
yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
return n_l, n_p, n_g, flops
def get_num_params(model):
"""Return the total number of parameters in a YOLO model."""
return sum(x.numel() for x in model.parameters())
def get_num_gradients(model):
"""Return the total number of parameters with gradients in a YOLO model."""
return sum(x.numel() for x in model.parameters() if x.requires_grad)
def model_info_for_loggers(trainer):
"""
Return model info dict with useful model information.
Example:
YOLOv8n info for loggers
```python
results = {'model/parameters': 3151904,
'model/GFLOPs': 8.746,
'model/speed_ONNX(ms)': 41.244,
'model/speed_TensorRT(ms)': 3.211,
'model/speed_PyTorch(ms)': 18.755}
```
"""
if trainer.args.profile: # profile ONNX and TensorRT times
from ultralytics.utils.benchmarks import ProfileModels
results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
results.pop("model/name")
else: # only return PyTorch times from most recent validation
results = {
"model/parameters": get_num_params(trainer.model),
"model/GFLOPs": round(get_flops(trainer.model), 3),
}
results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
return results
def get_flops(model, imgsz=640):
"""Return a YOLO model's FLOPs."""
if not thop:
return 0.0 # if not installed return 0.0 GFLOPs
try:
model = de_parallel(model)
p = next(model.parameters())
if not isinstance(imgsz, list):
imgsz = [imgsz, imgsz] # expand if int/float
try:
# Use stride size for input tensor
# stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
# im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
# flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
# return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
raise Exception
except Exception:
# Use actual image size for input tensor (i.e. required for RTDETR models)
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
except Exception:
return 0.0
def get_flops_with_torch_profiler(model, imgsz=640):
"""Compute model FLOPs (thop alternative)."""
if TORCH_2_0:
model = de_parallel(model)
p = next(model.parameters())
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
with torch.profiler.profile(with_flops=True) as prof:
model(im)
flops = sum(x.flops for x in prof.key_averages()) / 1e9
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
return flops
return 0
def initialize_weights(model):
"""Initialize model weights to random values."""
for m in model.modules():
t = type(m)
if t is nn.Conv2d:
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif t is nn.BatchNorm2d:
m.eps = 1e-3
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
"""Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally
retaining the original shape.
"""
if ratio == 1.0:
return img
h, w = img.shape[2:]
s = (int(h * ratio), int(w * ratio)) # new size
img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
if not same_shape: # pad/crop img
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
def make_divisible(x, divisor):
"""Returns nearest x divisible by divisor."""
if isinstance(divisor, torch.Tensor):
divisor = int(divisor.max()) # to int
return math.ceil(x / divisor) * divisor
def copy_attr(a, b, include=(), exclude=()):
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith("_") or k in exclude:
continue
else:
setattr(a, k, v)
def get_latest_opset():
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset
def intersect_dicts(da, db, exclude=()):
"""Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
def is_parallel(model):
"""Returns True if model is of type DP or DDP."""
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
def de_parallel(model):
"""De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
return model.module if is_parallel(model) else model
def one_cycle(y1=0.0, y2=1.0, steps=100):
"""Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
def init_seeds(seed=0, deterministic=False):
"""Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
if deterministic:
if TORCH_2_0:
torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
torch.backends.cudnn.deterministic = True
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = str(seed)
else:
LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
else:
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
class ModelEMA:
"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
To disable EMA set the `enabled` attribute to `False`.
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
"""Create EMA."""
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)
self.enabled = True
def update(self, model):
"""Update EMA parameters."""
if self.enabled:
self.updates += 1
d = self.decay(self.updates)
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point: # true for FP16 and FP32
v *= d
v += (1 - d) * msd[k].detach()
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
"""Updates attributes and saves stripped model with optimizer removed."""
if self.enabled:
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f: Union[str, Path] = "best_gift_v10n.pt", s: str = "") -> None:
"""
Strip optimizer from 'f' to finalize training, optionally save as 's'.
Args:
f (str): file path to model to strip the optimizer from. Default is 'best_gift_v10n.pt'.
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
Returns:
None
Example:
```python
from pathlib import Path
from ultralytics.utils.torch_utils import strip_optimizer
for f in Path('path/to/weights').rglob('*.pt'):
strip_optimizer(f)
```
"""
x = torch.load(f, map_location=torch.device("cpu"))
if "model" not in x:
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
return
if hasattr(x["model"], "args"):
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
if x.get("ema"):
x["model"] = x["ema"] # replace model with ema
for k in "optimizer", "best_fitness", "ema", "updates": # keys
x[k] = None
x["epoch"] = -1
x["model"].half() # to FP16
for p in x["model"].parameters():
p.requires_grad = False
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
# x['model'].args = x['train_args']
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
def profile(input, ops, n=10, device=None):
"""
Ultralytics speed, memory and FLOPs profiler.
Example:
```python
from ultralytics.utils.torch_utils import profile
input = torch.randn(16, 3, 640, 640)
m1 = lambda x: x * torch.sigmoid(x)
m2 = nn.SiLU()
profile(input, [m1, m2], n=100) # profile over 100 iterations
```
"""
results = []
if not isinstance(device, torch.device):
device = select_device(device)
LOGGER.info(
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
f"{'input':>24s}{'output':>24s}"
)
for x in input if isinstance(input, list) else [input]:
x = x.to(device)
x.requires_grad = True
for m in ops if isinstance(ops, list) else [ops]:
m = m.to(device) if hasattr(m, "to") else m # device
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
try:
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
except Exception:
flops = 0
try:
for _ in range(n):
t[0] = time_sync()
y = m(x)
t[1] = time_sync()
try:
(sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
t[2] = time_sync()
except Exception: # no backward method
# print(e) # for debug
t[2] = float("nan")
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
results.append([p, flops, mem, tf, tb, s_in, s_out])
except Exception as e:
LOGGER.info(e)
results.append(None)
torch.cuda.empty_cache()
return results
class EarlyStopping:
"""Early stopping class that stops training when a specified number of epochs have passed without improvement."""
def __init__(self, patience=50):
"""
Initialize early stopping object.
Args:
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
"""
self.best_fitness = 0.0 # i.e. mAP
self.best_epoch = 0
self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
self.possible_stop = False # possible stop may occur next epoch
def __call__(self, epoch, fitness):
"""
Check whether to stop training.
Args:
epoch (int): Current epoch of training
fitness (float): Fitness value of current epoch
Returns:
(bool): True if training should stop, False otherwise
"""
if fitness is None: # check if fitness=None (happens when val=False)
return False
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
self.best_epoch = epoch
self.best_fitness = fitness
delta = epoch - self.best_epoch # epochs without improvement
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
stop = delta >= self.patience # stop training if patience exceeded
if stop:
LOGGER.info(
f"Stopping training early as no improvement observed in last {self.patience} epochs. "
f"Best results observed at epoch {self.best_epoch}, best model saved as best_gift_v10n.pt.\n"
f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
)
return stop

View File

@ -0,0 +1,92 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import List
from urllib.parse import urlsplit
import numpy as np
class TritonRemoteModel:
"""
Client for interacting with a remote Triton Inference Server model.
Attributes:
endpoint (str): The name of the model on the Triton server.
url (str): The URL of the Triton server.
triton_client: The Triton client (either HTTP or gRPC).
InferInput: The input class for the Triton client.
InferRequestedOutput: The output request class for the Triton client.
input_formats (List[str]): The data types of the model inputs.
np_input_formats (List[type]): The numpy data types of the model inputs.
input_names (List[str]): The names of the model inputs.
output_names (List[str]): The names of the model outputs.
"""
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
"""
Initialize the TritonRemoteModel.
Arguments may be provided individually or parsed from a collective 'url' argument of the form
<scheme>://<netloc>/<endpoint>/<task_name>
Args:
url (str): The URL of the Triton server.
endpoint (str): The name of the model on the Triton server.
scheme (str): The communication scheme ('http' or 'grpc').
"""
if not endpoint and not scheme: # Parse all args from URL string
splits = urlsplit(url)
endpoint = splits.path.strip("/").split("/")[0]
scheme = splits.scheme
url = splits.netloc
self.endpoint = endpoint
self.url = url
# Choose the Triton client based on the communication scheme
if scheme == "http":
import tritonclient.http as client # noqa
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint)
else:
import tritonclient.grpc as client # noqa
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
# Define model attributes
type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
self.InferRequestedOutput = client.InferRequestedOutput
self.InferInput = client.InferInput
self.input_formats = [x["data_type"] for x in config["input"]]
self.np_input_formats = [type_map[x] for x in self.input_formats]
self.input_names = [x["name"] for x in config["input"]]
self.output_names = [x["name"] for x in config["output"]]
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
"""
Call the model with the given inputs.
Args:
*inputs (List[np.ndarray]): Input data to the model.
Returns:
(List[np.ndarray]): Model outputs.
"""
infer_inputs = []
input_format = inputs[0].dtype
for i, x in enumerate(inputs):
if x.dtype != self.np_input_formats[i]:
x = x.astype(self.np_input_formats[i])
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
infer_input.set_data_from_numpy(x)
infer_inputs.append(infer_input)
infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]

148
ultralytics/utils/tuner.py Normal file
View File

@ -0,0 +1,148 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import subprocess
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks
def run_ray_tune(
model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args
):
"""
Runs hyperparameter tuning using Ray Tune.
Args:
model (YOLO): Model to run the tuner on.
space (dict, optional): The hyperparameter search space. Defaults to None.
grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
Returns:
(dict): A dictionary containing the results of the hyperparameter search.
Example:
```python
from ultralytics import YOLO
# Load a YOLOv8n model
model = YOLO('yolov8n.pt')
# Start tuning hyperparameters for YOLOv8n training on the COCO8 dataset
result_grid = model.tune(data='coco8.yaml', use_ray=True)
```
"""
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
if train_args is None:
train_args = {}
try:
subprocess.run("pip install ray[tune]<=2.9.3".split(), check=True) # do not add single quotes here
import ray
from ray import tune
from ray.air import RunConfig
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.tune.schedulers import ASHAScheduler
except ImportError:
raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]<=2.9.3"')
try:
import wandb
assert hasattr(wandb, "__version__")
except (ImportError, AssertionError):
wandb = False
checks.check_version(ray.__version__, "<=2.9.3", "ray")
default_space = {
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
"lr0": tune.uniform(1e-5, 1e-1),
"lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
"momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
"weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
"warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
"warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
"box": tune.uniform(0.02, 0.2), # box loss gain
"cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
"hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
"hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
"hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
"degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
"translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
"scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
"shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
"perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
"flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
"fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
"bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability)
"mosaic": tune.uniform(0.0, 1.0), # image mixup (probability)
"mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
"copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
}
# Put the model in ray store
task = model.task
model_in_store = ray.put(model)
def _tune(config):
"""
Trains the YOLO model with the specified hyperparameters and additional arguments.
Args:
config (dict): A dictionary of hyperparameters to use for training.
Returns:
None
"""
model_to_train = ray.get(model_in_store) # get the model from ray store for tuning
model_to_train.reset_callbacks()
config.update(train_args)
results = model_to_train.train(**config)
return results.results_dict
# Get search space
if not space:
space = default_space
LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.")
# Get dataset
data = train_args.get("data", TASK2DATA[task])
space["data"] = data
if "data" not in train_args:
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
# Define the trainable function with allocated resources
trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
# Define the ASHA scheduler for hyperparameter search
asha_scheduler = ASHAScheduler(
time_attr="epoch",
metric=TASK2METRIC[task],
mode="max",
max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
grace_period=grace_period,
reduction_factor=3,
)
# Define the callbacks for the hyperparameter search
tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
# Create the Ray Tune hyperparameter search tuner
tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve() # must be absolute dir
tune_dir.mkdir(parents=True, exist_ok=True)
tuner = tune.Tuner(
trainable_with_resources,
param_space=space,
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),
)
# Run the hyperparameter search
tuner.fit()
# Return the results of the hyperparameter search
return tuner.get_results()