mirror of
https://gitee.com/nanjing-yimao-information/ieemoo-ai-gift.git
synced 2025-08-19 22:00:25 +00:00
update
This commit is contained in:
1049
ultralytics/utils/__init__.py
Normal file
1049
ultralytics/utils/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
88
ultralytics/utils/autobatch.py
Normal file
88
ultralytics/utils/autobatch.py
Normal file
@ -0,0 +1,88 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
|
||||
from ultralytics.utils.torch_utils import profile
|
||||
|
||||
|
||||
def check_train_batch_size(model, imgsz=640, amp=True):
|
||||
"""
|
||||
Check YOLO training batch size using the autobatch() function.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): YOLO model to check batch size for.
|
||||
imgsz (int): Image size used for training.
|
||||
amp (bool): If True, use automatic mixed precision (AMP) for training.
|
||||
|
||||
Returns:
|
||||
(int): Optimal batch size computed using the autobatch() function.
|
||||
"""
|
||||
|
||||
with torch.cuda.amp.autocast(amp):
|
||||
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
||||
|
||||
|
||||
def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
||||
"""
|
||||
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
|
||||
|
||||
Args:
|
||||
model (torch.nn.module): YOLO model to compute batch size for.
|
||||
imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
|
||||
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.60.
|
||||
batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
|
||||
|
||||
Returns:
|
||||
(int): The optimal batch size.
|
||||
"""
|
||||
|
||||
# Check device
|
||||
prefix = colorstr("AutoBatch: ")
|
||||
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}")
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type == "cpu":
|
||||
LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")
|
||||
return batch_size
|
||||
if torch.backends.cudnn.benchmark:
|
||||
LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
|
||||
return batch_size
|
||||
|
||||
# Inspect CUDA memory
|
||||
gb = 1 << 30 # bytes to GiB (1024 ** 3)
|
||||
d = str(device).upper() # 'CUDA:0'
|
||||
properties = torch.cuda.get_device_properties(device) # device properties
|
||||
t = properties.total_memory / gb # GiB total
|
||||
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
|
||||
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
|
||||
f = t - (r + a) # GiB free
|
||||
LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
|
||||
|
||||
# Profile batch sizes
|
||||
batch_sizes = [1, 2, 4, 8, 16]
|
||||
try:
|
||||
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||
results = profile(img, model, n=3, device=device)
|
||||
|
||||
# Fit a solution
|
||||
y = [x[2] for x in results if x] # memory [2]
|
||||
p = np.polyfit(batch_sizes[: len(y)], y, deg=1) # first degree polynomial fit
|
||||
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
|
||||
if None in results: # some sizes failed
|
||||
i = results.index(None) # first fail index
|
||||
if b >= batch_sizes[i]: # y intercept above failure point
|
||||
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
||||
if b < 1 or b > 1024: # b outside of safe range
|
||||
b = batch_size
|
||||
LOGGER.info(f"{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.")
|
||||
|
||||
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
|
||||
LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
|
||||
return b
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.")
|
||||
return batch_size
|
402
ultralytics/utils/benchmarks.py
Normal file
402
ultralytics/utils/benchmarks.py
Normal file
@ -0,0 +1,402 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Benchmark a YOLO model formats for speed and accuracy.
|
||||
|
||||
Usage:
|
||||
from ultralytics.utils.benchmarks import ProfileModels, benchmark
|
||||
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml']).profile()
|
||||
benchmark(model='yolov8n.pt', imgsz=160)
|
||||
|
||||
Format | `format=argument` | Model
|
||||
--- | --- | ---
|
||||
PyTorch | - | yolov8n.pt
|
||||
TorchScript | `torchscript` | yolov8n.torchscript
|
||||
ONNX | `onnx` | yolov8n.onnx
|
||||
OpenVINO | `openvino` | yolov8n_openvino_model/
|
||||
TensorRT | `engine` | yolov8n.engine
|
||||
CoreML | `coreml` | yolov8n.mlpackage
|
||||
TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
|
||||
TensorFlow GraphDef | `pb` | yolov8n.pb
|
||||
TensorFlow Lite | `tflite` | yolov8n.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
NCNN | `ncnn` | yolov8n_ncnn_model/
|
||||
"""
|
||||
|
||||
import glob
|
||||
import platform
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
|
||||
from ultralytics import YOLO, YOLOWorld
|
||||
from ultralytics.cfg import TASK2DATA, TASK2METRIC
|
||||
from ultralytics.engine.exporter import export_formats
|
||||
from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
|
||||
from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo
|
||||
from ultralytics.utils.files import file_size
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
|
||||
def benchmark(
|
||||
model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False
|
||||
):
|
||||
"""
|
||||
Benchmark a YOLO model across different formats for speed and accuracy.
|
||||
|
||||
Args:
|
||||
model (str | Path | optional): Path to the model file or directory. Default is
|
||||
Path(SETTINGS['weights_dir']) / 'yolov8n.pt'.
|
||||
data (str, optional): Dataset to evaluate on, inherited from TASK2DATA if not passed. Default is None.
|
||||
imgsz (int, optional): Image size for the benchmark. Default is 160.
|
||||
half (bool, optional): Use half-precision for the model if True. Default is False.
|
||||
int8 (bool, optional): Use int8-precision for the model if True. Default is False.
|
||||
device (str, optional): Device to run the benchmark on, either 'cpu' or 'cuda'. Default is 'cpu'.
|
||||
verbose (bool | float | optional): If True or a float, assert benchmarks pass with given metric.
|
||||
Default is False.
|
||||
|
||||
Returns:
|
||||
df (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size,
|
||||
metric, and inference time.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.benchmarks import benchmark
|
||||
|
||||
benchmark(model='yolov8n.pt', imgsz=640)
|
||||
```
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
pd.options.display.max_columns = 10
|
||||
pd.options.display.width = 120
|
||||
device = select_device(device, verbose=False)
|
||||
if isinstance(model, (str, Path)):
|
||||
model = YOLO(model)
|
||||
|
||||
y = []
|
||||
t0 = time.time()
|
||||
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
|
||||
emoji, filename = "❌", None # export defaults
|
||||
try:
|
||||
# Checks
|
||||
if i == 9: # Edge TPU
|
||||
assert LINUX, "Edge TPU export only supported on Linux"
|
||||
elif i == 7: # TF GraphDef
|
||||
assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
|
||||
elif i in {5, 10}: # CoreML and TF.js
|
||||
assert MACOS or LINUX, "export only supported on macOS and Linux"
|
||||
if i in {3, 5}: # CoreML and OpenVINO
|
||||
assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
|
||||
if i in {6, 7, 8, 9, 10}: # All TF formats
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
|
||||
if i in {11}: # Paddle
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
|
||||
if i in {12}: # NCNN
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
|
||||
if "cpu" in device.type:
|
||||
assert cpu, "inference not supported on CPU"
|
||||
if "cuda" in device.type:
|
||||
assert gpu, "inference not supported on GPU"
|
||||
|
||||
# Export
|
||||
if format == "-":
|
||||
filename = model.ckpt_path or model.cfg
|
||||
exported_model = model # PyTorch format
|
||||
else:
|
||||
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
|
||||
exported_model = YOLO(filename, task=model.task)
|
||||
assert suffix in str(filename), "export failed"
|
||||
emoji = "❎" # indicates export succeeded
|
||||
|
||||
# Predict
|
||||
assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
|
||||
assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported
|
||||
assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
|
||||
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
|
||||
|
||||
# Validate
|
||||
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
|
||||
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
|
||||
results = exported_model.val(
|
||||
data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False
|
||||
)
|
||||
metric, speed = results.results_dict[key], results.speed["inference"]
|
||||
y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
|
||||
LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}")
|
||||
y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
|
||||
|
||||
# Print results
|
||||
check_yolo(device=device) # print system info
|
||||
df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)"])
|
||||
|
||||
name = Path(model.ckpt_path).name
|
||||
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n"
|
||||
LOGGER.info(s)
|
||||
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
|
||||
f.write(s)
|
||||
|
||||
if verbose and isinstance(verbose, float):
|
||||
metrics = df[key].array # values to compare to floor
|
||||
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
|
||||
assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class ProfileModels:
|
||||
"""
|
||||
ProfileModels class for profiling different models on ONNX and TensorRT.
|
||||
|
||||
This class profiles the performance of different models, returning results such as model speed and FLOPs.
|
||||
|
||||
Attributes:
|
||||
paths (list): Paths of the models to profile.
|
||||
num_timed_runs (int): Number of timed runs for the profiling. Default is 100.
|
||||
num_warmup_runs (int): Number of warmup runs before profiling. Default is 10.
|
||||
min_time (float): Minimum number of seconds to profile for. Default is 60.
|
||||
imgsz (int): Image size used in the models. Default is 640.
|
||||
|
||||
Methods:
|
||||
profile(): Profiles the models and prints the result.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.benchmarks import ProfileModels
|
||||
|
||||
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml'], imgsz=640).profile()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paths: list,
|
||||
num_timed_runs=100,
|
||||
num_warmup_runs=10,
|
||||
min_time=60,
|
||||
imgsz=640,
|
||||
half=True,
|
||||
trt=True,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Initialize the ProfileModels class for profiling models.
|
||||
|
||||
Args:
|
||||
paths (list): List of paths of the models to be profiled.
|
||||
num_timed_runs (int, optional): Number of timed runs for the profiling. Default is 100.
|
||||
num_warmup_runs (int, optional): Number of warmup runs before the actual profiling starts. Default is 10.
|
||||
min_time (float, optional): Minimum time in seconds for profiling a model. Default is 60.
|
||||
imgsz (int, optional): Size of the image used during profiling. Default is 640.
|
||||
half (bool, optional): Flag to indicate whether to use half-precision floating point for profiling.
|
||||
trt (bool, optional): Flag to indicate whether to profile using TensorRT. Default is True.
|
||||
device (torch.device, optional): Device used for profiling. If None, it is determined automatically.
|
||||
"""
|
||||
self.paths = paths
|
||||
self.num_timed_runs = num_timed_runs
|
||||
self.num_warmup_runs = num_warmup_runs
|
||||
self.min_time = min_time
|
||||
self.imgsz = imgsz
|
||||
self.half = half
|
||||
self.trt = trt # run TensorRT profiling
|
||||
self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def profile(self):
|
||||
"""Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
|
||||
files = self.get_files()
|
||||
|
||||
if not files:
|
||||
print("No matching *.pt or *.onnx files found.")
|
||||
return
|
||||
|
||||
table_rows = []
|
||||
output = []
|
||||
for file in files:
|
||||
engine_file = file.with_suffix(".engine")
|
||||
if file.suffix in (".pt", ".yaml", ".yml"):
|
||||
model = YOLO(str(file))
|
||||
model.fuse() # to report correct params and GFLOPs in model.info()
|
||||
model_info = model.info()
|
||||
if self.trt and self.device.type != "cpu" and not engine_file.is_file():
|
||||
engine_file = model.export(
|
||||
format="engine", half=self.half, imgsz=self.imgsz, device=self.device, verbose=False
|
||||
)
|
||||
onnx_file = model.export(
|
||||
format="onnx", half=self.half, imgsz=self.imgsz, simplify=True, device=self.device, verbose=False
|
||||
)
|
||||
elif file.suffix == ".onnx":
|
||||
model_info = self.get_onnx_model_info(file)
|
||||
onnx_file = file
|
||||
else:
|
||||
continue
|
||||
|
||||
t_engine = self.profile_tensorrt_model(str(engine_file))
|
||||
t_onnx = self.profile_onnx_model(str(onnx_file))
|
||||
table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info))
|
||||
output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info))
|
||||
|
||||
self.print_table(table_rows)
|
||||
return output
|
||||
|
||||
def get_files(self):
|
||||
"""Returns a list of paths for all relevant model files given by the user."""
|
||||
files = []
|
||||
for path in self.paths:
|
||||
path = Path(path)
|
||||
if path.is_dir():
|
||||
extensions = ["*.pt", "*.onnx", "*.yaml"]
|
||||
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
|
||||
elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing
|
||||
files.append(str(path))
|
||||
else:
|
||||
files.extend(glob.glob(str(path)))
|
||||
|
||||
print(f"Profiling: {sorted(files)}")
|
||||
return [Path(file) for file in sorted(files)]
|
||||
|
||||
def get_onnx_model_info(self, onnx_file: str):
|
||||
"""Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
|
||||
file.
|
||||
"""
|
||||
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
|
||||
|
||||
@staticmethod
|
||||
def iterative_sigma_clipping(data, sigma=2, max_iters=3):
|
||||
"""Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
|
||||
data = np.array(data)
|
||||
for _ in range(max_iters):
|
||||
mean, std = np.mean(data), np.std(data)
|
||||
clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)]
|
||||
if len(clipped_data) == len(data):
|
||||
break
|
||||
data = clipped_data
|
||||
return data
|
||||
|
||||
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
|
||||
"""Profiles the TensorRT model, measuring average run time and standard deviation among runs."""
|
||||
if not self.trt or not Path(engine_file).is_file():
|
||||
return 0.0, 0.0
|
||||
|
||||
# Model and input
|
||||
model = YOLO(engine_file)
|
||||
input_data = np.random.rand(self.imgsz, self.imgsz, 3).astype(np.float32) # must be FP32
|
||||
|
||||
# Warmup runs
|
||||
elapsed = 0.0
|
||||
for _ in range(3):
|
||||
start_time = time.time()
|
||||
for _ in range(self.num_warmup_runs):
|
||||
model(input_data, imgsz=self.imgsz, verbose=False)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Compute number of runs as higher of min_time or num_timed_runs
|
||||
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
for _ in TQDM(range(num_runs), desc=engine_file):
|
||||
results = model(input_data, imgsz=self.imgsz, verbose=False)
|
||||
run_times.append(results[0].speed["inference"]) # Convert to milliseconds
|
||||
|
||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
|
||||
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
|
||||
times.
|
||||
"""
|
||||
check_requirements("onnxruntime")
|
||||
import onnxruntime as ort
|
||||
|
||||
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.intra_op_num_threads = 8 # Limit the number of threads
|
||||
sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
|
||||
|
||||
input_tensor = sess.get_inputs()[0]
|
||||
input_type = input_tensor.type
|
||||
|
||||
# Mapping ONNX datatype to numpy datatype
|
||||
if "float16" in input_type:
|
||||
input_dtype = np.float16
|
||||
elif "float" in input_type:
|
||||
input_dtype = np.float32
|
||||
elif "double" in input_type:
|
||||
input_dtype = np.float64
|
||||
elif "int64" in input_type:
|
||||
input_dtype = np.int64
|
||||
elif "int32" in input_type:
|
||||
input_dtype = np.int32
|
||||
else:
|
||||
raise ValueError(f"Unsupported ONNX datatype {input_type}")
|
||||
|
||||
input_data = np.random.rand(*input_tensor.shape).astype(input_dtype)
|
||||
input_name = input_tensor.name
|
||||
output_name = sess.get_outputs()[0].name
|
||||
|
||||
# Warmup runs
|
||||
elapsed = 0.0
|
||||
for _ in range(3):
|
||||
start_time = time.time()
|
||||
for _ in range(self.num_warmup_runs):
|
||||
sess.run([output_name], {input_name: input_data})
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Compute number of runs as higher of min_time or num_timed_runs
|
||||
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
for _ in TQDM(range(num_runs), desc=onnx_file):
|
||||
start_time = time.time()
|
||||
sess.run([output_name], {input_name: input_data})
|
||||
run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds
|
||||
|
||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
|
||||
"""Generates a formatted string for a table row that includes model performance and metric details."""
|
||||
layers, params, gradients, flops = model_info
|
||||
return (
|
||||
f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± "
|
||||
f"{t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_results_dict(model_name, t_onnx, t_engine, model_info):
|
||||
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
|
||||
layers, params, gradients, flops = model_info
|
||||
return {
|
||||
"model/name": model_name,
|
||||
"model/parameters": params,
|
||||
"model/GFLOPs": round(flops, 3),
|
||||
"model/speed_ONNX(ms)": round(t_onnx[0], 3),
|
||||
"model/speed_TensorRT(ms)": round(t_engine[0], 3),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def print_table(table_rows):
|
||||
"""Formats and prints a comparison table for different models with given statistics and performance data."""
|
||||
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
|
||||
header = (
|
||||
f"| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | "
|
||||
f"Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |"
|
||||
)
|
||||
separator = (
|
||||
"|-------------|---------------------|--------------------|------------------------------|"
|
||||
"-----------------------------------|------------------|-----------------|"
|
||||
)
|
||||
|
||||
print(f"\n\n{header}")
|
||||
print(separator)
|
||||
for row in table_rows:
|
||||
print(row)
|
5
ultralytics/utils/callbacks/__init__.py
Normal file
5
ultralytics/utils/callbacks/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
|
||||
|
||||
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"
|
219
ultralytics/utils/callbacks/base.py
Normal file
219
ultralytics/utils/callbacks/base.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""Base callbacks."""
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
# Trainer callbacks ----------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Called before the pretraining routine starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Called after the pretraining routine ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Called when the training starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_epoch_start(trainer):
|
||||
"""Called at the start of each training epoch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_batch_start(trainer):
|
||||
"""Called at the start of each training batch."""
|
||||
pass
|
||||
|
||||
|
||||
def optimizer_step(trainer):
|
||||
"""Called when the optimizer takes a step."""
|
||||
pass
|
||||
|
||||
|
||||
def on_before_zero_grad(trainer):
|
||||
"""Called before the gradients are set to zero."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_batch_end(trainer):
|
||||
"""Called at the end of each training batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Called at the end of each training epoch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Called at the end of each fit epoch (train + val)."""
|
||||
pass
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Called when the model is saved."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Called when the training ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_params_update(trainer):
|
||||
"""Called when the model parameters are updated."""
|
||||
pass
|
||||
|
||||
|
||||
def teardown(trainer):
|
||||
"""Called during the teardown of the training process."""
|
||||
pass
|
||||
|
||||
|
||||
# Validator callbacks --------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_val_start(validator):
|
||||
"""Called when the validation starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_start(validator):
|
||||
"""Called at the start of each validation batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_end(validator):
|
||||
"""Called at the end of each validation batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_end(validator):
|
||||
"""Called when the validation ends."""
|
||||
pass
|
||||
|
||||
|
||||
# Predictor callbacks --------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
"""Called when the prediction starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_batch_start(predictor):
|
||||
"""Called at the start of each prediction batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_batch_end(predictor):
|
||||
"""Called at the end of each prediction batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_postprocess_end(predictor):
|
||||
"""Called after the post-processing of the prediction ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_end(predictor):
|
||||
"""Called when the prediction ends."""
|
||||
pass
|
||||
|
||||
|
||||
# Exporter callbacks ---------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_export_start(exporter):
|
||||
"""Called when the model export starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_export_end(exporter):
|
||||
"""Called when the model export ends."""
|
||||
pass
|
||||
|
||||
|
||||
default_callbacks = {
|
||||
# Run in trainer
|
||||
"on_pretrain_routine_start": [on_pretrain_routine_start],
|
||||
"on_pretrain_routine_end": [on_pretrain_routine_end],
|
||||
"on_train_start": [on_train_start],
|
||||
"on_train_epoch_start": [on_train_epoch_start],
|
||||
"on_train_batch_start": [on_train_batch_start],
|
||||
"optimizer_step": [optimizer_step],
|
||||
"on_before_zero_grad": [on_before_zero_grad],
|
||||
"on_train_batch_end": [on_train_batch_end],
|
||||
"on_train_epoch_end": [on_train_epoch_end],
|
||||
"on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
|
||||
"on_model_save": [on_model_save],
|
||||
"on_train_end": [on_train_end],
|
||||
"on_params_update": [on_params_update],
|
||||
"teardown": [teardown],
|
||||
# Run in validator
|
||||
"on_val_start": [on_val_start],
|
||||
"on_val_batch_start": [on_val_batch_start],
|
||||
"on_val_batch_end": [on_val_batch_end],
|
||||
"on_val_end": [on_val_end],
|
||||
# Run in predictor
|
||||
"on_predict_start": [on_predict_start],
|
||||
"on_predict_batch_start": [on_predict_batch_start],
|
||||
"on_predict_postprocess_end": [on_predict_postprocess_end],
|
||||
"on_predict_batch_end": [on_predict_batch_end],
|
||||
"on_predict_end": [on_predict_end],
|
||||
# Run in exporter
|
||||
"on_export_start": [on_export_start],
|
||||
"on_export_end": [on_export_end],
|
||||
}
|
||||
|
||||
|
||||
def get_default_callbacks():
|
||||
"""
|
||||
Return a copy of the default_callbacks dictionary with lists as default values.
|
||||
|
||||
Returns:
|
||||
(defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values.
|
||||
"""
|
||||
return defaultdict(list, deepcopy(default_callbacks))
|
||||
|
||||
|
||||
def add_integration_callbacks(instance):
|
||||
"""
|
||||
Add integration callbacks from various sources to the instance's callbacks.
|
||||
|
||||
Args:
|
||||
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
|
||||
of callback lists.
|
||||
"""
|
||||
|
||||
# Load HUB callbacks
|
||||
from .hub import callbacks as hub_cb
|
||||
|
||||
callbacks_list = [hub_cb]
|
||||
|
||||
# Load training callbacks
|
||||
if "Trainer" in instance.__class__.__name__:
|
||||
from .clearml import callbacks as clear_cb
|
||||
from .comet import callbacks as comet_cb
|
||||
from .dvc import callbacks as dvc_cb
|
||||
from .mlflow import callbacks as mlflow_cb
|
||||
from .neptune import callbacks as neptune_cb
|
||||
from .raytune import callbacks as tune_cb
|
||||
from .tensorboard import callbacks as tb_cb
|
||||
from .wb import callbacks as wb_cb
|
||||
|
||||
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
|
||||
|
||||
# Add the callbacks to the callbacks dictionary
|
||||
for callbacks in callbacks_list:
|
||||
for k, v in callbacks.items():
|
||||
if v not in instance.callbacks[k]:
|
||||
instance.callbacks[k].append(v)
|
152
ultralytics/utils/callbacks/clearml.py
Normal file
152
ultralytics/utils/callbacks/clearml.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["clearml"] is True # verify integration is enabled
|
||||
import clearml
|
||||
from clearml import Task
|
||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||
|
||||
assert hasattr(clearml, "__version__") # verify package is not directory
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
clearml = None
|
||||
|
||||
|
||||
def _log_debug_samples(files, title="Debug Samples") -> None:
|
||||
"""
|
||||
Log files (images) as debug samples in the ClearML task.
|
||||
|
||||
Args:
|
||||
files (list): A list of file paths in PosixPath format.
|
||||
title (str): A title that groups together images with the same values.
|
||||
"""
|
||||
import re
|
||||
|
||||
if task := Task.current_task():
|
||||
for f in files:
|
||||
if f.exists():
|
||||
it = re.search(r"_batch(\d+)", f.name)
|
||||
iteration = int(it.groups()[0]) if it else 0
|
||||
task.get_logger().report_image(
|
||||
title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
|
||||
)
|
||||
|
||||
|
||||
def _log_plot(title, plot_path) -> None:
|
||||
"""
|
||||
Log an image as a plot in the plot section of ClearML.
|
||||
|
||||
Args:
|
||||
title (str): The title of the plot.
|
||||
plot_path (str): The path to the saved image file.
|
||||
"""
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
|
||||
Task.current_task().get_logger().report_matplotlib_figure(
|
||||
title=title, series="", figure=fig, report_interactive=False
|
||||
)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
|
||||
try:
|
||||
if task := Task.current_task():
|
||||
# Make sure the automatic pytorch and matplotlib bindings are disabled!
|
||||
# We are logging these plots and model files manually in the integration
|
||||
PatchPyTorchModelIO.update_current_task(None)
|
||||
PatchedMatplotlib.update_current_task(None)
|
||||
else:
|
||||
task = Task.init(
|
||||
project_name=trainer.args.project or "YOLOv8",
|
||||
task_name=trainer.args.name,
|
||||
tags=["YOLOv8"],
|
||||
output_uri=True,
|
||||
reuse_last_task_id=False,
|
||||
auto_connect_frameworks={"pytorch": False, "matplotlib": False},
|
||||
)
|
||||
LOGGER.warning(
|
||||
"ClearML Initialized a new task. If you want to run remotely, "
|
||||
"please add clearml-init and connect your arguments before initializing YOLO."
|
||||
)
|
||||
task.connect(vars(trainer.args), name="General")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Logs debug samples for the first epoch of YOLO training and report current training progress."""
|
||||
if task := Task.current_task():
|
||||
# Log debug samples
|
||||
if trainer.epoch == 1:
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
|
||||
# Report the current training progress
|
||||
for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
|
||||
task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
|
||||
for k, v in trainer.lr.items():
|
||||
task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Reports model information to logger at the end of an epoch."""
|
||||
if task := Task.current_task():
|
||||
# You should have access to the validation bboxes under jdict
|
||||
task.get_logger().report_scalar(
|
||||
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
|
||||
)
|
||||
for k, v in trainer.metrics.items():
|
||||
task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch)
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
for k, v in model_info_for_loggers(trainer).items():
|
||||
task.get_logger().report_single_value(k, v)
|
||||
|
||||
|
||||
def on_val_end(validator):
|
||||
"""Logs validation results including labels and predictions."""
|
||||
if Task.current_task():
|
||||
# Log val_labels and val_pred
|
||||
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Logs final model and its name on training completion."""
|
||||
if task := Task.current_task():
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = [
|
||||
"results.png",
|
||||
"confusion_matrix.png",
|
||||
"confusion_matrix_normalized.png",
|
||||
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
|
||||
]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
# Report final metrics
|
||||
for k, v in trainer.validator.metrics.results_dict.items():
|
||||
task.get_logger().report_single_value(k, v)
|
||||
# Log the final model
|
||||
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_val_end": on_val_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if clearml
|
||||
else {}
|
||||
)
|
375
ultralytics/utils/callbacks/comet.py
Normal file
375
ultralytics/utils/callbacks/comet.py
Normal file
@ -0,0 +1,375 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["comet"] is True # verify integration is enabled
|
||||
import comet_ml
|
||||
|
||||
assert hasattr(comet_ml, "__version__") # verify package is not directory
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Ensures certain logging functions only run for supported tasks
|
||||
COMET_SUPPORTED_TASKS = ["detect"]
|
||||
|
||||
# Names of plots created by YOLOv8 that are logged to Comet
|
||||
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
|
||||
LABEL_PLOT_NAMES = "labels", "labels_correlogram"
|
||||
|
||||
_comet_image_prediction_count = 0
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
comet_ml = None
|
||||
|
||||
|
||||
def _get_comet_mode():
|
||||
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
|
||||
return os.getenv("COMET_MODE", "online")
|
||||
|
||||
|
||||
def _get_comet_model_name():
|
||||
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
|
||||
return os.getenv("COMET_MODEL_NAME", "YOLOv8")
|
||||
|
||||
|
||||
def _get_eval_batch_logging_interval():
|
||||
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
|
||||
return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
|
||||
|
||||
|
||||
def _get_max_image_predictions_to_log():
|
||||
"""Get the maximum number of image predictions to log from the environment variables."""
|
||||
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
|
||||
|
||||
|
||||
def _scale_confidence_score(score):
|
||||
"""Scales the given confidence score by a factor specified in an environment variable."""
|
||||
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
|
||||
return score * scale
|
||||
|
||||
|
||||
def _should_log_confusion_matrix():
|
||||
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
|
||||
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
|
||||
|
||||
|
||||
def _should_log_image_predictions():
|
||||
"""Determines whether to log image predictions based on a specified environment variable."""
|
||||
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
|
||||
|
||||
|
||||
def _get_experiment_type(mode, project_name):
|
||||
"""Return an experiment based on mode and project name."""
|
||||
if mode == "offline":
|
||||
return comet_ml.OfflineExperiment(project_name=project_name)
|
||||
|
||||
return comet_ml.Experiment(project_name=project_name)
|
||||
|
||||
|
||||
def _create_experiment(args):
|
||||
"""Ensures that the experiment object is only created in a single process during distributed training."""
|
||||
if RANK not in (-1, 0):
|
||||
return
|
||||
try:
|
||||
comet_mode = _get_comet_mode()
|
||||
_project_name = os.getenv("COMET_PROJECT_NAME", args.project)
|
||||
experiment = _get_experiment_type(comet_mode, _project_name)
|
||||
experiment.log_parameters(vars(args))
|
||||
experiment.log_others(
|
||||
{
|
||||
"eval_batch_logging_interval": _get_eval_batch_logging_interval(),
|
||||
"log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
|
||||
"log_image_predictions": _should_log_image_predictions(),
|
||||
"max_image_predictions": _get_max_image_predictions_to_log(),
|
||||
}
|
||||
)
|
||||
experiment.log_other("Created from", "yolov8")
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def _fetch_trainer_metadata(trainer):
|
||||
"""Returns metadata for YOLO training including epoch and asset saving status."""
|
||||
curr_epoch = trainer.epoch + 1
|
||||
|
||||
train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
|
||||
curr_step = curr_epoch * train_num_steps_per_epoch
|
||||
final_epoch = curr_epoch == trainer.epochs
|
||||
|
||||
save = trainer.args.save
|
||||
save_period = trainer.args.save_period
|
||||
save_interval = curr_epoch % save_period == 0
|
||||
save_assets = save and save_period > 0 and save_interval and not final_epoch
|
||||
|
||||
return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
|
||||
|
||||
|
||||
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
|
||||
"""
|
||||
YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
|
||||
|
||||
This function rescales the bounding box labels to the original image shape.
|
||||
"""
|
||||
|
||||
resized_image_height, resized_image_width = resized_image_shape
|
||||
|
||||
# Convert normalized xywh format predictions to xyxy in resized scale format
|
||||
box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
|
||||
# Scale box predictions from resized image scale back to original image scale
|
||||
box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
|
||||
# Convert bounding box format from xyxy to xywh for Comet logging
|
||||
box = ops.xyxy2xywh(box)
|
||||
# Adjust xy center to correspond top-left corner
|
||||
box[:2] -= box[2:] / 2
|
||||
box = box.tolist()
|
||||
|
||||
return box
|
||||
|
||||
|
||||
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
|
||||
"""Format ground truth annotations for detection."""
|
||||
indices = batch["batch_idx"] == img_idx
|
||||
bboxes = batch["bboxes"][indices]
|
||||
if len(bboxes) == 0:
|
||||
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
|
||||
return None
|
||||
|
||||
cls_labels = batch["cls"][indices].squeeze(1).tolist()
|
||||
if class_name_map:
|
||||
cls_labels = [str(class_name_map[label]) for label in cls_labels]
|
||||
|
||||
original_image_shape = batch["ori_shape"][img_idx]
|
||||
resized_image_shape = batch["resized_shape"][img_idx]
|
||||
ratio_pad = batch["ratio_pad"][img_idx]
|
||||
|
||||
data = []
|
||||
for box, label in zip(bboxes, cls_labels):
|
||||
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
|
||||
data.append(
|
||||
{
|
||||
"boxes": [box],
|
||||
"label": f"gt_{label}",
|
||||
"score": _scale_confidence_score(1.0),
|
||||
}
|
||||
)
|
||||
|
||||
return {"name": "ground_truth", "data": data}
|
||||
|
||||
|
||||
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
|
||||
"""Format YOLO predictions for object detection visualization."""
|
||||
stem = image_path.stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
|
||||
predictions = metadata.get(image_id)
|
||||
if not predictions:
|
||||
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
|
||||
return None
|
||||
|
||||
data = []
|
||||
for prediction in predictions:
|
||||
boxes = prediction["bbox"]
|
||||
score = _scale_confidence_score(prediction["score"])
|
||||
cls_label = prediction["category_id"]
|
||||
if class_label_map:
|
||||
cls_label = str(class_label_map[cls_label])
|
||||
|
||||
data.append({"boxes": [boxes], "label": cls_label, "score": score})
|
||||
|
||||
return {"name": "prediction", "data": data}
|
||||
|
||||
|
||||
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
|
||||
"""Join the ground truth and prediction annotations if they exist."""
|
||||
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
|
||||
img_idx, image_path, batch, class_label_map
|
||||
)
|
||||
prediction_annotations = _format_prediction_annotations_for_detection(
|
||||
image_path, prediction_metadata_map, class_label_map
|
||||
)
|
||||
|
||||
annotations = [
|
||||
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
|
||||
]
|
||||
return [annotations] if annotations else None
|
||||
|
||||
|
||||
def _create_prediction_metadata_map(model_predictions):
|
||||
"""Create metadata map for model predictions by groupings them based on image ID."""
|
||||
pred_metadata_map = {}
|
||||
for prediction in model_predictions:
|
||||
pred_metadata_map.setdefault(prediction["image_id"], [])
|
||||
pred_metadata_map[prediction["image_id"]].append(prediction)
|
||||
|
||||
return pred_metadata_map
|
||||
|
||||
|
||||
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
|
||||
"""Log the confusion matrix to Comet experiment."""
|
||||
conf_mat = trainer.validator.confusion_matrix.matrix
|
||||
names = list(trainer.data["names"].values()) + ["background"]
|
||||
experiment.log_confusion_matrix(
|
||||
matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
|
||||
)
|
||||
|
||||
|
||||
def _log_images(experiment, image_paths, curr_step, annotations=None):
|
||||
"""Logs images to the experiment with optional annotations."""
|
||||
if annotations:
|
||||
for image_path, annotation in zip(image_paths, annotations):
|
||||
experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
|
||||
|
||||
else:
|
||||
for image_path in image_paths:
|
||||
experiment.log_image(image_path, name=image_path.stem, step=curr_step)
|
||||
|
||||
|
||||
def _log_image_predictions(experiment, validator, curr_step):
|
||||
"""Logs predicted boxes for a single image during training."""
|
||||
global _comet_image_prediction_count
|
||||
|
||||
task = validator.args.task
|
||||
if task not in COMET_SUPPORTED_TASKS:
|
||||
return
|
||||
|
||||
jdict = validator.jdict
|
||||
if not jdict:
|
||||
return
|
||||
|
||||
predictions_metadata_map = _create_prediction_metadata_map(jdict)
|
||||
dataloader = validator.dataloader
|
||||
class_label_map = validator.names
|
||||
|
||||
batch_logging_interval = _get_eval_batch_logging_interval()
|
||||
max_image_predictions = _get_max_image_predictions_to_log()
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if (batch_idx + 1) % batch_logging_interval != 0:
|
||||
continue
|
||||
|
||||
image_paths = batch["im_file"]
|
||||
for img_idx, image_path in enumerate(image_paths):
|
||||
if _comet_image_prediction_count >= max_image_predictions:
|
||||
return
|
||||
|
||||
image_path = Path(image_path)
|
||||
annotations = _fetch_annotations(
|
||||
img_idx,
|
||||
image_path,
|
||||
batch,
|
||||
predictions_metadata_map,
|
||||
class_label_map,
|
||||
)
|
||||
_log_images(
|
||||
experiment,
|
||||
[image_path],
|
||||
curr_step,
|
||||
annotations=annotations,
|
||||
)
|
||||
_comet_image_prediction_count += 1
|
||||
|
||||
|
||||
def _log_plots(experiment, trainer):
|
||||
"""Logs evaluation plots and label plots for the experiment."""
|
||||
plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
|
||||
_log_images(experiment, plot_filenames, None)
|
||||
|
||||
label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
|
||||
_log_images(experiment, label_plot_filenames, None)
|
||||
|
||||
|
||||
def _log_model(experiment, trainer):
|
||||
"""Log the best-trained model to Comet.ml."""
|
||||
model_name = _get_comet_model_name()
|
||||
experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best_gift_v10n.pt", overwrite=True)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
is_alive = getattr(experiment, "alive", False)
|
||||
if not experiment or not is_alive:
|
||||
_create_experiment(trainer.args)
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log metrics and save batch images at the end of training epochs."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
|
||||
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
|
||||
|
||||
if curr_epoch == 1:
|
||||
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs model assets at the end of each epoch."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
save_assets = metadata["save_assets"]
|
||||
|
||||
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
|
||||
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
|
||||
if curr_epoch == 1:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
|
||||
|
||||
if not save_assets:
|
||||
return
|
||||
|
||||
_log_model(experiment, trainer)
|
||||
if _should_log_confusion_matrix():
|
||||
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
|
||||
if _should_log_image_predictions():
|
||||
_log_image_predictions(experiment, trainer.validator, curr_step)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Perform operations at the end of training."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
plots = trainer.args.plots
|
||||
|
||||
_log_model(experiment, trainer)
|
||||
if plots:
|
||||
_log_plots(experiment, trainer)
|
||||
|
||||
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
|
||||
_log_image_predictions(experiment, trainer.validator, curr_step)
|
||||
experiment.end()
|
||||
|
||||
global _comet_image_prediction_count
|
||||
_comet_image_prediction_count = 0
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if comet_ml
|
||||
else {}
|
||||
)
|
145
ultralytics/utils/callbacks/dvc.py
Normal file
145
ultralytics/utils/callbacks/dvc.py
Normal file
@ -0,0 +1,145 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["dvc"] is True # verify integration is enabled
|
||||
import dvclive
|
||||
|
||||
assert checks.check_version("dvclive", "2.11.0", verbose=True)
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# DVCLive logger instance
|
||||
live = None
|
||||
_processed_plots = {}
|
||||
|
||||
# `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
|
||||
# distinguish final evaluation of the best model vs last epoch validation
|
||||
_training_epoch = False
|
||||
|
||||
except (ImportError, AssertionError, TypeError):
|
||||
dvclive = None
|
||||
|
||||
|
||||
def _log_images(path, prefix=""):
|
||||
"""Logs images at specified path with an optional prefix using DVCLive."""
|
||||
if live:
|
||||
name = path.name
|
||||
|
||||
# Group images by batch to enable sliders in UI
|
||||
if m := re.search(r"_batch(\d+)", name):
|
||||
ni = m[1]
|
||||
new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
|
||||
name = (Path(new_stem) / ni).with_suffix(path.suffix)
|
||||
|
||||
live.log_image(os.path.join(prefix, name), path)
|
||||
|
||||
|
||||
def _log_plots(plots, prefix=""):
|
||||
"""Logs plot images for training progress if they have not been previously processed."""
|
||||
for name, params in plots.items():
|
||||
timestamp = params["timestamp"]
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
_log_images(name, prefix)
|
||||
_processed_plots[name] = timestamp
|
||||
|
||||
|
||||
def _log_confusion_matrix(validator):
|
||||
"""Logs the confusion matrix for the given validator using DVCLive."""
|
||||
targets = []
|
||||
preds = []
|
||||
matrix = validator.confusion_matrix.matrix
|
||||
names = list(validator.names.values())
|
||||
if validator.confusion_matrix.task == "detect":
|
||||
names += ["background"]
|
||||
|
||||
for ti, pred in enumerate(matrix.T.astype(int)):
|
||||
for pi, num in enumerate(pred):
|
||||
targets.extend([names[ti]] * num)
|
||||
preds.extend([names[pi]] * num)
|
||||
|
||||
live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initializes DVCLive logger for training metadata during pre-training routine."""
|
||||
try:
|
||||
global live
|
||||
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
|
||||
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs plots related to the training process at the end of the pretraining routine."""
|
||||
_log_plots(trainer.plots, "train")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Logs the training parameters if DVCLive logging is active."""
|
||||
if live:
|
||||
live.log_params(trainer.args)
|
||||
|
||||
|
||||
def on_train_epoch_start(trainer):
|
||||
"""Sets the global variable _training_epoch value to True at the start of training each epoch."""
|
||||
global _training_epoch
|
||||
_training_epoch = True
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
|
||||
global _training_epoch
|
||||
if live and _training_epoch:
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
|
||||
for metric, value in all_metrics.items():
|
||||
live.log_metric(metric, value)
|
||||
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
for metric, value in model_info_for_loggers(trainer).items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
_log_plots(trainer.plots, "train")
|
||||
_log_plots(trainer.validator.plots, "val")
|
||||
|
||||
live.next_step()
|
||||
_training_epoch = False
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
|
||||
if live:
|
||||
# At the end log the best metrics. It runs validator on the best model internally.
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
|
||||
for metric, value in all_metrics.items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
_log_plots(trainer.plots, "val")
|
||||
_log_plots(trainer.validator.plots, "val")
|
||||
_log_confusion_matrix(trainer.validator)
|
||||
|
||||
if trainer.best.exists():
|
||||
live.log_artifact(trainer.best, copy=True, type="model")
|
||||
|
||||
live.end()
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_train_start": on_train_start,
|
||||
"on_train_epoch_start": on_train_epoch_start,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if dvclive
|
||||
else {}
|
||||
)
|
108
ultralytics/utils/callbacks/hub.py
Normal file
108
ultralytics/utils/callbacks/hub.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import json
|
||||
from time import time
|
||||
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
|
||||
from ultralytics.utils import LOGGER, SETTINGS
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs info before starting timer for upload rate limit."""
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Start timer for upload rate limit
|
||||
session.timers = {
|
||||
"metrics": time(),
|
||||
"ckpt": time(),
|
||||
} # start timer on session.rate_limit
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Uploads training progress metrics at the end of each epoch."""
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Upload metrics after val end
|
||||
all_plots = {
|
||||
**trainer.label_loss_items(trainer.tloss, prefix="train"),
|
||||
**trainer.metrics,
|
||||
}
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
|
||||
|
||||
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||
|
||||
# If any metrics fail to upload, add them to the queue to attempt uploading again.
|
||||
if session.metrics_upload_failed_queue:
|
||||
session.metrics_queue.update(session.metrics_upload_failed_queue)
|
||||
|
||||
if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
|
||||
session.upload_metrics()
|
||||
session.timers["metrics"] = time() # reset timer
|
||||
session.metrics_queue = {} # reset queue
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Upload checkpoints with rate limiting
|
||||
is_best = trainer.best_fitness == trainer.fitness
|
||||
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
|
||||
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}")
|
||||
session.upload_model(trainer.epoch, trainer.last, is_best)
|
||||
session.timers["ckpt"] = time() # reset timer
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Upload final model and metrics with exponential standoff
|
||||
LOGGER.info(f"{PREFIX}Syncing final model...")
|
||||
session.upload_model(
|
||||
trainer.epoch,
|
||||
trainer.best,
|
||||
map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
|
||||
final=True,
|
||||
)
|
||||
session.alive = False # stop heartbeats
|
||||
LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Run events on train start."""
|
||||
events(trainer.args)
|
||||
|
||||
|
||||
def on_val_start(validator):
|
||||
"""Runs events on validation start."""
|
||||
events(validator.args)
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
"""Run events on predict start."""
|
||||
events(predictor.args)
|
||||
|
||||
|
||||
def on_export_start(exporter):
|
||||
"""Run events on export start."""
|
||||
events(exporter.args)
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_model_save": on_model_save,
|
||||
"on_train_end": on_train_end,
|
||||
"on_train_start": on_train_start,
|
||||
"on_val_start": on_val_start,
|
||||
"on_predict_start": on_predict_start,
|
||||
"on_export_start": on_export_start,
|
||||
}
|
||||
if SETTINGS["hub"] is True
|
||||
else {}
|
||||
) # verify enabled
|
133
ultralytics/utils/callbacks/mlflow.py
Normal file
133
ultralytics/utils/callbacks/mlflow.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
MLflow Logging for Ultralytics YOLO.
|
||||
|
||||
This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts.
|
||||
For setting up, a tracking URI should be specified. The logging can be customized using environment variables.
|
||||
|
||||
Commands:
|
||||
1. To set a project name:
|
||||
`export MLFLOW_EXPERIMENT_NAME=<your_experiment_name>` or use the project=<project> argument
|
||||
|
||||
2. To set a run name:
|
||||
`export MLFLOW_RUN=<your_run_name>` or use the name=<name> argument
|
||||
|
||||
3. To start a local MLflow server:
|
||||
mlflow server --backend-store-uri runs/mlflow
|
||||
It will by default start a local server at http://127.0.0.1:5000.
|
||||
To specify a different URI, set the MLFLOW_TRACKING_URI environment variable.
|
||||
|
||||
4. To kill all running MLflow server instances:
|
||||
ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9
|
||||
"""
|
||||
|
||||
from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
|
||||
assert SETTINGS["mlflow"] is True # verify integration is enabled
|
||||
import mlflow
|
||||
|
||||
assert hasattr(mlflow, "__version__") # verify package is not directory
|
||||
from pathlib import Path
|
||||
|
||||
PREFIX = colorstr("MLflow: ")
|
||||
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
mlflow = None
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""
|
||||
Log training parameters to MLflow at the end of the pretraining routine.
|
||||
|
||||
This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
|
||||
experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
|
||||
from the trainer.
|
||||
|
||||
Args:
|
||||
trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
|
||||
|
||||
Global:
|
||||
mlflow: The imported mlflow module to use for logging.
|
||||
|
||||
Environment Variables:
|
||||
MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
|
||||
MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
|
||||
MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
|
||||
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of the training phase.
|
||||
"""
|
||||
global mlflow
|
||||
|
||||
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
|
||||
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
|
||||
mlflow.set_tracking_uri(uri)
|
||||
|
||||
# Set experiment and run names
|
||||
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
|
||||
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
mlflow.autolog()
|
||||
try:
|
||||
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
|
||||
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
|
||||
if Path(uri).is_dir():
|
||||
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
|
||||
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
|
||||
mlflow.log_params(dict(trainer.args))
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log training metrics at the end of each train epoch to MLflow."""
|
||||
if mlflow:
|
||||
mlflow.log_metrics(
|
||||
metrics={
|
||||
**SANITIZE(trainer.lr),
|
||||
**SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")),
|
||||
},
|
||||
step=trainer.epoch,
|
||||
)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Log training metrics at the end of each fit epoch to MLflow."""
|
||||
if mlflow:
|
||||
mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Log model artifacts at the end of the training."""
|
||||
if mlflow:
|
||||
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best_gift_v10n.pt and last.pt
|
||||
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
||||
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
||||
mlflow.log_artifact(str(f))
|
||||
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true")
|
||||
if keep_run_active:
|
||||
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
|
||||
else:
|
||||
mlflow.end_run()
|
||||
LOGGER.debug(f"{PREFIX}mlflow run ended")
|
||||
|
||||
LOGGER.info(
|
||||
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
|
||||
f"{PREFIX}disable with 'yolo settings mlflow=False'"
|
||||
)
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if mlflow
|
||||
else {}
|
||||
)
|
112
ultralytics/utils/callbacks/neptune.py
Normal file
112
ultralytics/utils/callbacks/neptune.py
Normal file
@ -0,0 +1,112 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["neptune"] is True # verify integration is enabled
|
||||
import neptune
|
||||
from neptune.types import File
|
||||
|
||||
assert hasattr(neptune, "__version__")
|
||||
|
||||
run = None # NeptuneAI experiment logger instance
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
neptune = None
|
||||
|
||||
|
||||
def _log_scalars(scalars, step=0):
|
||||
"""Log scalars to the NeptuneAI experiment logger."""
|
||||
if run:
|
||||
for k, v in scalars.items():
|
||||
run[k].append(value=v, step=step)
|
||||
|
||||
|
||||
def _log_images(imgs_dict, group=""):
|
||||
"""Log scalars to the NeptuneAI experiment logger."""
|
||||
if run:
|
||||
for k, v in imgs_dict.items():
|
||||
run[f"{group}/{k}"].upload(File(v))
|
||||
|
||||
|
||||
def _log_plot(title, plot_path):
|
||||
"""
|
||||
Log plots to the NeptuneAI experiment logger.
|
||||
|
||||
Args:
|
||||
title (str): Title of the plot.
|
||||
plot_path (PosixPath | str): Path to the saved image file.
|
||||
"""
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
run[f"Plots/{title}"].upload(fig)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Callback function called before the training routine starts."""
|
||||
try:
|
||||
global run
|
||||
run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
|
||||
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Callback function called at end of each training epoch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
||||
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Callback function called at end of each fit (train+val) epoch."""
|
||||
if run and trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
run["Configuration/Model"] = model_info_for_loggers(trainer)
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_val_end(validator):
|
||||
"""Callback function called at end of each validation."""
|
||||
if run:
|
||||
# Log val_labels and val_pred
|
||||
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Callback function called at end of training."""
|
||||
if run:
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = [
|
||||
"results.png",
|
||||
"confusion_matrix.png",
|
||||
"confusion_matrix_normalized.png",
|
||||
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
|
||||
]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
# Log the final model
|
||||
run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_val_end": on_val_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if neptune
|
||||
else {}
|
||||
)
|
29
ultralytics/utils/callbacks/raytune.py
Normal file
29
ultralytics/utils/callbacks/raytune.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.utils import SETTINGS
|
||||
|
||||
try:
|
||||
assert SETTINGS["raytune"] is True # verify integration is enabled
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air import session
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
tune = None
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Sends training metrics to Ray Tune at end of each epoch."""
|
||||
if ray.tune.is_session_enabled():
|
||||
metrics = trainer.metrics
|
||||
metrics["epoch"] = trainer.epoch
|
||||
session.report(metrics)
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
}
|
||||
if tune
|
||||
else {}
|
||||
)
|
106
ultralytics/utils/callbacks/tensorboard.py
Normal file
106
ultralytics/utils/callbacks/tensorboard.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import contextlib
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
|
||||
|
||||
try:
|
||||
# WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["tensorboard"] is True # verify integration is enabled
|
||||
WRITER = None # TensorBoard SummaryWriter instance
|
||||
PREFIX = colorstr("TensorBoard: ")
|
||||
|
||||
# Imports below only required if TensorBoard enabled
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from ultralytics.utils.torch_utils import de_parallel, torch
|
||||
|
||||
except (ImportError, AssertionError, TypeError, AttributeError):
|
||||
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
|
||||
# AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
|
||||
SummaryWriter = None
|
||||
|
||||
|
||||
def _log_scalars(scalars, step=0):
|
||||
"""Logs scalar values to TensorBoard."""
|
||||
if WRITER:
|
||||
for k, v in scalars.items():
|
||||
WRITER.add_scalar(k, v, step)
|
||||
|
||||
|
||||
def _log_tensorboard_graph(trainer):
|
||||
"""Log model graph to TensorBoard."""
|
||||
|
||||
# Input image
|
||||
imgsz = trainer.args.imgsz
|
||||
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
||||
p = next(trainer.model.parameters()) # for device, type
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
|
||||
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
|
||||
|
||||
# Try simple method first (YOLO)
|
||||
with contextlib.suppress(Exception):
|
||||
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
|
||||
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
return
|
||||
|
||||
# Fallback to TorchScript export steps (RTDETR)
|
||||
try:
|
||||
model = deepcopy(de_parallel(trainer.model))
|
||||
model.eval()
|
||||
model = model.fuse(verbose=False)
|
||||
for m in model.modules():
|
||||
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
|
||||
m.export = True
|
||||
m.format = "torchscript"
|
||||
model(im) # dry run
|
||||
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initialize TensorBoard logging with SummaryWriter."""
|
||||
if SummaryWriter:
|
||||
try:
|
||||
global WRITER
|
||||
WRITER = SummaryWriter(str(trainer.save_dir))
|
||||
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Log TensorBoard graph."""
|
||||
if WRITER:
|
||||
_log_tensorboard_graph(trainer)
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Logs scalar statistics at the end of a training epoch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
||||
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs epoch metrics at end of training epoch."""
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_start": on_train_start,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
}
|
||||
if SummaryWriter
|
||||
else {}
|
||||
)
|
163
ultralytics/utils/callbacks/wb.py
Normal file
163
ultralytics/utils/callbacks/wb.py
Normal file
@ -0,0 +1,163 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.utils import SETTINGS, TESTS_RUNNING
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["wandb"] is True # verify integration is enabled
|
||||
import wandb as wb
|
||||
|
||||
assert hasattr(wb, "__version__") # verify package is not directory
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
_processed_plots = {}
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
wb = None
|
||||
|
||||
|
||||
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
|
||||
"""
|
||||
Create and log a custom metric visualization to wandb.plot.pr_curve.
|
||||
|
||||
This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall
|
||||
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
|
||||
different classes.
|
||||
|
||||
Args:
|
||||
x (List): Values for the x-axis; expected to have length N.
|
||||
y (List): Corresponding values for the y-axis; also expected to have length N.
|
||||
classes (List): Labels identifying the class of each point; length N.
|
||||
title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
|
||||
x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
|
||||
y_title (str, optional): Label for the y-axis; defaults to 'Precision'.
|
||||
|
||||
Returns:
|
||||
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
||||
"""
|
||||
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
||||
fields = {"x": "x", "y": "y", "class": "class"}
|
||||
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
||||
return wb.plot_table(
|
||||
"wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
|
||||
)
|
||||
|
||||
|
||||
def _plot_curve(
|
||||
x,
|
||||
y,
|
||||
names=None,
|
||||
id="precision-recall",
|
||||
title="Precision Recall Curve",
|
||||
x_title="Recall",
|
||||
y_title="Precision",
|
||||
num_x=100,
|
||||
only_mean=False,
|
||||
):
|
||||
"""
|
||||
Log a metric curve visualization.
|
||||
|
||||
This function generates a metric curve based on input data and logs the visualization to wandb.
|
||||
The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Data points for the x-axis with length N.
|
||||
y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
|
||||
names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
|
||||
id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
|
||||
title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
|
||||
x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
|
||||
y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
|
||||
num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
|
||||
only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.
|
||||
|
||||
Note:
|
||||
The function leverages the '_custom_table' function to generate the actual visualization.
|
||||
"""
|
||||
# Create new x
|
||||
if names is None:
|
||||
names = []
|
||||
x_new = np.linspace(x[0], x[-1], num_x).round(5)
|
||||
|
||||
# Create arrays for logging
|
||||
x_log = x_new.tolist()
|
||||
y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
|
||||
|
||||
if only_mean:
|
||||
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
|
||||
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
|
||||
else:
|
||||
classes = ["mean"] * len(x_log)
|
||||
for i, yi in enumerate(y):
|
||||
x_log.extend(x_new) # add new x
|
||||
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
|
||||
classes.extend([names[i]] * len(x_new)) # add class names
|
||||
wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
|
||||
|
||||
|
||||
def _log_plots(plots, step):
|
||||
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
|
||||
for name, params in plots.items():
|
||||
timestamp = params["timestamp"]
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
|
||||
_processed_plots[name] = timestamp
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initiate and start project if module is present."""
|
||||
wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs training metrics and model information at the end of an epoch."""
|
||||
wb.run.log(trainer.metrics, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 0:
|
||||
wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log metrics and save images at the end of each training epoch."""
|
||||
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
|
||||
wb.run.log(trainer.lr, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Save the best model as an artifact at end of training."""
|
||||
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
|
||||
if trainer.best.exists():
|
||||
art.add_file(trainer.best)
|
||||
wb.run.log_artifact(art, aliases=["best"])
|
||||
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
|
||||
x, y, x_title, y_title = curve_values
|
||||
_plot_curve(
|
||||
x,
|
||||
y,
|
||||
names=list(trainer.validator.metrics.names.values()),
|
||||
id=f"curves/{curve_name}",
|
||||
title=curve_name,
|
||||
x_title=x_title,
|
||||
y_title=y_title,
|
||||
)
|
||||
wb.run.finish() # required or run continues on dashboard
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if wb
|
||||
else {}
|
||||
)
|
731
ultralytics/utils/checks.py
Normal file
731
ultralytics/utils/checks.py
Normal file
@ -0,0 +1,731 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import glob
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from importlib import metadata
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from matplotlib import font_manager
|
||||
|
||||
from ultralytics.utils import (
|
||||
ASSETS,
|
||||
AUTOINSTALL,
|
||||
LINUX,
|
||||
LOGGER,
|
||||
ONLINE,
|
||||
ROOT,
|
||||
USER_CONFIG_DIR,
|
||||
SimpleNamespace,
|
||||
ThreadingLocked,
|
||||
TryExcept,
|
||||
clean_url,
|
||||
colorstr,
|
||||
downloads,
|
||||
emojis,
|
||||
is_colab,
|
||||
is_docker,
|
||||
is_github_action_running,
|
||||
is_jupyter,
|
||||
is_kaggle,
|
||||
is_online,
|
||||
is_pip_package,
|
||||
url2file,
|
||||
)
|
||||
|
||||
PYTHON_VERSION = platform.python_version()
|
||||
|
||||
|
||||
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
||||
"""
|
||||
Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the requirements.txt file.
|
||||
package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'.
|
||||
|
||||
Returns:
|
||||
(List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.checks import parse_requirements
|
||||
|
||||
parse_requirements(package='ultralytics')
|
||||
```
|
||||
"""
|
||||
|
||||
if package:
|
||||
requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
|
||||
else:
|
||||
requires = Path(file_path).read_text().splitlines()
|
||||
|
||||
requirements = []
|
||||
for line in requires:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
line = line.split("#")[0].strip() # ignore inline comments
|
||||
match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line)
|
||||
if match:
|
||||
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
|
||||
|
||||
return requirements
|
||||
|
||||
|
||||
def parse_version(version="0.0.0") -> tuple:
|
||||
"""
|
||||
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
|
||||
function replaces deprecated 'pkg_resources.parse_version(v)'.
|
||||
|
||||
Args:
|
||||
version (str): Version string, i.e. '2.0.1+cpu'
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
|
||||
"""
|
||||
try:
|
||||
return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
|
||||
return 0, 0, 0
|
||||
|
||||
|
||||
def is_ascii(s) -> bool:
|
||||
"""
|
||||
Check if a string is composed of only ASCII characters.
|
||||
|
||||
Args:
|
||||
s (str): String to be checked.
|
||||
|
||||
Returns:
|
||||
(bool): True if the string is composed only of ASCII characters, False otherwise.
|
||||
"""
|
||||
# Convert list, tuple, None, etc. to string
|
||||
s = str(s)
|
||||
|
||||
# Check if the string is composed of only ASCII characters
|
||||
return all(ord(c) < 128 for c in s)
|
||||
|
||||
|
||||
def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
||||
"""
|
||||
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
|
||||
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
|
||||
|
||||
Args:
|
||||
imgsz (int | cList[int]): Image size.
|
||||
stride (int): Stride value.
|
||||
min_dim (int): Minimum number of dimensions.
|
||||
max_dim (int): Maximum number of dimensions.
|
||||
floor (int): Minimum allowed value for image size.
|
||||
|
||||
Returns:
|
||||
(List[int]): Updated image size.
|
||||
"""
|
||||
# Convert stride to integer if it is a tensor
|
||||
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
|
||||
|
||||
# Convert image size to list if it is an integer
|
||||
if isinstance(imgsz, int):
|
||||
imgsz = [imgsz]
|
||||
elif isinstance(imgsz, (list, tuple)):
|
||||
imgsz = list(imgsz)
|
||||
elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'
|
||||
imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
|
||||
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
|
||||
)
|
||||
|
||||
# Apply max_dim
|
||||
if len(imgsz) > max_dim:
|
||||
msg = (
|
||||
"'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
|
||||
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
|
||||
)
|
||||
if max_dim != 1:
|
||||
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
|
||||
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
|
||||
imgsz = [max(imgsz)]
|
||||
# Make image size a multiple of the stride
|
||||
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
|
||||
|
||||
# Print warning message if image size was updated
|
||||
if sz != imgsz:
|
||||
LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
|
||||
|
||||
# Add missing dimensions if necessary
|
||||
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
|
||||
|
||||
return sz
|
||||
|
||||
|
||||
def check_version(
|
||||
current: str = "0.0.0",
|
||||
required: str = "0.0.0",
|
||||
name: str = "version",
|
||||
hard: bool = False,
|
||||
verbose: bool = False,
|
||||
msg: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
Check current version against the required version or range.
|
||||
|
||||
Args:
|
||||
current (str): Current version or package name to get version from.
|
||||
required (str): Required version or range (in pip-style format).
|
||||
name (str, optional): Name to be used in warning message.
|
||||
hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
|
||||
verbose (bool, optional): If True, print warning message if requirement is not met.
|
||||
msg (str, optional): Extra message to display if verbose.
|
||||
|
||||
Returns:
|
||||
(bool): True if requirement is met, False otherwise.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Check if current version is exactly 22.04
|
||||
check_version(current='22.04', required='==22.04')
|
||||
|
||||
# Check if current version is greater than or equal to 22.04
|
||||
check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed
|
||||
|
||||
# Check if current version is less than or equal to 22.04
|
||||
check_version(current='22.04', required='<=22.04')
|
||||
|
||||
# Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
|
||||
check_version(current='21.10', required='>20.04,<22.04')
|
||||
```
|
||||
"""
|
||||
if not current: # if current is '' or None
|
||||
LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
|
||||
return True
|
||||
elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
|
||||
try:
|
||||
name = current # assigned package name to 'name' arg
|
||||
current = metadata.version(current) # get version string from package name
|
||||
except metadata.PackageNotFoundError as e:
|
||||
if hard:
|
||||
raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e
|
||||
else:
|
||||
return False
|
||||
|
||||
if not required: # if required is '' or None
|
||||
return True
|
||||
|
||||
op = ""
|
||||
version = ""
|
||||
result = True
|
||||
c = parse_version(current) # '1.2.3' -> (1, 2, 3)
|
||||
for r in required.strip(",").split(","):
|
||||
op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
|
||||
v = parse_version(version) # '1.2.3' -> (1, 2, 3)
|
||||
if op == "==" and c != v:
|
||||
result = False
|
||||
elif op == "!=" and c == v:
|
||||
result = False
|
||||
elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
|
||||
result = False
|
||||
elif op == "<=" and not (c <= v):
|
||||
result = False
|
||||
elif op == ">" and not (c > v):
|
||||
result = False
|
||||
elif op == "<" and not (c < v):
|
||||
result = False
|
||||
if not result:
|
||||
warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
|
||||
if hard:
|
||||
raise ModuleNotFoundError(emojis(warning)) # assert version requirements met
|
||||
if verbose:
|
||||
LOGGER.warning(warning)
|
||||
return result
|
||||
|
||||
|
||||
def check_latest_pypi_version(package_name="ultralytics"):
|
||||
"""
|
||||
Returns the latest version of a PyPI package without downloading or installing it.
|
||||
|
||||
Parameters:
|
||||
package_name (str): The name of the package to find the latest version for.
|
||||
|
||||
Returns:
|
||||
(str): The latest version of the package.
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
|
||||
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
|
||||
if response.status_code == 200:
|
||||
return response.json()["info"]["version"]
|
||||
|
||||
|
||||
def check_pip_update_available():
|
||||
"""
|
||||
Checks if a new version of the ultralytics package is available on PyPI.
|
||||
|
||||
Returns:
|
||||
(bool): True if an update is available, False otherwise.
|
||||
"""
|
||||
if ONLINE and is_pip_package():
|
||||
with contextlib.suppress(Exception):
|
||||
from ultralytics import __version__
|
||||
|
||||
latest = check_latest_pypi_version()
|
||||
if check_version(__version__, f"<{latest}"): # check if current version is < latest version
|
||||
LOGGER.info(
|
||||
f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
|
||||
f"Update with 'pip install -U ultralytics'"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ThreadingLocked()
|
||||
def check_font(font="Arial.ttf"):
|
||||
"""
|
||||
Find font locally or download to user's configuration directory if it does not already exist.
|
||||
|
||||
Args:
|
||||
font (str): Path or name of font.
|
||||
|
||||
Returns:
|
||||
file (Path): Resolved font file path.
|
||||
"""
|
||||
name = Path(font).name
|
||||
|
||||
# Check USER_CONFIG_DIR
|
||||
file = USER_CONFIG_DIR / name
|
||||
if file.exists():
|
||||
return file
|
||||
|
||||
# Check system fonts
|
||||
matches = [s for s in font_manager.findSystemFonts() if font in s]
|
||||
if any(matches):
|
||||
return matches[0]
|
||||
|
||||
# Download to USER_CONFIG_DIR if missing
|
||||
url = f"https://ultralytics.com/assets/{name}"
|
||||
if downloads.is_url(url, check=True):
|
||||
downloads.safe_download(url=url, file=file)
|
||||
return file
|
||||
|
||||
|
||||
def check_python(minimum: str = "3.8.0") -> bool:
|
||||
"""
|
||||
Check current python version against the required minimum version.
|
||||
|
||||
Args:
|
||||
minimum (str): Required minimum version of python.
|
||||
|
||||
Returns:
|
||||
(bool): Whether the installed Python version meets the minimum constraints.
|
||||
"""
|
||||
return check_version(PYTHON_VERSION, minimum, name="Python ", hard=True)
|
||||
|
||||
|
||||
@TryExcept()
|
||||
def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
|
||||
"""
|
||||
Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
|
||||
|
||||
Args:
|
||||
requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a
|
||||
string, or a list of package requirements as strings.
|
||||
exclude (Tuple[str]): Tuple of package names to exclude from checking.
|
||||
install (bool): If True, attempt to auto-update packages that don't meet requirements.
|
||||
cmds (str): Additional commands to pass to the pip install command when auto-updating.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
# Check a requirements.txt file
|
||||
check_requirements('path/to/requirements.txt')
|
||||
|
||||
# Check a single package
|
||||
check_requirements('ultralytics>=8.0.0')
|
||||
|
||||
# Check multiple packages
|
||||
check_requirements(['numpy', 'ultralytics>=8.0.0'])
|
||||
```
|
||||
"""
|
||||
|
||||
prefix = colorstr("red", "bold", "requirements:")
|
||||
check_python() # check python version
|
||||
check_torchvision() # check torch-torchvision compatibility
|
||||
if isinstance(requirements, Path): # requirements.txt file
|
||||
file = requirements.resolve()
|
||||
assert file.exists(), f"{prefix} {file} not found, check failed."
|
||||
requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
|
||||
elif isinstance(requirements, str):
|
||||
requirements = [requirements]
|
||||
|
||||
pkgs = []
|
||||
for r in requirements:
|
||||
r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
|
||||
match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
|
||||
name, required = match[1], match[2].strip() if match[2] else ""
|
||||
try:
|
||||
assert check_version(metadata.version(name), required) # exception if requirements not met
|
||||
except (AssertionError, metadata.PackageNotFoundError):
|
||||
pkgs.append(r)
|
||||
|
||||
s = " ".join(f'"{x}"' for x in pkgs) # console string
|
||||
if s:
|
||||
if install and AUTOINSTALL: # check environment variable
|
||||
n = len(pkgs) # number of packages updates
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
||||
try:
|
||||
t = time.time()
|
||||
assert is_online(), "AutoUpdate skipped (offline)"
|
||||
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
|
||||
dt = time.time() - t
|
||||
LOGGER.info(
|
||||
f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{prefix} ❌ {e}")
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_torchvision():
|
||||
"""
|
||||
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
||||
|
||||
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
|
||||
to the provided compatibility table based on:
|
||||
https://github.com/pytorch/vision#installation.
|
||||
|
||||
The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
||||
Torchvision versions.
|
||||
"""
|
||||
|
||||
import torchvision
|
||||
|
||||
# Compatibility table
|
||||
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
|
||||
|
||||
# Extract only the major and minor versions
|
||||
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
|
||||
v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
|
||||
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
if all(v_torchvision != v for v in compatible_versions):
|
||||
print(
|
||||
f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
|
||||
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
|
||||
"'pip install -U torch torchvision' to update both.\n"
|
||||
"For a full compatibility table see https://github.com/pytorch/vision#installation"
|
||||
)
|
||||
|
||||
|
||||
def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):
|
||||
"""Check file(s) for acceptable suffix."""
|
||||
if file and suffix:
|
||||
if isinstance(suffix, str):
|
||||
suffix = (suffix,)
|
||||
for f in file if isinstance(file, (list, tuple)) else [file]:
|
||||
s = Path(f).suffix.lower().strip() # file suffix
|
||||
if len(s):
|
||||
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"
|
||||
|
||||
|
||||
def check_yolov5u_filename(file: str, verbose: bool = True):
|
||||
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
|
||||
if "yolov3" in file or "yolov5" in file:
|
||||
if "u.yaml" in file:
|
||||
file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
|
||||
elif ".pt" in file and "u" not in file:
|
||||
original_file = file
|
||||
file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
|
||||
file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
|
||||
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
|
||||
if file != original_file and verbose:
|
||||
LOGGER.info(
|
||||
f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
|
||||
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
|
||||
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
def check_model_file_from_stem(model="yolov8n"):
|
||||
"""Return a model filename from a valid model stem."""
|
||||
if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
|
||||
return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def check_file(file, suffix="", download=True, hard=True):
|
||||
"""Search/download file (if necessary) and return path."""
|
||||
check_suffix(file, suffix) # optional
|
||||
file = str(file).strip() # convert to string and strip spaces
|
||||
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
|
||||
if (
|
||||
not file
|
||||
or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
|
||||
or file.lower().startswith("grpc://")
|
||||
): # file exists or gRPC Triton images
|
||||
return file
|
||||
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
|
||||
url = file # warning: Pathlib turns :// -> :/
|
||||
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
if Path(file).exists():
|
||||
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
|
||||
else:
|
||||
downloads.safe_download(url=url, file=file, unzip=False)
|
||||
return file
|
||||
else: # search
|
||||
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
|
||||
if not files and hard:
|
||||
raise FileNotFoundError(f"'{file}' does not exist")
|
||||
elif len(files) > 1 and hard:
|
||||
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
||||
return files[0] if len(files) else [] if hard else file # return file
|
||||
|
||||
|
||||
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
|
||||
"""Search/download YAML file (if necessary) and return path, checking suffix."""
|
||||
return check_file(file, suffix, hard=hard)
|
||||
|
||||
|
||||
def check_is_path_safe(basedir, path):
|
||||
"""
|
||||
Check if the resolved path is under the intended directory to prevent path traversal.
|
||||
|
||||
Args:
|
||||
basedir (Path | str): The intended directory.
|
||||
path (Path | str): The path to check.
|
||||
|
||||
Returns:
|
||||
(bool): True if the path is safe, False otherwise.
|
||||
"""
|
||||
base_dir_resolved = Path(basedir).resolve()
|
||||
path_resolved = Path(path).resolve()
|
||||
|
||||
return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
|
||||
|
||||
|
||||
def check_imshow(warn=False):
|
||||
"""Check if environment supports image displays."""
|
||||
try:
|
||||
if LINUX:
|
||||
assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle()
|
||||
cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
|
||||
cv2.waitKey(1)
|
||||
cv2.destroyAllWindows()
|
||||
cv2.waitKey(1)
|
||||
return True
|
||||
except Exception as e:
|
||||
if warn:
|
||||
LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_yolo(verbose=True, device=""):
|
||||
"""Return a human-readable YOLO software and hardware summary."""
|
||||
import psutil
|
||||
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
if is_jupyter():
|
||||
if check_requirements("wandb", install=False):
|
||||
os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang
|
||||
if is_colab():
|
||||
shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
|
||||
|
||||
if verbose:
|
||||
# System info
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
ram = psutil.virtual_memory().total
|
||||
total, used, free = shutil.disk_usage("/")
|
||||
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
|
||||
with contextlib.suppress(Exception): # clear display if ipython is installed
|
||||
from IPython import display
|
||||
|
||||
display.clear_output()
|
||||
else:
|
||||
s = ""
|
||||
|
||||
select_device(device=device, newline=False)
|
||||
LOGGER.info(f"Setup complete ✅ {s}")
|
||||
|
||||
|
||||
def collect_system_info():
|
||||
"""Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
|
||||
|
||||
import psutil
|
||||
|
||||
from ultralytics.utils import ENVIRONMENT, is_git_dir
|
||||
from ultralytics.utils.torch_utils import get_cpu_info
|
||||
|
||||
ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
|
||||
check_yolo()
|
||||
LOGGER.info(
|
||||
f"\n{'OS':<20}{platform.platform()}\n"
|
||||
f"{'Environment':<20}{ENVIRONMENT}\n"
|
||||
f"{'Python':<20}{PYTHON_VERSION}\n"
|
||||
f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
|
||||
f"{'RAM':<20}{ram_info:.2f} GB\n"
|
||||
f"{'CPU':<20}{get_cpu_info()}\n"
|
||||
f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
|
||||
)
|
||||
|
||||
for r in parse_requirements(package="ultralytics"):
|
||||
try:
|
||||
current = metadata.version(r.name)
|
||||
is_met = "✅ " if check_version(current, str(r.specifier), hard=True) else "❌ "
|
||||
except metadata.PackageNotFoundError:
|
||||
current = "(not installed)"
|
||||
is_met = "❌ "
|
||||
LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")
|
||||
|
||||
if is_github_action_running():
|
||||
LOGGER.info(
|
||||
f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
|
||||
f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
|
||||
f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
|
||||
f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
|
||||
f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
|
||||
f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
|
||||
)
|
||||
|
||||
|
||||
def check_amp(model):
|
||||
"""
|
||||
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
|
||||
fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
|
||||
be disabled during training.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A YOLOv8 model instance.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.utils.checks import check_amp
|
||||
|
||||
model = YOLO('yolov8n.pt').model.cuda()
|
||||
check_amp(model)
|
||||
```
|
||||
|
||||
Returns:
|
||||
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||
"""
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in ("cpu", "mps"):
|
||||
return False # AMP only used on CUDA devices
|
||||
|
||||
def amp_allclose(m, im):
|
||||
"""All close FP32 vs AMP results."""
|
||||
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
|
||||
with torch.cuda.amp.autocast(True):
|
||||
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
|
||||
del m
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||
|
||||
im = ASSETS / "bus.jpg" # image to check
|
||||
prefix = colorstr("AMP: ")
|
||||
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")
|
||||
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
assert amp_allclose(YOLO("yolov8n.pt"), im)
|
||||
LOGGER.info(f"{prefix}checks passed ✅")
|
||||
except ConnectionError:
|
||||
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")
|
||||
except (AttributeError, ModuleNotFoundError):
|
||||
LOGGER.warning(
|
||||
f"{prefix}checks skipped ⚠️. "
|
||||
f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
|
||||
)
|
||||
except AssertionError:
|
||||
LOGGER.warning(
|
||||
f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
|
||||
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def git_describe(path=ROOT): # path must be a directory
|
||||
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
|
||||
with contextlib.suppress(Exception):
|
||||
return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
|
||||
return ""
|
||||
|
||||
|
||||
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
"""Print function arguments (optional args dict)."""
|
||||
|
||||
def strip_auth(v):
|
||||
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
|
||||
return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
|
||||
|
||||
x = inspect.currentframe().f_back # previous frame
|
||||
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||
if args is None: # get args automatically
|
||||
args, _, _, frm = inspect.getargvalues(x)
|
||||
args = {k: v for k, v in frm.items() if k in args}
|
||||
try:
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
|
||||
except ValueError:
|
||||
file = Path(file).stem
|
||||
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
|
||||
LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))
|
||||
|
||||
|
||||
def cuda_device_count() -> int:
|
||||
"""
|
||||
Get the number of NVIDIA GPUs available in the environment.
|
||||
|
||||
Returns:
|
||||
(int): The number of NVIDIA GPUs available.
|
||||
"""
|
||||
try:
|
||||
# Run the nvidia-smi command and capture its output
|
||||
output = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
|
||||
)
|
||||
|
||||
# Take the first line and strip any leading/trailing white space
|
||||
first_line = output.strip().split("\n")[0]
|
||||
|
||||
return int(first_line)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
|
||||
# If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available
|
||||
return 0
|
||||
|
||||
|
||||
def cuda_is_available() -> bool:
|
||||
"""
|
||||
Check if CUDA is available in the environment.
|
||||
|
||||
Returns:
|
||||
(bool): True if one or more NVIDIA GPUs are available, False otherwise.
|
||||
"""
|
||||
return cuda_device_count() > 0
|
||||
|
||||
|
||||
# Define constants
|
||||
IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")
|
71
ultralytics/utils/dist.py
Normal file
71
ultralytics/utils/dist.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from . import USER_CONFIG_DIR
|
||||
from .torch_utils import TORCH_1_9
|
||||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
"""
|
||||
Finds a free port on localhost.
|
||||
|
||||
It is useful in single-node training when we don't want to connect to a real main node but have to set the
|
||||
`MASTER_PORT` environment variable.
|
||||
"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1] # port
|
||||
|
||||
|
||||
def generate_ddp_file(trainer):
|
||||
"""Generates a DDP file and returns its file name."""
|
||||
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
|
||||
|
||||
content = f"""
|
||||
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
|
||||
overrides = {vars(trainer.args)}
|
||||
|
||||
if __name__ == "__main__":
|
||||
from {module} import {name}
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT
|
||||
|
||||
cfg = DEFAULT_CFG_DICT.copy()
|
||||
cfg.update(save_dir='') # handle the extra key 'save_dir'
|
||||
trainer = {name}(cfg=cfg, overrides=overrides)
|
||||
results = trainer.train()
|
||||
"""
|
||||
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
prefix="_temp_",
|
||||
suffix=f"{id(trainer)}.py",
|
||||
mode="w+",
|
||||
encoding="utf-8",
|
||||
dir=USER_CONFIG_DIR / "DDP",
|
||||
delete=False,
|
||||
) as file:
|
||||
file.write(content)
|
||||
return file.name
|
||||
|
||||
|
||||
def generate_ddp_command(world_size, trainer):
|
||||
"""Generates and returns command for distributed training."""
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
|
||||
if not trainer.resume:
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
file = generate_ddp_file(trainer)
|
||||
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
|
||||
port = find_free_network_port()
|
||||
cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
|
||||
return cmd, file
|
||||
|
||||
|
||||
def ddp_cleanup(trainer, file):
|
||||
"""Delete temp file if created."""
|
||||
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
|
||||
os.remove(file)
|
500
ultralytics/utils/downloads.py
Normal file
500
ultralytics/utils/downloads.py
Normal file
@ -0,0 +1,500 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from urllib import parse, request
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
|
||||
|
||||
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
|
||||
GITHUB_ASSETS_REPO = "ultralytics/assets"
|
||||
GITHUB_ASSETS_NAMES = (
|
||||
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
|
||||
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
|
||||
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
|
||||
+ [f"yolov8{k}-world.pt" for k in "smlx"]
|
||||
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
|
||||
+ [f"yolov9{k}.pt" for k in "ce"]
|
||||
+ [f"yolo_nas_{k}.pt" for k in "sml"]
|
||||
+ [f"sam_{k}.pt" for k in "bl"]
|
||||
+ [f"FastSAM-{k}.pt" for k in "sx"]
|
||||
+ [f"rtdetr-{k}.pt" for k in "lx"]
|
||||
+ ["mobile_sam.pt"]
|
||||
+ ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"]
|
||||
)
|
||||
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
|
||||
|
||||
|
||||
def is_url(url, check=False):
|
||||
"""
|
||||
Validates if the given string is a URL and optionally checks if the URL exists online.
|
||||
|
||||
Args:
|
||||
url (str): The string to be validated as a URL.
|
||||
check (bool, optional): If True, performs an additional check to see if the URL exists online.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
(bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online.
|
||||
Returns False otherwise.
|
||||
|
||||
Example:
|
||||
```python
|
||||
valid = is_url("https://www.example.com")
|
||||
```
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
url = str(url)
|
||||
result = parse.urlparse(url)
|
||||
assert all([result.scheme, result.netloc]) # check if is url
|
||||
if check:
|
||||
with request.urlopen(url) as response:
|
||||
return response.getcode() == 200 # check if exists online
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
|
||||
"""
|
||||
Deletes all ".DS_store" files under a specified directory.
|
||||
|
||||
Args:
|
||||
path (str, optional): The directory path where the ".DS_store" files should be deleted.
|
||||
files_to_delete (tuple): The files to be deleted.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.downloads import delete_dsstore
|
||||
|
||||
delete_dsstore('path/to/dir')
|
||||
```
|
||||
|
||||
Note:
|
||||
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
|
||||
are hidden system files and can cause issues when transferring files between different operating systems.
|
||||
"""
|
||||
for file in files_to_delete:
|
||||
matches = list(Path(path).rglob(file))
|
||||
LOGGER.info(f"Deleting {file} files: {matches}")
|
||||
for f in matches:
|
||||
f.unlink()
|
||||
|
||||
|
||||
def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
|
||||
"""
|
||||
Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
|
||||
named after the directory and placed alongside it.
|
||||
|
||||
Args:
|
||||
directory (str | Path): The path to the directory to be zipped.
|
||||
compress (bool): Whether to compress the files while zipping. Default is True.
|
||||
exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
|
||||
progress (bool, optional): Whether to display a progress bar. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(Path): The path to the resulting zip file.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.downloads import zip_directory
|
||||
|
||||
file = zip_directory('path/to/dir')
|
||||
```
|
||||
"""
|
||||
from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile
|
||||
|
||||
delete_dsstore(directory)
|
||||
directory = Path(directory)
|
||||
if not directory.is_dir():
|
||||
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
|
||||
|
||||
# Unzip with progress bar
|
||||
files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
|
||||
zip_file = directory.with_suffix(".zip")
|
||||
compression = ZIP_DEFLATED if compress else ZIP_STORED
|
||||
with ZipFile(zip_file, "w", compression) as f:
|
||||
for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress):
|
||||
f.write(file, file.relative_to(directory))
|
||||
|
||||
return zip_file # return path to zip file
|
||||
|
||||
|
||||
def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
|
||||
"""
|
||||
Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
|
||||
|
||||
If the zipfile does not contain a single top-level directory, the function will create a new
|
||||
directory with the same name as the zipfile (without the extension) to extract its contents.
|
||||
If a path is not provided, the function will use the parent directory of the zipfile as the default path.
|
||||
|
||||
Args:
|
||||
file (str): The path to the zipfile to be extracted.
|
||||
path (str, optional): The path to extract the zipfile to. Defaults to None.
|
||||
exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
|
||||
exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False.
|
||||
progress (bool, optional): Whether to display a progress bar. Defaults to True.
|
||||
|
||||
Raises:
|
||||
BadZipFile: If the provided file does not exist or is not a valid zipfile.
|
||||
|
||||
Returns:
|
||||
(Path): The path to the directory where the zipfile was extracted.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.downloads import unzip_file
|
||||
|
||||
dir = unzip_file('path/to/file.zip')
|
||||
```
|
||||
"""
|
||||
from zipfile import BadZipFile, ZipFile, is_zipfile
|
||||
|
||||
if not (Path(file).exists() and is_zipfile(file)):
|
||||
raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.")
|
||||
if path is None:
|
||||
path = Path(file).parent # default path
|
||||
|
||||
# Unzip the file contents
|
||||
with ZipFile(file) as zipObj:
|
||||
files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
|
||||
top_level_dirs = {Path(f).parts[0] for f in files}
|
||||
|
||||
if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")):
|
||||
# Zip has multiple files at top level
|
||||
path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8
|
||||
else:
|
||||
# Zip has 1 top-level directory
|
||||
extract_path = path # i.e. ../datasets
|
||||
path = Path(path) / list(top_level_dirs)[0] # i.e. ../datasets/coco8
|
||||
|
||||
# Check if destination directory already exists and contains files
|
||||
if path.exists() and any(path.iterdir()) and not exist_ok:
|
||||
# If it exists and is not empty, return the path without unzipping
|
||||
LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.")
|
||||
return path
|
||||
|
||||
for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress):
|
||||
# Ensure the file is within the extract_path to avoid path traversal security vulnerability
|
||||
if ".." in Path(f).parts:
|
||||
LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
|
||||
continue
|
||||
zipObj.extract(f, extract_path)
|
||||
|
||||
return path # return unzip dir
|
||||
|
||||
|
||||
def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", path=Path.cwd(), sf=1.5, hard=True):
|
||||
"""
|
||||
Check if there is sufficient disk space to download and store a file.
|
||||
|
||||
Args:
|
||||
url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco128.zip'.
|
||||
path (str | Path, optional): The path or drive to check the available free space on.
|
||||
sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0.
|
||||
hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(bool): True if there is sufficient disk space, False otherwise.
|
||||
"""
|
||||
try:
|
||||
r = requests.head(url) # response
|
||||
assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response
|
||||
except Exception:
|
||||
return True # requests issue, default to True
|
||||
|
||||
# Check file size
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB)
|
||||
total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes
|
||||
|
||||
if data * sf < free:
|
||||
return True # sufficient space
|
||||
|
||||
# Insufficient space
|
||||
text = (
|
||||
f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
|
||||
f"Please free {data * sf - free:.1f} GB additional disk space and try again."
|
||||
)
|
||||
if hard:
|
||||
raise MemoryError(text)
|
||||
LOGGER.warning(text)
|
||||
return False
|
||||
|
||||
|
||||
def get_google_drive_file_info(link):
|
||||
"""
|
||||
Retrieves the direct download link and filename for a shareable Google Drive file link.
|
||||
|
||||
Args:
|
||||
link (str): The shareable link of the Google Drive file.
|
||||
|
||||
Returns:
|
||||
(str): Direct download URL for the Google Drive file.
|
||||
(str): Original filename of the Google Drive file. If filename extraction fails, returns None.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.downloads import get_google_drive_file_info
|
||||
|
||||
link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link"
|
||||
url, filename = get_google_drive_file_info(link)
|
||||
```
|
||||
"""
|
||||
file_id = link.split("/d/")[1].split("/view")[0]
|
||||
drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
||||
filename = None
|
||||
|
||||
# Start session
|
||||
with requests.Session() as session:
|
||||
response = session.get(drive_url, stream=True)
|
||||
if "quota exceeded" in str(response.content.lower()):
|
||||
raise ConnectionError(
|
||||
emojis(
|
||||
f"❌ Google Drive file download quota exceeded. "
|
||||
f"Please try again later or download this file manually at {link}."
|
||||
)
|
||||
)
|
||||
for k, v in response.cookies.items():
|
||||
if k.startswith("download_warning"):
|
||||
drive_url += f"&confirm={v}" # v is token
|
||||
cd = response.headers.get("content-disposition")
|
||||
if cd:
|
||||
filename = re.findall('filename="(.+)"', cd)[0]
|
||||
return drive_url, filename
|
||||
|
||||
|
||||
def safe_download(
|
||||
url,
|
||||
file=None,
|
||||
dir=None,
|
||||
unzip=True,
|
||||
delete=False,
|
||||
curl=False,
|
||||
retry=3,
|
||||
min_bytes=1e0,
|
||||
exist_ok=False,
|
||||
progress=True,
|
||||
):
|
||||
"""
|
||||
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the file to be downloaded.
|
||||
file (str, optional): The filename of the downloaded file.
|
||||
If not provided, the file will be saved with the same name as the URL.
|
||||
dir (str, optional): The directory to save the downloaded file.
|
||||
If not provided, the file will be saved in the current working directory.
|
||||
unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
|
||||
delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
|
||||
curl (bool, optional): Whether to use curl command line tool for downloading. Default: False.
|
||||
retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
|
||||
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
|
||||
a successful download. Default: 1E0.
|
||||
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
|
||||
progress (bool, optional): Whether to display a progress bar during the download. Default: True.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.downloads import safe_download
|
||||
|
||||
link = "https://ultralytics.com/assets/bus.jpg"
|
||||
path = safe_download(link)
|
||||
```
|
||||
"""
|
||||
gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
|
||||
if gdrive:
|
||||
url, file = get_google_drive_file_info(url)
|
||||
|
||||
f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
|
||||
if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
|
||||
f = Path(url) # filename
|
||||
elif not f.is_file(): # URL and file do not exist
|
||||
desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'"
|
||||
LOGGER.info(f"{desc}...")
|
||||
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
|
||||
check_disk_space(url, path=f.parent)
|
||||
for i in range(retry + 1):
|
||||
try:
|
||||
if curl or i > 0: # curl download with retry, continue
|
||||
s = "sS" * (not progress) # silent
|
||||
r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
|
||||
assert r == 0, f"Curl return value {r}"
|
||||
else: # urllib download
|
||||
method = "torch"
|
||||
if method == "torch":
|
||||
torch.hub.download_url_to_file(url, f, progress=progress)
|
||||
else:
|
||||
with request.urlopen(url) as response, TQDM(
|
||||
total=int(response.getheader("Content-Length", 0)),
|
||||
desc=desc,
|
||||
disable=not progress,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as pbar:
|
||||
with open(f, "wb") as f_opened:
|
||||
for data in response:
|
||||
f_opened.write(data)
|
||||
pbar.update(len(data))
|
||||
|
||||
if f.exists():
|
||||
if f.stat().st_size > min_bytes:
|
||||
break # success
|
||||
f.unlink() # remove partial downloads
|
||||
except Exception as e:
|
||||
if i == 0 and not is_online():
|
||||
raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e
|
||||
elif i >= retry:
|
||||
raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
|
||||
LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
|
||||
|
||||
if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
|
||||
from zipfile import is_zipfile
|
||||
|
||||
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
|
||||
if is_zipfile(f):
|
||||
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
|
||||
elif f.suffix in (".tar", ".gz"):
|
||||
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
|
||||
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
|
||||
if delete:
|
||||
f.unlink() # remove zip
|
||||
return unzip_dir
|
||||
|
||||
|
||||
def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
|
||||
"""
|
||||
Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
|
||||
function fetches the latest release assets.
|
||||
|
||||
Args:
|
||||
repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
|
||||
version (str, optional): The release version to fetch assets from. Defaults to 'latest'.
|
||||
retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the release tag and a list of asset names.
|
||||
|
||||
Example:
|
||||
```python
|
||||
tag, assets = get_github_assets(repo='ultralytics/assets', version='latest')
|
||||
```
|
||||
"""
|
||||
|
||||
if version != "latest":
|
||||
version = f"tags/{version}" # i.e. tags/v6.2
|
||||
url = f"https://api.github.com/repos/{repo}/releases/{version}"
|
||||
r = requests.get(url) # github api
|
||||
if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
|
||||
r = requests.get(url) # try again
|
||||
if r.status_code != 200:
|
||||
LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}")
|
||||
return "", []
|
||||
data = r.json()
|
||||
return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
|
||||
|
||||
|
||||
def attempt_download_asset(file, repo="ultralytics/assets", release="v8.1.0", **kwargs):
|
||||
"""
|
||||
Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
|
||||
locally first, then tries to download it from the specified GitHub repository release.
|
||||
|
||||
Args:
|
||||
file (str | Path): The filename or file path to be downloaded.
|
||||
repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
|
||||
release (str, optional): The specific release version to be downloaded. Defaults to 'v8.1.0'.
|
||||
**kwargs (any): Additional keyword arguments for the download process.
|
||||
|
||||
Returns:
|
||||
(str): The path to the downloaded file.
|
||||
|
||||
Example:
|
||||
```python
|
||||
file_path = attempt_download_asset('yolov5s.pt', repo='ultralytics/assets', release='latest')
|
||||
```
|
||||
"""
|
||||
from ultralytics.utils import SETTINGS # scoped for circular import
|
||||
|
||||
# YOLOv3/5u updates
|
||||
file = str(file)
|
||||
file = checks.check_yolov5u_filename(file)
|
||||
file = Path(file.strip().replace("'", ""))
|
||||
if file.exists():
|
||||
return str(file)
|
||||
elif (SETTINGS["weights_dir"] / file).exists():
|
||||
return str(SETTINGS["weights_dir"] / file)
|
||||
else:
|
||||
# URL specified
|
||||
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
||||
download_url = f"https://github.com/{repo}/releases/download"
|
||||
if str(file).startswith(("http:/", "https:/")): # download
|
||||
url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
|
||||
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
|
||||
if Path(file).is_file():
|
||||
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
|
||||
else:
|
||||
safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
|
||||
safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
else:
|
||||
tag, assets = get_github_assets(repo, release)
|
||||
if not assets:
|
||||
tag, assets = get_github_assets(repo) # latest release
|
||||
if name in assets:
|
||||
safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
return str(file)
|
||||
|
||||
|
||||
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):
|
||||
"""
|
||||
Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are
|
||||
specified.
|
||||
|
||||
Args:
|
||||
url (str | list): The URL or list of URLs of the files to be downloaded.
|
||||
dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory.
|
||||
unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True.
|
||||
delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False.
|
||||
curl (bool, optional): Flag to use curl for downloading. Defaults to False.
|
||||
threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1.
|
||||
retry (int, optional): Number of retries in case of download failure. Defaults to 3.
|
||||
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
|
||||
|
||||
Example:
|
||||
```python
|
||||
download('https://ultralytics.com/assets/example.zip', dir='path/to/dir', unzip=True)
|
||||
```
|
||||
"""
|
||||
dir = Path(dir)
|
||||
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||||
if threads > 1:
|
||||
with ThreadPool(threads) as pool:
|
||||
pool.map(
|
||||
lambda x: safe_download(
|
||||
url=x[0],
|
||||
dir=x[1],
|
||||
unzip=unzip,
|
||||
delete=delete,
|
||||
curl=curl,
|
||||
retry=retry,
|
||||
exist_ok=exist_ok,
|
||||
progress=threads <= 1,
|
||||
),
|
||||
zip(url, repeat(dir)),
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
for u in [url] if isinstance(url, (str, Path)) else url:
|
||||
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)
|
22
ultralytics/utils/errors.py
Normal file
22
ultralytics/utils/errors.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.utils import emojis
|
||||
|
||||
|
||||
class HUBModelError(Exception):
|
||||
"""
|
||||
Custom exception class for handling errors related to model fetching in Ultralytics YOLO.
|
||||
|
||||
This exception is raised when a requested model is not found or cannot be retrieved.
|
||||
The message is also processed to include emojis for better user experience.
|
||||
|
||||
Attributes:
|
||||
message (str): The error message displayed when the exception is raised.
|
||||
|
||||
Note:
|
||||
The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Model not found. Please check model URL and try again."):
|
||||
"""Create an exception for when a model is not found."""
|
||||
super().__init__(emojis(message))
|
188
ultralytics/utils/files.py
Normal file
188
ultralytics/utils/files.py
Normal file
@ -0,0 +1,188 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class WorkingDirectory(contextlib.ContextDecorator):
|
||||
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
|
||||
|
||||
def __init__(self, new_dir):
|
||||
"""Sets the working directory to 'new_dir' upon instantiation."""
|
||||
self.dir = new_dir # new dir
|
||||
self.cwd = Path.cwd().resolve() # current dir
|
||||
|
||||
def __enter__(self):
|
||||
"""Changes the current directory to the specified directory."""
|
||||
os.chdir(self.dir)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb): # noqa
|
||||
"""Restore the current working directory on context exit."""
|
||||
os.chdir(self.cwd)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def spaces_in_path(path):
|
||||
"""
|
||||
Context manager to handle paths with spaces in their names. If a path contains spaces, it replaces them with
|
||||
underscores, copies the file/directory to the new path, executes the context code block, then copies the
|
||||
file/directory back to its original location.
|
||||
|
||||
Args:
|
||||
path (str | Path): The original path.
|
||||
|
||||
Yields:
|
||||
(Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.
|
||||
|
||||
Example:
|
||||
```python
|
||||
with ultralytics.utils.files import spaces_in_path
|
||||
|
||||
with spaces_in_path('/path/with spaces') as new_path:
|
||||
# Your code here
|
||||
```
|
||||
"""
|
||||
|
||||
# If path has spaces, replace them with underscores
|
||||
if " " in str(path):
|
||||
string = isinstance(path, str) # input type
|
||||
path = Path(path)
|
||||
|
||||
# Create a temporary directory and construct the new path
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
|
||||
|
||||
# Copy file/directory
|
||||
if path.is_dir():
|
||||
# tmp_path.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copytree(path, tmp_path)
|
||||
elif path.is_file():
|
||||
tmp_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(path, tmp_path)
|
||||
|
||||
try:
|
||||
# Yield the temporary path
|
||||
yield str(tmp_path) if string else tmp_path
|
||||
|
||||
finally:
|
||||
# Copy file/directory back
|
||||
if tmp_path.is_dir():
|
||||
shutil.copytree(tmp_path, path, dirs_exist_ok=True)
|
||||
elif tmp_path.is_file():
|
||||
shutil.copy2(tmp_path, path) # Copy back the file
|
||||
|
||||
else:
|
||||
# If there are no spaces, just yield the original path
|
||||
yield path
|
||||
|
||||
|
||||
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
||||
"""
|
||||
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
|
||||
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
|
||||
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
|
||||
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
|
||||
directory if it does not already exist.
|
||||
|
||||
Args:
|
||||
path (str, pathlib.Path): Path to increment.
|
||||
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
|
||||
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
|
||||
mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(pathlib.Path): Incremented path.
|
||||
"""
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
|
||||
|
||||
# Method 1
|
||||
for n in range(2, 9999):
|
||||
p = f"{path}{sep}{n}{suffix}" # increment path
|
||||
if not os.path.exists(p):
|
||||
break
|
||||
path = Path(p)
|
||||
|
||||
if mkdir:
|
||||
path.mkdir(parents=True, exist_ok=True) # make directory
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def file_age(path=__file__):
|
||||
"""Return days since last file update."""
|
||||
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
|
||||
return dt.days # + dt.seconds / 86400 # fractional days
|
||||
|
||||
|
||||
def file_date(path=__file__):
|
||||
"""Return human-readable file modification date, i.e. '2021-3-26'."""
|
||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||
return f"{t.year}-{t.month}-{t.day}"
|
||||
|
||||
|
||||
def file_size(path):
|
||||
"""Return file/dir size (MB)."""
|
||||
if isinstance(path, (str, Path)):
|
||||
mb = 1 << 20 # bytes to MiB (1024 ** 2)
|
||||
path = Path(path)
|
||||
if path.is_file():
|
||||
return path.stat().st_size / mb
|
||||
elif path.is_dir():
|
||||
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_latest_run(search_dir="."):
|
||||
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
|
||||
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ""
|
||||
|
||||
|
||||
def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False):
|
||||
"""
|
||||
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
|
||||
|
||||
Args:
|
||||
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
|
||||
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
|
||||
update_names (bool, optional): Update model names from a data YAML.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.files import update_models
|
||||
|
||||
model_names = (f"rtdetr-{size}.pt" for size in "lx")
|
||||
update_models(model_names)
|
||||
```
|
||||
"""
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.nn.autobackend import default_class_names
|
||||
|
||||
target_dir = source_dir / "updated_models"
|
||||
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
|
||||
|
||||
for model_name in model_names:
|
||||
model_path = source_dir / model_name
|
||||
print(f"Loading model from {model_path}")
|
||||
|
||||
# Load model
|
||||
model = YOLO(model_path)
|
||||
model.half()
|
||||
if update_names: # update model names from a dataset YAML
|
||||
model.model.names = default_class_names("coco8.yaml")
|
||||
|
||||
# Define new save path
|
||||
save_path = target_dir / model_name
|
||||
|
||||
# Save model using model.save()
|
||||
print(f"Re-saving {model_name} model to {save_path}")
|
||||
model.save(save_path, use_dill=False)
|
407
ultralytics/utils/instance.py
Normal file
407
ultralytics/utils/instance.py
Normal file
@ -0,0 +1,407 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from collections import abc
|
||||
from itertools import repeat
|
||||
from numbers import Number
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .ops import ltwh2xywh, ltwh2xyxy, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
"""From PyTorch internals."""
|
||||
|
||||
def parse(x):
|
||||
"""Parse bounding boxes format between XYWH and LTWH."""
|
||||
return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
to_4tuple = _ntuple(4)
|
||||
|
||||
# `xyxy` means left top and right bottom
|
||||
# `xywh` means center x, center y and width, height(YOLO format)
|
||||
# `ltwh` means left top and width, height(COCO format)
|
||||
_formats = ["xyxy", "xywh", "ltwh"]
|
||||
|
||||
__all__ = ("Bboxes",) # tuple or list
|
||||
|
||||
|
||||
class Bboxes:
|
||||
"""
|
||||
A class for handling bounding boxes.
|
||||
|
||||
The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'.
|
||||
Bounding box data should be provided in numpy arrays.
|
||||
|
||||
Attributes:
|
||||
bboxes (numpy.ndarray): The bounding boxes stored in a 2D numpy array.
|
||||
format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').
|
||||
|
||||
Note:
|
||||
This class does not handle normalization or denormalization of bounding boxes.
|
||||
"""
|
||||
|
||||
def __init__(self, bboxes, format="xyxy") -> None:
|
||||
"""Initializes the Bboxes class with bounding box data in a specified format."""
|
||||
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
|
||||
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
|
||||
assert bboxes.ndim == 2
|
||||
assert bboxes.shape[1] == 4
|
||||
self.bboxes = bboxes
|
||||
self.format = format
|
||||
# self.normalized = normalized
|
||||
|
||||
def convert(self, format):
|
||||
"""Converts bounding box format from one type to another."""
|
||||
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
|
||||
if self.format == format:
|
||||
return
|
||||
elif self.format == "xyxy":
|
||||
func = xyxy2xywh if format == "xywh" else xyxy2ltwh
|
||||
elif self.format == "xywh":
|
||||
func = xywh2xyxy if format == "xyxy" else xywh2ltwh
|
||||
else:
|
||||
func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
|
||||
self.bboxes = func(self.bboxes)
|
||||
self.format = format
|
||||
|
||||
def areas(self):
|
||||
"""Return box areas."""
|
||||
self.convert("xyxy")
|
||||
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
|
||||
|
||||
# def denormalize(self, w, h):
|
||||
# if not self.normalized:
|
||||
# return
|
||||
# assert (self.bboxes <= 1.0).all()
|
||||
# self.bboxes[:, 0::2] *= w
|
||||
# self.bboxes[:, 1::2] *= h
|
||||
# self.normalized = False
|
||||
#
|
||||
# def normalize(self, w, h):
|
||||
# if self.normalized:
|
||||
# return
|
||||
# assert (self.bboxes > 1.0).any()
|
||||
# self.bboxes[:, 0::2] /= w
|
||||
# self.bboxes[:, 1::2] /= h
|
||||
# self.normalized = True
|
||||
|
||||
def mul(self, scale):
|
||||
"""
|
||||
Args:
|
||||
scale (tuple | list | int): the scale for four coords.
|
||||
"""
|
||||
if isinstance(scale, Number):
|
||||
scale = to_4tuple(scale)
|
||||
assert isinstance(scale, (tuple, list))
|
||||
assert len(scale) == 4
|
||||
self.bboxes[:, 0] *= scale[0]
|
||||
self.bboxes[:, 1] *= scale[1]
|
||||
self.bboxes[:, 2] *= scale[2]
|
||||
self.bboxes[:, 3] *= scale[3]
|
||||
|
||||
def add(self, offset):
|
||||
"""
|
||||
Args:
|
||||
offset (tuple | list | int): the offset for four coords.
|
||||
"""
|
||||
if isinstance(offset, Number):
|
||||
offset = to_4tuple(offset)
|
||||
assert isinstance(offset, (tuple, list))
|
||||
assert len(offset) == 4
|
||||
self.bboxes[:, 0] += offset[0]
|
||||
self.bboxes[:, 1] += offset[1]
|
||||
self.bboxes[:, 2] += offset[2]
|
||||
self.bboxes[:, 3] += offset[3]
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of boxes."""
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
|
||||
"""
|
||||
Concatenate a list of Bboxes objects into a single Bboxes object.
|
||||
|
||||
Args:
|
||||
boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate.
|
||||
axis (int, optional): The axis along which to concatenate the bounding boxes.
|
||||
Defaults to 0.
|
||||
|
||||
Returns:
|
||||
Bboxes: A new Bboxes object containing the concatenated bounding boxes.
|
||||
|
||||
Note:
|
||||
The input should be a list or tuple of Bboxes objects.
|
||||
"""
|
||||
assert isinstance(boxes_list, (list, tuple))
|
||||
if not boxes_list:
|
||||
return cls(np.empty(0))
|
||||
assert all(isinstance(box, Bboxes) for box in boxes_list)
|
||||
|
||||
if len(boxes_list) == 1:
|
||||
return boxes_list[0]
|
||||
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
||||
|
||||
def __getitem__(self, index) -> "Bboxes":
|
||||
"""
|
||||
Retrieve a specific bounding box or a set of bounding boxes using indexing.
|
||||
|
||||
Args:
|
||||
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
|
||||
the desired bounding boxes.
|
||||
|
||||
Returns:
|
||||
Bboxes: A new Bboxes object containing the selected bounding boxes.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix.
|
||||
|
||||
Note:
|
||||
When using boolean indexing, make sure to provide a boolean array with the same
|
||||
length as the number of bounding boxes.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
return Bboxes(self.bboxes[index].view(1, -1))
|
||||
b = self.bboxes[index]
|
||||
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
|
||||
return Bboxes(b)
|
||||
|
||||
|
||||
class Instances:
|
||||
"""
|
||||
Container for bounding boxes, segments, and keypoints of detected objects in an image.
|
||||
|
||||
Attributes:
|
||||
_bboxes (Bboxes): Internal object for handling bounding box operations.
|
||||
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3]. Default is None.
|
||||
normalized (bool): Flag indicating whether the bounding box coordinates are normalized.
|
||||
segments (ndarray): Segments array with shape [N, 1000, 2] after resampling.
|
||||
|
||||
Args:
|
||||
bboxes (ndarray): An array of bounding boxes with shape [N, 4].
|
||||
segments (list | ndarray, optional): A list or array of object segments. Default is None.
|
||||
keypoints (ndarray, optional): An array of keypoints with shape [N, 17, 3]. Default is None.
|
||||
bbox_format (str, optional): The format of bounding boxes ('xywh' or 'xyxy'). Default is 'xywh'.
|
||||
normalized (bool, optional): Whether the bounding box coordinates are normalized. Default is True.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Create an Instances object
|
||||
instances = Instances(
|
||||
bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
|
||||
segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
|
||||
keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]])
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
The bounding box format is either 'xywh' or 'xyxy', and is determined by the `bbox_format` argument.
|
||||
This class does not perform input validation, and it assumes the inputs are well-formed.
|
||||
"""
|
||||
|
||||
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
|
||||
"""
|
||||
Args:
|
||||
bboxes (ndarray): bboxes with shape [N, 4].
|
||||
segments (list | ndarray): segments.
|
||||
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
|
||||
"""
|
||||
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
|
||||
self.keypoints = keypoints
|
||||
self.normalized = normalized
|
||||
self.segments = segments
|
||||
|
||||
def convert_bbox(self, format):
|
||||
"""Convert bounding box format."""
|
||||
self._bboxes.convert(format=format)
|
||||
|
||||
@property
|
||||
def bbox_areas(self):
|
||||
"""Calculate the area of bounding boxes."""
|
||||
return self._bboxes.areas()
|
||||
|
||||
def scale(self, scale_w, scale_h, bbox_only=False):
|
||||
"""This might be similar with denormalize func but without normalized sign."""
|
||||
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
|
||||
if bbox_only:
|
||||
return
|
||||
self.segments[..., 0] *= scale_w
|
||||
self.segments[..., 1] *= scale_h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] *= scale_w
|
||||
self.keypoints[..., 1] *= scale_h
|
||||
|
||||
def denormalize(self, w, h):
|
||||
"""Denormalizes boxes, segments, and keypoints from normalized coordinates."""
|
||||
if not self.normalized:
|
||||
return
|
||||
self._bboxes.mul(scale=(w, h, w, h))
|
||||
self.segments[..., 0] *= w
|
||||
self.segments[..., 1] *= h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] *= w
|
||||
self.keypoints[..., 1] *= h
|
||||
self.normalized = False
|
||||
|
||||
def normalize(self, w, h):
|
||||
"""Normalize bounding boxes, segments, and keypoints to image dimensions."""
|
||||
if self.normalized:
|
||||
return
|
||||
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
|
||||
self.segments[..., 0] /= w
|
||||
self.segments[..., 1] /= h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] /= w
|
||||
self.keypoints[..., 1] /= h
|
||||
self.normalized = True
|
||||
|
||||
def add_padding(self, padw, padh):
|
||||
"""Handle rect and mosaic situation."""
|
||||
assert not self.normalized, "you should add padding with absolute coordinates."
|
||||
self._bboxes.add(offset=(padw, padh, padw, padh))
|
||||
self.segments[..., 0] += padw
|
||||
self.segments[..., 1] += padh
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] += padw
|
||||
self.keypoints[..., 1] += padh
|
||||
|
||||
def __getitem__(self, index) -> "Instances":
|
||||
"""
|
||||
Retrieve a specific instance or a set of instances using indexing.
|
||||
|
||||
Args:
|
||||
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
|
||||
the desired instances.
|
||||
|
||||
Returns:
|
||||
Instances: A new Instances object containing the selected bounding boxes,
|
||||
segments, and keypoints if present.
|
||||
|
||||
Note:
|
||||
When using boolean indexing, make sure to provide a boolean array with the same
|
||||
length as the number of instances.
|
||||
"""
|
||||
segments = self.segments[index] if len(self.segments) else self.segments
|
||||
keypoints = self.keypoints[index] if self.keypoints is not None else None
|
||||
bboxes = self.bboxes[index]
|
||||
bbox_format = self._bboxes.format
|
||||
return Instances(
|
||||
bboxes=bboxes,
|
||||
segments=segments,
|
||||
keypoints=keypoints,
|
||||
bbox_format=bbox_format,
|
||||
normalized=self.normalized,
|
||||
)
|
||||
|
||||
def flipud(self, h):
|
||||
"""Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
|
||||
if self._bboxes.format == "xyxy":
|
||||
y1 = self.bboxes[:, 1].copy()
|
||||
y2 = self.bboxes[:, 3].copy()
|
||||
self.bboxes[:, 1] = h - y2
|
||||
self.bboxes[:, 3] = h - y1
|
||||
else:
|
||||
self.bboxes[:, 1] = h - self.bboxes[:, 1]
|
||||
self.segments[..., 1] = h - self.segments[..., 1]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 1] = h - self.keypoints[..., 1]
|
||||
|
||||
def fliplr(self, w):
|
||||
"""Reverses the order of the bounding boxes and segments horizontally."""
|
||||
if self._bboxes.format == "xyxy":
|
||||
x1 = self.bboxes[:, 0].copy()
|
||||
x2 = self.bboxes[:, 2].copy()
|
||||
self.bboxes[:, 0] = w - x2
|
||||
self.bboxes[:, 2] = w - x1
|
||||
else:
|
||||
self.bboxes[:, 0] = w - self.bboxes[:, 0]
|
||||
self.segments[..., 0] = w - self.segments[..., 0]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] = w - self.keypoints[..., 0]
|
||||
|
||||
def clip(self, w, h):
|
||||
"""Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
|
||||
ori_format = self._bboxes.format
|
||||
self.convert_bbox(format="xyxy")
|
||||
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
||||
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
|
||||
if ori_format != "xyxy":
|
||||
self.convert_bbox(format=ori_format)
|
||||
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
|
||||
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
|
||||
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
||||
|
||||
def remove_zero_area_boxes(self):
|
||||
"""
|
||||
Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
|
||||
|
||||
This removes them.
|
||||
"""
|
||||
good = self.bbox_areas > 0
|
||||
if not all(good):
|
||||
self._bboxes = self._bboxes[good]
|
||||
if len(self.segments):
|
||||
self.segments = self.segments[good]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints = self.keypoints[good]
|
||||
return good
|
||||
|
||||
def update(self, bboxes, segments=None, keypoints=None):
|
||||
"""Updates instance variables."""
|
||||
self._bboxes = Bboxes(bboxes, format=self._bboxes.format)
|
||||
if segments is not None:
|
||||
self.segments = segments
|
||||
if keypoints is not None:
|
||||
self.keypoints = keypoints
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the instance list."""
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
|
||||
"""
|
||||
Concatenates a list of Instances objects into a single Instances object.
|
||||
|
||||
Args:
|
||||
instances_list (List[Instances]): A list of Instances objects to concatenate.
|
||||
axis (int, optional): The axis along which the arrays will be concatenated. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
Instances: A new Instances object containing the concatenated bounding boxes,
|
||||
segments, and keypoints if present.
|
||||
|
||||
Note:
|
||||
The `Instances` objects in the list should have the same properties, such as
|
||||
the format of the bounding boxes, whether keypoints are present, and if the
|
||||
coordinates are normalized.
|
||||
"""
|
||||
assert isinstance(instances_list, (list, tuple))
|
||||
if not instances_list:
|
||||
return cls(np.empty(0))
|
||||
assert all(isinstance(instance, Instances) for instance in instances_list)
|
||||
|
||||
if len(instances_list) == 1:
|
||||
return instances_list[0]
|
||||
|
||||
use_keypoint = instances_list[0].keypoints is not None
|
||||
bbox_format = instances_list[0]._bboxes.format
|
||||
normalized = instances_list[0].normalized
|
||||
|
||||
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
|
||||
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
|
||||
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
|
||||
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
|
||||
|
||||
@property
|
||||
def bboxes(self):
|
||||
"""Return bounding boxes."""
|
||||
return self._bboxes.bboxes
|
727
ultralytics/utils/loss.py
Normal file
727
ultralytics/utils/loss.py
Normal file
@ -0,0 +1,727 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils.metrics import OKS_SIGMA
|
||||
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
||||
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
||||
from .metrics import bbox_iou, probiou
|
||||
from .tal import bbox2dist
|
||||
|
||||
|
||||
class VarifocalLoss(nn.Module):
|
||||
"""
|
||||
Varifocal loss by Zhang et al.
|
||||
|
||||
https://arxiv.org/abs/2008.13367.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the VarifocalLoss class."""
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
||||
"""Computes varfocal loss."""
|
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
loss = (
|
||||
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
|
||||
.mean(1)
|
||||
.sum()
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializer for FocalLoss class with no parameters."""
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def forward(pred, label, gamma=1.5, alpha=0.25):
|
||||
"""Calculates and updates confusion matrix for object detection/classification tasks."""
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
|
||||
# p_t = torch.exp(-loss)
|
||||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||
|
||||
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
||||
pred_prob = pred.sigmoid() # prob from logits
|
||||
p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
|
||||
modulating_factor = (1.0 - p_t) ** gamma
|
||||
loss *= modulating_factor
|
||||
if alpha > 0:
|
||||
alpha_factor = label * alpha + (1 - label) * (1 - alpha)
|
||||
loss *= alpha_factor
|
||||
return loss.mean(1).sum()
|
||||
|
||||
|
||||
class BboxLoss(nn.Module):
|
||||
"""Criterion class for computing training losses during training."""
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
self.use_dfl = use_dfl
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
"""IoU loss."""
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
if self.use_dfl:
|
||||
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
|
||||
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
@staticmethod
|
||||
def _df_loss(pred_dist, target):
|
||||
"""
|
||||
Return sum of left and right DFL losses.
|
||||
|
||||
Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
|
||||
https://ieeexplore.ieee.org/document/9792391
|
||||
"""
|
||||
tl = target.long() # target left
|
||||
tr = tl + 1 # target right
|
||||
wl = tr - target # weight left
|
||||
wr = 1 - wl # weight right
|
||||
return (
|
||||
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
|
||||
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
|
||||
).mean(-1, keepdim=True)
|
||||
|
||||
|
||||
class RotatedBboxLoss(BboxLoss):
|
||||
"""Criterion class for computing training losses during training."""
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__(reg_max, use_dfl)
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
"""IoU loss."""
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
if self.use_dfl:
|
||||
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
|
||||
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
|
||||
class KeypointLoss(nn.Module):
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, sigmas) -> None:
|
||||
"""Initialize the KeypointLoss class."""
|
||||
super().__init__()
|
||||
self.sigmas = sigmas
|
||||
|
||||
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
|
||||
"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
|
||||
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
|
||||
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
|
||||
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
|
||||
e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
|
||||
return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
|
||||
|
||||
|
||||
class v8DetectionLoss:
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, model, tal_topk=10): # model must be de-paralleled
|
||||
"""Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
|
||||
device = next(model.parameters()).device # get model device
|
||||
h = model.args # hyperparameters
|
||||
|
||||
m = model.model[-1] # Detect() module
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction="none")
|
||||
self.hyp = h
|
||||
self.stride = m.stride # model strides
|
||||
self.nc = m.nc # number of classes
|
||||
self.no = m.no
|
||||
self.reg_max = m.reg_max
|
||||
self.device = device
|
||||
|
||||
self.use_dfl = m.reg_max > 1
|
||||
|
||||
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
|
||||
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
||||
|
||||
def preprocess(self, targets, batch_size, scale_tensor):
|
||||
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
|
||||
if targets.shape[0] == 0:
|
||||
out = torch.zeros(batch_size, 0, 5, device=self.device)
|
||||
else:
|
||||
i = targets[:, 0] # image index
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
n = matches.sum()
|
||||
if n:
|
||||
out[j, :n] = targets[matches, 1:]
|
||||
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
||||
return out
|
||||
|
||||
def bbox_decode(self, anchor_points, pred_dist):
|
||||
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
|
||||
if self.use_dfl:
|
||||
b, a, c = pred_dist.shape # batch, anchors, channels
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
||||
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
feats = preds[1] if isinstance(preds, tuple) else preds
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
batch_size = pred_scores.shape[0]
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# Targets
|
||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.cls # cls gain
|
||||
loss[2] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
|
||||
class v8SegmentationLoss(v8DetectionLoss):
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
"""Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument."""
|
||||
super().__init__(model)
|
||||
self.overlap = model.args.overlap_mask
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate and return the loss for the YOLO model."""
|
||||
loss = torch.zeros(4, device=self.device) # box, cls, dfl
|
||||
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
||||
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# B, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# Targets
|
||||
try:
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
except RuntimeError as e:
|
||||
raise TypeError(
|
||||
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
|
||||
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
||||
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
|
||||
) from e
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
if fg_mask.sum():
|
||||
# Bbox loss
|
||||
loss[0], loss[3] = self.bbox_loss(
|
||||
pred_distri,
|
||||
pred_bboxes,
|
||||
anchor_points,
|
||||
target_bboxes / stride_tensor,
|
||||
target_scores,
|
||||
target_scores_sum,
|
||||
fg_mask,
|
||||
)
|
||||
# Masks loss
|
||||
masks = batch["masks"].to(self.device).float()
|
||||
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
||||
|
||||
loss[1] = self.calculate_segmentation_loss(
|
||||
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
|
||||
)
|
||||
|
||||
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.box # seg gain
|
||||
loss[2] *= self.hyp.cls # cls gain
|
||||
loss[3] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
@staticmethod
|
||||
def single_mask_loss(
|
||||
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the instance segmentation loss for a single image.
|
||||
|
||||
Args:
|
||||
gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
|
||||
pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
|
||||
proto (torch.Tensor): Prototype masks of shape (32, H, W).
|
||||
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
|
||||
area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The calculated mask loss for a single image.
|
||||
|
||||
Notes:
|
||||
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
|
||||
predicted masks from the prototype masks and predicted mask coefficients.
|
||||
"""
|
||||
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
|
||||
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
|
||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
|
||||
|
||||
def calculate_segmentation_loss(
|
||||
self,
|
||||
fg_mask: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
target_gt_idx: torch.Tensor,
|
||||
target_bboxes: torch.Tensor,
|
||||
batch_idx: torch.Tensor,
|
||||
proto: torch.Tensor,
|
||||
pred_masks: torch.Tensor,
|
||||
imgsz: torch.Tensor,
|
||||
overlap: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the loss for instance segmentation.
|
||||
|
||||
Args:
|
||||
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
|
||||
masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
|
||||
target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
|
||||
target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
|
||||
batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
|
||||
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
|
||||
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
|
||||
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
|
||||
overlap (bool): Whether the masks in `masks` tensor overlap.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The calculated loss for instance segmentation.
|
||||
|
||||
Notes:
|
||||
The batch loss can be computed for improved speed at higher memory usage.
|
||||
For example, pred_mask can be computed as follows:
|
||||
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
|
||||
"""
|
||||
_, _, mask_h, mask_w = proto.shape
|
||||
loss = 0
|
||||
|
||||
# Normalize to 0-1
|
||||
target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
|
||||
|
||||
# Areas of target bboxes
|
||||
marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
|
||||
|
||||
# Normalize to mask size
|
||||
mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
|
||||
|
||||
for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
|
||||
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
|
||||
if fg_mask_i.any():
|
||||
mask_idx = target_gt_idx_i[fg_mask_i]
|
||||
if overlap:
|
||||
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
|
||||
gt_mask = gt_mask.float()
|
||||
else:
|
||||
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
|
||||
|
||||
loss += self.single_mask_loss(
|
||||
gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
|
||||
)
|
||||
|
||||
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
|
||||
return loss / fg_mask.sum()
|
||||
|
||||
|
||||
class v8PoseLoss(v8DetectionLoss):
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
"""Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance."""
|
||||
super().__init__(model)
|
||||
self.kpt_shape = model.model[-1].kpt_shape
|
||||
self.bce_pose = nn.BCEWithLogitsLoss()
|
||||
is_pose = self.kpt_shape == [17, 3]
|
||||
nkpt = self.kpt_shape[0] # number of keypoints
|
||||
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
||||
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate the total loss and detach it."""
|
||||
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
||||
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# B, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# Targets
|
||||
batch_size = pred_scores.shape[0]
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[4] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
keypoints = batch["keypoints"].to(self.device).float().clone()
|
||||
keypoints[..., 0] *= imgsz[1]
|
||||
keypoints[..., 1] *= imgsz[0]
|
||||
|
||||
loss[1], loss[2] = self.calculate_keypoints_loss(
|
||||
fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
|
||||
)
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.pose # pose gain
|
||||
loss[2] *= self.hyp.kobj # kobj gain
|
||||
loss[3] *= self.hyp.cls # cls gain
|
||||
loss[4] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
@staticmethod
|
||||
def kpts_decode(anchor_points, pred_kpts):
|
||||
"""Decodes predicted keypoints to image coordinates."""
|
||||
y = pred_kpts.clone()
|
||||
y[..., :2] *= 2.0
|
||||
y[..., 0] += anchor_points[:, [0]] - 0.5
|
||||
y[..., 1] += anchor_points[:, [1]] - 0.5
|
||||
return y
|
||||
|
||||
def calculate_keypoints_loss(
|
||||
self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
|
||||
):
|
||||
"""
|
||||
Calculate the keypoints loss for the model.
|
||||
|
||||
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
||||
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
||||
a binary classification loss that classifies whether a keypoint is present or not.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
||||
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
||||
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
||||
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
||||
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
||||
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
||||
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
||||
|
||||
Returns:
|
||||
(tuple): Returns a tuple containing:
|
||||
- kpts_loss (torch.Tensor): The keypoints loss.
|
||||
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
||||
"""
|
||||
batch_idx = batch_idx.flatten()
|
||||
batch_size = len(masks)
|
||||
|
||||
# Find the maximum number of keypoints in a single image
|
||||
max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
|
||||
|
||||
# Create a tensor to hold batched keypoints
|
||||
batched_keypoints = torch.zeros(
|
||||
(batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
|
||||
)
|
||||
|
||||
# TODO: any idea how to vectorize this?
|
||||
# Fill batched_keypoints with keypoints based on batch_idx
|
||||
for i in range(batch_size):
|
||||
keypoints_i = keypoints[batch_idx == i]
|
||||
batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
|
||||
|
||||
# Expand dimensions of target_gt_idx to match the shape of batched_keypoints
|
||||
target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# Use target_gt_idx_expanded to select keypoints from batched_keypoints
|
||||
selected_keypoints = batched_keypoints.gather(
|
||||
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
|
||||
)
|
||||
|
||||
# Divide coordinates by stride
|
||||
selected_keypoints /= stride_tensor.view(1, -1, 1, 1)
|
||||
|
||||
kpts_loss = 0
|
||||
kpts_obj_loss = 0
|
||||
|
||||
if masks.any():
|
||||
gt_kpt = selected_keypoints[masks]
|
||||
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
||||
pred_kpt = pred_kpts[masks]
|
||||
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
|
||||
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
|
||||
|
||||
if pred_kpt.shape[-1] == 3:
|
||||
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
|
||||
|
||||
return kpts_loss, kpts_obj_loss
|
||||
|
||||
|
||||
class v8ClassificationLoss:
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
|
||||
|
||||
class v8OBBLoss(v8DetectionLoss):
|
||||
def __init__(self, model):
|
||||
"""
|
||||
Initializes v8OBBLoss with model, assigner, and rotated bbox loss.
|
||||
|
||||
Note model must be de-paralleled.
|
||||
"""
|
||||
super().__init__(model)
|
||||
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)
|
||||
|
||||
def preprocess(self, targets, batch_size, scale_tensor):
|
||||
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
|
||||
if targets.shape[0] == 0:
|
||||
out = torch.zeros(batch_size, 0, 6, device=self.device)
|
||||
else:
|
||||
i = targets[:, 0] # image index
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
n = matches.sum()
|
||||
if n:
|
||||
bboxes = targets[matches, 2:]
|
||||
bboxes[..., :4].mul_(scale_tensor)
|
||||
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
||||
return out
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate and return the loss for the YOLO model."""
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
|
||||
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# b, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# targets
|
||||
try:
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
||||
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
|
||||
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
except RuntimeError as e:
|
||||
raise TypeError(
|
||||
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
|
||||
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
|
||||
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
|
||||
) from e
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
|
||||
|
||||
bboxes_for_assigner = pred_bboxes.clone().detach()
|
||||
# Only the first four elements need to be scaled
|
||||
bboxes_for_assigner[..., :4] *= stride_tensor
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
bboxes_for_assigner.type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes[..., :4] /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
else:
|
||||
loss[0] += (pred_angle * 0).sum()
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.cls # cls gain
|
||||
loss[2] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def bbox_decode(self, anchor_points, pred_dist, pred_angle):
|
||||
"""
|
||||
Decode predicted object bounding box coordinates from anchor points and distribution.
|
||||
|
||||
Args:
|
||||
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
||||
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
|
||||
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
|
||||
"""
|
||||
if self.use_dfl:
|
||||
b, a, c = pred_dist.shape # batch, anchors, channels
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
|
||||
|
||||
class v10DetectLoss:
|
||||
def __init__(self, model):
|
||||
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
||||
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
one2many = preds["one2many"]
|
||||
loss_one2many = self.one2many(one2many, batch)
|
||||
one2one = preds["one2one"]
|
||||
loss_one2one = self.one2one(one2one, batch)
|
||||
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
|
1292
ultralytics/utils/metrics.py
Normal file
1292
ultralytics/utils/metrics.py
Normal file
File diff suppressed because it is too large
Load Diff
865
ultralytics/utils/ops.py
Normal file
865
ultralytics/utils/ops.py
Normal file
@ -0,0 +1,865 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import batch_probiou
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
"""
|
||||
YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.ops import Profile
|
||||
|
||||
with Profile(device=device) as dt:
|
||||
pass # slow operation here
|
||||
|
||||
print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, t=0.0, device: torch.device = None):
|
||||
"""
|
||||
Initialize the Profile class.
|
||||
|
||||
Args:
|
||||
t (float): Initial time. Defaults to 0.0.
|
||||
device (torch.device): Devices used for model inference. Defaults to None (cpu).
|
||||
"""
|
||||
self.t = t
|
||||
self.device = device
|
||||
self.cuda = bool(device and str(device).startswith("cuda"))
|
||||
|
||||
def __enter__(self):
|
||||
"""Start timing."""
|
||||
self.start = self.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback): # noqa
|
||||
"""Stop timing."""
|
||||
self.dt = self.time() - self.start # delta-time
|
||||
self.t += self.dt # accumulate dt
|
||||
|
||||
def __str__(self):
|
||||
"""Returns a human-readable string representing the accumulated elapsed time in the profiler."""
|
||||
return f"Elapsed time is {self.t} s"
|
||||
|
||||
def time(self):
|
||||
"""Get current time."""
|
||||
if self.cuda:
|
||||
torch.cuda.synchronize(self.device)
|
||||
return time.time()
|
||||
|
||||
|
||||
def segment2box(segment, width=640, height=640):
|
||||
"""
|
||||
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
|
||||
|
||||
Args:
|
||||
segment (torch.Tensor): the segment label
|
||||
width (int): the width of the image. Defaults to 640
|
||||
height (int): The height of the image. Defaults to 640
|
||||
|
||||
Returns:
|
||||
(np.ndarray): the minimum and maximum x and y values of the segment.
|
||||
"""
|
||||
x, y = segment.T # segment xy
|
||||
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
||||
x = x[inside]
|
||||
y = y[inside]
|
||||
return (
|
||||
np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
|
||||
if any(x)
|
||||
else np.zeros(4, dtype=segment.dtype)
|
||||
) # xyxy
|
||||
|
||||
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
|
||||
"""
|
||||
Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
|
||||
specified in (img1_shape) to the shape of a different image (img0_shape).
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
|
||||
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
|
||||
img0_shape (tuple): the shape of the target image, in the format of (height, width).
|
||||
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
|
||||
calculated based on the size difference between the two images.
|
||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||
rescaling.
|
||||
xywh (bool): The box format is xywh or not, default=False.
|
||||
|
||||
Returns:
|
||||
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad = (
|
||||
round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
|
||||
round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
|
||||
) # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
||||
if padding:
|
||||
boxes[..., 0] -= pad[0] # x padding
|
||||
boxes[..., 1] -= pad[1] # y padding
|
||||
if not xywh:
|
||||
boxes[..., 2] -= pad[0] # x padding
|
||||
boxes[..., 3] -= pad[1] # y padding
|
||||
boxes[..., :4] /= gain
|
||||
return clip_boxes(boxes, img0_shape)
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
"""
|
||||
Returns the nearest number that is divisible by the given divisor.
|
||||
|
||||
Args:
|
||||
x (int): The number to make divisible.
|
||||
divisor (int | torch.Tensor): The divisor.
|
||||
|
||||
Returns:
|
||||
(int): The nearest number divisible by the divisor.
|
||||
"""
|
||||
if isinstance(divisor, torch.Tensor):
|
||||
divisor = int(divisor.max()) # to int
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def nms_rotated(boxes, scores, threshold=0.45):
|
||||
"""
|
||||
NMS for obbs, powered by probiou and fast-nms.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): (N, 5), xywhr.
|
||||
scores (torch.Tensor): (N, ).
|
||||
threshold (float): IoU threshold.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return np.empty((0,), dtype=np.int8)
|
||||
sorted_idx = torch.argsort(scores, descending=True)
|
||||
boxes = boxes[sorted_idx]
|
||||
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
|
||||
pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
|
||||
return sorted_idx[pick]
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
classes=None,
|
||||
agnostic=False,
|
||||
multi_label=False,
|
||||
labels=(),
|
||||
max_det=300,
|
||||
nc=0, # number of classes (optional)
|
||||
max_time_img=0.05,
|
||||
max_nms=30000,
|
||||
max_wh=7680,
|
||||
in_place=True,
|
||||
rotated=False,
|
||||
):
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
||||
|
||||
Args:
|
||||
prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
|
||||
containing the predicted boxes, classes, and masks. The tensor should be in the format
|
||||
output by a model, such as YOLO.
|
||||
conf_thres (float): The confidence threshold below which boxes will be filtered out.
|
||||
Valid values are between 0.0 and 1.0.
|
||||
iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
|
||||
Valid values are between 0.0 and 1.0.
|
||||
classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
|
||||
agnostic (bool): If True, the model is agnostic to the number of classes, and all
|
||||
classes will be considered as one.
|
||||
multi_label (bool): If True, each box may have multiple labels.
|
||||
labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
|
||||
list contains the apriori labels for a given image. The list should be in the format
|
||||
output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
|
||||
max_det (int): The maximum number of boxes to keep after NMS.
|
||||
nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
|
||||
max_time_img (float): The maximum time (seconds) for processing one image.
|
||||
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
|
||||
max_wh (int): The maximum box width and height in pixels.
|
||||
in_place (bool): If True, the input prediction tensor will be modified in place.
|
||||
|
||||
Returns:
|
||||
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
|
||||
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
|
||||
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
||||
"""
|
||||
|
||||
# Checks
|
||||
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
||||
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
|
||||
bs = prediction.shape[0] # batch size
|
||||
nc = nc or (prediction.shape[1] - 4) # number of classes
|
||||
nm = prediction.shape[1] - nc - 4
|
||||
mi = 4 + nc # mask start index
|
||||
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
|
||||
|
||||
# Settings
|
||||
# min_wh = 2 # (pixels) minimum box width and height
|
||||
time_limit = 2.0 + max_time_img * bs # seconds to quit after
|
||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
|
||||
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
if not rotated:
|
||||
if in_place:
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
else:
|
||||
prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy
|
||||
|
||||
t = time.time()
|
||||
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
||||
for xi, x in enumerate(prediction): # image index, image inference
|
||||
# Apply constraints
|
||||
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||
x = x[xc[xi]] # confidence
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]) and not rotated:
|
||||
lb = labels[xi]
|
||||
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
|
||||
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
|
||||
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
|
||||
x = torch.cat((x, v), 0)
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Detections matrix nx6 (xyxy, conf, cls)
|
||||
box, cls, mask = x.split((4, nc, nm), 1)
|
||||
|
||||
if multi_label:
|
||||
i, j = torch.where(cls > conf_thres)
|
||||
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
|
||||
else: # best class only
|
||||
conf, j = cls.max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
||||
|
||||
# Filter by class
|
||||
if classes is not None:
|
||||
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
||||
|
||||
# Check shape
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n: # no boxes
|
||||
continue
|
||||
if n > max_nms: # excess boxes
|
||||
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
scores = x[:, 4] # scores
|
||||
if rotated:
|
||||
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
|
||||
i = nms_rotated(boxes, scores, iou_thres)
|
||||
else:
|
||||
boxes = x[:, :4] + c # boxes (offset by class)
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
i = i[:max_det] # limit detections
|
||||
|
||||
# # Experimental
|
||||
# merge = False # use merge-NMS
|
||||
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||
# # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||
# from .metrics import box_iou
|
||||
# iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix
|
||||
# weights = iou * scores[None] # box weights
|
||||
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||
# redundant = True # require redundant detections
|
||||
# if redundant:
|
||||
# i = i[iou.sum(1) > 1] # require redundancy
|
||||
|
||||
output[xi] = x[i]
|
||||
if (time.time() - t) > time_limit:
|
||||
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
|
||||
break # time limit exceeded
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def clip_boxes(boxes, shape):
|
||||
"""
|
||||
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): the bounding boxes to clip
|
||||
shape (tuple): the shape of the image
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | numpy.ndarray): Clipped boxes
|
||||
"""
|
||||
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
||||
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
|
||||
boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1
|
||||
boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2
|
||||
boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2
|
||||
else: # np.array (faster grouped)
|
||||
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
|
||||
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
|
||||
return boxes
|
||||
|
||||
|
||||
def clip_coords(coords, shape):
|
||||
"""
|
||||
Clip line coordinates to the image boundaries.
|
||||
|
||||
Args:
|
||||
coords (torch.Tensor | numpy.ndarray): A list of line coordinates.
|
||||
shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | numpy.ndarray): Clipped coordinates
|
||||
"""
|
||||
if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
||||
coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
|
||||
coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y
|
||||
else: # np.array (faster grouped)
|
||||
coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
|
||||
coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
|
||||
return coords
|
||||
|
||||
|
||||
def scale_image(masks, im0_shape, ratio_pad=None):
|
||||
"""
|
||||
Takes a mask, and resizes it to the original image size.
|
||||
|
||||
Args:
|
||||
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||||
im0_shape (tuple): the original image shape
|
||||
ratio_pad (tuple): the ratio of the padding to the original image.
|
||||
|
||||
Returns:
|
||||
masks (torch.Tensor): The masks that are being returned.
|
||||
"""
|
||||
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
||||
im1_shape = masks.shape
|
||||
if im1_shape[:2] == im0_shape[:2]:
|
||||
return masks
|
||||
if ratio_pad is None: # calculate from im0_shape
|
||||
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
|
||||
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
|
||||
else:
|
||||
# gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
top, left = int(pad[1]), int(pad[0]) # y, x
|
||||
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
|
||||
|
||||
if len(masks.shape) < 2:
|
||||
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
|
||||
masks = masks[top:bottom, left:right]
|
||||
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
|
||||
if len(masks.shape) == 2:
|
||||
masks = masks[:, :, None]
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
def xyxy2xywh(x):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
|
||||
top-left corner and (x2, y2) is the bottom-right corner.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
||||
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
||||
y[..., 2] = x[..., 2] - x[..., 0] # width
|
||||
y[..., 3] = x[..., 3] - x[..., 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def xywh2xyxy(x):
|
||||
"""
|
||||
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
|
||||
top-left corner and (x2, y2) is the bottom-right corner.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
|
||||
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
dw = x[..., 2] / 2 # half-width
|
||||
dh = x[..., 3] / 2 # half-height
|
||||
y[..., 0] = x[..., 0] - dw # top left x
|
||||
y[..., 1] = x[..., 1] - dh # top left y
|
||||
y[..., 2] = x[..., 0] + dw # bottom right x
|
||||
y[..., 3] = x[..., 1] + dh # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
"""
|
||||
Convert normalized bounding box coordinates to pixel coordinates.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The bounding box coordinates.
|
||||
w (int): Width of the image. Defaults to 640
|
||||
h (int): Height of the image. Defaults to 640
|
||||
padw (int): Padding width. Defaults to 0
|
||||
padh (int): Padding height. Defaults to 0
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
||||
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
||||
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
||||
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
|
||||
y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
|
||||
width and height are normalized to image dimensions.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
w (int): The width of the image. Defaults to 640
|
||||
h (int): The height of the image. Defaults to 640
|
||||
clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
|
||||
eps (float): The minimum value of the box's width and height. Defaults to 0.0
|
||||
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
|
||||
"""
|
||||
if clip:
|
||||
x = clip_boxes(x, (h - eps, w - eps))
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
||||
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
||||
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
|
||||
y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
|
||||
return y
|
||||
|
||||
|
||||
def xywh2ltwh(x):
|
||||
"""
|
||||
Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
|
||||
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
return y
|
||||
|
||||
|
||||
def xyxy2ltwh(x):
|
||||
"""
|
||||
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
|
||||
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 2] = x[..., 2] - x[..., 0] # width
|
||||
y[..., 3] = x[..., 3] - x[..., 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def ltwh2xywh(x):
|
||||
"""
|
||||
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): the input tensor
|
||||
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
|
||||
y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y
|
||||
return y
|
||||
|
||||
|
||||
def xyxyxyxy2xywhr(corners):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
|
||||
expected in degrees from 0 to 90.
|
||||
|
||||
Args:
|
||||
corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8).
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
"""
|
||||
is_torch = isinstance(corners, torch.Tensor)
|
||||
points = corners.cpu().numpy() if is_torch else corners
|
||||
points = points.reshape(len(corners), -1, 2)
|
||||
rboxes = []
|
||||
for pts in points:
|
||||
# NOTE: Use cv2.minAreaRect to get accurate xywhr,
|
||||
# especially some objects are cut off by augmentations in dataloader.
|
||||
(x, y), (w, h), angle = cv2.minAreaRect(pts)
|
||||
rboxes.append([x, y, w, h, angle / 180 * np.pi])
|
||||
return (
|
||||
torch.tensor(rboxes, device=corners.device, dtype=corners.dtype)
|
||||
if is_torch
|
||||
else np.asarray(rboxes, dtype=points.dtype)
|
||||
) # rboxes
|
||||
|
||||
|
||||
def xywhr2xyxyxyxy(rboxes):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
|
||||
be in degrees from 0 to 90.
|
||||
|
||||
Args:
|
||||
rboxes (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
|
||||
"""
|
||||
is_numpy = isinstance(rboxes, np.ndarray)
|
||||
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
|
||||
|
||||
ctr = rboxes[..., :2]
|
||||
w, h, angle = (rboxes[..., i : i + 1] for i in range(2, 5))
|
||||
cos_value, sin_value = cos(angle), sin(angle)
|
||||
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
|
||||
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
|
||||
vec1 = np.concatenate(vec1, axis=-1) if is_numpy else torch.cat(vec1, dim=-1)
|
||||
vec2 = np.concatenate(vec2, axis=-1) if is_numpy else torch.cat(vec2, dim=-1)
|
||||
pt1 = ctr + vec1 + vec2
|
||||
pt2 = ctr + vec1 - vec2
|
||||
pt3 = ctr - vec1 - vec2
|
||||
pt4 = ctr - vec1 + vec2
|
||||
return np.stack([pt1, pt2, pt3, pt4], axis=-2) if is_numpy else torch.stack([pt1, pt2, pt3, pt4], dim=-2)
|
||||
|
||||
|
||||
def ltwh2xyxy(x):
|
||||
"""
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): the input image
|
||||
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 2] = x[..., 2] + x[..., 0] # width
|
||||
y[..., 3] = x[..., 3] + x[..., 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def segments2boxes(segments):
|
||||
"""
|
||||
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
||||
|
||||
Args:
|
||||
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
|
||||
|
||||
Returns:
|
||||
(np.ndarray): the xywh coordinates of the bounding boxes.
|
||||
"""
|
||||
boxes = []
|
||||
for s in segments:
|
||||
x, y = s.T # segment xy
|
||||
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
|
||||
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
||||
|
||||
|
||||
def resample_segments(segments, n=1000):
|
||||
"""
|
||||
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
|
||||
|
||||
Args:
|
||||
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
|
||||
n (int): number of points to resample the segment to. Defaults to 1000
|
||||
|
||||
Returns:
|
||||
segments (list): the resampled segments.
|
||||
"""
|
||||
for i, s in enumerate(segments):
|
||||
s = np.concatenate((s, s[0:1, :]), axis=0)
|
||||
x = np.linspace(0, len(s) - 1, n)
|
||||
xp = np.arange(len(s))
|
||||
segments[i] = (
|
||||
np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
|
||||
) # segment xy
|
||||
return segments
|
||||
|
||||
|
||||
def crop_mask(masks, boxes):
|
||||
"""
|
||||
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): [n, h, w] tensor of masks
|
||||
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The masks are being cropped to the bounding box.
|
||||
"""
|
||||
_, h, w = masks.shape
|
||||
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
|
||||
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
|
||||
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
|
||||
|
||||
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
||||
|
||||
|
||||
def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
|
||||
but is slower.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||
shape (tuple): the size of the input image (h,w)
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The upsampled masks.
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
||||
"""
|
||||
Apply masks to bounding boxes using the output of the mask head.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
|
||||
masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
|
||||
bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
|
||||
shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
|
||||
upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
|
||||
are the height and width of the input image. The mask is applied to the bounding boxes.
|
||||
"""
|
||||
|
||||
c, mh, mw = protos.shape # CHW
|
||||
ih, iw = shape
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
|
||||
width_ratio = mw / iw
|
||||
height_ratio = mh / ih
|
||||
|
||||
downsampled_bboxes = bboxes.clone()
|
||||
downsampled_bboxes[:, 0] *= width_ratio
|
||||
downsampled_bboxes[:, 2] *= width_ratio
|
||||
downsampled_bboxes[:, 3] *= height_ratio
|
||||
downsampled_bboxes[:, 1] *= height_ratio
|
||||
|
||||
masks = crop_mask(masks, downsampled_bboxes) # CHW
|
||||
if upsample:
|
||||
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
def process_mask_native(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||
shape (tuple): the size of the input image (h,w)
|
||||
|
||||
Returns:
|
||||
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
masks = scale_masks(masks[None], shape)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
def scale_masks(masks, shape, padding=True):
|
||||
"""
|
||||
Rescale segment masks to shape.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): (N, C, H, W).
|
||||
shape (tuple): Height and width.
|
||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||
rescaling.
|
||||
"""
|
||||
mh, mw = masks.shape[2:]
|
||||
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
|
||||
pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding
|
||||
if padding:
|
||||
pad[0] /= 2
|
||||
pad[1] /= 2
|
||||
top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x
|
||||
bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
|
||||
masks = masks[..., top:bottom, left:right]
|
||||
|
||||
masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
|
||||
return masks
|
||||
|
||||
|
||||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
|
||||
"""
|
||||
Rescale segment coordinates (xy) from img1_shape to img0_shape.
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the coords are from.
|
||||
coords (torch.Tensor): the coords to be scaled of shape n,2.
|
||||
img0_shape (tuple): the shape of the image that the segmentation is being applied to.
|
||||
ratio_pad (tuple): the ratio of the image size to the padded image size.
|
||||
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False.
|
||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||
rescaling.
|
||||
|
||||
Returns:
|
||||
coords (torch.Tensor): The scaled coordinates.
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
||||
if padding:
|
||||
coords[..., 0] -= pad[0] # x padding
|
||||
coords[..., 1] -= pad[1] # y padding
|
||||
coords[..., 0] /= gain
|
||||
coords[..., 1] /= gain
|
||||
coords = clip_coords(coords, img0_shape)
|
||||
if normalize:
|
||||
coords[..., 0] /= img0_shape[1] # width
|
||||
coords[..., 1] /= img0_shape[0] # height
|
||||
return coords
|
||||
|
||||
|
||||
def regularize_rboxes(rboxes):
|
||||
"""
|
||||
Regularize rotated boxes in range [0, pi/2].
|
||||
|
||||
Args:
|
||||
rboxes (torch.Tensor): (N, 5), xywhr.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The regularized boxes.
|
||||
"""
|
||||
x, y, w, h, t = rboxes.unbind(dim=-1)
|
||||
# Swap edge and angle if h >= w
|
||||
w_ = torch.where(w > h, w, h)
|
||||
h_ = torch.where(w > h, h, w)
|
||||
t = torch.where(w > h, t, t + math.pi / 2) % math.pi
|
||||
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
|
||||
|
||||
|
||||
def masks2segments(masks, strategy="largest"):
|
||||
"""
|
||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
|
||||
strategy (str): 'concat' or 'largest'. Defaults to largest
|
||||
|
||||
Returns:
|
||||
segments (List): list of segment masks
|
||||
"""
|
||||
segments = []
|
||||
for x in masks.int().cpu().numpy().astype("uint8"):
|
||||
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
||||
if c:
|
||||
if strategy == "concat": # concatenate all segments
|
||||
c = np.concatenate([x.reshape(-1, 2) for x in c])
|
||||
elif strategy == "largest": # select largest segment
|
||||
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
||||
else:
|
||||
c = np.zeros((0, 2)) # no segments found
|
||||
segments.append(c.astype("float32"))
|
||||
return segments
|
||||
|
||||
|
||||
def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout.
|
||||
|
||||
Args:
|
||||
batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8.
|
||||
"""
|
||||
return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
|
||||
|
||||
def clean_str(s):
|
||||
"""
|
||||
Cleans a string by replacing special characters with underscore _
|
||||
|
||||
Args:
|
||||
s (str): a string needing special characters replaced
|
||||
|
||||
Returns:
|
||||
(str): a string with special characters replaced by an underscore _
|
||||
"""
|
||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||
|
||||
def v10postprocess(preds, max_det, nc=80):
|
||||
assert(4 + nc == preds.shape[-1])
|
||||
boxes, scores = preds.split([4, nc], dim=-1)
|
||||
max_scores = scores.amax(dim=-1)
|
||||
# print(max_scores.shape, max_det)
|
||||
max_scores, index = torch.topk(max_scores, max_det, dim=-1)
|
||||
index = index.unsqueeze(-1)
|
||||
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
|
||||
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
|
||||
|
||||
scores, index = torch.topk(scores.flatten(1), max_det, dim=-1)
|
||||
labels = index % nc
|
||||
index = index // nc
|
||||
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
||||
return boxes, scores, labels
|
88
ultralytics/utils/patches.py
Normal file
88
ultralytics/utils/patches.py
Normal file
@ -0,0 +1,88 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""Monkey patches to update/extend functionality of existing functions."""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
|
||||
_imshow = cv2.imshow # copy to avoid recursion errors
|
||||
|
||||
|
||||
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
||||
"""
|
||||
Read an image from a file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file to read.
|
||||
flags (int, optional): Flag that can take values of cv2.IMREAD_*. Defaults to cv2.IMREAD_COLOR.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The read image.
|
||||
"""
|
||||
return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
|
||||
|
||||
|
||||
def imwrite(filename: str, img: np.ndarray, params=None):
|
||||
"""
|
||||
Write an image to a file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file to write.
|
||||
img (np.ndarray): Image to write.
|
||||
params (list of ints, optional): Additional parameters. See OpenCV documentation.
|
||||
|
||||
Returns:
|
||||
(bool): True if the file was written, False otherwise.
|
||||
"""
|
||||
try:
|
||||
cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def imshow(winname: str, mat: np.ndarray):
|
||||
"""
|
||||
Displays an image in the specified window.
|
||||
|
||||
Args:
|
||||
winname (str): Name of the window.
|
||||
mat (np.ndarray): Image to be shown.
|
||||
"""
|
||||
_imshow(winname.encode("unicode_escape").decode(), mat)
|
||||
|
||||
|
||||
# PyTorch functions ----------------------------------------------------------------------------------------------------
|
||||
_torch_save = torch.save # copy to avoid recursion errors
|
||||
|
||||
|
||||
def torch_save(*args, use_dill=True, **kwargs):
|
||||
"""
|
||||
Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and
|
||||
exponential standoff in case of save failure.
|
||||
|
||||
Args:
|
||||
*args (tuple): Positional arguments to pass to torch.save.
|
||||
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
|
||||
**kwargs (any): Keyword arguments to pass to torch.save.
|
||||
"""
|
||||
try:
|
||||
assert use_dill
|
||||
import dill as pickle
|
||||
except (AssertionError, ImportError):
|
||||
import pickle
|
||||
|
||||
if "pickle_module" not in kwargs:
|
||||
kwargs["pickle_module"] = pickle
|
||||
|
||||
for i in range(4): # 3 retries
|
||||
try:
|
||||
return _torch_save(*args, **kwargs)
|
||||
except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan
|
||||
if i == 3:
|
||||
raise e
|
||||
time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s
|
1047
ultralytics/utils/plotting.py
Normal file
1047
ultralytics/utils/plotting.py
Normal file
File diff suppressed because it is too large
Load Diff
116
ultralytics/utils/show_trace_pr.py
Normal file
116
ultralytics/utils/show_trace_pr.py
Normal file
@ -0,0 +1,116 @@
|
||||
import os.path
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
# [tag, bandage, null]
|
||||
# [0, 1, 2]
|
||||
class ShowPR:
|
||||
def __init__(self, tags, prec_value, title_name=None):
|
||||
self.tags = tags
|
||||
self.prec_value = prec_value
|
||||
self.thres = [i * 0.01 for i in range(101)]
|
||||
self.ratios = [i * 0.05 for i in range(21)]
|
||||
self.title_name = title_name
|
||||
|
||||
def change_precValue(self, thre=0.5):
|
||||
values = []
|
||||
for i in range(len(self.prec_value)):
|
||||
value = []
|
||||
for j in range(len(self.prec_value[i])):
|
||||
if self.prec_value[i][j] > thre:
|
||||
value.append(1)
|
||||
else:
|
||||
value.append(0)
|
||||
values.append(value)
|
||||
return values
|
||||
|
||||
def _calculate_pr(self, prec_value):
|
||||
FN, FP, TN, TP = 0, 0, 0, 0
|
||||
for output, target in zip(prec_value, self.tags):
|
||||
# print("output >> {} , target >> {}".format(output, target))
|
||||
if output != target:
|
||||
if target == 0:
|
||||
FP += 1
|
||||
elif target == 1:
|
||||
FN += 1
|
||||
else:
|
||||
if target == 0:
|
||||
TN += 1
|
||||
elif target == 1:
|
||||
TP += 1
|
||||
if TP == 0:
|
||||
prec, recall = 0, 0
|
||||
else:
|
||||
prec = TP / (TP + FP)
|
||||
recall = TP / (TP + FN)
|
||||
if TN == 0:
|
||||
tn_prec, tn_recall = 0, 0
|
||||
else:
|
||||
tn_prec = TN / (TN + FN)
|
||||
tn_recall = TN / (TN + FP)
|
||||
# print("TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(TP, FP, TN, FN))
|
||||
return prec, recall, tn_prec, tn_recall
|
||||
|
||||
def calculate_multiple(self, ratio=0.2):
|
||||
recall, recall_TN, PrecisePos, PreciseNeg = [], [], [], []
|
||||
for thre in self.thres:
|
||||
prec_value = []
|
||||
value = self.change_precValue(thre)
|
||||
for num in range(len(value)):
|
||||
proportion = value[num].count(1) / len(value[num])
|
||||
if proportion >= ratio:
|
||||
prec_value.append(1)
|
||||
else:
|
||||
prec_value.append(0)
|
||||
prec, recall_pos, tn_prec, tn_recall = self._calculate_pr(prec_value)
|
||||
print(
|
||||
f"thre>>{ratio:.2f}, recall>>{recall_pos:.4f}, precise_pos>>{prec:.4f}, recall_tn>>{tn_recall:.4f}, precise_neg>>{tn_prec:4f}")
|
||||
PrecisePos.append(prec)
|
||||
recall.append(recall_pos)
|
||||
PreciseNeg.append(tn_prec)
|
||||
recall_TN.append(tn_recall)
|
||||
return recall, recall_TN, PrecisePos, PreciseNeg
|
||||
|
||||
def write_results_to_file(self, recall, recall_TN, PrecisePos, PreciseNeg, ratio):
|
||||
file_path = os.sep.join(['./ckpts/tracePR', self.title_name + f"_{ratio:.2f}" + '.txt'])
|
||||
with open(file_path, 'w') as file:
|
||||
file.write("threshold, recall, recall_TN, PrecisePos, PreciseNeg\n")
|
||||
for thre, rec, rec_tn, prec_pos, prec_neg in zip(self.thres, recall, recall_TN, PrecisePos, PreciseNeg):
|
||||
file.write(
|
||||
f"thre>>{thre:.2f}, recall>>{rec:.4f}, precise_pos>>{prec_pos:.4f}, recall_tn>>{rec_tn:.4f}, precise_neg>>{prec_neg:4f}\n")
|
||||
|
||||
def show_pr(self, recall, recall_TN, PrecisePos, PreciseNeg, ratio):
|
||||
# self.calculate_multiple()
|
||||
x = self.thres
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||||
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
|
||||
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
if self.title_name is not None:
|
||||
plt.title(f"PrecisePos & Recall ratio:{ratio:.2f}", fontdict={'fontsize': 12, 'fontweight': 'black'})
|
||||
# plt.grid(True, linestyle='--', alpha=0.5)
|
||||
# 启用次刻度
|
||||
plt.minorticks_on()
|
||||
|
||||
# 设置主刻度的网格线
|
||||
plt.grid(which='major', linestyle='-', alpha=0.5, color='gray')
|
||||
|
||||
# 设置次刻度的网格线
|
||||
plt.grid(which='minor', linestyle=':', alpha=0.3, color='gray')
|
||||
plt.savefig(os.sep.join(['./ckpts/tracePR', self.title_name + f"_{ratio:.2f}" + '.png']))
|
||||
plt.show()
|
||||
plt.close()
|
||||
self.write_results_to_file(recall, recall_TN, PrecisePos, PreciseNeg, ratio)
|
||||
|
||||
def get_pic(self):
|
||||
for ratio in self.ratios:
|
||||
# ratio = 0.5
|
||||
if ratio < 0.2 or ratio > 0.95:
|
||||
continue
|
||||
recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple(ratio)
|
||||
self.show_pr(recall, recall_TN, PrecisePos, PreciseNeg, ratio)
|
345
ultralytics/utils/tal.py
Normal file
345
ultralytics/utils/tal.py
Normal file
@ -0,0 +1,345 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .checks import check_version
|
||||
from .metrics import bbox_iou, probiou
|
||||
from .ops import xywhr2xyxyxyxy
|
||||
|
||||
TORCH_1_10 = check_version(torch.__version__, "1.10.0")
|
||||
|
||||
|
||||
class TaskAlignedAssigner(nn.Module):
|
||||
"""
|
||||
A task-aligned assigner for object detection.
|
||||
|
||||
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
|
||||
classification and localization information.
|
||||
|
||||
Attributes:
|
||||
topk (int): The number of top candidates to consider.
|
||||
num_classes (int): The number of object classes.
|
||||
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
|
||||
beta (float): The beta parameter for the localization component of the task-aligned metric.
|
||||
eps (float): A small value to prevent division by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
||||
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.num_classes = num_classes
|
||||
self.bg_idx = num_classes
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
"""
|
||||
Compute the task-aligned assignment. Reference code is available at
|
||||
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
|
||||
|
||||
Args:
|
||||
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
||||
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
||||
anc_points (Tensor): shape(num_total_anchors, 2)
|
||||
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
|
||||
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
|
||||
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
|
||||
|
||||
Returns:
|
||||
target_labels (Tensor): shape(bs, num_total_anchors)
|
||||
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
||||
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
||||
fg_mask (Tensor): shape(bs, num_total_anchors)
|
||||
target_gt_idx (Tensor): shape(bs, num_total_anchors)
|
||||
"""
|
||||
self.bs = pd_scores.shape[0]
|
||||
self.n_max_boxes = gt_bboxes.shape[1]
|
||||
|
||||
if self.n_max_boxes == 0:
|
||||
device = gt_bboxes.device
|
||||
return (
|
||||
torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
|
||||
torch.zeros_like(pd_bboxes).to(device),
|
||||
torch.zeros_like(pd_scores).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
)
|
||||
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
||||
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
||||
)
|
||||
|
||||
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
|
||||
# Assigned target
|
||||
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
||||
|
||||
# Normalize
|
||||
align_metric *= mask_pos
|
||||
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj
|
||||
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj
|
||||
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
|
||||
target_scores = target_scores * norm_align_metric
|
||||
|
||||
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
||||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
||||
"""Get in_gts mask, (b, max_num_obj, h*w)."""
|
||||
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
# Get anchor_align metric, (b, max_num_obj, h*w)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
||||
# Get topk_metric mask, (b, max_num_obj, h*w)
|
||||
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
|
||||
# Merge all mask to a final mask, (b, max_num_obj, h*w)
|
||||
mask_pos = mask_topk * mask_in_gts * mask_gt
|
||||
|
||||
return mask_pos, align_metric, overlaps
|
||||
|
||||
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
||||
"""Compute alignment metric given predicted and ground truth bounding boxes."""
|
||||
na = pd_bboxes.shape[-2]
|
||||
mask_gt = mask_gt.bool() # b, max_num_obj, h*w
|
||||
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
||||
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
||||
|
||||
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
||||
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
|
||||
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
|
||||
# Get the scores of each grid for each gt cls
|
||||
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
|
||||
|
||||
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
|
||||
overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
|
||||
|
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
||||
return align_metric, overlaps
|
||||
|
||||
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
||||
"""IoU calculation for horizontal bounding boxes."""
|
||||
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
||||
|
||||
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
|
||||
"""
|
||||
Select the top-k candidates based on the given metrics.
|
||||
|
||||
Args:
|
||||
metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
||||
max_num_obj is the maximum number of objects, and h*w represents the
|
||||
total number of anchor points.
|
||||
largest (bool): If True, select the largest values; otherwise, select the smallest values.
|
||||
topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
||||
topk is the number of top candidates to consider. If not provided,
|
||||
the top-k values are automatically computed based on the given metrics.
|
||||
|
||||
Returns:
|
||||
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
||||
"""
|
||||
|
||||
# (b, max_num_obj, topk)
|
||||
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
|
||||
if topk_mask is None:
|
||||
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
|
||||
# (b, max_num_obj, topk)
|
||||
topk_idxs.masked_fill_(~topk_mask, 0)
|
||||
|
||||
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
|
||||
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
|
||||
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
|
||||
for k in range(self.topk):
|
||||
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
|
||||
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
|
||||
# Filter invalid bboxes
|
||||
count_tensor.masked_fill_(count_tensor > 1, 0)
|
||||
|
||||
return count_tensor.to(metrics.dtype)
|
||||
|
||||
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
||||
"""
|
||||
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
||||
|
||||
Args:
|
||||
gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
||||
batch size and max_num_obj is the maximum number of objects.
|
||||
gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
||||
target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
|
||||
anchor points, with shape (b, h*w), where h*w is the total
|
||||
number of anchor points.
|
||||
fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
|
||||
(foreground) anchor points.
|
||||
|
||||
Returns:
|
||||
(Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
|
||||
- target_labels (Tensor): Shape (b, h*w), containing the target labels for
|
||||
positive anchor points.
|
||||
- target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
|
||||
for positive anchor points.
|
||||
- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
|
||||
for positive anchor points, where num_classes is the number
|
||||
of object classes.
|
||||
"""
|
||||
|
||||
# Assigned target labels, (b, 1)
|
||||
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
||||
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
|
||||
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
|
||||
|
||||
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
|
||||
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
|
||||
|
||||
# Assigned target scores
|
||||
target_labels.clamp_(0)
|
||||
|
||||
# 10x faster than F.one_hot()
|
||||
target_scores = torch.zeros(
|
||||
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device,
|
||||
) # (b, h*w, 80)
|
||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||
|
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
|
||||
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
|
||||
|
||||
return target_labels, target_bboxes, target_scores
|
||||
|
||||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
"""
|
||||
Select the positive anchor center in gt.
|
||||
|
||||
Args:
|
||||
xy_centers (Tensor): shape(h*w, 2)
|
||||
gt_bboxes (Tensor): shape(b, n_boxes, 4)
|
||||
|
||||
Returns:
|
||||
(Tensor): shape(b, n_boxes, h*w)
|
||||
"""
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
||||
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
|
||||
return bbox_deltas.amin(3).gt_(eps)
|
||||
|
||||
@staticmethod
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
"""
|
||||
If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
|
||||
|
||||
Args:
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
overlaps (Tensor): shape(b, n_max_boxes, h*w)
|
||||
|
||||
Returns:
|
||||
target_gt_idx (Tensor): shape(b, h*w)
|
||||
fg_mask (Tensor): shape(b, h*w)
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
"""
|
||||
# (b, n_max_boxes, h*w) -> (b, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
||||
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
# Find each grid serve which gt(index)
|
||||
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
||||
return target_gt_idx, fg_mask, mask_pos
|
||||
|
||||
|
||||
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
||||
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
||||
"""IoU calculation for rotated bounding boxes."""
|
||||
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
|
||||
|
||||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes):
|
||||
"""
|
||||
Select the positive anchor center in gt for rotated bounding boxes.
|
||||
|
||||
Args:
|
||||
xy_centers (Tensor): shape(h*w, 2)
|
||||
gt_bboxes (Tensor): shape(b, n_boxes, 5)
|
||||
|
||||
Returns:
|
||||
(Tensor): shape(b, n_boxes, h*w)
|
||||
"""
|
||||
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
|
||||
corners = xywhr2xyxyxyxy(gt_bboxes)
|
||||
# (b, n_boxes, 1, 2)
|
||||
a, b, _, d = corners.split(1, dim=-2)
|
||||
ab = b - a
|
||||
ad = d - a
|
||||
|
||||
# (b, n_boxes, h*w, 2)
|
||||
ap = xy_centers - a
|
||||
norm_ab = (ab * ab).sum(dim=-1)
|
||||
norm_ad = (ad * ad).sum(dim=-1)
|
||||
ap_dot_ab = (ap * ab).sum(dim=-1)
|
||||
ap_dot_ad = (ap * ad).sum(dim=-1)
|
||||
return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
|
||||
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
"""Generate anchors from features."""
|
||||
anchor_points, stride_tensor = [], []
|
||||
assert feats is not None
|
||||
dtype, device = feats[0].dtype, feats[0].device
|
||||
for i, stride in enumerate(strides):
|
||||
_, _, h, w = feats[i].shape
|
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
|
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||||
|
||||
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
"""Transform distance(ltrb) to box(xywh or xyxy)."""
|
||||
assert(distance.shape[dim] == 4)
|
||||
lt, rb = distance.split([2, 2], dim)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
if xywh:
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
wh = x2y2 - x1y1
|
||||
return torch.cat((c_xy, wh), dim) # xywh bbox
|
||||
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
|
||||
|
||||
|
||||
def bbox2dist(anchor_points, bbox, reg_max):
|
||||
"""Transform bbox(xyxy) to dist(ltrb)."""
|
||||
x1y1, x2y2 = bbox.chunk(2, -1)
|
||||
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
|
||||
|
||||
|
||||
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
||||
"""
|
||||
Decode predicted object bounding box coordinates from anchor points and distribution.
|
||||
|
||||
Args:
|
||||
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
|
||||
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
|
||||
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
||||
Returns:
|
||||
(torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
|
||||
"""
|
||||
lt, rb = pred_dist.split(2, dim=dim)
|
||||
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
|
||||
# (bs, h*w, 1)
|
||||
xf, yf = ((rb - lt) / 2).split(1, dim=dim)
|
||||
x, y = xf * cos - yf * sin, xf * sin + yf * cos
|
||||
xy = torch.cat([x, y], dim=dim) + anchor_points
|
||||
return torch.cat([xy, lt + rb], dim=dim)
|
610
ultralytics/utils/torch_utils.py
Normal file
610
ultralytics/utils/torch_utils.py
Normal file
@ -0,0 +1,610 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
|
||||
from ultralytics.utils.checks import PYTHON_VERSION, check_version
|
||||
|
||||
try:
|
||||
import thop
|
||||
except ImportError:
|
||||
thop = None
|
||||
|
||||
# Version checks (all default to version>=min_version)
|
||||
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
|
||||
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
|
||||
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
|
||||
TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
|
||||
TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
|
||||
TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
"""Decorator to make all processes in distributed training wait for each local_master to do something."""
|
||||
initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
if initialized and local_rank not in (-1, 0):
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
yield
|
||||
if initialized and local_rank == 0:
|
||||
dist.barrier(device_ids=[0])
|
||||
|
||||
|
||||
def smart_inference_mode():
|
||||
"""Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
|
||||
|
||||
def decorate(fn):
|
||||
"""Applies appropriate torch decorator for inference mode based on torch version."""
|
||||
if TORCH_1_9 and torch.is_inference_mode_enabled():
|
||||
return fn # already in inference_mode, act as a pass-through
|
||||
else:
|
||||
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def get_cpu_info():
|
||||
"""Return a string with system CPU information, i.e. 'Apple M2'."""
|
||||
import cpuinfo # pip install py-cpuinfo
|
||||
|
||||
k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available)
|
||||
info = cpuinfo.get_cpu_info() # info dict
|
||||
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
|
||||
return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
|
||||
|
||||
|
||||
def select_device(device="", batch=0, newline=False, verbose=True):
|
||||
"""
|
||||
Selects the appropriate PyTorch device based on the provided arguments.
|
||||
|
||||
The function takes a string specifying the device or a torch.device object and returns a torch.device object
|
||||
representing the selected device. The function also validates the number of available devices and raises an
|
||||
exception if the requested device(s) are not available.
|
||||
|
||||
Args:
|
||||
device (str | torch.device, optional): Device string or torch.device object.
|
||||
Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
|
||||
the first available GPU, or CPU if no GPU is available.
|
||||
batch (int, optional): Batch size being used in your model. Defaults to 0.
|
||||
newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
|
||||
verbose (bool, optional): If True, logs the device information. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(torch.device): Selected device.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
|
||||
devices when using multiple GPUs.
|
||||
|
||||
Examples:
|
||||
>>> select_device('cuda:0')
|
||||
device(type='cuda', index=0)
|
||||
|
||||
>>> select_device('cpu')
|
||||
device(type='cpu')
|
||||
|
||||
Note:
|
||||
Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
|
||||
"""
|
||||
|
||||
if isinstance(device, torch.device):
|
||||
return device
|
||||
|
||||
s = f"Ultralytics YOLOv{__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
|
||||
device = str(device).lower()
|
||||
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
|
||||
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||
cpu = device == "cpu"
|
||||
mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
|
||||
if cpu or mps:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
|
||||
elif device: # non-cpu device requested
|
||||
if device == "cuda":
|
||||
device = "0"
|
||||
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
|
||||
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
|
||||
LOGGER.info(s)
|
||||
install = (
|
||||
"See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
|
||||
"CUDA devices are seen by torch.\n"
|
||||
if torch.cuda.device_count() == 0
|
||||
else ""
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid CUDA 'device={device}' requested."
|
||||
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
|
||||
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
|
||||
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
|
||||
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
|
||||
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
|
||||
f"{install}"
|
||||
)
|
||||
|
||||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
n = len(devices) # device count
|
||||
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||||
raise ValueError(
|
||||
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||||
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
|
||||
)
|
||||
space = " " * (len(s) + 1)
|
||||
for i, d in enumerate(devices):
|
||||
p = torch.cuda.get_device_properties(i)
|
||||
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
||||
arg = "cuda:0"
|
||||
elif mps and TORCH_2_0 and torch.backends.mps.is_available():
|
||||
# Prefer MPS if available
|
||||
s += f"MPS ({get_cpu_info()})\n"
|
||||
arg = "mps"
|
||||
else: # revert to CPU
|
||||
s += f"CPU ({get_cpu_info()})\n"
|
||||
arg = "cpu"
|
||||
|
||||
if verbose:
|
||||
LOGGER.info(s if newline else s.rstrip())
|
||||
return torch.device(arg)
|
||||
|
||||
|
||||
def time_sync():
|
||||
"""PyTorch-accurate time."""
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
||||
|
||||
def fuse_conv_and_bn(conv, bn):
|
||||
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
|
||||
fusedconv = (
|
||||
nn.Conv2d(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
dilation=conv.dilation,
|
||||
groups=conv.groups,
|
||||
bias=True,
|
||||
)
|
||||
.requires_grad_(False)
|
||||
.to(conv.weight.device)
|
||||
)
|
||||
|
||||
# Prepare filters
|
||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
||||
|
||||
# Prepare spatial bias
|
||||
b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
|
||||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||||
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||||
|
||||
return fusedconv
|
||||
|
||||
|
||||
def fuse_deconv_and_bn(deconv, bn):
|
||||
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
|
||||
fuseddconv = (
|
||||
nn.ConvTranspose2d(
|
||||
deconv.in_channels,
|
||||
deconv.out_channels,
|
||||
kernel_size=deconv.kernel_size,
|
||||
stride=deconv.stride,
|
||||
padding=deconv.padding,
|
||||
output_padding=deconv.output_padding,
|
||||
dilation=deconv.dilation,
|
||||
groups=deconv.groups,
|
||||
bias=True,
|
||||
)
|
||||
.requires_grad_(False)
|
||||
.to(deconv.weight.device)
|
||||
)
|
||||
|
||||
# Prepare filters
|
||||
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
|
||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
|
||||
|
||||
# Prepare spatial bias
|
||||
b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
|
||||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||||
fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||||
|
||||
return fuseddconv
|
||||
|
||||
|
||||
def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
"""
|
||||
Model information.
|
||||
|
||||
imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
|
||||
"""
|
||||
if not verbose:
|
||||
return
|
||||
n_p = get_num_params(model) # number of parameters
|
||||
n_g = get_num_gradients(model) # number of gradients
|
||||
n_l = len(list(model.modules())) # number of layers
|
||||
if detailed:
|
||||
LOGGER.info(
|
||||
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
|
||||
)
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
name = name.replace("module_list.", "")
|
||||
LOGGER.info(
|
||||
"%5g %40s %9s %12g %20s %10.3g %10.3g %10s"
|
||||
% (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)
|
||||
)
|
||||
|
||||
flops = get_flops(model, imgsz)
|
||||
fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
|
||||
fs = f", {flops:.1f} GFLOPs" if flops else ""
|
||||
yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
|
||||
model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
|
||||
LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
|
||||
return n_l, n_p, n_g, flops
|
||||
|
||||
|
||||
def get_num_params(model):
|
||||
"""Return the total number of parameters in a YOLO model."""
|
||||
return sum(x.numel() for x in model.parameters())
|
||||
|
||||
|
||||
def get_num_gradients(model):
|
||||
"""Return the total number of parameters with gradients in a YOLO model."""
|
||||
return sum(x.numel() for x in model.parameters() if x.requires_grad)
|
||||
|
||||
|
||||
def model_info_for_loggers(trainer):
|
||||
"""
|
||||
Return model info dict with useful model information.
|
||||
|
||||
Example:
|
||||
YOLOv8n info for loggers
|
||||
```python
|
||||
results = {'model/parameters': 3151904,
|
||||
'model/GFLOPs': 8.746,
|
||||
'model/speed_ONNX(ms)': 41.244,
|
||||
'model/speed_TensorRT(ms)': 3.211,
|
||||
'model/speed_PyTorch(ms)': 18.755}
|
||||
```
|
||||
"""
|
||||
if trainer.args.profile: # profile ONNX and TensorRT times
|
||||
from ultralytics.utils.benchmarks import ProfileModels
|
||||
|
||||
results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
|
||||
results.pop("model/name")
|
||||
else: # only return PyTorch times from most recent validation
|
||||
results = {
|
||||
"model/parameters": get_num_params(trainer.model),
|
||||
"model/GFLOPs": round(get_flops(trainer.model), 3),
|
||||
}
|
||||
results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
|
||||
return results
|
||||
|
||||
|
||||
def get_flops(model, imgsz=640):
|
||||
"""Return a YOLO model's FLOPs."""
|
||||
if not thop:
|
||||
return 0.0 # if not installed return 0.0 GFLOPs
|
||||
|
||||
try:
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
if not isinstance(imgsz, list):
|
||||
imgsz = [imgsz, imgsz] # expand if int/float
|
||||
try:
|
||||
# Use stride size for input tensor
|
||||
# stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
|
||||
# im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
# flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
|
||||
# return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
|
||||
raise Exception
|
||||
except Exception:
|
||||
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
||||
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
|
||||
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_flops_with_torch_profiler(model, imgsz=640):
|
||||
"""Compute model FLOPs (thop alternative)."""
|
||||
if TORCH_2_0:
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
|
||||
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
with torch.profiler.profile(with_flops=True) as prof:
|
||||
model(im)
|
||||
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||
return flops
|
||||
return 0
|
||||
|
||||
|
||||
def initialize_weights(model):
|
||||
"""Initialize model weights to random values."""
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t is nn.Conv2d:
|
||||
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif t is nn.BatchNorm2d:
|
||||
m.eps = 1e-3
|
||||
m.momentum = 0.03
|
||||
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
||||
m.inplace = True
|
||||
|
||||
|
||||
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
|
||||
"""Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally
|
||||
retaining the original shape.
|
||||
"""
|
||||
if ratio == 1.0:
|
||||
return img
|
||||
h, w = img.shape[2:]
|
||||
s = (int(h * ratio), int(w * ratio)) # new size
|
||||
img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
|
||||
if not same_shape: # pad/crop img
|
||||
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
|
||||
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
"""Returns nearest x divisible by divisor."""
|
||||
if isinstance(divisor, torch.Tensor):
|
||||
divisor = int(divisor.max()) # to int
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def copy_attr(a, b, include=(), exclude=()):
|
||||
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
|
||||
for k, v in b.__dict__.items():
|
||||
if (len(include) and k not in include) or k.startswith("_") or k in exclude:
|
||||
continue
|
||||
else:
|
||||
setattr(a, k, v)
|
||||
|
||||
|
||||
def get_latest_opset():
|
||||
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
|
||||
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset
|
||||
|
||||
|
||||
def intersect_dicts(da, db, exclude=()):
|
||||
"""Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
|
||||
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
||||
|
||||
|
||||
def is_parallel(model):
|
||||
"""Returns True if model is of type DP or DDP."""
|
||||
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
|
||||
|
||||
|
||||
def de_parallel(model):
|
||||
"""De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
|
||||
return model.module if is_parallel(model) else model
|
||||
|
||||
|
||||
def one_cycle(y1=0.0, y2=1.0, steps=100):
|
||||
"""Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
|
||||
return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
|
||||
|
||||
|
||||
def init_seeds(seed=0, deterministic=False):
|
||||
"""Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
|
||||
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
|
||||
if deterministic:
|
||||
if TORCH_2_0:
|
||||
torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
|
||||
torch.backends.cudnn.deterministic = True
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
else:
|
||||
LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
|
||||
else:
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = False
|
||||
|
||||
|
||||
class ModelEMA:
|
||||
"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
To disable EMA set the `enabled` attribute to `False`.
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
"""Create EMA."""
|
||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||
self.updates = updates # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
self.enabled = True
|
||||
|
||||
def update(self, model):
|
||||
"""Update EMA parameters."""
|
||||
if self.enabled:
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
||||
msd = de_parallel(model).state_dict() # model state_dict
|
||||
for k, v in self.ema.state_dict().items():
|
||||
if v.dtype.is_floating_point: # true for FP16 and FP32
|
||||
v *= d
|
||||
v += (1 - d) * msd[k].detach()
|
||||
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
|
||||
|
||||
def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
|
||||
"""Updates attributes and saves stripped model with optimizer removed."""
|
||||
if self.enabled:
|
||||
copy_attr(self.ema, model, include, exclude)
|
||||
|
||||
|
||||
def strip_optimizer(f: Union[str, Path] = "best_gift_v10n.pt", s: str = "") -> None:
|
||||
"""
|
||||
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
||||
|
||||
Args:
|
||||
f (str): file path to model to strip the optimizer from. Default is 'best_gift_v10n.pt'.
|
||||
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Example:
|
||||
```python
|
||||
from pathlib import Path
|
||||
from ultralytics.utils.torch_utils import strip_optimizer
|
||||
|
||||
for f in Path('path/to/weights').rglob('*.pt'):
|
||||
strip_optimizer(f)
|
||||
```
|
||||
"""
|
||||
x = torch.load(f, map_location=torch.device("cpu"))
|
||||
if "model" not in x:
|
||||
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
|
||||
return
|
||||
|
||||
if hasattr(x["model"], "args"):
|
||||
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
|
||||
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
|
||||
if x.get("ema"):
|
||||
x["model"] = x["ema"] # replace model with ema
|
||||
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
||||
x[k] = None
|
||||
x["epoch"] = -1
|
||||
x["model"].half() # to FP16
|
||||
for p in x["model"].parameters():
|
||||
p.requires_grad = False
|
||||
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||
# x['model'].args = x['train_args']
|
||||
torch.save(x, s or f)
|
||||
mb = os.path.getsize(s or f) / 1e6 # file size
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
||||
|
||||
def profile(input, ops, n=10, device=None):
|
||||
"""
|
||||
Ultralytics speed, memory and FLOPs profiler.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.torch_utils import profile
|
||||
|
||||
input = torch.randn(16, 3, 640, 640)
|
||||
m1 = lambda x: x * torch.sigmoid(x)
|
||||
m2 = nn.SiLU()
|
||||
profile(input, [m1, m2], n=100) # profile over 100 iterations
|
||||
```
|
||||
"""
|
||||
results = []
|
||||
if not isinstance(device, torch.device):
|
||||
device = select_device(device)
|
||||
LOGGER.info(
|
||||
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||
f"{'input':>24s}{'output':>24s}"
|
||||
)
|
||||
|
||||
for x in input if isinstance(input, list) else [input]:
|
||||
x = x.to(device)
|
||||
x.requires_grad = True
|
||||
for m in ops if isinstance(ops, list) else [ops]:
|
||||
m = m.to(device) if hasattr(m, "to") else m # device
|
||||
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
||||
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
||||
try:
|
||||
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
|
||||
except Exception:
|
||||
flops = 0
|
||||
|
||||
try:
|
||||
for _ in range(n):
|
||||
t[0] = time_sync()
|
||||
y = m(x)
|
||||
t[1] = time_sync()
|
||||
try:
|
||||
(sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
|
||||
t[2] = time_sync()
|
||||
except Exception: # no backward method
|
||||
# print(e) # for debug
|
||||
t[2] = float("nan")
|
||||
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||
mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
|
||||
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
|
||||
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
||||
LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
|
||||
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
||||
except Exception as e:
|
||||
LOGGER.info(e)
|
||||
results.append(None)
|
||||
torch.cuda.empty_cache()
|
||||
return results
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
"""Early stopping class that stops training when a specified number of epochs have passed without improvement."""
|
||||
|
||||
def __init__(self, patience=50):
|
||||
"""
|
||||
Initialize early stopping object.
|
||||
|
||||
Args:
|
||||
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
|
||||
"""
|
||||
self.best_fitness = 0.0 # i.e. mAP
|
||||
self.best_epoch = 0
|
||||
self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
|
||||
self.possible_stop = False # possible stop may occur next epoch
|
||||
|
||||
def __call__(self, epoch, fitness):
|
||||
"""
|
||||
Check whether to stop training.
|
||||
|
||||
Args:
|
||||
epoch (int): Current epoch of training
|
||||
fitness (float): Fitness value of current epoch
|
||||
|
||||
Returns:
|
||||
(bool): True if training should stop, False otherwise
|
||||
"""
|
||||
if fitness is None: # check if fitness=None (happens when val=False)
|
||||
return False
|
||||
|
||||
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
||||
self.best_epoch = epoch
|
||||
self.best_fitness = fitness
|
||||
delta = epoch - self.best_epoch # epochs without improvement
|
||||
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
||||
stop = delta >= self.patience # stop training if patience exceeded
|
||||
if stop:
|
||||
LOGGER.info(
|
||||
f"Stopping training early as no improvement observed in last {self.patience} epochs. "
|
||||
f"Best results observed at epoch {self.best_epoch}, best model saved as best_gift_v10n.pt.\n"
|
||||
f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
|
||||
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
|
||||
)
|
||||
return stop
|
92
ultralytics/utils/triton.py
Normal file
92
ultralytics/utils/triton.py
Normal file
@ -0,0 +1,92 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from typing import List
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TritonRemoteModel:
|
||||
"""
|
||||
Client for interacting with a remote Triton Inference Server model.
|
||||
|
||||
Attributes:
|
||||
endpoint (str): The name of the model on the Triton server.
|
||||
url (str): The URL of the Triton server.
|
||||
triton_client: The Triton client (either HTTP or gRPC).
|
||||
InferInput: The input class for the Triton client.
|
||||
InferRequestedOutput: The output request class for the Triton client.
|
||||
input_formats (List[str]): The data types of the model inputs.
|
||||
np_input_formats (List[type]): The numpy data types of the model inputs.
|
||||
input_names (List[str]): The names of the model inputs.
|
||||
output_names (List[str]): The names of the model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
|
||||
"""
|
||||
Initialize the TritonRemoteModel.
|
||||
|
||||
Arguments may be provided individually or parsed from a collective 'url' argument of the form
|
||||
<scheme>://<netloc>/<endpoint>/<task_name>
|
||||
|
||||
Args:
|
||||
url (str): The URL of the Triton server.
|
||||
endpoint (str): The name of the model on the Triton server.
|
||||
scheme (str): The communication scheme ('http' or 'grpc').
|
||||
"""
|
||||
if not endpoint and not scheme: # Parse all args from URL string
|
||||
splits = urlsplit(url)
|
||||
endpoint = splits.path.strip("/").split("/")[0]
|
||||
scheme = splits.scheme
|
||||
url = splits.netloc
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.url = url
|
||||
|
||||
# Choose the Triton client based on the communication scheme
|
||||
if scheme == "http":
|
||||
import tritonclient.http as client # noqa
|
||||
|
||||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint)
|
||||
else:
|
||||
import tritonclient.grpc as client # noqa
|
||||
|
||||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
|
||||
|
||||
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
|
||||
config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
|
||||
|
||||
# Define model attributes
|
||||
type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
|
||||
self.InferRequestedOutput = client.InferRequestedOutput
|
||||
self.InferInput = client.InferInput
|
||||
self.input_formats = [x["data_type"] for x in config["input"]]
|
||||
self.np_input_formats = [type_map[x] for x in self.input_formats]
|
||||
self.input_names = [x["name"] for x in config["input"]]
|
||||
self.output_names = [x["name"] for x in config["output"]]
|
||||
|
||||
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
Call the model with the given inputs.
|
||||
|
||||
Args:
|
||||
*inputs (List[np.ndarray]): Input data to the model.
|
||||
|
||||
Returns:
|
||||
(List[np.ndarray]): Model outputs.
|
||||
"""
|
||||
infer_inputs = []
|
||||
input_format = inputs[0].dtype
|
||||
for i, x in enumerate(inputs):
|
||||
if x.dtype != self.np_input_formats[i]:
|
||||
x = x.astype(self.np_input_formats[i])
|
||||
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
|
||||
infer_input.set_data_from_numpy(x)
|
||||
infer_inputs.append(infer_input)
|
||||
|
||||
infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
|
||||
outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
|
||||
|
||||
return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]
|
148
ultralytics/utils/tuner.py
Normal file
148
ultralytics/utils/tuner.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import subprocess
|
||||
|
||||
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir
|
||||
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks
|
||||
|
||||
|
||||
def run_ray_tune(
|
||||
model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args
|
||||
):
|
||||
"""
|
||||
Runs hyperparameter tuning using Ray Tune.
|
||||
|
||||
Args:
|
||||
model (YOLO): Model to run the tuner on.
|
||||
space (dict, optional): The hyperparameter search space. Defaults to None.
|
||||
grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
|
||||
gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
|
||||
max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
|
||||
train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the results of the hyperparameter search.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load a YOLOv8n model
|
||||
model = YOLO('yolov8n.pt')
|
||||
|
||||
# Start tuning hyperparameters for YOLOv8n training on the COCO8 dataset
|
||||
result_grid = model.tune(data='coco8.yaml', use_ray=True)
|
||||
```
|
||||
"""
|
||||
|
||||
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
|
||||
if train_args is None:
|
||||
train_args = {}
|
||||
|
||||
try:
|
||||
subprocess.run("pip install ray[tune]<=2.9.3".split(), check=True) # do not add single quotes here
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air import RunConfig
|
||||
from ray.air.integrations.wandb import WandbLoggerCallback
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]<=2.9.3"')
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
assert hasattr(wandb, "__version__")
|
||||
except (ImportError, AssertionError):
|
||||
wandb = False
|
||||
|
||||
checks.check_version(ray.__version__, "<=2.9.3", "ray")
|
||||
default_space = {
|
||||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
||||
"lr0": tune.uniform(1e-5, 1e-1),
|
||||
"lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
"momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
|
||||
"weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
|
||||
"warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
|
||||
"warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
|
||||
"box": tune.uniform(0.02, 0.2), # box loss gain
|
||||
"cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
|
||||
"hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
"hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
"hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||
"degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
|
||||
"translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
|
||||
"scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
|
||||
"shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
|
||||
"perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
"flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
|
||||
"fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
|
||||
"bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability)
|
||||
"mosaic": tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
"mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
"copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
|
||||
}
|
||||
|
||||
# Put the model in ray store
|
||||
task = model.task
|
||||
model_in_store = ray.put(model)
|
||||
|
||||
def _tune(config):
|
||||
"""
|
||||
Trains the YOLO model with the specified hyperparameters and additional arguments.
|
||||
|
||||
Args:
|
||||
config (dict): A dictionary of hyperparameters to use for training.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
model_to_train = ray.get(model_in_store) # get the model from ray store for tuning
|
||||
model_to_train.reset_callbacks()
|
||||
config.update(train_args)
|
||||
results = model_to_train.train(**config)
|
||||
return results.results_dict
|
||||
|
||||
# Get search space
|
||||
if not space:
|
||||
space = default_space
|
||||
LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.")
|
||||
|
||||
# Get dataset
|
||||
data = train_args.get("data", TASK2DATA[task])
|
||||
space["data"] = data
|
||||
if "data" not in train_args:
|
||||
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
|
||||
|
||||
# Define the trainable function with allocated resources
|
||||
trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
|
||||
|
||||
# Define the ASHA scheduler for hyperparameter search
|
||||
asha_scheduler = ASHAScheduler(
|
||||
time_attr="epoch",
|
||||
metric=TASK2METRIC[task],
|
||||
mode="max",
|
||||
max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
|
||||
grace_period=grace_period,
|
||||
reduction_factor=3,
|
||||
)
|
||||
|
||||
# Define the callbacks for the hyperparameter search
|
||||
tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
|
||||
|
||||
# Create the Ray Tune hyperparameter search tuner
|
||||
tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve() # must be absolute dir
|
||||
tune_dir.mkdir(parents=True, exist_ok=True)
|
||||
tuner = tune.Tuner(
|
||||
trainable_with_resources,
|
||||
param_space=space,
|
||||
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
|
||||
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),
|
||||
)
|
||||
|
||||
# Run the hyperparameter search
|
||||
tuner.fit()
|
||||
|
||||
# Return the results of the hyperparameter search
|
||||
return tuner.get_results()
|
Reference in New Issue
Block a user