add yolo v10 and modify pipeline

This commit is contained in:
王庆刚
2025-03-28 13:19:54 +08:00
parent 183299c06b
commit 798c596acc
471 changed files with 19109 additions and 7342 deletions

View File

@ -9,6 +9,7 @@ import re
import subprocess
import sys
import threading
import time
import urllib
import uuid
from pathlib import Path
@ -25,23 +26,22 @@ from tqdm import tqdm as tqdm_original
from ultralytics import __version__
# PyTorch Multi-GPU DDP Constants
RANK = int(os.getenv('RANK', -1))
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv("RANK", -1))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
# Other Constants
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLO
ASSETS = ROOT / 'assets' # default images
DEFAULT_CFG_PATH = ROOT / 'cfg/default.yaml'
ASSETS = ROOT / "assets" # default images
DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml"
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' if VERBOSE else None # tqdm bar format
LOGGING_NAME = 'ultralytics'
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans
ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans
HELP_MSG = \
"""
AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode
VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format
LOGGING_NAME = "ultralytics"
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans
HELP_MSG = """
Usage examples for running YOLOv8:
1. Install the ultralytics package:
@ -77,7 +77,7 @@ HELP_MSG = \
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
- Predict a YouTube video using a pretrained segmentation model at image size 320:
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320
- Val a pretrained detection model at batch-size 1 and image size 640:
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
@ -99,12 +99,12 @@ HELP_MSG = \
"""
# Settings
torch.set_printoptions(linewidth=320, precision=4, profile='default')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
torch.set_printoptions(linewidth=320, precision=4, profile="default")
np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab
os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # for deterministic training
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress verbose TF compiler warnings in Colab
class TQDM(tqdm_original):
@ -113,19 +113,22 @@ class TQDM(tqdm_original):
Args:
*args (list): Positional arguments passed to original tqdm.
**kwargs (dict): Keyword arguments, with custom defaults applied.
**kwargs (any): Keyword arguments, with custom defaults applied.
"""
def __init__(self, *args, **kwargs):
# Set new default values (these can still be overridden when calling TQDM)
kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
"""
Initialize custom Ultralytics tqdm class with different default arguments.
Note these can still be overridden when calling TQDM.
"""
kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed
kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed
super().__init__(*args, **kwargs)
class SimpleClass:
"""
Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
"""Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
access methods for easier debugging and usage.
"""
@ -134,14 +137,14 @@ class SimpleClass:
attr = []
for a in dir(self):
v = getattr(self, a)
if not callable(v) and not a.startswith('_'):
if not callable(v) and not a.startswith("_"):
if isinstance(v, SimpleClass):
# Display only the module and class name for subclasses
s = f'{a}: {v.__module__}.{v.__class__.__name__} object'
s = f"{a}: {v.__module__}.{v.__class__.__name__} object"
else:
s = f'{a}: {repr(v)}'
s = f"{a}: {repr(v)}"
attr.append(s)
return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr)
return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr)
def __repr__(self):
"""Return a machine-readable string representation of the object."""
@ -154,8 +157,7 @@ class SimpleClass:
class IterableSimpleNamespace(SimpleNamespace):
"""
Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
"""Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
enables usage with dict() and for loops.
"""
@ -165,24 +167,26 @@ class IterableSimpleNamespace(SimpleNamespace):
def __str__(self):
"""Return a human-readable string representation of the object."""
return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
return "\n".join(f"{k}={v}" for k, v in vars(self).items())
def __getattr__(self, attr):
"""Custom attribute access error message with helpful information."""
name = self.__class__.__name__
raise AttributeError(f"""
raise AttributeError(
f"""
'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace
{DEFAULT_CFG_PATH} with the latest version from
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
""")
"""
)
def get(self, key, default=None):
"""Return the value of the specified key if it exists; otherwise, return the default value."""
return getattr(self, key, default)
def plt_settings(rcparams=None, backend='Agg'):
def plt_settings(rcparams=None, backend="Agg"):
"""
Decorator to temporarily set rc parameters and the backend for a plotting function.
@ -200,7 +204,7 @@ def plt_settings(rcparams=None, backend='Agg'):
"""
if rcparams is None:
rcparams = {'font.size': 11}
rcparams = {"font.size": 11}
def decorator(func):
"""Decorator to apply temporary rc parameters and backend to a function."""
@ -208,12 +212,16 @@ def plt_settings(rcparams=None, backend='Agg'):
def wrapper(*args, **kwargs):
"""Sets rc parameters and backend, calls the original function, and restores the settings."""
original_backend = plt.get_backend()
plt.switch_backend(backend)
if backend.lower() != original_backend.lower():
plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8
plt.switch_backend(backend)
with plt.rc_context(rcparams):
result = func(*args, **kwargs)
plt.switch_backend(original_backend)
if backend != original_backend:
plt.close("all")
plt.switch_backend(original_backend)
return result
return wrapper
@ -222,58 +230,59 @@ def plt_settings(rcparams=None, backend='Agg'):
def set_logging(name=LOGGING_NAME, verbose=True):
"""Sets up logging for the given name."""
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
logging.config.dictConfig({
'version': 1,
'disable_existing_loggers': False,
'formatters': {
name: {
'format': '%(message)s'}},
'handlers': {
name: {
'class': 'logging.StreamHandler',
'formatter': name,
'level': level}},
'loggers': {
name: {
'level': level,
'handlers': [name],
'propagate': False}}})
"""Sets up logging for the given name with UTF-8 encoding support."""
level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
# Configure the console (stdout) encoding to UTF-8
formatter = logging.Formatter("%(message)s") # Default formatter
if WINDOWS and sys.stdout.encoding != "utf-8":
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
elif hasattr(sys.stdout, "buffer"):
import io
def emojis(string=''):
"""Return platform-dependent emoji-safe version of string."""
return string.encode().decode('ascii', 'ignore') if WINDOWS else string
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
else:
sys.stdout.encoding = "utf-8"
except Exception as e:
print(f"Creating custom formatter for non UTF-8 environments due to {e}")
class CustomFormatter(logging.Formatter):
def format(self, record):
"""Sets up logging with UTF-8 encoding and configurable verbosity."""
return emojis(super().format(record))
class EmojiFilter(logging.Filter):
"""
A custom logging filter class for removing emojis in log messages.
formatter = CustomFormatter("%(message)s") # Use CustomFormatter to eliminate UTF-8 output as last recourse
This filter is particularly useful for ensuring compatibility with Windows terminals
that may not support the display of emojis in log messages.
"""
# Create and configure the StreamHandler
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
stream_handler.setLevel(level)
def filter(self, record):
"""Filter logs by emoji unicode characters on windows."""
record.msg = emojis(record.msg)
return super().filter(record)
logger = logging.getLogger(name)
logger.setLevel(level)
logger.addHandler(stream_handler)
logger.propagate = False
return logger
# Set logger
set_logging(LOGGING_NAME, verbose=VERBOSE) # run before defining LOGGER
LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
if WINDOWS: # emoji-safe logging
LOGGER.addFilter(EmojiFilter())
LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.)
for logger in "sentry_sdk", "urllib3.connectionpool":
logging.getLogger(logger).setLevel(logging.CRITICAL + 1)
def emojis(string=""):
"""Return platform-dependent emoji-safe version of string."""
return string.encode().decode("ascii", "ignore") if WINDOWS else string
class ThreadingLocked:
"""
A decorator class for ensuring thread-safe execution of a function or method.
This class can be used as a decorator to make sure that if the decorated function
is called from multiple threads, only one thread at a time will be able to execute the function.
A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator
to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able
to execute the function.
Attributes:
lock (threading.Lock): A lock object used to manage access to the decorated function.
@ -290,20 +299,23 @@ class ThreadingLocked:
"""
def __init__(self):
"""Initializes the decorator class for thread-safe execution of a function or method."""
self.lock = threading.Lock()
def __call__(self, f):
"""Run thread-safe execution of function or method."""
from functools import wraps
@wraps(f)
def decorated(*args, **kwargs):
"""Applies thread-safety to the decorated function or method."""
with self.lock:
return f(*args, **kwargs)
return decorated
def yaml_save(file='data.yaml', data=None, header=''):
def yaml_save(file="data.yaml", data=None, header=""):
"""
Save YAML data to a file.
@ -323,18 +335,19 @@ def yaml_save(file='data.yaml', data=None, header=''):
file.parent.mkdir(parents=True, exist_ok=True)
# Convert Path objects to strings
valid_types = int, float, str, bool, list, tuple, dict, type(None)
for k, v in data.items():
if isinstance(v, Path):
if not isinstance(v, valid_types):
data[k] = str(v)
# Dump data to file in YAML format
with open(file, 'w', errors='ignore', encoding='utf-8') as f:
with open(file, "w", errors="ignore", encoding="utf-8") as f:
if header:
f.write(header)
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
def yaml_load(file='data.yaml', append_filename=False):
def yaml_load(file="data.yaml", append_filename=False):
"""
Load YAML data from a file.
@ -345,18 +358,18 @@ def yaml_load(file='data.yaml', append_filename=False):
Returns:
(dict): YAML data and file name.
"""
assert Path(file).suffix in ('.yaml', '.yml'), f'Attempting to load non-YAML file {file} with yaml_load()'
with open(file, errors='ignore', encoding='utf-8') as f:
assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()"
with open(file, errors="ignore", encoding="utf-8") as f:
s = f.read() # string
# Remove special characters
if not s.isprintable():
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)
# Add YAML filename to dict and return
data = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files)
if append_filename:
data['yaml_file'] = str(file)
data["yaml_file"] = str(file)
return data
@ -368,7 +381,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
yaml_file: The file path of the YAML file or a YAML-formatted dictionary.
Returns:
None
(None)
"""
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True)
@ -378,7 +391,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
# Default configuration
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
for k, v in DEFAULT_CFG_DICT.items():
if isinstance(v, str) and v.lower() == 'none':
if isinstance(v, str) and v.lower() == "none":
DEFAULT_CFG_DICT[k] = None
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
@ -392,8 +405,8 @@ def is_ubuntu() -> bool:
(bool): True if OS is Ubuntu, False otherwise.
"""
with contextlib.suppress(FileNotFoundError):
with open('/etc/os-release') as f:
return 'ID=ubuntu' in f.read()
with open("/etc/os-release") as f:
return "ID=ubuntu" in f.read()
return False
@ -404,7 +417,7 @@ def is_colab():
Returns:
(bool): True if running inside a Colab notebook, False otherwise.
"""
return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ
return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ
def is_kaggle():
@ -414,19 +427,19 @@ def is_kaggle():
Returns:
(bool): True if running inside a Kaggle kernel, False otherwise.
"""
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
def is_jupyter():
"""
Check if the current script is running inside a Jupyter Notebook.
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
Check if the current script is running inside a Jupyter Notebook. Verified on Colab, Jupyterlab, Kaggle, Paperspace.
Returns:
(bool): True if running inside a Jupyter Notebook, False otherwise.
"""
with contextlib.suppress(Exception):
from IPython import get_ipython
return get_ipython() is not None
return False
@ -438,10 +451,10 @@ def is_docker() -> bool:
Returns:
(bool): True if the script is running inside a Docker container, False otherwise.
"""
file = Path('/proc/self/cgroup')
file = Path("/proc/self/cgroup")
if file.exists():
with open(file) as f:
return 'docker' in f.read()
return "docker" in f.read()
else:
return False
@ -455,7 +468,7 @@ def is_online() -> bool:
"""
import socket
for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
try:
test_connection = socket.create_connection(address=(host, 53), timeout=2)
except (socket.timeout, socket.gaierror, OSError):
@ -509,23 +522,23 @@ def is_pytest_running():
Returns:
(bool): True if pytest is running, False otherwise.
"""
return ('PYTEST_CURRENT_TEST' in os.environ) or ('pytest' in sys.modules) or ('pytest' in Path(sys.argv[0]).stem)
return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(sys.argv[0]).stem)
def is_github_actions_ci() -> bool:
def is_github_action_running() -> bool:
"""
Determine if the current environment is a GitHub Actions CI Python runner.
Determine if the current environment is a GitHub Actions runner.
Returns:
(bool): True if the current environment is a GitHub Actions CI Python runner, False otherwise.
(bool): True if the current environment is a GitHub Actions runner, False otherwise.
"""
return 'GITHUB_ACTIONS' in os.environ and 'RUNNER_OS' in os.environ and 'RUNNER_TOOL_CACHE' in os.environ
return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ
def is_git_dir():
"""
Determines whether the current file is part of a git repository.
If the current file is not part of a git repository, returns None.
Determines whether the current file is part of a git repository. If the current file is not part of a git
repository, returns None.
Returns:
(bool): True if current file is part of a git repository.
@ -535,14 +548,14 @@ def is_git_dir():
def get_git_dir():
"""
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
If the current file is not part of a git repository, returns None.
Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
the current file is not part of a git repository, returns None.
Returns:
(Path | None): Git root directory if found or None if not found.
"""
for d in Path(__file__).parents:
if (d / '.git').is_dir():
if (d / ".git").is_dir():
return d
@ -555,7 +568,7 @@ def get_git_origin_url():
"""
if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError):
origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
return origin.decode().strip()
@ -568,12 +581,13 @@ def get_git_branch():
"""
if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError):
origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
return origin.decode().strip()
def get_default_args(func):
"""Returns a dictionary of default arguments for a function.
"""
Returns a dictionary of default arguments for a function.
Args:
func (callable): The function to inspect.
@ -594,13 +608,13 @@ def get_ubuntu_version():
"""
if is_ubuntu():
with contextlib.suppress(FileNotFoundError, AttributeError):
with open('/etc/os-release') as f:
with open("/etc/os-release") as f:
return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
def get_user_config_dir(sub_dir='Ultralytics'):
def get_user_config_dir(sub_dir="yolov10"):
"""
Get the user config directory.
Return the appropriate config directory based on the environment operating system.
Args:
sub_dir (str): The name of the subdirectory to create.
@ -608,21 +622,22 @@ def get_user_config_dir(sub_dir='Ultralytics'):
Returns:
(Path): The path to the user config directory.
"""
# Return the appropriate config directory for each operating system
if WINDOWS:
path = Path.home() / 'AppData' / 'Roaming' / sub_dir
path = Path.home() / "AppData" / "Roaming" / sub_dir
elif MACOS: # macOS
path = Path.home() / 'Library' / 'Application Support' / sub_dir
path = Path.home() / "Library" / "Application Support" / sub_dir
elif LINUX:
path = Path.home() / '.config' / sub_dir
path = Path.home() / ".config" / sub_dir
else:
raise ValueError(f'Unsupported operating system: {platform.system()}')
raise ValueError(f"Unsupported operating system: {platform.system()}")
# GCP and AWS lambda fix, only /tmp is writeable
if not is_dir_writeable(path.parent):
LOGGER.warning(f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
'Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.')
path = Path('/tmp') / sub_dir if is_dir_writeable('/tmp') else Path().cwd() / sub_dir
LOGGER.warning(
f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
"Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path."
)
path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir
# Create the subdirectory if it does not exist
path.mkdir(parents=True, exist_ok=True)
@ -630,40 +645,99 @@ def get_user_config_dir(sub_dir='Ultralytics'):
return path
USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR') or get_user_config_dir()) # Ultralytics settings dir
SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir
SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
def colorstr(*input):
"""Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')."""
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
"""
Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
This function can be called in two ways:
- colorstr('color', 'style', 'your string')
- colorstr('your string')
In the second form, 'blue' and 'bold' will be applied by default.
Args:
*input (str): A sequence of strings where the first n-1 strings are color and style arguments,
and the last string is the one to be colored.
Supported Colors and Styles:
Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',
'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'
Misc: 'end', 'bold', 'underline'
Returns:
(str): The input string wrapped with ANSI escape codes for the specified color and style.
Examples:
>>> colorstr('blue', 'bold', 'hello world')
>>> '\033[34m\033[1mhello world\033[0m'
"""
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
colors = {
'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
'bright_black': '\033[90m', # bright colors
'bright_red': '\033[91m',
'bright_green': '\033[92m',
'bright_yellow': '\033[93m',
'bright_blue': '\033[94m',
'bright_magenta': '\033[95m',
'bright_cyan': '\033[96m',
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'underline': '\033[4m'}
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
"black": "\033[30m", # basic colors
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
"bright_black": "\033[90m", # bright colors
"bright_red": "\033[91m",
"bright_green": "\033[92m",
"bright_yellow": "\033[93m",
"bright_blue": "\033[94m",
"bright_magenta": "\033[95m",
"bright_cyan": "\033[96m",
"bright_white": "\033[97m",
"end": "\033[0m", # misc
"bold": "\033[1m",
"underline": "\033[4m",
}
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
def remove_colorstr(input_string):
"""
Removes ANSI escape codes from a string, effectively un-coloring it.
Args:
input_string (str): The string to remove color and style from.
Returns:
(str): A new string with all ANSI escape codes removed.
Examples:
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
>>> 'hello world'
"""
ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
return ansi_escape.sub("", input_string)
class TryExcept(contextlib.ContextDecorator):
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
"""
Ultralytics TryExcept class. Use as @TryExcept() decorator or 'with TryExcept():' context manager.
def __init__(self, msg='', verbose=True):
Examples:
As a decorator:
>>> @TryExcept(msg="Error occurred in func", verbose=True)
>>> def func():
>>> # Function logic here
>>> pass
As a context manager:
>>> with TryExcept(msg="Error occurred in block", verbose=True):
>>> # Code block here
>>> pass
"""
def __init__(self, msg="", verbose=True):
"""Initialize TryExcept class with optional message and verbosity settings."""
self.msg = msg
self.verbose = verbose
@ -679,14 +753,80 @@ class TryExcept(contextlib.ContextDecorator):
return True
class Retry(contextlib.ContextDecorator):
"""
Retry class for function execution with exponential backoff.
Can be used as a decorator or a context manager to retry a function or block of code on exceptions, up to a
specified number of times with an exponentially increasing delay between retries.
Examples:
Example usage as a decorator:
>>> @Retry(times=3, delay=2)
>>> def test_func():
>>> # Replace with function logic that may raise exceptions
>>> return True
Example usage as a context manager:
>>> with Retry(times=3, delay=2):
>>> # Replace with code block that may raise exceptions
>>> pass
"""
def __init__(self, times=3, delay=2):
"""Initialize Retry class with specified number of retries and delay."""
self.times = times
self.delay = delay
self._attempts = 0
def __call__(self, func):
"""Decorator implementation for Retry with exponential backoff."""
def wrapped_func(*args, **kwargs):
"""Applies retries to the decorated function or method."""
self._attempts = 0
while self._attempts < self.times:
try:
return func(*args, **kwargs)
except Exception as e:
self._attempts += 1
print(f"Retry {self._attempts}/{self.times} failed: {e}")
if self._attempts >= self.times:
raise e
time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay
return wrapped_func
def __enter__(self):
"""Enter the runtime context related to this object."""
self._attempts = 0
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the runtime context related to this object with exponential backoff."""
if exc_type is not None:
self._attempts += 1
if self._attempts < self.times:
print(f"Retry {self._attempts}/{self.times} failed: {exc_value}")
time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay
return True # Suppresses the exception and retries
return False # Re-raises the exception if retries are exhausted
def threaded(func):
"""Multi-threads a target function and returns thread. Usage: @threaded decorator."""
"""
Multi-threads a target function by default and returns the thread or function result.
Use as @threaded decorator. The function runs in a separate thread unless 'threaded=False' is passed.
"""
def wrapper(*args, **kwargs):
"""Multi-threads a given function and returns the thread."""
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
"""Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
if kwargs.pop("threaded", True): # run in thread
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
else:
return func(*args, **kwargs)
return wrapper
@ -723,27 +863,28 @@ def set_sentry():
Returns:
dict: The modified event or None if the event should not be sent to Sentry.
"""
if 'exc_info' in hint:
exc_type, exc_value, tb = hint['exc_info']
if exc_type in (KeyboardInterrupt, FileNotFoundError) \
or 'out of memory' in str(exc_value):
if "exc_info" in hint:
exc_type, exc_value, tb = hint["exc_info"]
if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value):
return None # do not send event
event['tags'] = {
'sys_argv': sys.argv[0],
'sys_argv_name': Path(sys.argv[0]).name,
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
'os': ENVIRONMENT}
event["tags"] = {
"sys_argv": sys.argv[0],
"sys_argv_name": Path(sys.argv[0]).name,
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
"os": ENVIRONMENT,
}
return event
if SETTINGS['sync'] and \
RANK in (-1, 0) and \
Path(sys.argv[0]).name == 'yolo' and \
not TESTS_RUNNING and \
ONLINE and \
is_pip_package() and \
not is_git_dir():
if (
SETTINGS["sync"]
and RANK in (-1, 0)
and Path(sys.argv[0]).name == "yolo"
and not TESTS_RUNNING
and ONLINE
and is_pip_package()
and not is_git_dir()
):
# If sentry_sdk package is not installed then return and do not use Sentry
try:
import sentry_sdk # noqa
@ -751,18 +892,15 @@ def set_sentry():
return
sentry_sdk.init(
dsn='https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016',
dsn="https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016",
debug=False,
traces_sample_rate=1.0,
release=__version__,
environment='production', # 'dev' or 'production'
environment="production", # 'dev' or 'production'
before_send=before_send,
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
sentry_sdk.set_user({'id': SETTINGS['uuid']}) # SHA-256 anonymized UUID hash
# Disable all sentry logging
for logger in 'sentry_sdk', 'sentry_sdk.errors':
logging.getLogger(logger).setLevel(logging.CRITICAL)
ignore_errors=[KeyboardInterrupt, FileNotFoundError],
)
sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash
class SettingsManager(dict):
@ -774,7 +912,10 @@ class SettingsManager(dict):
version (str): Settings version. In case of local version mismatch, new default settings will be saved.
"""
def __init__(self, file=SETTINGS_YAML, version='0.0.4'):
def __init__(self, file=SETTINGS_YAML, version="0.0.4"):
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
file.
"""
import copy
import hashlib
@ -788,22 +929,24 @@ class SettingsManager(dict):
self.file = Path(file)
self.version = version
self.defaults = {
'settings_version': version,
'datasets_dir': str(datasets_root / 'datasets'),
'weights_dir': str(root / 'weights'),
'runs_dir': str(root / 'runs'),
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
'sync': True,
'api_key': '',
'clearml': True, # integrations
'comet': True,
'dvc': True,
'hub': True,
'mlflow': True,
'neptune': True,
'raytune': True,
'tensorboard': True,
'wandb': True}
"settings_version": version,
"datasets_dir": str(datasets_root / "datasets"),
"weights_dir": str(root / "weights"),
"runs_dir": str(root / "runs"),
"uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
"sync": True,
"api_key": "",
"openai_api_key": "",
"clearml": True, # integrations
"comet": True,
"dvc": True,
"hub": True,
"mlflow": True,
"neptune": True,
"raytune": True,
"tensorboard": True,
"wandb": True,
}
super().__init__(copy.deepcopy(self.defaults))
@ -814,15 +957,26 @@ class SettingsManager(dict):
self.load()
correct_keys = self.keys() == self.defaults.keys()
correct_types = all(type(a) is type(b) for a, b in zip(self.values(), self.defaults.values()))
correct_version = check_version(self['settings_version'], self.version)
correct_version = check_version(self["settings_version"], self.version)
help_msg = (
f"\nView settings with 'yolo settings' or at '{self.file}'"
"\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. "
"For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings."
)
if not (correct_keys and correct_types and correct_version):
LOGGER.warning(
'WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem '
'with your settings or a recent ultralytics package update. '
f"\nView settings with 'yolo settings' or at '{self.file}'"
"\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'.")
"WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem "
f"with your settings or a recent ultralytics package update. {help_msg}"
)
self.reset()
if self.get("datasets_dir") == self.get("runs_dir"):
LOGGER.warning(
f"WARNING ⚠️ Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' "
f"must be different than 'runs_dir: {self.get('runs_dir')}'. "
f"Please change one to avoid possible issues during training. {help_msg}"
)
def load(self):
"""Loads settings from the YAML file."""
super().update(yaml_load(self.file))
@ -847,14 +1001,16 @@ def deprecation_warn(arg, new_arg, version=None):
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
if not version:
version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release
LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
f"Please use '{new_arg}' instead.")
LOGGER.warning(
f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
f"Please use '{new_arg}' instead."
)
def clean_url(url):
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
url = Path(url).as_posix().replace(':/', '://') # Pathlib turns :// -> :/, as_posix() for Windows
return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows
return urllib.parse.unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth
def url2file(url):
@ -865,12 +1021,23 @@ def url2file(url):
# Run below code on utils init ------------------------------------------------------------------------------------
# Check first-install steps
PREFIX = colorstr('Ultralytics: ')
PREFIX = colorstr("Ultralytics: ")
SETTINGS = SettingsManager() # initialize settings
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
'Docker' if is_docker() else platform.system()
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory
WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory
RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory
ENVIRONMENT = (
"Colab"
if is_colab()
else "Kaggle"
if is_kaggle()
else "Jupyter"
if is_jupyter()
else "Docker"
if is_docker()
else platform.system()
)
TESTS_RUNNING = is_pytest_running() or is_github_action_running()
set_sentry()
# Apply monkey patches

View File

@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
"""
"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
from copy import deepcopy
@ -36,7 +34,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
Args:
model (torch.nn.module): YOLO model to compute batch size for.
imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.60.
batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
Returns:
@ -44,14 +42,14 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
"""
# Check device
prefix = colorstr('AutoBatch: ')
LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
prefix = colorstr("AutoBatch: ")
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}")
device = next(model.parameters()).device # get model device
if device.type == 'cpu':
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
if device.type == "cpu":
LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")
return batch_size
if torch.backends.cudnn.benchmark:
LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
return batch_size
# Inspect CUDA memory
@ -62,7 +60,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
f = t - (r + a) # GiB free
LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
# Profile batch sizes
batch_sizes = [1, 2, 4, 8, 16]
@ -72,7 +70,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
# Fit a solution
y = [x[2] for x in results if x] # memory [2]
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
p = np.polyfit(batch_sizes[: len(y)], y, deg=1) # first degree polynomial fit
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
if None in results: # some sizes failed
i = results.index(None) # first fail index
@ -80,11 +78,11 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
b = batch_sizes[max(i - 1, 0)] # select prior safe point
if b < 1 or b > 1024: # b outside of safe range
b = batch_size
LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
LOGGER.info(f"{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.")
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
return b
except Exception as e:
LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.")
return batch_size

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Benchmark a YOLO model formats for speed and accuracy
Benchmark a YOLO model formats for speed and accuracy.
Usage:
from ultralytics.utils.benchmarks import ProfileModels, benchmark
@ -21,34 +21,29 @@ TensorFlow Lite | `tflite` | yolov8n.tflite
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
TensorFlow.js | `tfjs` | yolov8n_web_model/
PaddlePaddle | `paddle` | yolov8n_paddle_model/
ncnn | `ncnn` | yolov8n_ncnn_model/
NCNN | `ncnn` | yolov8n_ncnn_model/
"""
import glob
import platform
import sys
import time
from pathlib import Path
import numpy as np
import torch.cuda
from ultralytics import YOLO
from ultralytics import YOLO, YOLOWorld
from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.engine.exporter import export_formats
from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, SETTINGS, TQDM
from ultralytics.utils.checks import check_requirements, check_yolo
from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo
from ultralytics.utils.files import file_size
from ultralytics.utils.torch_utils import select_device
def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
data=None,
imgsz=160,
half=False,
int8=False,
device='cpu',
verbose=False):
def benchmark(
model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False
):
"""
Benchmark a YOLO model across different formats for speed and accuracy.
@ -76,6 +71,7 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
"""
import pandas as pd
pd.options.display.max_columns = 10
pd.options.display.width = 120
device = select_device(device, verbose=False)
@ -85,67 +81,72 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
y = []
t0 = time.time()
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
emoji, filename = '', None # export defaults
emoji, filename = "", None # export defaults
try:
assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
if i == 10:
assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux'
elif i == 11:
assert sys.version_info < (3, 11), 'PaddlePaddle export only supported on Python<=3.10'
if 'cpu' in device.type:
assert cpu, 'inference not supported on CPU'
if 'cuda' in device.type:
assert gpu, 'inference not supported on GPU'
# Checks
if i == 9: # Edge TPU
assert LINUX, "Edge TPU export only supported on Linux"
elif i == 7: # TF GraphDef
assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
elif i in {5, 10}: # CoreML and TF.js
assert MACOS or LINUX, "export only supported on macOS and Linux"
if i in {3, 5}: # CoreML and OpenVINO
assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
if i in {6, 7, 8, 9, 10}: # All TF formats
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
if i in {11}: # Paddle
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
if i in {12}: # NCNN
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
if "cpu" in device.type:
assert cpu, "inference not supported on CPU"
if "cuda" in device.type:
assert gpu, "inference not supported on GPU"
# Export
if format == '-':
if format == "-":
filename = model.ckpt_path or model.cfg
export = model # PyTorch format
exported_model = model # PyTorch format
else:
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
export = YOLO(filename, task=model.task)
assert suffix in str(filename), 'export failed'
emoji = '' # indicates export succeeded
exported_model = YOLO(filename, task=model.task)
assert suffix in str(filename), "export failed"
emoji = "" # indicates export succeeded
# Predict
assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
export.predict(ASSETS / 'bus.jpg', imgsz=imgsz, device=device, half=half)
assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported
assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
# Validate
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
results = export.val(data=data,
batch=1,
imgsz=imgsz,
plots=False,
device=device,
half=half,
int8=int8,
verbose=False)
metric, speed = results.results_dict[key], results.speed['inference']
y.append([name, '', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
results = exported_model.val(
data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False
)
metric, speed = results.results_dict[key], results.speed["inference"]
y.append([name, "", round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
except Exception as e:
if verbose:
assert type(e) is AssertionError, f'Benchmark failure for {name}: {e}'
LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}')
assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}")
y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
# Print results
check_yolo(device=device) # print system info
df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'])
df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)"])
name = Path(model.ckpt_path).name
s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n'
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n"
LOGGER.info(s)
with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f:
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
f.write(s)
if verbose and isinstance(verbose, float):
metrics = df[key].array # values to compare to floor
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
assert all(x > floor for x in metrics if pd.notna(x)), f'Benchmark failure: metric(s) < floor {floor}'
assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
return df
@ -154,8 +155,7 @@ class ProfileModels:
"""
ProfileModels class for profiling different models on ONNX and TensorRT.
This class profiles the performance of different models, provided their paths. The profiling includes parameters such as
model speed and FLOPs.
This class profiles the performance of different models, returning results such as model speed and FLOPs.
Attributes:
paths (list): Paths of the models to profile.
@ -175,15 +175,30 @@ class ProfileModels:
```
"""
def __init__(self,
paths: list,
num_timed_runs=100,
num_warmup_runs=10,
min_time=60,
imgsz=640,
half=True,
trt=True,
device=None):
def __init__(
self,
paths: list,
num_timed_runs=100,
num_warmup_runs=10,
min_time=60,
imgsz=640,
half=True,
trt=True,
device=None,
):
"""
Initialize the ProfileModels class for profiling models.
Args:
paths (list): List of paths of the models to be profiled.
num_timed_runs (int, optional): Number of timed runs for the profiling. Default is 100.
num_warmup_runs (int, optional): Number of warmup runs before the actual profiling starts. Default is 10.
min_time (float, optional): Minimum time in seconds for profiling a model. Default is 60.
imgsz (int, optional): Size of the image used during profiling. Default is 640.
half (bool, optional): Flag to indicate whether to use half-precision floating point for profiling.
trt (bool, optional): Flag to indicate whether to profile using TensorRT. Default is True.
device (torch.device, optional): Device used for profiling. If None, it is determined automatically.
"""
self.paths = paths
self.num_timed_runs = num_timed_runs
self.num_warmup_runs = num_warmup_runs
@ -191,36 +206,32 @@ class ProfileModels:
self.imgsz = imgsz
self.half = half
self.trt = trt # run TensorRT profiling
self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu')
self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu")
def profile(self):
"""Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
files = self.get_files()
if not files:
print('No matching *.pt or *.onnx files found.')
print("No matching *.pt or *.onnx files found.")
return
table_rows = []
output = []
for file in files:
engine_file = file.with_suffix('.engine')
if file.suffix in ('.pt', '.yaml', '.yml'):
engine_file = file.with_suffix(".engine")
if file.suffix in (".pt", ".yaml", ".yml"):
model = YOLO(str(file))
model.fuse() # to report correct params and GFLOPs in model.info()
model_info = model.info()
if self.trt and self.device.type != 'cpu' and not engine_file.is_file():
engine_file = model.export(format='engine',
half=self.half,
imgsz=self.imgsz,
device=self.device,
verbose=False)
onnx_file = model.export(format='onnx',
half=self.half,
imgsz=self.imgsz,
simplify=True,
device=self.device,
verbose=False)
elif file.suffix == '.onnx':
if self.trt and self.device.type != "cpu" and not engine_file.is_file():
engine_file = model.export(
format="engine", half=self.half, imgsz=self.imgsz, device=self.device, verbose=False
)
onnx_file = model.export(
format="onnx", half=self.half, imgsz=self.imgsz, simplify=True, device=self.device, verbose=False
)
elif file.suffix == ".onnx":
model_info = self.get_onnx_model_info(file)
onnx_file = file
else:
@ -235,25 +246,30 @@ class ProfileModels:
return output
def get_files(self):
"""Returns a list of paths for all relevant model files given by the user."""
files = []
for path in self.paths:
path = Path(path)
if path.is_dir():
extensions = ['*.pt', '*.onnx', '*.yaml']
extensions = ["*.pt", "*.onnx", "*.yaml"]
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
elif path.suffix in {'.pt', '.yaml', '.yml'}: # add non-existing
elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing
files.append(str(path))
else:
files.extend(glob.glob(str(path)))
print(f'Profiling: {sorted(files)}')
print(f"Profiling: {sorted(files)}")
return [Path(file) for file in sorted(files)]
def get_onnx_model_info(self, onnx_file: str):
# return (num_layers, num_params, num_gradients, num_flops)
return 0.0, 0.0, 0.0, 0.0
"""Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
file.
"""
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
def iterative_sigma_clipping(self, data, sigma=2, max_iters=3):
@staticmethod
def iterative_sigma_clipping(data, sigma=2, max_iters=3):
"""Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
data = np.array(data)
for _ in range(max_iters):
mean, std = np.mean(data), np.std(data)
@ -264,6 +280,7 @@ class ProfileModels:
return data
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
"""Profiles the TensorRT model, measuring average run time and standard deviation among runs."""
if not self.trt or not Path(engine_file).is_file():
return 0.0, 0.0
@ -286,39 +303,44 @@ class ProfileModels:
run_times = []
for _ in TQDM(range(num_runs), desc=engine_file):
results = model(input_data, imgsz=self.imgsz, verbose=False)
run_times.append(results[0].speed['inference']) # Convert to milliseconds
run_times.append(results[0].speed["inference"]) # Convert to milliseconds
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
return np.mean(run_times), np.std(run_times)
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
check_requirements('onnxruntime')
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
times.
"""
check_requirements("onnxruntime")
import onnxruntime as ort
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 8 # Limit the number of threads
sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider'])
sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
input_tensor = sess.get_inputs()[0]
input_type = input_tensor.type
dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape
input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape
# Mapping ONNX datatype to numpy datatype
if 'float16' in input_type:
if "float16" in input_type:
input_dtype = np.float16
elif 'float' in input_type:
elif "float" in input_type:
input_dtype = np.float32
elif 'double' in input_type:
elif "double" in input_type:
input_dtype = np.float64
elif 'int64' in input_type:
elif "int64" in input_type:
input_dtype = np.int64
elif 'int32' in input_type:
elif "int32" in input_type:
input_dtype = np.int32
else:
raise ValueError(f'Unsupported ONNX datatype {input_type}')
raise ValueError(f"Unsupported ONNX datatype {input_type}")
input_data = np.random.rand(*input_tensor.shape).astype(input_dtype)
input_data = np.random.rand(*input_shape).astype(input_dtype)
input_name = input_tensor.name
output_name = sess.get_outputs()[0].name
@ -344,24 +366,39 @@ class ProfileModels:
return np.mean(run_times), np.std(run_times)
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
"""Generates a formatted string for a table row that includes model performance and metric details."""
layers, params, gradients, flops = model_info
return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
return (
f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± "
f"{t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
)
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
@staticmethod
def generate_results_dict(model_name, t_onnx, t_engine, model_info):
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
layers, params, gradients, flops = model_info
return {
'model/name': model_name,
'model/parameters': params,
'model/GFLOPs': round(flops, 3),
'model/speed_ONNX(ms)': round(t_onnx[0], 3),
'model/speed_TensorRT(ms)': round(t_engine[0], 3)}
"model/name": model_name,
"model/parameters": params,
"model/GFLOPs": round(flops, 3),
"model/speed_ONNX(ms)": round(t_onnx[0], 3),
"model/speed_TensorRT(ms)": round(t_engine[0], 3),
}
def print_table(self, table_rows):
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
@staticmethod
def print_table(table_rows):
"""Formats and prints a comparison table for different models with given statistics and performance data."""
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
header = (
f"| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | "
f"Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |"
)
separator = (
"|-------------|---------------------|--------------------|------------------------------|"
"-----------------------------------|------------------|-----------------|"
)
print(f'\n\n{header}')
print(f"\n\n{header}")
print(separator)
for row in table_rows:
print(row)

View File

@ -2,4 +2,4 @@
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
__all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks'
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"

View File

@ -1,11 +1,10 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Base callbacks
"""
"""Base callbacks."""
from collections import defaultdict
from copy import deepcopy
# Trainer callbacks ----------------------------------------------------------------------------------------------------
@ -145,37 +144,35 @@ def on_export_end(exporter):
default_callbacks = {
# Run in trainer
'on_pretrain_routine_start': [on_pretrain_routine_start],
'on_pretrain_routine_end': [on_pretrain_routine_end],
'on_train_start': [on_train_start],
'on_train_epoch_start': [on_train_epoch_start],
'on_train_batch_start': [on_train_batch_start],
'optimizer_step': [optimizer_step],
'on_before_zero_grad': [on_before_zero_grad],
'on_train_batch_end': [on_train_batch_end],
'on_train_epoch_end': [on_train_epoch_end],
'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
'on_model_save': [on_model_save],
'on_train_end': [on_train_end],
'on_params_update': [on_params_update],
'teardown': [teardown],
"on_pretrain_routine_start": [on_pretrain_routine_start],
"on_pretrain_routine_end": [on_pretrain_routine_end],
"on_train_start": [on_train_start],
"on_train_epoch_start": [on_train_epoch_start],
"on_train_batch_start": [on_train_batch_start],
"optimizer_step": [optimizer_step],
"on_before_zero_grad": [on_before_zero_grad],
"on_train_batch_end": [on_train_batch_end],
"on_train_epoch_end": [on_train_epoch_end],
"on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
"on_model_save": [on_model_save],
"on_train_end": [on_train_end],
"on_params_update": [on_params_update],
"teardown": [teardown],
# Run in validator
'on_val_start': [on_val_start],
'on_val_batch_start': [on_val_batch_start],
'on_val_batch_end': [on_val_batch_end],
'on_val_end': [on_val_end],
"on_val_start": [on_val_start],
"on_val_batch_start": [on_val_batch_start],
"on_val_batch_end": [on_val_batch_end],
"on_val_end": [on_val_end],
# Run in predictor
'on_predict_start': [on_predict_start],
'on_predict_batch_start': [on_predict_batch_start],
'on_predict_postprocess_end': [on_predict_postprocess_end],
'on_predict_batch_end': [on_predict_batch_end],
'on_predict_end': [on_predict_end],
"on_predict_start": [on_predict_start],
"on_predict_batch_start": [on_predict_batch_start],
"on_predict_postprocess_end": [on_predict_postprocess_end],
"on_predict_batch_end": [on_predict_batch_end],
"on_predict_end": [on_predict_end],
# Run in exporter
'on_export_start': [on_export_start],
'on_export_end': [on_export_end]}
"on_export_start": [on_export_start],
"on_export_end": [on_export_end],
}
def get_default_callbacks():
@ -199,10 +196,11 @@ def add_integration_callbacks(instance):
# Load HUB callbacks
from .hub import callbacks as hub_cb
callbacks_list = [hub_cb]
# Load training callbacks
if 'Trainer' in instance.__class__.__name__:
if "Trainer" in instance.__class__.__name__:
from .clearml import callbacks as clear_cb
from .comet import callbacks as comet_cb
from .dvc import callbacks as dvc_cb
@ -211,12 +209,8 @@ def add_integration_callbacks(instance):
from .raytune import callbacks as tune_cb
from .tensorboard import callbacks as tb_cb
from .wb import callbacks as wb_cb
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
# Load export callbacks (patch to avoid CoreML protobuf error)
if 'Exporter' in instance.__class__.__name__:
from .tensorboard import callbacks as tb_cb
callbacks_list.append(tb_cb)
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
# Add the callbacks to the callbacks dictionary
for callbacks in callbacks_list:

View File

@ -4,19 +4,19 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['clearml'] is True # verify integration is enabled
assert SETTINGS["clearml"] is True # verify integration is enabled
import clearml
from clearml import Task
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from clearml.binding.matplotlib_bind import PatchedMatplotlib
assert hasattr(clearml, '__version__') # verify package is not directory
assert hasattr(clearml, "__version__") # verify package is not directory
except (ImportError, AssertionError):
clearml = None
def _log_debug_samples(files, title='Debug Samples') -> None:
def _log_debug_samples(files, title="Debug Samples") -> None:
"""
Log files (images) as debug samples in the ClearML task.
@ -29,12 +29,11 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
if task := Task.current_task():
for f in files:
if f.exists():
it = re.search(r'_batch(\d+)', f.name)
it = re.search(r"_batch(\d+)", f.name)
iteration = int(it.groups()[0]) if it else 0
task.get_logger().report_image(title=title,
series=f.name.replace(it.group(), ''),
local_path=str(f),
iteration=iteration)
task.get_logger().report_image(
title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
)
def _log_plot(title, plot_path) -> None:
@ -50,13 +49,12 @@ def _log_plot(title, plot_path) -> None:
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
Task.current_task().get_logger().report_matplotlib_figure(title=title,
series='',
figure=fig,
report_interactive=False)
Task.current_task().get_logger().report_matplotlib_figure(
title=title, series="", figure=fig, report_interactive=False
)
def on_pretrain_routine_start(trainer):
@ -68,19 +66,21 @@ def on_pretrain_routine_start(trainer):
PatchPyTorchModelIO.update_current_task(None)
PatchedMatplotlib.update_current_task(None)
else:
task = Task.init(project_name=trainer.args.project or 'YOLOv8',
task_name=trainer.args.name,
tags=['YOLOv8'],
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={
'pytorch': False,
'matplotlib': False})
LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, '
'please add clearml-init and connect your arguments before initializing YOLO.')
task.connect(vars(trainer.args), name='General')
task = Task.init(
project_name=trainer.args.project or "YOLOv8",
task_name=trainer.args.name,
tags=["YOLOv8"],
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={"pytorch": False, "matplotlib": False},
)
LOGGER.warning(
"ClearML Initialized a new task. If you want to run remotely, "
"please add clearml-init and connect your arguments before initializing YOLO."
)
task.connect(vars(trainer.args), name="General")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
@ -88,22 +88,26 @@ def on_train_epoch_end(trainer):
if task := Task.current_task():
# Log debug samples
if trainer.epoch == 1:
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
# Report the current training progress
for k, v in trainer.validator.metrics.results_dict.items():
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
for k, v in trainer.lr.items():
task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
def on_fit_epoch_end(trainer):
"""Reports model information to logger at the end of an epoch."""
if task := Task.current_task():
# You should have access to the validation bboxes under jdict
task.get_logger().report_scalar(title='Epoch Time',
series='Epoch Time',
value=trainer.epoch_time,
iteration=trainer.epoch)
task.get_logger().report_scalar(
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
)
for k, v in trainer.metrics.items():
task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for k, v in model_info_for_loggers(trainer).items():
task.get_logger().report_single_value(k, v)
@ -112,7 +116,7 @@ def on_val_end(validator):
"""Logs validation results including labels and predictions."""
if Task.current_task():
# Log val_labels and val_pred
_log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
def on_train_end(trainer):
@ -120,8 +124,11 @@ def on_train_end(trainer):
if task := Task.current_task():
# Log final results, CM matrix + PR plots
files = [
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
@ -132,9 +139,14 @@ def on_train_end(trainer):
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_val_end': on_val_end,
'on_train_end': on_train_end} if clearml else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if clearml
else {}
)

View File

@ -4,20 +4,20 @@ from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['comet'] is True # verify integration is enabled
assert SETTINGS["comet"] is True # verify integration is enabled
import comet_ml
assert hasattr(comet_ml, '__version__') # verify package is not directory
assert hasattr(comet_ml, "__version__") # verify package is not directory
import os
from pathlib import Path
# Ensures certain logging functions only run for supported tasks
COMET_SUPPORTED_TASKS = ['detect']
COMET_SUPPORTED_TASKS = ["detect"]
# Names of plots created by YOLOv8 that are logged to Comet
EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
LABEL_PLOT_NAMES = "labels", "labels_correlogram"
_comet_image_prediction_count = 0
@ -26,37 +26,44 @@ except (ImportError, AssertionError):
def _get_comet_mode():
return os.getenv('COMET_MODE', 'online')
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
return os.getenv("COMET_MODE", "online")
def _get_comet_model_name():
return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
return os.getenv("COMET_MODEL_NAME", "YOLOv8")
def _get_eval_batch_logging_interval():
return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
def _get_max_image_predictions_to_log():
return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
"""Get the maximum number of image predictions to log from the environment variables."""
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
def _scale_confidence_score(score):
scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
"""Scales the given confidence score by a factor specified in an environment variable."""
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
return score * scale
def _should_log_confusion_matrix():
return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
def _should_log_image_predictions():
return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
"""Determines whether to log image predictions based on a specified environment variable."""
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
def _get_experiment_type(mode, project_name):
"""Return an experiment based on mode and project name."""
if mode == 'offline':
if mode == "offline":
return comet_ml.OfflineExperiment(project_name=project_name)
return comet_ml.Experiment(project_name=project_name)
@ -68,18 +75,21 @@ def _create_experiment(args):
return
try:
comet_mode = _get_comet_mode()
_project_name = os.getenv('COMET_PROJECT_NAME', args.project)
_project_name = os.getenv("COMET_PROJECT_NAME", args.project)
experiment = _get_experiment_type(comet_mode, _project_name)
experiment.log_parameters(vars(args))
experiment.log_others({
'eval_batch_logging_interval': _get_eval_batch_logging_interval(),
'log_confusion_matrix_on_eval': _should_log_confusion_matrix(),
'log_image_predictions': _should_log_image_predictions(),
'max_image_predictions': _get_max_image_predictions_to_log(), })
experiment.log_other('Created from', 'yolov8')
experiment.log_others(
{
"eval_batch_logging_interval": _get_eval_batch_logging_interval(),
"log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
"log_image_predictions": _should_log_image_predictions(),
"max_image_predictions": _get_max_image_predictions_to_log(),
}
)
experiment.log_other("Created from", "yolov8")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
def _fetch_trainer_metadata(trainer):
@ -95,18 +105,14 @@ def _fetch_trainer_metadata(trainer):
save_interval = curr_epoch % save_period == 0
save_assets = save and save_period > 0 and save_interval and not final_epoch
return dict(
curr_epoch=curr_epoch,
curr_step=curr_step,
save_assets=save_assets,
final_epoch=final_epoch,
)
return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
"""YOLOv8 resizes images during training and the label values
are normalized based on this resized shape. This function rescales the
bounding box labels to the original image shape.
"""
YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
This function rescales the bounding box labels to the original image shape.
"""
resized_image_height, resized_image_width = resized_image_shape
@ -126,29 +132,32 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
"""Format ground truth annotations for detection."""
indices = batch['batch_idx'] == img_idx
bboxes = batch['bboxes'][indices]
indices = batch["batch_idx"] == img_idx
bboxes = batch["bboxes"][indices]
if len(bboxes) == 0:
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes labels')
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
return None
cls_labels = batch['cls'][indices].squeeze(1).tolist()
cls_labels = batch["cls"][indices].squeeze(1).tolist()
if class_name_map:
cls_labels = [str(class_name_map[label]) for label in cls_labels]
original_image_shape = batch['ori_shape'][img_idx]
resized_image_shape = batch['resized_shape'][img_idx]
ratio_pad = batch['ratio_pad'][img_idx]
original_image_shape = batch["ori_shape"][img_idx]
resized_image_shape = batch["resized_shape"][img_idx]
ratio_pad = batch["ratio_pad"][img_idx]
data = []
for box, label in zip(bboxes, cls_labels):
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
data.append({
'boxes': [box],
'label': f'gt_{label}',
'score': _scale_confidence_score(1.0), })
data.append(
{
"boxes": [box],
"label": f"gt_{label}",
"score": _scale_confidence_score(1.0),
}
)
return {'name': 'ground_truth', 'data': data}
return {"name": "ground_truth", "data": data}
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
@ -158,31 +167,34 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab
predictions = metadata.get(image_id)
if not predictions:
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes predictions')
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
return None
data = []
for prediction in predictions:
boxes = prediction['bbox']
score = _scale_confidence_score(prediction['score'])
cls_label = prediction['category_id']
boxes = prediction["bbox"]
score = _scale_confidence_score(prediction["score"])
cls_label = prediction["category_id"]
if class_label_map:
cls_label = str(class_label_map[cls_label])
data.append({'boxes': [boxes], 'label': cls_label, 'score': score})
data.append({"boxes": [boxes], "label": cls_label, "score": score})
return {'name': 'prediction', 'data': data}
return {"name": "prediction", "data": data}
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
"""Join the ground truth and prediction annotations if they exist."""
ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch,
class_label_map)
prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map,
class_label_map)
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
img_idx, image_path, batch, class_label_map
)
prediction_annotations = _format_prediction_annotations_for_detection(
image_path, prediction_metadata_map, class_label_map
)
annotations = [
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None]
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
]
return [annotations] if annotations else None
@ -190,8 +202,8 @@ def _create_prediction_metadata_map(model_predictions):
"""Create metadata map for model predictions by groupings them based on image ID."""
pred_metadata_map = {}
for prediction in model_predictions:
pred_metadata_map.setdefault(prediction['image_id'], [])
pred_metadata_map[prediction['image_id']].append(prediction)
pred_metadata_map.setdefault(prediction["image_id"], [])
pred_metadata_map[prediction["image_id"]].append(prediction)
return pred_metadata_map
@ -199,13 +211,9 @@ def _create_prediction_metadata_map(model_predictions):
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
"""Log the confusion matrix to Comet experiment."""
conf_mat = trainer.validator.confusion_matrix.matrix
names = list(trainer.data['names'].values()) + ['background']
names = list(trainer.data["names"].values()) + ["background"]
experiment.log_confusion_matrix(
matrix=conf_mat,
labels=names,
max_categories=len(names),
epoch=curr_epoch,
step=curr_step,
matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
)
@ -243,7 +251,7 @@ def _log_image_predictions(experiment, validator, curr_step):
if (batch_idx + 1) % batch_logging_interval != 0:
continue
image_paths = batch['im_file']
image_paths = batch["im_file"]
for img_idx, image_path in enumerate(image_paths):
if _comet_image_prediction_count >= max_image_predictions:
return
@ -267,28 +275,23 @@ def _log_image_predictions(experiment, validator, curr_step):
def _log_plots(experiment, trainer):
"""Logs evaluation plots and label plots for the experiment."""
plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES]
plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
_log_images(experiment, plot_filenames, None)
label_plot_filenames = [trainer.save_dir / f'{labels}.jpg' for labels in LABEL_PLOT_NAMES]
label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
_log_images(experiment, label_plot_filenames, None)
def _log_model(experiment, trainer):
"""Log the best-trained model to Comet.ml."""
model_name = _get_comet_model_name()
experiment.log_model(
model_name,
file_or_folder=str(trainer.best),
file_name='best.pt',
overwrite=True,
)
experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
def on_pretrain_routine_start(trainer):
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
experiment = comet_ml.get_global_experiment()
is_alive = getattr(experiment, 'alive', False)
is_alive = getattr(experiment, "alive", False)
if not experiment or not is_alive:
_create_experiment(trainer.args)
@ -300,17 +303,13 @@ def on_train_epoch_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata['curr_epoch']
curr_step = metadata['curr_step']
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
experiment.log_metrics(
trainer.label_loss_items(trainer.tloss, prefix='train'),
step=curr_step,
epoch=curr_epoch,
)
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
_log_images(experiment, trainer.save_dir.glob('train_batch*.jpg'), curr_step)
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
def on_fit_epoch_end(trainer):
@ -320,14 +319,15 @@ def on_fit_epoch_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata['curr_epoch']
curr_step = metadata['curr_step']
save_assets = metadata['save_assets']
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
save_assets = metadata["save_assets"]
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
from ultralytics.utils.torch_utils import model_info_for_loggers
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
if not save_assets:
@ -347,8 +347,8 @@ def on_train_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata['curr_epoch']
curr_step = metadata['curr_step']
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
plots = trainer.args.plots
_log_model(experiment, trainer)
@ -363,8 +363,13 @@ def on_train_end(trainer):
_comet_image_prediction_count = 0
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if comet_ml else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if comet_ml
else {}
)

View File

@ -1,26 +1,18 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['dvc'] is True # verify integration is enabled
assert SETTINGS["dvc"] is True # verify integration is enabled
import dvclive
assert hasattr(dvclive, '__version__') # verify package is not directory
assert checks.check_version("dvclive", "2.11.0", verbose=True)
import os
import re
from importlib.metadata import version
from pathlib import Path
import pkg_resources as pkg
ver = version('dvclive')
if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
dvclive = None # noqa: F811
# DVCLive logger instance
live = None
_processed_plots = {}
@ -33,108 +25,121 @@ except (ImportError, AssertionError, TypeError):
dvclive = None
def _log_images(path, prefix=''):
def _log_images(path, prefix=""):
"""Logs images at specified path with an optional prefix using DVCLive."""
if live:
name = path.name
# Group images by batch to enable sliders in UI
if m := re.search(r'_batch(\d+)', name):
if m := re.search(r"_batch(\d+)", name):
ni = m[1]
new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
name = (Path(new_stem) / ni).with_suffix(path.suffix)
live.log_image(os.path.join(prefix, name), path)
def _log_plots(plots, prefix=''):
def _log_plots(plots, prefix=""):
"""Logs plot images for training progress if they have not been previously processed."""
for name, params in plots.items():
timestamp = params['timestamp']
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
_log_images(name, prefix)
_processed_plots[name] = timestamp
def _log_confusion_matrix(validator):
"""Logs the confusion matrix for the given validator using DVCLive."""
targets = []
preds = []
matrix = validator.confusion_matrix.matrix
names = list(validator.names.values())
if validator.confusion_matrix.task == 'detect':
names += ['background']
if validator.confusion_matrix.task == "detect":
names += ["background"]
for ti, pred in enumerate(matrix.T.astype(int)):
for pi, num in enumerate(pred):
targets.extend([names[ti]] * num)
preds.extend([names[pi]] * num)
live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)
live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
def on_pretrain_routine_start(trainer):
"""Initializes DVCLive logger for training metadata during pre-training routine."""
try:
global live
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
LOGGER.info(
f'DVCLive is detected and auto logging is enabled (can be disabled in the {SETTINGS.file} with `dvc: false`).'
)
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
def on_pretrain_routine_end(trainer):
_log_plots(trainer.plots, 'train')
"""Logs plots related to the training process at the end of the pretraining routine."""
_log_plots(trainer.plots, "train")
def on_train_start(trainer):
"""Logs the training parameters if DVCLive logging is active."""
if live:
live.log_params(trainer.args)
def on_train_epoch_start(trainer):
"""Sets the global variable _training_epoch value to True at the start of training each epoch."""
global _training_epoch
_training_epoch = True
def on_fit_epoch_end(trainer):
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
global _training_epoch
if live and _training_epoch:
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for metric, value in model_info_for_loggers(trainer).items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, 'train')
_log_plots(trainer.validator.plots, 'val')
_log_plots(trainer.plots, "train")
_log_plots(trainer.validator.plots, "val")
live.next_step()
_training_epoch = False
def on_train_end(trainer):
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
if live:
# At the end log the best metrics. It runs validator on the best model internally.
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, 'val')
_log_plots(trainer.validator.plots, 'val')
_log_plots(trainer.plots, "val")
_log_plots(trainer.validator.plots, "val")
_log_confusion_matrix(trainer.validator)
if trainer.best.exists():
live.log_artifact(trainer.best, copy=True, type='model')
live.log_artifact(trainer.best, copy=True, type="model")
live.end()
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_train_start': on_train_start,
'on_train_epoch_start': on_train_epoch_start,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if dvclive else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_start": on_train_start,
"on_train_epoch_start": on_train_epoch_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if dvclive
else {}
)

View File

@ -9,51 +9,67 @@ from ultralytics.utils import LOGGER, SETTINGS
def on_pretrain_routine_end(trainer):
"""Logs info before starting timer for upload rate limit."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Start timer for upload rate limit
LOGGER.info(f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
session.timers = {
"metrics": time(),
"ckpt": time(),
} # start timer on session.rate_limit
def on_fit_epoch_end(trainer):
"""Uploads training progress metrics at the end of each epoch."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Upload metrics after val end
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
all_plots = {
**trainer.label_loss_items(trainer.tloss, prefix="train"),
**trainer.metrics,
}
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
# If any metrics fail to upload, add them to the queue to attempt uploading again.
if session.metrics_upload_failed_queue:
session.metrics_queue.update(session.metrics_upload_failed_queue)
if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
session.upload_metrics()
session.timers['metrics'] = time() # reset timer
session.timers["metrics"] = time() # reset timer
session.metrics_queue = {} # reset queue
def on_model_save(trainer):
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}')
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}")
session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers['ckpt'] = time() # reset timer
session.timers["ckpt"] = time() # reset timer
def on_train_end(trainer):
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Upload final model and metrics with exponential standoff
LOGGER.info(f'{PREFIX}Syncing final model...')
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
LOGGER.info(f"{PREFIX}Syncing final model...")
session.upload_model(
trainer.epoch,
trainer.best,
map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
final=True,
)
session.alive = False # stop heartbeats
LOGGER.info(f'{PREFIX}Done ✅\n'
f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀")
def on_train_start(trainer):
@ -76,12 +92,17 @@ def on_export_start(exporter):
events(exporter.args)
callbacks = {
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_model_save': on_model_save,
'on_train_end': on_train_end,
'on_train_start': on_train_start,
'on_val_start': on_val_start,
'on_predict_start': on_predict_start,
'on_export_start': on_export_start} if SETTINGS['hub'] is True else {} # verify enabled
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_model_save": on_model_save,
"on_train_end": on_train_end,
"on_train_start": on_train_start,
"on_val_start": on_val_start,
"on_predict_start": on_predict_start,
"on_export_start": on_export_start,
}
if SETTINGS["hub"] is True
else {}
) # verify enabled

View File

@ -1,70 +1,133 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
MLflow Logging for Ultralytics YOLO.
from ultralytics.utils import LOGGER, ROOT, SETTINGS, TESTS_RUNNING, colorstr
This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts.
For setting up, a tracking URI should be specified. The logging can be customized using environment variables.
Commands:
1. To set a project name:
`export MLFLOW_EXPERIMENT_NAME=<your_experiment_name>` or use the project=<project> argument
2. To set a run name:
`export MLFLOW_RUN=<your_run_name>` or use the name=<name> argument
3. To start a local MLflow server:
mlflow server --backend-store-uri runs/mlflow
It will by default start a local server at http://127.0.0.1:5000.
To specify a different URI, set the MLFLOW_TRACKING_URI environment variable.
4. To kill all running MLflow server instances:
ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9
"""
from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['mlflow'] is True # verify integration is enabled
import os
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
assert SETTINGS["mlflow"] is True # verify integration is enabled
import mlflow
assert hasattr(mlflow, '__version__') # verify package is not directory
assert hasattr(mlflow, "__version__") # verify package is not directory
from pathlib import Path
import os
import re
PREFIX = colorstr("MLflow: ")
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
except (ImportError, AssertionError):
mlflow = None
def on_pretrain_routine_end(trainer):
"""Logs training parameters to MLflow."""
global mlflow, run, experiment_name
"""
Log training parameters to MLflow at the end of the pretraining routine.
if os.environ.get('MLFLOW_TRACKING_URI') is None:
mlflow = None
This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
from the trainer.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
Global:
mlflow: The imported mlflow module to use for logging.
Environment Variables:
MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of the training phase.
"""
global mlflow
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
mlflow.set_tracking_uri(uri)
# Set experiment and run names
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
mlflow.set_experiment(experiment_name)
mlflow.autolog()
try:
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
if Path(uri).is_dir():
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
mlflow.log_params(dict(trainer.args))
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
def on_train_epoch_end(trainer):
"""Log training metrics at the end of each train epoch to MLflow."""
if mlflow:
mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
mlflow.set_tracking_uri(mlflow_location)
experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment is None:
mlflow.create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)
prefix = colorstr('MLFlow: ')
try:
run, active_run = mlflow, mlflow.active_run()
if not active_run:
active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name)
LOGGER.info(f'{prefix}Using run_id({active_run.info.run_id}) at {mlflow_location}')
run.log_params(vars(trainer.model.args))
except Exception as err:
LOGGER.error(f'{prefix}Failing init - {repr(err)}')
LOGGER.warning(f'{prefix}Continuing without Mlflow')
mlflow.log_metrics(
metrics={
**SANITIZE(trainer.lr),
**SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")),
},
step=trainer.epoch,
)
def on_fit_epoch_end(trainer):
"""Logs training metrics to Mlflow."""
"""Log training metrics at the end of each fit epoch to MLflow."""
if mlflow:
metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch)
def on_train_end(trainer):
"""Called at end of train loop to log model artifact info."""
"""Log model artifacts at the end of the training."""
if mlflow:
run.log_artifact(trainer.last)
run.log_artifact(trainer.best)
run.pyfunc.log_model(artifact_path=experiment_name,
code_path=[str(ROOT.parent)],
artifacts={'model_path': str(trainer.save_dir)},
python_model=run.pyfunc.PythonModel())
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
mlflow.log_artifact(str(f))
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true")
if keep_run_active:
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
else:
mlflow.end_run()
LOGGER.debug(f"{PREFIX}mlflow run ended")
LOGGER.info(
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
f"{PREFIX}disable with 'yolo settings mlflow=False'"
)
callbacks = {
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if mlflow else {}
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if mlflow
else {}
)

View File

@ -4,11 +4,11 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['neptune'] is True # verify integration is enabled
assert SETTINGS["neptune"] is True # verify integration is enabled
import neptune
from neptune.types import File
assert hasattr(neptune, '__version__')
assert hasattr(neptune, "__version__")
run = None # NeptuneAI experiment logger instance
@ -23,55 +23,55 @@ def _log_scalars(scalars, step=0):
run[k].append(value=v, step=step)
def _log_images(imgs_dict, group=''):
def _log_images(imgs_dict, group=""):
"""Log scalars to the NeptuneAI experiment logger."""
if run:
for k, v in imgs_dict.items():
run[f'{group}/{k}'].upload(File(v))
run[f"{group}/{k}"].upload(File(v))
def _log_plot(title, plot_path):
"""Log plots to the NeptuneAI experiment logger."""
"""
Log image as plot in the plot section of NeptuneAI
Log plots to the NeptuneAI experiment logger.
arguments:
title (str) Title of the plot
plot_path (PosixPath or str) Path to the saved image file
"""
Args:
title (str): Title of the plot.
plot_path (PosixPath | str): Path to the saved image file.
"""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
run[f'Plots/{title}'].upload(fig)
run[f"Plots/{title}"].upload(fig)
def on_pretrain_routine_start(trainer):
"""Callback function called before the training routine starts."""
try:
global run
run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
"""Callback function called at end of each training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
if trainer.epoch == 1:
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
def on_fit_epoch_end(trainer):
"""Callback function called at end of each fit (train+val) epoch."""
if run and trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
run['Configuration/Model'] = model_info_for_loggers(trainer)
run["Configuration/Model"] = model_info_for_loggers(trainer)
_log_scalars(trainer.metrics, trainer.epoch + 1)
@ -79,7 +79,7 @@ def on_val_end(validator):
"""Callback function called at end of each validation."""
if run:
# Log val_labels and val_pred
_log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
def on_train_end(trainer):
@ -87,19 +87,26 @@ def on_train_end(trainer):
if run:
# Log final results, CM matrix + PR plots
files = [
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Log the final model
run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
trainer.best)))
run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_val_end': on_val_end,
'on_train_end': on_train_end} if neptune else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if neptune
else {}
)

View File

@ -3,7 +3,7 @@
from ultralytics.utils import SETTINGS
try:
assert SETTINGS['raytune'] is True # verify integration is enabled
assert SETTINGS["raytune"] is True # verify integration is enabled
import ray
from ray import tune
from ray.air import session
@ -16,9 +16,14 @@ def on_fit_epoch_end(trainer):
"""Sends training metrics to Ray Tune at end of each epoch."""
if ray.tune.is_session_enabled():
metrics = trainer.metrics
metrics['epoch'] = trainer.epoch
metrics["epoch"] = trainer.epoch
session.report(metrics)
callbacks = {
'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}
callbacks = (
{
"on_fit_epoch_end": on_fit_epoch_end,
}
if tune
else {}
)

View File

@ -1,17 +1,25 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
try:
# WARNING: do not move import due to protobuf issue in https://github.com/ultralytics/ultralytics/pull/4674
# WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
from torch.utils.tensorboard import SummaryWriter
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['tensorboard'] is True # verify integration is enabled
assert SETTINGS["tensorboard"] is True # verify integration is enabled
WRITER = None # TensorBoard SummaryWriter instance
PREFIX = colorstr("TensorBoard: ")
except (ImportError, AssertionError, TypeError):
# Imports below only required if TensorBoard enabled
import warnings
from copy import deepcopy
from ultralytics.utils.torch_utils import de_parallel, torch
except (ImportError, AssertionError, TypeError, AttributeError):
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
# AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
SummaryWriter = None
@ -24,20 +32,38 @@ def _log_scalars(scalars, step=0):
def _log_tensorboard_graph(trainer):
"""Log model graph to TensorBoard."""
try:
import warnings
from ultralytics.utils.torch_utils import de_parallel, torch
# Input image
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
# Try simple method first (YOLO)
with contextlib.suppress(Exception):
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(de_parallel(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
def on_pretrain_routine_start(trainer):
@ -46,10 +72,9 @@ def on_pretrain_routine_start(trainer):
try:
global WRITER
WRITER = SummaryWriter(str(trainer.save_dir))
prefix = colorstr('TensorBoard: ')
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
def on_train_start(trainer):
@ -58,9 +83,10 @@ def on_train_start(trainer):
_log_tensorboard_graph(trainer)
def on_batch_end(trainer):
"""Logs scalar statistics at the end of a training batch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
def on_train_epoch_end(trainer):
"""Logs scalar statistics at the end of a training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
def on_fit_epoch_end(trainer):
@ -68,8 +94,13 @@ def on_fit_epoch_end(trainer):
_log_scalars(trainer.metrics, trainer.epoch + 1)
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_start': on_train_start,
'on_fit_epoch_end': on_fit_epoch_end,
'on_batch_end': on_batch_end} if SummaryWriter else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_start": on_train_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_epoch_end": on_train_epoch_end,
}
if SummaryWriter
else {}
)

View File

@ -5,10 +5,13 @@ from ultralytics.utils.torch_utils import model_info_for_loggers
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['wandb'] is True # verify integration is enabled
assert SETTINGS["wandb"] is True # verify integration is enabled
import wandb as wb
assert hasattr(wb, '__version__')
assert hasattr(wb, "__version__") # verify package is not directory
import numpy as np
import pandas as pd
_processed_plots = {}
@ -16,9 +19,89 @@ except (ImportError, AssertionError):
wb = None
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
"""
Create and log a custom metric visualization to wandb.plot.pr_curve.
This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
different classes.
Args:
x (List): Values for the x-axis; expected to have length N.
y (List): Corresponding values for the y-axis; also expected to have length N.
classes (List): Labels identifying the class of each point; length N.
title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
y_title (str, optional): Label for the y-axis; defaults to 'Precision'.
Returns:
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
"""
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
fields = {"x": "x", "y": "y", "class": "class"}
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
return wb.plot_table(
"wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
)
def _plot_curve(
x,
y,
names=None,
id="precision-recall",
title="Precision Recall Curve",
x_title="Recall",
y_title="Precision",
num_x=100,
only_mean=False,
):
"""
Log a metric curve visualization.
This function generates a metric curve based on input data and logs the visualization to wandb.
The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
Args:
x (np.ndarray): Data points for the x-axis with length N.
y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.
Note:
The function leverages the '_custom_table' function to generate the actual visualization.
"""
# Create new x
if names is None:
names = []
x_new = np.linspace(x[0], x[-1], num_x).round(5)
# Create arrays for logging
x_log = x_new.tolist()
y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
if only_mean:
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
else:
classes = ["mean"] * len(x_log)
for i, yi in enumerate(y):
x_log.extend(x_new) # add new x
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
classes.extend([names[i]] * len(x_new)) # add class names
wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
def _log_plots(plots, step):
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
for name, params in plots.items():
timestamp = params['timestamp']
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
_processed_plots[name] = timestamp
@ -26,7 +109,7 @@ def _log_plots(plots, step):
def on_pretrain_routine_start(trainer):
"""Initiate and start project if module is present."""
wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args))
wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
def on_fit_epoch_end(trainer):
@ -40,7 +123,7 @@ def on_fit_epoch_end(trainer):
def on_train_epoch_end(trainer):
"""Log metrics and save images at the end of each training epoch."""
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
wb.run.log(trainer.lr, step=trainer.epoch + 1)
if trainer.epoch == 1:
_log_plots(trainer.plots, step=trainer.epoch + 1)
@ -50,14 +133,31 @@ def on_train_end(trainer):
"""Save the best model as an artifact at end of training."""
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
_log_plots(trainer.plots, step=trainer.epoch + 1)
art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
if trainer.best.exists():
art.add_file(trainer.best)
wb.run.log_artifact(art, aliases=['best'])
wb.run.log_artifact(art, aliases=["best"])
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
x, y, x_title, y_title = curve_values
_plot_curve(
x,
y,
names=list(trainer.validator.metrics.names.values()),
id=f"curves/{curve_name}",
title=curve_name,
x_title=x_title,
y_title=y_title,
)
wb.run.finish() # required or run continues on dashboard
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if wb else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if wb
else {}
)

View File

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import glob
import inspect
@ -9,20 +10,96 @@ import re
import shutil
import subprocess
import time
from importlib import metadata
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
import pkg_resources as pkg
import psutil
import requests
import torch
from matplotlib import font_manager
from ultralytics.utils import (ASSETS, AUTOINSTALL, LINUX, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, ThreadingLocked,
TryExcept, clean_url, colorstr, downloads, emojis, is_colab, is_docker, is_jupyter,
is_kaggle, is_online, is_pip_package, url2file)
from ultralytics.utils import (
ASSETS,
AUTOINSTALL,
LINUX,
LOGGER,
ONLINE,
ROOT,
USER_CONFIG_DIR,
SimpleNamespace,
ThreadingLocked,
TryExcept,
clean_url,
colorstr,
downloads,
emojis,
is_colab,
is_docker,
is_github_action_running,
is_jupyter,
is_kaggle,
is_online,
is_pip_package,
url2file,
)
PYTHON_VERSION = platform.python_version()
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
"""
Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
Args:
file_path (Path): Path to the requirements.txt file.
package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'.
Returns:
(List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys.
Example:
```python
from ultralytics.utils.checks import parse_requirements
parse_requirements(package='ultralytics')
```
"""
if package:
requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
else:
requires = Path(file_path).read_text().splitlines()
requirements = []
for line in requires:
line = line.strip()
if line and not line.startswith("#"):
line = line.split("#")[0].strip() # ignore inline comments
match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line)
if match:
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
return requirements
def parse_version(version="0.0.0") -> tuple:
"""
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
function replaces deprecated 'pkg_resources.parse_version(v)'.
Args:
version (str): Version string, i.e. '2.0.1+cpu'
Returns:
(tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
"""
try:
return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
return 0, 0, 0
def is_ascii(s) -> bool:
@ -33,7 +110,7 @@ def is_ascii(s) -> bool:
s (str): String to be checked.
Returns:
bool: True if the string is composed only of ASCII characters, False otherwise.
(bool): True if the string is composed only of ASCII characters, False otherwise.
"""
# Convert list, tuple, None, etc. to string
s = str(s)
@ -65,16 +142,22 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
imgsz = [imgsz]
elif isinstance(imgsz, (list, tuple)):
imgsz = list(imgsz)
elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'
imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
else:
raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")
raise TypeError(
f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
)
# Apply max_dim
if len(imgsz) > max_dim:
msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
msg = (
"'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
)
if max_dim != 1:
raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
imgsz = [max(imgsz)]
# Make image size a multiple of the stride
@ -82,7 +165,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
# Print warning message if image size was updated
if sz != imgsz:
LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
# Add missing dimensions if necessary
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
@ -90,66 +173,88 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
return sz
def check_version(current: str = '0.0.0',
required: str = '0.0.0',
name: str = 'version ',
hard: bool = False,
verbose: bool = False) -> bool:
def check_version(
current: str = "0.0.0",
required: str = "0.0.0",
name: str = "version",
hard: bool = False,
verbose: bool = False,
msg: str = "",
) -> bool:
"""
Check current version against the required version or range.
Args:
current (str): Current version.
current (str): Current version or package name to get version from.
required (str): Required version or range (in pip-style format).
name (str): Name to be used in warning message.
hard (bool): If True, raise an AssertionError if the requirement is not met.
verbose (bool): If True, print warning message if requirement is not met.
name (str, optional): Name to be used in warning message.
hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
verbose (bool, optional): If True, print warning message if requirement is not met.
msg (str, optional): Extra message to display if verbose.
Returns:
(bool): True if requirement is met, False otherwise.
Example:
# check if current version is exactly 22.04
```python
# Check if current version is exactly 22.04
check_version(current='22.04', required='==22.04')
# check if current version is greater than or equal to 22.04
# Check if current version is greater than or equal to 22.04
check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed
# check if current version is less than or equal to 22.04
# Check if current version is less than or equal to 22.04
check_version(current='22.04', required='<=22.04')
# check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
# Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
check_version(current='21.10', required='>20.04,<22.04')
```
"""
current = pkg.parse_version(current)
constraints = re.findall(r'([<>!=]{1,2}\s*\d+\.\d+)', required) or [f'>={required}']
if not current: # if current is '' or None
LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
return True
elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
try:
name = current # assigned package name to 'name' arg
current = metadata.version(current) # get version string from package name
except metadata.PackageNotFoundError as e:
if hard:
raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e
else:
return False
if not required: # if required is '' or None
return True
op = ""
version = ""
result = True
for constraint in constraints:
op, version = re.match(r'([<>!=]{1,2})\s*(\d+\.\d+)', constraint).groups()
version = pkg.parse_version(version)
if op == '==' and current != version:
c = parse_version(current) # '1.2.3' -> (1, 2, 3)
for r in required.strip(",").split(","):
op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
v = parse_version(version) # '1.2.3' -> (1, 2, 3)
if op == "==" and c != v:
result = False
elif op == '!=' and current == version:
elif op == "!=" and c == v:
result = False
elif op == '>=' and not (current >= version):
elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
result = False
elif op == '<=' and not (current <= version):
elif op == "<=" and not (c <= v):
result = False
elif op == '>' and not (current > version):
elif op == ">" and not (c > v):
result = False
elif op == '<' and not (current < version):
elif op == "<" and not (c < v):
result = False
if not result:
warning_message = f'WARNING ⚠️ {name}{required} is required, but {name}{current} is currently installed'
warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
if hard:
raise ModuleNotFoundError(emojis(warning_message)) # assert version requirements met
raise ModuleNotFoundError(emojis(warning)) # assert version requirements met
if verbose:
LOGGER.warning(warning_message)
LOGGER.warning(warning)
return result
def check_latest_pypi_version(package_name='ultralytics'):
def check_latest_pypi_version(package_name="ultralytics"):
"""
Returns the latest version of a PyPI package without downloading or installing it.
@ -161,9 +266,9 @@ def check_latest_pypi_version(package_name='ultralytics'):
"""
with contextlib.suppress(Exception):
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
response = requests.get(f'https://pypi.org/pypi/{package_name}/json', timeout=3)
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
if response.status_code == 200:
return response.json()['info']['version']
return response.json()["info"]["version"]
def check_pip_update_available():
@ -176,16 +281,19 @@ def check_pip_update_available():
if ONLINE and is_pip_package():
with contextlib.suppress(Exception):
from ultralytics import __version__
latest = check_latest_pypi_version()
if pkg.parse_version(__version__) < pkg.parse_version(latest): # update is available
LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 '
f"Update with 'pip install -U ultralytics'")
if check_version(__version__, f"<{latest}"): # check if current version is < latest version
LOGGER.info(
f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
f"Update with 'pip install -U ultralytics'"
)
return True
return False
@ThreadingLocked()
def check_font(font='Arial.ttf'):
def check_font(font="Arial.ttf"):
"""
Find font locally or download to user's configuration directory if it does not already exist.
@ -208,13 +316,13 @@ def check_font(font='Arial.ttf'):
return matches[0]
# Download to USER_CONFIG_DIR if missing
url = f'https://ultralytics.com/assets/{name}'
if downloads.is_url(url):
url = f"https://ultralytics.com/assets/{name}"
if downloads.is_url(url, check=True):
downloads.safe_download(url=url, file=file)
return file
def check_python(minimum: str = '3.8.0') -> bool:
def check_python(minimum: str = "3.8.0") -> bool:
"""
Check current python version against the required minimum version.
@ -222,13 +330,13 @@ def check_python(minimum: str = '3.8.0') -> bool:
minimum (str): Required minimum version of python.
Returns:
None
(bool): Whether the installed Python version meets the minimum constraints.
"""
return check_version(platform.python_version(), minimum, name='Python ', hard=True)
return check_version(PYTHON_VERSION, minimum, name="Python ", hard=True)
@TryExcept()
def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''):
def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
"""
Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
@ -253,46 +361,43 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
check_requirements(['numpy', 'ultralytics>=8.0.0'])
```
"""
prefix = colorstr('red', 'bold', 'requirements:')
prefix = colorstr("red", "bold", "requirements:")
check_python() # check python version
check_torchvision() # check torch-torchvision compatibility
if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve()
assert file.exists(), f'{prefix} {file} not found, check failed.'
with file.open() as f:
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
assert file.exists(), f"{prefix} {file} not found, check failed."
requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
elif isinstance(requirements, str):
requirements = [requirements]
pkgs = []
for r in requirements:
r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
name, required = match[1], match[2].strip() if match[2] else ""
try:
pkg.require(r_stripped) # exception if requirements not met
except pkg.DistributionNotFound:
try: # attempt to import (slower but more accurate)
import importlib
importlib.import_module(next(pkg.parse_requirements(r_stripped)).name)
except ImportError:
pkgs.append(r)
except pkg.VersionConflict:
assert check_version(metadata.version(name), required) # exception if requirements not met
except (AssertionError, metadata.PackageNotFoundError):
pkgs.append(r)
s = ' '.join(f'"{x}"' for x in pkgs) # console string
s = " ".join(f'"{x}"' for x in pkgs) # console string
if s:
if install and AUTOINSTALL: # check environment variable
n = len(pkgs) # number of packages updates
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
try:
t = time.time()
assert is_online(), 'AutoUpdate skipped (offline)'
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
assert is_online(), "AutoUpdate skipped (offline)"
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
dt = time.time() - t
LOGGER.info(
f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
)
except Exception as e:
LOGGER.warning(f'{prefix}{e}')
LOGGER.warning(f"{prefix}{e}")
return False
else:
return False
@ -305,134 +410,211 @@ def check_torchvision():
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
to the provided compatibility table based on:
https://github.com/pytorch/vision#installation.
The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
Torchvision versions.
"""
import torchvision
# Compatibility table
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
# Extract only the major and minor versions
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch]
if all(pkg.parse_version(v_torchvision) != pkg.parse_version(v) for v in compatible_versions):
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
"'pip install -U torch torchvision' to update both.\n"
'For a full compatibility table see https://github.com/pytorch/vision#installation')
if all(v_torchvision != v for v in compatible_versions):
print(
f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
"'pip install -U torch torchvision' to update both.\n"
"For a full compatibility table see https://github.com/pytorch/vision#installation"
)
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):
"""Check file(s) for acceptable suffix."""
if file and suffix:
if isinstance(suffix, str):
suffix = (suffix, )
suffix = (suffix,)
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower().strip() # file suffix
if len(s):
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}'
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"
def check_yolov5u_filename(file: str, verbose: bool = True):
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
if 'yolov3' in file or 'yolov5' in file:
if 'u.yaml' in file:
file = file.replace('u.yaml', '.yaml') # i.e. yolov5nu.yaml -> yolov5n.yaml
elif '.pt' in file and 'u' not in file:
if "yolov3" in file or "yolov5" in file:
if "u.yaml" in file:
file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
elif ".pt" in file and "u" not in file:
original_file = file
file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r'(.*yolov5([nsmlx])6)\.pt', '\\1u.pt', file) # i.e. yolov5n6.pt -> yolov5n6u.pt
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.pt', '\\1u.pt', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file and verbose:
LOGGER.info(
f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
)
return file
def check_file(file, suffix='', download=True, hard=True):
def check_model_file_from_stem(model="yolov8n"):
"""Return a model filename from a valid model stem."""
if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt
else:
return model
def check_file(file, suffix="", download=True, hard=True):
"""Search/download file (if necessary) and return path."""
check_suffix(file, suffix) # optional
file = str(file).strip() # convert to string and strip spaces
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
if not file or ('://' not in file and Path(file).exists()): # exists ('://' check required in Windows Python<3.10)
if (
not file
or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
or file.lower().startswith("grpc://")
): # file exists or gRPC Triton images
return file
elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
url = file # warning: Pathlib turns :// -> :/
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
if Path(file).exists():
LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
downloads.safe_download(url=url, file=file, unzip=False)
return file
else: # search
files = glob.glob(str(ROOT / 'cfg' / '**' / file), recursive=True) # find file
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
if not files and hard:
raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard:
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
return files[0] if len(files) else [] # return file
return files[0] if len(files) else [] if hard else file # return file
def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
"""Search/download YAML file (if necessary) and return path, checking suffix."""
return check_file(file, suffix, hard=hard)
def check_is_path_safe(basedir, path):
"""
Check if the resolved path is under the intended directory to prevent path traversal.
Args:
basedir (Path | str): The intended directory.
path (Path | str): The path to check.
Returns:
(bool): True if the path is safe, False otherwise.
"""
base_dir_resolved = Path(basedir).resolve()
path_resolved = Path(path).resolve()
return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
def check_imshow(warn=False):
"""Check if environment supports image displays."""
try:
if LINUX:
assert 'DISPLAY' in os.environ and not is_docker() and not is_colab() and not is_kaggle()
cv2.imshow('test', np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle()
cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
return True
except Exception as e:
if warn:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
return False
def check_yolo(verbose=True, device=''):
def check_yolo(verbose=True, device=""):
"""Return a human-readable YOLO software and hardware summary."""
import psutil
from ultralytics.utils.torch_utils import select_device
if is_jupyter():
if check_requirements('wandb', install=False):
os.system('pip uninstall -y wandb') # uninstall wandb: unwanted account creation prompt with infinite hang
if check_requirements("wandb", install=False):
os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang
if is_colab():
shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory
shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
if verbose:
# System info
gib = 1 << 30 # bytes per GiB
ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage('/')
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
total, used, free = shutil.disk_usage("/")
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
with contextlib.suppress(Exception): # clear display if ipython is installed
from IPython import display
display.clear_output()
else:
s = ''
s = ""
select_device(device=device, newline=False)
LOGGER.info(f'Setup complete ✅ {s}')
LOGGER.info(f"Setup complete ✅ {s}")
def collect_system_info():
"""Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
import psutil
from ultralytics.utils import ENVIRONMENT, is_git_dir
from ultralytics.utils.torch_utils import get_cpu_info
ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
check_yolo()
LOGGER.info(
f"\n{'OS':<20}{platform.platform()}\n"
f"{'Environment':<20}{ENVIRONMENT}\n"
f"{'Python':<20}{PYTHON_VERSION}\n"
f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
f"{'RAM':<20}{ram_info:.2f} GB\n"
f"{'CPU':<20}{get_cpu_info()}\n"
f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
)
for r in parse_requirements(package="ultralytics"):
try:
current = metadata.version(r.name)
is_met = "" if check_version(current, str(r.specifier), hard=True) else ""
except metadata.PackageNotFoundError:
current = "(not installed)"
is_met = ""
LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")
if is_github_action_running():
LOGGER.info(
f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
)
def check_amp(model):
"""
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model.
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
results, so AMP will be disabled during training.
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
be disabled during training.
Args:
model (nn.Module): A YOLOv8 model instance.
@ -450,7 +632,7 @@ def check_amp(model):
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
"""
device = next(model.parameters()).device # get model device
if device.type in ('cpu', 'mps'):
if device.type in ("cpu", "mps"):
return False # AMP only used on CUDA devices
def amp_allclose(m, im):
@ -461,23 +643,27 @@ def check_amp(model):
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
im = ASSETS / 'bus.jpg' # image to check
prefix = colorstr('AMP: ')
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
im = ASSETS / "bus.jpg" # image to check
prefix = colorstr("AMP: ")
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
try:
from ultralytics import YOLO
assert amp_allclose(YOLO('yolov8n.pt'), im)
LOGGER.info(f'{prefix}checks passed ✅')
assert amp_allclose(YOLO("yolov8n.pt"), im)
LOGGER.info(f"{prefix}checks passed ✅")
except ConnectionError:
LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")
except (AttributeError, ModuleNotFoundError):
LOGGER.warning(
f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}'
f"{prefix}checks skipped ⚠️. "
f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
)
except AssertionError:
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
LOGGER.warning(
f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
)
return False
return True
@ -485,8 +671,8 @@ def check_amp(model):
def git_describe(path=ROOT): # path must be a directory
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
with contextlib.suppress(Exception):
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
return ''
return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
return ""
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
@ -494,7 +680,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
def strip_auth(v):
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v
return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
@ -502,26 +688,28 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
except ValueError:
file = Path(file).stem
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
LOGGER.info(colorstr(s) + ', '.join(f'{k}={strip_auth(v)}' for k, v in args.items()))
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))
def cuda_device_count() -> int:
"""Get the number of NVIDIA GPUs available in the environment.
"""
Get the number of NVIDIA GPUs available in the environment.
Returns:
(int): The number of NVIDIA GPUs available.
"""
try:
# Run the nvidia-smi command and capture its output
output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
encoding='utf-8')
output = subprocess.check_output(
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
)
# Take the first line and strip any leading/trailing white space
first_line = output.strip().split('\n')[0]
first_line = output.strip().split("\n")[0]
return int(first_line)
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
@ -530,9 +718,14 @@ def cuda_device_count() -> int:
def cuda_is_available() -> bool:
"""Check if CUDA is available in the environment.
"""
Check if CUDA is available in the environment.
Returns:
(bool): True if one or more NVIDIA GPUs are available, False otherwise.
"""
return cuda_device_count() > 0
# Define constants
IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")

View File

@ -1,47 +1,53 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
import re
import shutil
import socket
import sys
import tempfile
from pathlib import Path
from . import USER_CONFIG_DIR
from .torch_utils import TORCH_1_9
def find_free_network_port() -> int:
"""Finds a free port on localhost.
"""
Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
s.bind(("127.0.0.1", 0))
return s.getsockname()[1] # port
def generate_ddp_file(trainer):
"""Generates a DDP file and returns its file name."""
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__":
content = f"""
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
overrides = {vars(trainer.args)}
if __name__ == "__main__":
from {module} import {name}
from ultralytics.utils import DEFAULT_CFG_DICT
cfg = DEFAULT_CFG_DICT.copy()
cfg.update(save_dir='') # handle the extra key 'save_dir'
trainer = {name}(cfg=cfg, overrides=overrides)
trainer.train()'''
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix='_temp_',
suffix=f'{id(trainer)}.py',
mode='w+',
encoding='utf-8',
dir=USER_CONFIG_DIR / 'DDP',
delete=False) as file:
results = trainer.train()
"""
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(
prefix="_temp_",
suffix=f"{id(trainer)}.py",
mode="w+",
encoding="utf-8",
dir=USER_CONFIG_DIR / "DDP",
delete=False,
) as file:
file.write(content)
return file.name
@ -49,19 +55,17 @@ def generate_ddp_file(trainer):
def generate_ddp_command(world_size, trainer):
"""Generates and returns command for distributed training."""
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir
file = str(Path(sys.argv[0]).resolve())
safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters
if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI
file = generate_ddp_file(trainer)
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
file = generate_ddp_file(trainer)
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
port = find_free_network_port()
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
return cmd, file
def ddp_cleanup(trainer, file):
"""Delete temp file if created."""
if f'{id(trainer)}.py' in file: # if temp_file suffix in file
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
os.remove(file)

View File

@ -15,20 +15,42 @@ import torch
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
GITHUB_ASSETS_REPO = 'ultralytics/assets'
GITHUB_ASSETS_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '6', '-cls', '-seg', '-pose')] + \
[f'yolov5{k}u.pt' for k in 'nsmlx'] + \
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
[f'yolo_nas_{k}.pt' for k in 'sml'] + \
[f'sam_{k}.pt' for k in 'bl'] + \
[f'FastSAM-{k}.pt' for k in 'sx'] + \
[f'rtdetr-{k}.pt' for k in 'lx'] + \
['mobile_sam.pt']
GITHUB_ASSETS_REPO = "ultralytics/assets"
GITHUB_ASSETS_NAMES = (
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+ [f"yolov8{k}-world.pt" for k in "smlx"]
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
+ [f"yolov9{k}.pt" for k in "ce"]
+ [f"yolo_nas_{k}.pt" for k in "sml"]
+ [f"sam_{k}.pt" for k in "bl"]
+ [f"FastSAM-{k}.pt" for k in "sx"]
+ [f"rtdetr-{k}.pt" for k in "lx"]
+ ["mobile_sam.pt"]
+ ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"]
)
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
def is_url(url, check=True):
"""Check if string is URL and check if URL exists."""
def is_url(url, check=False):
"""
Validates if the given string is a URL and optionally checks if the URL exists online.
Args:
url (str): The string to be validated as a URL.
check (bool, optional): If True, performs an additional check to see if the URL exists online.
Defaults to True.
Returns:
(bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online.
Returns False otherwise.
Example:
```python
valid = is_url("https://www.example.com")
```
"""
with contextlib.suppress(Exception):
url = str(url)
result = parse.urlparse(url)
@ -40,7 +62,7 @@ def is_url(url, check=True):
return False
def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
"""
Deletes all ".DS_store" files under a specified directory.
@ -59,18 +81,17 @@ def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
are hidden system files and can cause issues when transferring files between different operating systems.
"""
# Delete Apple .DS_store files
for file in files_to_delete:
matches = list(Path(path).rglob(file))
LOGGER.info(f'Deleting {file} files: {matches}')
LOGGER.info(f"Deleting {file} files: {matches}")
for f in matches:
f.unlink()
def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True):
def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
"""
Zips the contents of a directory, excluding files containing strings in the exclude list.
The resulting zip file is named after the directory and placed alongside it.
Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
named after the directory and placed alongside it.
Args:
directory (str | Path): The path to the directory to be zipped.
@ -96,17 +117,17 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
# Unzip with progress bar
files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)]
zip_file = directory.with_suffix('.zip')
files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
zip_file = directory.with_suffix(".zip")
compression = ZIP_DEFLATED if compress else ZIP_STORED
with ZipFile(zip_file, 'w', compression) as f:
for file in TQDM(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress):
with ZipFile(zip_file, "w", compression) as f:
for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress):
f.write(file, file.relative_to(directory))
return zip_file # return path to zip file
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=False, progress=True):
def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
"""
Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
@ -146,51 +167,62 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
top_level_dirs = {Path(f).parts[0] for f in files}
if len(top_level_dirs) > 1 or not files[0].endswith('/'): # zip has multiple files at top level
if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")):
# Zip has multiple files at top level
path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8
else: # zip has 1 top-level directory
else:
# Zip has 1 top-level directory
extract_path = path # i.e. ../datasets
path = Path(path) / list(top_level_dirs)[0] # i.e. ../datasets/coco8
# Check if destination directory already exists and contains files
if path.exists() and any(path.iterdir()) and not exist_ok:
# If it exists and is not empty, return the path without unzipping
LOGGER.warning(f'WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.')
LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.")
return path
for f in TQDM(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress):
zipObj.extract(f, path=extract_path)
for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress):
# Ensure the file is within the extract_path to avoid path traversal security vulnerability
if ".." in Path(f).parts:
LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
continue
zipObj.extract(f, extract_path)
return path # return unzip dir
def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, hard=True):
def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", path=Path.cwd(), sf=1.5, hard=True):
"""
Check if there is sufficient disk space to download and store a file.
Args:
url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco128.zip'.
path (str | Path, optional): The path or drive to check the available free space on.
sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0.
hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True.
Returns:
(bool): True if there is sufficient disk space, False otherwise.
"""
r = requests.head(url) # response
# Check response
assert r.status_code < 400, f'URL error for {url}: {r.status_code} {r.reason}'
try:
r = requests.head(url) # response
assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response
except Exception:
return True # requests issue, default to True
# Check file size
gib = 1 << 30 # bytes per GiB
data = int(r.headers.get('Content-Length', 0)) / gib # file size (GB)
total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB)
total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes
if data * sf < free:
return True # sufficient space
# Insufficient space
text = (f'WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, '
f'Please free {data * sf - free:.1f} GB additional disk space and try again.')
text = (
f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
f"Please free {data * sf - free:.1f} GB additional disk space and try again."
)
if hard:
raise MemoryError(text)
LOGGER.warning(text)
@ -216,35 +248,41 @@ def get_google_drive_file_info(link):
url, filename = get_google_drive_file_info(link)
```
"""
file_id = link.split('/d/')[1].split('/view')[0]
drive_url = f'https://drive.google.com/uc?export=download&id={file_id}'
file_id = link.split("/d/")[1].split("/view")[0]
drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
filename = None
# Start session
with requests.Session() as session:
response = session.get(drive_url, stream=True)
if 'quota exceeded' in str(response.content.lower()):
if "quota exceeded" in str(response.content.lower()):
raise ConnectionError(
emojis(f'❌ Google Drive file download quota exceeded. '
f'Please try again later or download this file manually at {link}.'))
emojis(
f"❌ Google Drive file download quota exceeded. "
f"Please try again later or download this file manually at {link}."
)
)
for k, v in response.cookies.items():
if k.startswith('download_warning'):
drive_url += f'&confirm={v}' # v is token
cd = response.headers.get('content-disposition')
if k.startswith("download_warning"):
drive_url += f"&confirm={v}" # v is token
cd = response.headers.get("content-disposition")
if cd:
filename = re.findall('filename="(.+)"', cd)[0]
return drive_url, filename
def safe_download(url,
file=None,
dir=None,
unzip=True,
delete=False,
curl=False,
retry=3,
min_bytes=1E0,
progress=True):
def safe_download(
url,
file=None,
dir=None,
unzip=True,
delete=False,
curl=False,
retry=3,
min_bytes=1e0,
exist_ok=False,
progress=True,
):
"""
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
@ -260,41 +298,49 @@ def safe_download(url,
retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
a successful download. Default: 1E0.
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
progress (bool, optional): Whether to display a progress bar during the download. Default: True.
"""
# Check if the URL is a Google Drive link
gdrive = url.startswith('https://drive.google.com/')
Example:
```python
from ultralytics.utils.downloads import safe_download
link = "https://ultralytics.com/assets/bus.jpg"
path = safe_download(link)
```
"""
gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
if gdrive:
url, file = get_google_drive_file_info(url)
f = dir / (file if gdrive else url2file(url)) if dir else Path(file) # URL converted to filename
if '://' not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename
elif not f.is_file(): # URL and file do not exist
assert dir or file, 'dir or file required for download'
desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'"
LOGGER.info(f'{desc}...')
LOGGER.info(f"{desc}...")
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
check_disk_space(url)
check_disk_space(url, path=f.parent)
for i in range(retry + 1):
try:
if curl or i > 0: # curl download with retry, continue
s = 'sS' * (not progress) # silent
r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode
assert r == 0, f'Curl return value {r}'
s = "sS" * (not progress) # silent
r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
assert r == 0, f"Curl return value {r}"
else: # urllib download
method = 'torch'
if method == 'torch':
method = "torch"
if method == "torch":
torch.hub.download_url_to_file(url, f, progress=progress)
else:
with request.urlopen(url) as response, TQDM(total=int(response.getheader('Content-Length', 0)),
desc=desc,
disable=not progress,
unit='B',
unit_scale=True,
unit_divisor=1024) as pbar:
with open(f, 'wb') as f_opened:
with request.urlopen(url) as response, TQDM(
total=int(response.getheader("Content-Length", 0)),
desc=desc,
disable=not progress,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
with open(f, "wb") as f_opened:
for data in response:
f_opened.write(data)
pbar.update(len(data))
@ -305,88 +351,150 @@ def safe_download(url,
f.unlink() # remove partial downloads
except Exception as e:
if i == 0 and not is_online():
raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e
raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e
elif i >= retry:
raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'):
if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
from zipfile import is_zipfile
unzip_dir = dir or f.parent # unzip to dir if provided else unzip in place
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
if is_zipfile(f):
unzip_dir = unzip_file(file=f, path=unzip_dir, progress=progress) # unzip
elif f.suffix in ('.tar', '.gz'):
LOGGER.info(f'Unzipping {f} to {unzip_dir.resolve()}...')
subprocess.run(['tar', 'xf' if f.suffix == '.tar' else 'xfz', f, '--directory', unzip_dir], check=True)
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
elif f.suffix in (".tar", ".gz"):
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
if delete:
f.unlink() # remove zip
return unzip_dir
def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
"""Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
if version != 'latest':
version = f'tags/{version}' # i.e. tags/v6.2
url = f'https://api.github.com/repos/{repo}/releases/{version}'
def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
"""
Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
function fetches the latest release assets.
Args:
repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
version (str, optional): The release version to fetch assets from. Defaults to 'latest'.
retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False.
Returns:
(tuple): A tuple containing the release tag and a list of asset names.
Example:
```python
tag, assets = get_github_assets(repo='ultralytics/assets', version='latest')
```
"""
if version != "latest":
version = f"tags/{version}" # i.e. tags/v6.2
url = f"https://api.github.com/repos/{repo}/releases/{version}"
r = requests.get(url) # github api
if r.status_code != 200 and r.reason != 'rate limit exceeded' and retry: # failed and not 403 rate limit exceeded
if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
r = requests.get(url) # try again
if r.status_code != 200:
LOGGER.warning(f'⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}')
return '', []
LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}")
return "", []
data = r.json()
return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets
return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
"""Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc."""
def attempt_download_asset(file, repo="ultralytics/assets", release="v8.1.0", **kwargs):
"""
Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
locally first, then tries to download it from the specified GitHub repository release.
Args:
file (str | Path): The filename or file path to be downloaded.
repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
release (str, optional): The specific release version to be downloaded. Defaults to 'v8.1.0'.
**kwargs (any): Additional keyword arguments for the download process.
Returns:
(str): The path to the downloaded file.
Example:
```python
file_path = attempt_download_asset('yolov5s.pt', repo='ultralytics/assets', release='latest')
```
"""
from ultralytics.utils import SETTINGS # scoped for circular import
# YOLOv3/5u updates
file = str(file)
file = checks.check_yolov5u_filename(file)
file = Path(file.strip().replace("'", ''))
file = Path(file.strip().replace("'", ""))
if file.exists():
return str(file)
elif (SETTINGS['weights_dir'] / file).exists():
return str(SETTINGS['weights_dir'] / file)
elif (SETTINGS["weights_dir"] / file).exists():
return str(SETTINGS["weights_dir"] / file)
else:
# URL specified
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
if str(file).startswith(('http:/', 'https:/')): # download
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
download_url = f"https://github.com/{repo}/releases/download"
if str(file).startswith(("http:/", "https:/")): # download
url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
if Path(file).is_file():
LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
safe_download(url=url, file=file, min_bytes=1E5)
safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
safe_download(url=f'https://github.com/{repo}/releases/download/{release}/{name}', file=file, min_bytes=1E5)
safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
else:
tag, assets = get_github_assets(repo, release)
if not assets:
tag, assets = get_github_assets(repo) # latest release
if name in assets:
safe_download(url=f'https://github.com/{repo}/releases/download/{tag}/{name}', file=file, min_bytes=1E5)
safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
return str(file)
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3):
"""Downloads and unzips files concurrently if threads > 1, else sequentially."""
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):
"""
Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are
specified.
Args:
url (str | list): The URL or list of URLs of the files to be downloaded.
dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory.
unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True.
delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False.
curl (bool, optional): Flag to use curl for downloading. Defaults to False.
threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1.
retry (int, optional): Number of retries in case of download failure. Defaults to 3.
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
Example:
```python
download('https://ultralytics.com/assets/example.zip', dir='path/to/dir', unzip=True)
```
"""
dir = Path(dir)
dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1:
with ThreadPool(threads) as pool:
pool.map(
lambda x: safe_download(
url=x[0], dir=x[1], unzip=unzip, delete=delete, curl=curl, retry=retry, progress=threads <= 1),
zip(url, repeat(dir)))
url=x[0],
dir=x[1],
unzip=unzip,
delete=delete,
curl=curl,
retry=retry,
exist_ok=exist_ok,
progress=threads <= 1,
),
zip(url, repeat(dir)),
)
pool.close()
pool.join()
else:
for u in [url] if isinstance(url, (str, Path)) else url:
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry)
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)

View File

@ -4,7 +4,19 @@ from ultralytics.utils import emojis
class HUBModelError(Exception):
"""
Custom exception class for handling errors related to model fetching in Ultralytics YOLO.
def __init__(self, message='Model not found. Please check model URL and try again.'):
This exception is raised when a requested model is not found or cannot be retrieved.
The message is also processed to include emojis for better user experience.
Attributes:
message (str): The error message displayed when the exception is raised.
Note:
The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package.
"""
def __init__(self, message="Model not found. Please check model URL and try again."):
"""Create an exception for when a model is not found."""
super().__init__(emojis(message))

View File

@ -30,9 +30,9 @@ class WorkingDirectory(contextlib.ContextDecorator):
@contextmanager
def spaces_in_path(path):
"""
Context manager to handle paths with spaces in their names.
If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path,
executes the context code block, then copies the file/directory back to its original location.
Context manager to handle paths with spaces in their names. If a path contains spaces, it replaces them with
underscores, copies the file/directory to the new path, executes the context code block, then copies the
file/directory back to its original location.
Args:
path (str | Path): The original path.
@ -45,18 +45,18 @@ def spaces_in_path(path):
with ultralytics.utils.files import spaces_in_path
with spaces_in_path('/path/with spaces') as new_path:
# your code here
# Your code here
```
"""
# If path has spaces, replace them with underscores
if ' ' in str(path):
if " " in str(path):
string = isinstance(path, str) # input type
path = Path(path)
# Create a temporary directory and construct the new path
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir) / path.name.replace(' ', '_')
tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
# Copy file/directory
if path.is_dir():
@ -82,7 +82,7 @@ def spaces_in_path(path):
yield path
def increment_path(path, exist_ok=False, sep='', mkdir=False):
def increment_path(path, exist_ok=False, sep="", mkdir=False):
"""
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
@ -102,12 +102,12 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
"""
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
# Method 1
for n in range(2, 9999):
p = f'{path}{sep}{n}{suffix}' # increment path
if not os.path.exists(p): #
p = f"{path}{sep}{n}{suffix}" # increment path
if not os.path.exists(p):
break
path = Path(p)
@ -119,14 +119,14 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
def file_age(path=__file__):
"""Return days since last file update."""
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
return dt.days # + dt.seconds / 86400 # fractional days
def file_date(path=__file__):
"""Return human-readable file modification date, i.e. '2021-3-26'."""
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
return f'{t.year}-{t.month}-{t.day}'
return f"{t.year}-{t.month}-{t.day}"
def file_size(path):
@ -137,11 +137,52 @@ def file_size(path):
if path.is_file():
return path.stat().st_size / mb
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
return 0.0
def get_latest_run(search_dir='.'):
def get_latest_run(search_dir="."):
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ''
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ""
def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False):
"""
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
Args:
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
update_names (bool, optional): Update model names from a data YAML.
Example:
```python
from ultralytics.utils.files import update_models
model_names = (f"rtdetr-{size}.pt" for size in "lx")
update_models(model_names)
```
"""
from ultralytics import YOLO
from ultralytics.nn.autobackend import default_class_names
target_dir = source_dir / "updated_models"
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
for model_name in model_names:
model_path = source_dir / model_name
print(f"Loading model from {model_path}")
# Load model
model = YOLO(model_path)
model.half()
if update_names: # update model names from a dataset YAML
model.model.names = default_class_names("coco8.yaml")
# Define new save path
save_path = target_dir / model_name
# Save model using model.save()
print(f"Re-saving {model_name} model to {save_path}")
model.save(save_path, use_dill=False)

View File

@ -7,7 +7,7 @@ from typing import List
import numpy as np
from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
from .ops import ltwh2xywh, ltwh2xyxy, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
def _ntuple(n):
@ -26,16 +26,29 @@ to_4tuple = _ntuple(4)
# `xyxy` means left top and right bottom
# `xywh` means center x, center y and width, height(YOLO format)
# `ltwh` means left top and width, height(COCO format)
_formats = ['xyxy', 'xywh', 'ltwh']
_formats = ["xyxy", "xywh", "ltwh"]
__all__ = 'Bboxes', # tuple or list
__all__ = ("Bboxes",) # tuple or list
class Bboxes:
"""Bounding Boxes class. Only numpy variables are supported."""
"""
A class for handling bounding boxes.
def __init__(self, bboxes, format='xyxy') -> None:
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'.
Bounding box data should be provided in numpy arrays.
Attributes:
bboxes (numpy.ndarray): The bounding boxes stored in a 2D numpy array.
format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').
Note:
This class does not handle normalization or denormalization of bounding boxes.
"""
def __init__(self, bboxes, format="xyxy") -> None:
"""Initializes the Bboxes class with bounding box data in a specified format."""
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
assert bboxes.ndim == 2
assert bboxes.shape[1] == 4
@ -45,21 +58,21 @@ class Bboxes:
def convert(self, format):
"""Converts bounding box format from one type to another."""
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
if self.format == format:
return
elif self.format == 'xyxy':
func = xyxy2xywh if format == 'xywh' else xyxy2ltwh
elif self.format == 'xywh':
func = xywh2xyxy if format == 'xyxy' else xywh2ltwh
elif self.format == "xyxy":
func = xyxy2xywh if format == "xywh" else xyxy2ltwh
elif self.format == "xywh":
func = xywh2xyxy if format == "xyxy" else xywh2ltwh
else:
func = ltwh2xyxy if format == 'xyxy' else ltwh2xywh
func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
self.bboxes = func(self.bboxes)
self.format = format
def areas(self):
"""Return box areas."""
self.convert('xyxy')
self.convert("xyxy")
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
# def denormalize(self, w, h):
@ -111,7 +124,7 @@ class Bboxes:
return len(self.bboxes)
@classmethod
def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
"""
Concatenate a list of Bboxes objects into a single Bboxes object.
@ -135,7 +148,7 @@ class Bboxes:
return boxes_list[0]
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
def __getitem__(self, index) -> 'Bboxes':
def __getitem__(self, index) -> "Bboxes":
"""
Retrieve a specific bounding box or a set of bounding boxes using indexing.
@ -156,32 +169,52 @@ class Bboxes:
if isinstance(index, int):
return Bboxes(self.bboxes[index].view(1, -1))
b = self.bboxes[index]
assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
return Bboxes(b)
class Instances:
"""
Container for bounding boxes, segments, and keypoints of detected objects in an image.
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
Attributes:
_bboxes (Bboxes): Internal object for handling bounding box operations.
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3]. Default is None.
normalized (bool): Flag indicating whether the bounding box coordinates are normalized.
segments (ndarray): Segments array with shape [N, 1000, 2] after resampling.
Args:
bboxes (ndarray): An array of bounding boxes with shape [N, 4].
segments (list | ndarray, optional): A list or array of object segments. Default is None.
keypoints (ndarray, optional): An array of keypoints with shape [N, 17, 3]. Default is None.
bbox_format (str, optional): The format of bounding boxes ('xywh' or 'xyxy'). Default is 'xywh'.
normalized (bool, optional): Whether the bounding box coordinates are normalized. Default is True.
Examples:
```python
# Create an Instances object
instances = Instances(
bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]])
)
```
Note:
The bounding box format is either 'xywh' or 'xyxy', and is determined by the `bbox_format` argument.
This class does not perform input validation, and it assumes the inputs are well-formed.
"""
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
"""
Args:
bboxes (ndarray): bboxes with shape [N, 4].
segments (list | ndarray): segments.
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
"""
if segments is None:
segments = []
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
self.keypoints = keypoints
self.normalized = normalized
if len(segments) > 0:
# list[np.array(1000, 2)] * num_samples
segments = resample_segments(segments)
# (N, 1000, 2)
segments = np.stack(segments, axis=0)
else:
segments = np.zeros((0, 1000, 2), dtype=np.float32)
self.segments = segments
def convert_bbox(self, format):
@ -194,7 +227,7 @@ class Instances:
return self._bboxes.areas()
def scale(self, scale_w, scale_h, bbox_only=False):
"""this might be similar with denormalize func but without normalized sign."""
"""This might be similar with denormalize func but without normalized sign."""
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
if bbox_only:
return
@ -230,7 +263,7 @@ class Instances:
def add_padding(self, padw, padh):
"""Handle rect and mosaic situation."""
assert not self.normalized, 'you should add padding with absolute coordinates.'
assert not self.normalized, "you should add padding with absolute coordinates."
self._bboxes.add(offset=(padw, padh, padw, padh))
self.segments[..., 0] += padw
self.segments[..., 1] += padh
@ -238,7 +271,7 @@ class Instances:
self.keypoints[..., 0] += padw
self.keypoints[..., 1] += padh
def __getitem__(self, index) -> 'Instances':
def __getitem__(self, index) -> "Instances":
"""
Retrieve a specific instance or a set of instances using indexing.
@ -268,7 +301,7 @@ class Instances:
def flipud(self, h):
"""Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
if self._bboxes.format == 'xyxy':
if self._bboxes.format == "xyxy":
y1 = self.bboxes[:, 1].copy()
y2 = self.bboxes[:, 3].copy()
self.bboxes[:, 1] = h - y2
@ -281,7 +314,7 @@ class Instances:
def fliplr(self, w):
"""Reverses the order of the bounding boxes and segments horizontally."""
if self._bboxes.format == 'xyxy':
if self._bboxes.format == "xyxy":
x1 = self.bboxes[:, 0].copy()
x2 = self.bboxes[:, 2].copy()
self.bboxes[:, 0] = w - x2
@ -295,10 +328,10 @@ class Instances:
def clip(self, w, h):
"""Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
ori_format = self._bboxes.format
self.convert_bbox(format='xyxy')
self.convert_bbox(format="xyxy")
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
if ori_format != 'xyxy':
if ori_format != "xyxy":
self.convert_bbox(format=ori_format)
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
@ -307,7 +340,11 @@ class Instances:
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
def remove_zero_area_boxes(self):
"""Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. This removes them."""
"""
Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
This removes them.
"""
good = self.bbox_areas > 0
if not all(good):
self._bboxes = self._bboxes[good]
@ -330,7 +367,7 @@ class Instances:
return len(self.bboxes)
@classmethod
def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
"""
Concatenates a list of Instances objects into a single Instances object.

View File

@ -6,14 +6,17 @@ import torch.nn.functional as F
from ultralytics.utils.metrics import OKS_SIGMA
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
from .metrics import bbox_iou
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
from .metrics import bbox_iou, probiou
from .tal import bbox2dist
class VarifocalLoss(nn.Module):
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
"""
Varifocal loss by Zhang et al.
https://arxiv.org/abs/2008.13367.
"""
def __init__(self):
"""Initialize the VarifocalLoss class."""
@ -24,21 +27,25 @@ class VarifocalLoss(nn.Module):
"""Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
weight).mean(1).sum()
loss = (
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
.mean(1)
.sum()
)
return loss
class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self, ):
def __init__(self):
"""Initializer for FocalLoss class with no parameters."""
super().__init__()
@staticmethod
def forward(pred, label, gamma=1.5, alpha=0.25):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
@ -54,6 +61,7 @@ class FocalLoss(nn.Module):
class BboxLoss(nn.Module):
"""Criterion class for computing training losses during training."""
def __init__(self, reg_max, use_dfl=False):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
@ -79,42 +87,73 @@ class BboxLoss(nn.Module):
@staticmethod
def _df_loss(pred_dist, target):
"""Return sum of left and right DFL losses."""
# Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
"""
Return sum of left and right DFL losses.
Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
https://ieeexplore.ieee.org/document/9792391
"""
tl = target.long() # target left
tr = tl + 1 # target right
wl = tr - target # weight left
wr = 1 - wl # weight right
return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
return (
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
).mean(-1, keepdim=True)
class RotatedBboxLoss(BboxLoss):
"""Criterion class for computing training losses during training."""
def __init__(self, reg_max, use_dfl=False):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
super().__init__(reg_max, use_dfl)
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.use_dfl:
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
class KeypointLoss(nn.Module):
"""Criterion class for computing training losses."""
def __init__(self, sigmas) -> None:
"""Initialize the KeypointLoss class."""
super().__init__()
self.sigmas = sigmas
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2
kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9)
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
e = d / (2 * self.sigmas) ** 2 / (area + 1e-9) / 2 # from cocoeval
return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean()
e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
class v8DetectionLoss:
"""Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled
def __init__(self, model, tal_topk=10): # model must be de-paralleled
"""Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
device = next(model.parameters()).device # get model device
h = model.args # hyperparameters
m = model.model[-1] # Detect() module
self.bce = nn.BCEWithLogitsLoss(reduction='none')
self.bce = nn.BCEWithLogitsLoss(reduction="none")
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
@ -124,7 +163,7 @@ class v8DetectionLoss:
self.use_dfl = m.reg_max > 1
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
@ -159,7 +198,8 @@ class v8DetectionLoss:
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats = preds[1] if isinstance(preds, tuple) else preds
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
(self.reg_max * 4, self.nc), 1
)
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
@ -169,30 +209,36 @@ class v8DetectionLoss:
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
# Targets
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# pboxes
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# cls loss
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# bbox loss
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
target_scores_sum, fg_mask)
loss[0], loss[2] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.cls # cls gain
@ -205,8 +251,8 @@ class v8SegmentationLoss(v8DetectionLoss):
"""Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled
"""Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument."""
super().__init__(model)
self.nm = model.model[-1].nm # number of masks
self.overlap = model.args.overlap_mask
def __call__(self, preds, batch):
@ -215,9 +261,10 @@ class v8SegmentationLoss(v8DetectionLoss):
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
(self.reg_max * 4, self.nc), 1
)
# b, grids, ..
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
@ -226,80 +273,168 @@ class v8SegmentationLoss(v8DetectionLoss):
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
# Targets
try:
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
"correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
raise TypeError(
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
) from e
# pboxes
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# cls loss
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
if fg_mask.sum():
# bbox loss
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
target_scores, target_scores_sum, fg_mask)
# masks loss
masks = batch['masks'].to(self.device).float()
# Bbox loss
loss[0], loss[3] = self.bbox_loss(
pred_distri,
pred_bboxes,
anchor_points,
target_bboxes / stride_tensor,
target_scores,
target_scores_sum,
fg_mask,
)
# Masks loss
masks = batch["masks"].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
for i in range(batch_size):
if fg_mask[i].sum():
mask_idx = target_gt_idx[i][fg_mask[i]]
if self.overlap:
gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
else:
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
loss[1] = self.calculate_segmentation_loss(
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
)
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.box / batch_size # seg gain
loss[1] *= self.hyp.box # seg gain
loss[2] *= self.hyp.cls # cls gain
loss[3] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
"""Mask loss for one image."""
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
@staticmethod
def single_mask_loss(
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
) -> torch.Tensor:
"""
Compute the instance segmentation loss for a single image.
Args:
gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
proto (torch.Tensor): Prototype masks of shape (32, H, W).
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
Returns:
(torch.Tensor): The calculated mask loss for a single image.
Notes:
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
predicted masks from the prototype masks and predicted mask coefficients.
"""
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
def calculate_segmentation_loss(
self,
fg_mask: torch.Tensor,
masks: torch.Tensor,
target_gt_idx: torch.Tensor,
target_bboxes: torch.Tensor,
batch_idx: torch.Tensor,
proto: torch.Tensor,
pred_masks: torch.Tensor,
imgsz: torch.Tensor,
overlap: bool,
) -> torch.Tensor:
"""
Calculate the loss for instance segmentation.
Args:
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
overlap (bool): Whether the masks in `masks` tensor overlap.
Returns:
(torch.Tensor): The calculated loss for instance segmentation.
Notes:
The batch loss can be computed for improved speed at higher memory usage.
For example, pred_mask can be computed as follows:
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
"""
_, _, mask_h, mask_w = proto.shape
loss = 0
# Normalize to 0-1
target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
# Areas of target bboxes
marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
# Normalize to mask size
mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
if fg_mask_i.any():
mask_idx = target_gt_idx_i[fg_mask_i]
if overlap:
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
gt_mask = gt_mask.float()
else:
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
loss += self.single_mask_loss(
gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
)
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
return loss / fg_mask.sum()
class v8PoseLoss(v8DetectionLoss):
"""Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled
"""Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance."""
super().__init__(model)
self.kpt_shape = model.model[-1].kpt_shape
self.bce_pose = nn.BCEWithLogitsLoss()
@ -313,9 +448,10 @@ class v8PoseLoss(v8DetectionLoss):
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
(self.reg_max * 4, self.nc), 1
)
# b, grids, ..
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
@ -324,53 +460,50 @@ class v8PoseLoss(v8DetectionLoss):
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
# Targets
batch_size = pred_scores.shape[0]
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# pboxes
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# cls loss
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# bbox loss
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
target_scores_sum, fg_mask)
keypoints = batch['keypoints'].to(self.device).float().clone()
loss[0], loss[4] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
keypoints = batch["keypoints"].to(self.device).float().clone()
keypoints[..., 0] *= imgsz[1]
keypoints[..., 1] *= imgsz[0]
for i in range(batch_size):
if fg_mask[i].sum():
idx = target_gt_idx[i][fg_mask[i]]
gt_kpt = keypoints[batch_idx.view(-1) == i][idx] # (n, 51)
gt_kpt[..., 0] /= stride_tensor[fg_mask[i]]
gt_kpt[..., 1] /= stride_tensor[fg_mask[i]]
area = xyxy2xywh(target_bboxes[i][fg_mask[i]])[:, 2:].prod(1, keepdim=True)
pred_kpt = pred_kpts[i][fg_mask[i]]
kpt_mask = gt_kpt[..., 2] != 0
loss[1] += self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
# kpt_score loss
if pred_kpt.shape[-1] == 3:
loss[2] += self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
loss[1], loss[2] = self.calculate_keypoints_loss(
fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.pose / batch_size # pose gain
loss[2] *= self.hyp.kobj / batch_size # kobj gain
loss[1] *= self.hyp.pose # pose gain
loss[2] *= self.hyp.kobj # kobj gain
loss[3] *= self.hyp.cls # cls gain
loss[4] *= self.hyp.dfl # dfl gain
@ -385,12 +518,210 @@ class v8PoseLoss(v8DetectionLoss):
y[..., 1] += anchor_points[:, [1]] - 0.5
return y
def calculate_keypoints_loss(
self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
):
"""
Calculate the keypoints loss for the model.
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
a binary classification loss that classifies whether a keypoint is present or not.
Args:
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
Returns:
(tuple): Returns a tuple containing:
- kpts_loss (torch.Tensor): The keypoints loss.
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
"""
batch_idx = batch_idx.flatten()
batch_size = len(masks)
# Find the maximum number of keypoints in a single image
max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
# Create a tensor to hold batched keypoints
batched_keypoints = torch.zeros(
(batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
)
# TODO: any idea how to vectorize this?
# Fill batched_keypoints with keypoints based on batch_idx
for i in range(batch_size):
keypoints_i = keypoints[batch_idx == i]
batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
# Expand dimensions of target_gt_idx to match the shape of batched_keypoints
target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
# Use target_gt_idx_expanded to select keypoints from batched_keypoints
selected_keypoints = batched_keypoints.gather(
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
)
# Divide coordinates by stride
selected_keypoints /= stride_tensor.view(1, -1, 1, 1)
kpts_loss = 0
kpts_obj_loss = 0
if masks.any():
gt_kpt = selected_keypoints[masks]
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
pred_kpt = pred_kpts[masks]
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
if pred_kpt.shape[-1] == 3:
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
return kpts_loss, kpts_obj_loss
class v8ClassificationLoss:
"""Criterion class for computing training losses."""
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / 64
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
loss_items = loss.detach()
return loss, loss_items
class v8OBBLoss(v8DetectionLoss):
def __init__(self, model):
"""
Initializes v8OBBLoss with model, assigner, and rotated bbox loss.
Note model must be de-paralleled.
"""
super().__init__(model)
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 6, device=self.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
for j in range(batch_size):
matches = i == j
n = matches.sum()
if n:
bboxes = targets[matches, 2:]
bboxes[..., :4].mul_(scale_tensor)
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
return out
def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
# b, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
try:
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
raise TypeError(
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
) from e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
bboxes_for_assigner = pred_bboxes.clone().detach()
# Only the first four elements need to be scaled
bboxes_for_assigner[..., :4] *= stride_tensor
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(),
bboxes_for_assigner.type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# Bbox loss
if fg_mask.sum():
target_bboxes[..., :4] /= stride_tensor
loss[0], loss[2] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
else:
loss[0] += (pred_angle * 0).sum()
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.cls # cls gain
loss[2] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def bbox_decode(self, anchor_points, pred_dist, pred_angle):
"""
Decode predicted object bounding box coordinates from anchor points and distribution.
Args:
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
Returns:
(torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
"""
if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
class v10DetectLoss:
def __init__(self, model):
self.one2many = v8DetectionLoss(model, tal_topk=10)
self.one2one = v8DetectionLoss(model, tal_topk=1)
def __call__(self, preds, batch):
one2many = preds["one2many"]
loss_one2many = self.one2many(one2many, batch)
one2one = preds["one2one"]
loss_one2one = self.one2one(one2one, batch)
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,7 @@ import torch.nn.functional as F
import torchvision
from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import batch_probiou
class Profile(contextlib.ContextDecorator):
@ -22,22 +23,24 @@ class Profile(contextlib.ContextDecorator):
```python
from ultralytics.utils.ops import Profile
with Profile() as dt:
with Profile(device=device) as dt:
pass # slow operation here
print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
```
"""
def __init__(self, t=0.0):
def __init__(self, t=0.0, device: torch.device = None):
"""
Initialize the Profile class.
Args:
t (float): Initial time. Defaults to 0.0.
device (torch.device): Devices used for model inference. Defaults to None (cpu).
"""
self.t = t
self.cuda = torch.cuda.is_available()
self.device = device
self.cuda = bool(device and str(device).startswith("cuda"))
def __enter__(self):
"""Start timing."""
@ -50,12 +53,13 @@ class Profile(contextlib.ContextDecorator):
self.t += self.dt # accumulate dt
def __str__(self):
return f'Elapsed time is {self.t} s'
"""Returns a human-readable string representing the accumulated elapsed time in the profiler."""
return f"Elapsed time is {self.t} s"
def time(self):
"""Get current time."""
if self.cuda:
torch.cuda.synchronize()
torch.cuda.synchronize(self.device)
return time.time()
@ -71,18 +75,21 @@ def segment2box(segment, width=640, height=640):
Returns:
(np.ndarray): the minimum and maximum x and y values of the segment.
"""
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
x, y = segment.T # segment xy
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
x, y, = x[inside], y[inside]
return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
4, dtype=segment.dtype) # xyxy
x = x[inside]
y = y[inside]
return (
np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
if any(x)
else np.zeros(4, dtype=segment.dtype)
) # xyxy
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
"""
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
(img1_shape) to the shape of a different image (img0_shape).
Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
specified in (img1_shape) to the shape of a different image (img0_shape).
Args:
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
@ -92,24 +99,29 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
calculated based on the size difference between the two images.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
xywh (bool): The box format is xywh or not, default=False.
Returns:
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
(img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding
pad = (
round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
) # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
if padding:
boxes[..., [0, 2]] -= pad[0] # x padding
boxes[..., [1, 3]] -= pad[1] # y padding
boxes[..., 0] -= pad[0] # x padding
boxes[..., 1] -= pad[1] # y padding
if not xywh:
boxes[..., 2] -= pad[0] # x padding
boxes[..., 3] -= pad[1] # y padding
boxes[..., :4] /= gain
clip_boxes(boxes, img0_shape)
return boxes
return clip_boxes(boxes, img0_shape)
def make_divisible(x, divisor):
@ -128,19 +140,41 @@ def make_divisible(x, divisor):
return math.ceil(x / divisor) * divisor
def nms_rotated(boxes, scores, threshold=0.45):
"""
NMS for obbs, powered by probiou and fast-nms.
Args:
boxes (torch.Tensor): (N, 5), xywhr.
scores (torch.Tensor): (N, ).
threshold (float): IoU threshold.
Returns:
"""
if len(boxes) == 0:
return np.empty((0,), dtype=np.int8)
sorted_idx = torch.argsort(scores, descending=True)
boxes = boxes[sorted_idx]
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
return sorted_idx[pick]
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nc=0, # number of classes (optional)
max_time_img=0.05,
max_nms=30000,
max_wh=7680,
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nc=0, # number of classes (optional)
max_time_img=0.05,
max_nms=30000,
max_wh=7680,
in_place=True,
rotated=False,
):
"""
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
@ -164,7 +198,8 @@ def non_max_suppression(
nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
max_time_img (float): The maximum time (seconds) for processing one image.
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
max_wh (int): The maximum box width and height in pixels
max_wh (int): The maximum box width and height in pixels.
in_place (bool): If True, the input prediction tensor will be modified in place.
Returns:
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
@ -173,15 +208,11 @@ def non_max_suppression(
"""
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = nc or (prediction.shape[1] - 4) # number of classes
nm = prediction.shape[1] - nc - 4
@ -190,11 +221,15 @@ def non_max_suppression(
# Settings
# min_wh = 2 # (pixels) minimum box width and height
time_limit = 0.5 + max_time_img * bs # seconds to quit after
time_limit = 2.0 + max_time_img * bs # seconds to quit after
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
if not rotated:
if in_place:
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
else:
prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy
t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
@ -204,7 +239,7 @@ def non_max_suppression(
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
if labels and len(labels[xi]) and not rotated:
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
@ -238,8 +273,13 @@ def non_max_suppression(
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
scores = x[:, 4] # scores
if rotated:
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
i = nms_rotated(boxes, scores, iou_thres)
else:
boxes = x[:, :4] + c # boxes (offset by class)
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
# # Experimental
@ -247,7 +287,7 @@ def non_max_suppression(
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# from .metrics import box_iou
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix
# weights = iou * scores[None] # box weights
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# redundant = True # require redundant detections
@ -255,10 +295,8 @@ def non_max_suppression(
# i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
break # time limit exceeded
return output
@ -269,17 +307,21 @@ def clip_boxes(boxes, shape):
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
Args:
boxes (torch.Tensor): the bounding boxes to clip
shape (tuple): the shape of the image
boxes (torch.Tensor): the bounding boxes to clip
shape (tuple): the shape of the image
Returns:
(torch.Tensor | numpy.ndarray): Clipped boxes
"""
if isinstance(boxes, torch.Tensor): # faster individually
boxes[..., 0].clamp_(0, shape[1]) # x1
boxes[..., 1].clamp_(0, shape[0]) # y1
boxes[..., 2].clamp_(0, shape[1]) # x2
boxes[..., 3].clamp_(0, shape[0]) # y2
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1
boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2
boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2
else: # np.array (faster grouped)
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
return boxes
def clip_coords(coords, shape):
@ -291,19 +333,20 @@ def clip_coords(coords, shape):
shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
Returns:
(None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries.
(torch.Tensor | numpy.ndarray): Clipped coordinates
"""
if isinstance(coords, torch.Tensor): # faster individually
coords[..., 0].clamp_(0, shape[1]) # x
coords[..., 1].clamp_(0, shape[0]) # y
if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y
else: # np.array (faster grouped)
coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
return coords
def scale_image(masks, im0_shape, ratio_pad=None):
"""
Takes a mask, and resizes it to the original image size
Takes a mask, and resizes it to the original image size.
Args:
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
@ -321,7 +364,7 @@ def scale_image(masks, im0_shape, ratio_pad=None):
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
# gain = ratio_pad[0][0]
pad = ratio_pad[1]
top, left = int(pad[1]), int(pad[0]) # y, x
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
@ -347,7 +390,7 @@ def xyxy2xywh(x):
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
"""
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
@ -367,7 +410,7 @@ def xywh2xyxy(x):
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
"""
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
dw = x[..., 2] / 2 # half-width
dh = x[..., 3] / 2 # half-height
@ -392,7 +435,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
"""
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
@ -403,8 +446,8 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
"""
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
x, y, width and height are normalized to image dimensions
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
width and height are normalized to image dimensions.
Args:
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
@ -417,8 +460,8 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
"""
if clip:
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
x = clip_boxes(x, (h - eps, w - eps))
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
@ -445,7 +488,7 @@ def xywh2ltwh(x):
def xyxy2ltwh(x):
"""
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
Args:
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
@ -461,7 +504,7 @@ def xyxy2ltwh(x):
def ltwh2xywh(x):
"""
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
Args:
x (torch.Tensor): the input tensor
@ -477,7 +520,8 @@ def ltwh2xywh(x):
def xyxyxyxy2xywhr(corners):
"""
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation].
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
expected in degrees from 0 to 90.
Args:
corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8).
@ -485,66 +529,53 @@ def xyxyxyxy2xywhr(corners):
Returns:
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
"""
is_numpy = isinstance(corners, np.ndarray)
atan2, sqrt = (np.arctan2, np.sqrt) if is_numpy else (torch.atan2, torch.sqrt)
x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
cx = (x1 + x3) / 2
cy = (y1 + y3) / 2
dx21 = x2 - x1
dy21 = y2 - y1
w = sqrt(dx21 ** 2 + dy21 ** 2)
h = sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
rotation = atan2(-dy21, dx21)
rotation *= 180.0 / math.pi # radians to degrees
return np.vstack((cx, cy, w, h, rotation)).T if is_numpy else torch.stack((cx, cy, w, h, rotation), dim=1)
is_torch = isinstance(corners, torch.Tensor)
points = corners.cpu().numpy() if is_torch else corners
points = points.reshape(len(corners), -1, 2)
rboxes = []
for pts in points:
# NOTE: Use cv2.minAreaRect to get accurate xywhr,
# especially some objects are cut off by augmentations in dataloader.
(x, y), (w, h), angle = cv2.minAreaRect(pts)
rboxes.append([x, y, w, h, angle / 180 * np.pi])
return (
torch.tensor(rboxes, device=corners.device, dtype=corners.dtype)
if is_torch
else np.asarray(rboxes, dtype=points.dtype)
) # rboxes
def xywhr2xyxyxyxy(center):
def xywhr2xyxyxyxy(rboxes):
"""
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4].
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
be in degrees from 0 to 90.
Args:
center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5).
rboxes (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
Returns:
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
"""
is_numpy = isinstance(center, np.ndarray)
is_numpy = isinstance(rboxes, np.ndarray)
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
cx, cy, w, h, rotation = center.T
rotation *= math.pi / 180.0 # degrees to radians
dx = w / 2
dy = h / 2
cos_rot = cos(rotation)
sin_rot = sin(rotation)
dx_cos_rot = dx * cos_rot
dx_sin_rot = dx * sin_rot
dy_cos_rot = dy * cos_rot
dy_sin_rot = dy * sin_rot
x1 = cx - dx_cos_rot - dy_sin_rot
y1 = cy + dx_sin_rot - dy_cos_rot
x2 = cx + dx_cos_rot - dy_sin_rot
y2 = cy - dx_sin_rot - dy_cos_rot
x3 = cx + dx_cos_rot + dy_sin_rot
y3 = cy - dx_sin_rot + dy_cos_rot
x4 = cx - dx_cos_rot + dy_sin_rot
y4 = cy + dx_sin_rot + dy_cos_rot
return np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).T if is_numpy else torch.stack(
(x1, y1, x2, y2, x3, y3, x4, y4), dim=1)
ctr = rboxes[..., :2]
w, h, angle = (rboxes[..., i : i + 1] for i in range(2, 5))
cos_value, sin_value = cos(angle), sin(angle)
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
vec1 = np.concatenate(vec1, axis=-1) if is_numpy else torch.cat(vec1, dim=-1)
vec2 = np.concatenate(vec2, axis=-1) if is_numpy else torch.cat(vec2, dim=-1)
pt1 = ctr + vec1 + vec2
pt2 = ctr + vec1 - vec2
pt3 = ctr - vec1 - vec2
pt4 = ctr - vec1 + vec2
return np.stack([pt1, pt2, pt3, pt4], axis=-2) if is_numpy else torch.stack([pt1, pt2, pt3, pt4], dim=-2)
def ltwh2xyxy(x):
"""
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
Args:
x (np.ndarray | torch.Tensor): the input image
@ -590,8 +621,9 @@ def resample_segments(segments, n=1000):
s = np.concatenate((s, s[0:1, :]), axis=0)
x = np.linspace(0, len(s) - 1, n)
xp = np.arange(len(s))
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
dtype=np.float32).reshape(2, -1).T # segment xy
segments[i] = (
np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
) # segment xy
return segments
@ -606,7 +638,7 @@ def crop_mask(masks, boxes):
Returns:
(torch.Tensor): The masks are being cropped to the bounding box.
"""
n, h, w = masks.shape
_, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
@ -616,8 +648,8 @@ def crop_mask(masks, boxes):
def process_mask_upsample(protos, masks_in, bboxes, shape):
"""
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
quality but is slower.
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
but is slower.
Args:
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
@ -630,7 +662,7 @@ def process_mask_upsample(protos, masks_in, bboxes, shape):
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)
@ -654,16 +686,18 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
c, mh, mw = protos.shape # CHW
ih, iw = shape
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
width_ratio = mw / iw
height_ratio = mh / ih
downsampled_bboxes = bboxes.clone()
downsampled_bboxes[:, 0] *= mw / iw
downsampled_bboxes[:, 2] *= mw / iw
downsampled_bboxes[:, 3] *= mh / ih
downsampled_bboxes[:, 1] *= mh / ih
downsampled_bboxes[:, 0] *= width_ratio
downsampled_bboxes[:, 2] *= width_ratio
downsampled_bboxes[:, 3] *= height_ratio
downsampled_bboxes[:, 1] *= height_ratio
masks = crop_mask(masks, downsampled_bboxes) # CHW
if upsample:
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
return masks.gt_(0.5)
@ -707,13 +741,13 @@ def scale_masks(masks, shape, padding=True):
bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
masks = masks[..., top:bottom, left:right]
masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
return masks
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
"""
Rescale segment coordinates (xy) from img1_shape to img0_shape
Rescale segment coordinates (xy) from img1_shape to img0_shape.
Args:
img1_shape (tuple): The shape of the image that the coords are from.
@ -739,14 +773,32 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
coords[..., 1] -= pad[1] # y padding
coords[..., 0] /= gain
coords[..., 1] /= gain
clip_coords(coords, img0_shape)
coords = clip_coords(coords, img0_shape)
if normalize:
coords[..., 0] /= img0_shape[1] # width
coords[..., 1] /= img0_shape[0] # height
return coords
def masks2segments(masks, strategy='largest'):
def regularize_rboxes(rboxes):
"""
Regularize rotated boxes in range [0, pi/2].
Args:
rboxes (torch.Tensor): (N, 5), xywhr.
Returns:
(torch.Tensor): The regularized boxes.
"""
x, y, w, h, t = rboxes.unbind(dim=-1)
# Swap edge and angle if h >= w
w_ = torch.where(w > h, w, h)
h_ = torch.where(w > h, h, w)
t = torch.where(w > h, t, t + math.pi / 2) % math.pi
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
def masks2segments(masks, strategy="largest"):
"""
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
@ -758,16 +810,16 @@ def masks2segments(masks, strategy='largest'):
segments (List): list of segment masks
"""
segments = []
for x in masks.int().cpu().numpy().astype('uint8'):
for x in masks.int().cpu().numpy().astype("uint8"):
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
if c:
if strategy == 'concat': # concatenate all segments
if strategy == "concat": # concatenate all segments
c = np.concatenate([x.reshape(-1, 2) for x in c])
elif strategy == 'largest': # select largest segment
elif strategy == "largest": # select largest segment
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
else:
c = np.zeros((0, 2)) # no segments found
segments.append(c.astype('float32'))
segments.append(c.astype("float32"))
return segments
@ -794,4 +846,19 @@ def clean_str(s):
Returns:
(str): a string with special characters replaced by an underscore _
"""
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
def v10postprocess(preds, max_det, nc=80):
assert(4 + nc == preds.shape[-1])
boxes, scores = preds.split([4, nc], dim=-1)
max_scores = scores.amax(dim=-1)
max_scores, index = torch.topk(max_scores, max_det, dim=-1)
index = index.unsqueeze(-1)
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
scores, index = torch.topk(scores.flatten(1), max_det, dim=-1)
labels = index % nc
index = index // nc
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
return boxes, scores, labels

View File

@ -1,8 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Monkey patches to update/extend functionality of existing functions
"""
"""Monkey patches to update/extend functionality of existing functions."""
import time
from pathlib import Path
import cv2
@ -14,7 +13,8 @@ _imshow = cv2.imshow # copy to avoid recursion errors
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
"""Read an image from a file.
"""
Read an image from a file.
Args:
filename (str): Path to the file to read.
@ -27,7 +27,8 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
def imwrite(filename: str, img: np.ndarray, params=None):
"""Write an image to a file.
"""
Write an image to a file.
Args:
filename (str): Path to the file to write.
@ -45,31 +46,43 @@ def imwrite(filename: str, img: np.ndarray, params=None):
def imshow(winname: str, mat: np.ndarray):
"""Displays an image in the specified window.
"""
Displays an image in the specified window.
Args:
winname (str): Name of the window.
mat (np.ndarray): Image to be shown.
"""
_imshow(winname.encode('unicode_escape').decode(), mat)
_imshow(winname.encode("unicode_escape").decode(), mat)
# PyTorch functions ----------------------------------------------------------------------------------------------------
_torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, **kwargs):
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this.
def torch_save(*args, use_dill=True, **kwargs):
"""
Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and
exponential standoff in case of save failure.
Args:
*args (tuple): Positional arguments to pass to torch.save.
**kwargs (dict): Keyword arguments to pass to torch.save.
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
**kwargs (any): Keyword arguments to pass to torch.save.
"""
try:
import dill as pickle # noqa
except ImportError:
assert use_dill
import dill as pickle
except (AssertionError, ImportError):
import pickle
if 'pickle_module' not in kwargs:
kwargs['pickle_module'] = pickle # noqa
return _torch_save(*args, **kwargs)
if "pickle_module" not in kwargs:
kwargs["pickle_module"] = pickle
for i in range(4): # 3 retries
try:
return _torch_save(*args, **kwargs)
except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan
if i == 3:
raise e
time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s

View File

@ -13,7 +13,6 @@ from PIL import Image, ImageDraw, ImageFont
from PIL import __version__ as pil_version
from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded
from .checks import check_font, check_version, is_ascii
from .files import increment_path
@ -28,20 +27,60 @@ class Colors:
Attributes:
palette (list of tuple): List of RGB color values.
n (int): The number of colors in the palette.
pose_palette (np.array): A specific color palette array with dtype np.uint8.
pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
"""
def __init__(self):
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
hexs = (
"FF3838",
"FF9D97",
"FF701F",
"FFB21D",
"CFD231",
"48F90A",
"92CC17",
"3DDB86",
"1A9334",
"00D4BB",
"2C99A8",
"00C2FF",
"344593",
"6473FF",
"0018EC",
"8438FF",
"520085",
"CB38FF",
"FF95C8",
"FF37C7",
)
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
self.n = len(self.palette)
self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255],
[153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255],
[255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102],
[51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]],
dtype=np.uint8)
self.pose_palette = np.array(
[
[255, 128, 0],
[255, 153, 51],
[255, 178, 102],
[230, 230, 0],
[255, 153, 255],
[153, 204, 255],
[255, 102, 255],
[255, 51, 255],
[102, 178, 255],
[51, 153, 255],
[255, 153, 153],
[255, 102, 102],
[255, 51, 51],
[153, 255, 153],
[102, 255, 102],
[51, 255, 51],
[0, 255, 0],
[0, 0, 255],
[255, 0, 0],
[255, 255, 255],
],
dtype=np.uint8,
)
def __call__(self, i, bgr=False):
"""Converts hex color codes to RGB values."""
@ -51,7 +90,7 @@ class Colors:
@staticmethod
def hex2rgb(h):
"""Converts hex color codes to RGB values (i.e. default PIL order)."""
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
colors = Colors() # create instance for 'from utils.plots import colors'
@ -71,65 +110,99 @@ class Annotator:
kpt_color (List[int]): Color palette for keypoints.
"""
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
self.pil = pil or non_ascii
input_is_pil = isinstance(im, Image.Image)
self.pil = pil or non_ascii or input_is_pil
self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
if self.pil: # use PIL
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
self.im = im if input_is_pil else Image.fromarray(im)
self.draw = ImageDraw.Draw(self.im)
try:
font = check_font('Arial.Unicode.ttf' if non_ascii else font)
font = check_font("Arial.Unicode.ttf" if non_ascii else font)
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
self.font = ImageFont.truetype(str(font), size)
except Exception:
self.font = ImageFont.load_default()
# Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
if check_version(pil_version, '9.2.0'):
if check_version(pil_version, "9.2.0"):
self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
else: # use cv2
self.im = im
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
self.im = im if im.flags.writeable else im.copy()
self.tf = max(self.lw - 1, 1) # font thickness
self.sf = self.lw / 3 # font scale
# Pose
self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
[8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
self.skeleton = [
[16, 14],
[14, 12],
[17, 15],
[15, 13],
[12, 13],
[6, 12],
[7, 13],
[6, 7],
[6, 8],
[7, 9],
[8, 10],
[9, 11],
[2, 3],
[1, 2],
[1, 3],
[2, 4],
[3, 5],
[4, 6],
[5, 7],
]
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
"""Add one xyxy box to image with label."""
if isinstance(box, torch.Tensor):
box = box.tolist()
if self.pil or not is_ascii(label):
self.draw.rectangle(box, width=self.lw, outline=color) # box
if rotated:
p1 = box[0]
# NOTE: PIL-version polygon needs tuple type.
self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color)
else:
p1 = (box[0], box[1])
self.draw.rectangle(box, width=self.lw, outline=color) # box
if label:
w, h = self.font.getsize(label) # text width, height
outside = box[1] - h >= 0 # label fits outside box
outside = p1[1] - h >= 0 # label fits outside box
self.draw.rectangle(
(box[0], box[1] - h if outside else box[1], box[0] + w + 1,
box[1] + 1 if outside else box[1] + h + 1),
(p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
fill=color,
)
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
else: # cv2
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
if rotated:
p1 = [int(b) for b in box[0]]
# NOTE: cv2-version polylines needs np.asarray type.
cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw)
else:
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
if label:
tf = max(self.lw - 1, 1) # font thickness
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
cv2.putText(self.im,
label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
0,
self.lw / 3,
txt_color,
thickness=tf,
lineType=cv2.LINE_AA)
cv2.putText(
self.im,
label,
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
0,
self.sf,
txt_color,
thickness=self.tf,
lineType=cv2.LINE_AA,
)
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
"""
@ -154,13 +227,13 @@ class Annotator:
masks = masks.unsqueeze(3) # shape(n,h,w,1)
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
im_gpu = im_gpu.flip(dims=[0]) # flip channel
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
im_gpu = im_gpu * inv_alph_masks[-1] + mcs
im_mask = (im_gpu * 255)
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
im_mask = im_gpu * 255
im_mask_np = im_mask.byte().cpu().numpy()
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
if self.pil:
@ -178,13 +251,14 @@ class Annotator:
kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
for human pose. Default is True.
Note: `kpt_line=True` currently only supports human pose plotting.
Note:
`kpt_line=True` currently only supports human pose plotting.
"""
if self.pil:
# Convert to numpy first
self.im = np.asarray(self.im).copy()
nkpt, ndim = kpts.shape
is_pose = nkpt == 17 and ndim == 3
is_pose = nkpt == 17 and ndim in {2, 3}
kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
for i, k in enumerate(kpts):
color_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i)
@ -219,9 +293,9 @@ class Annotator:
"""Add rectangle to image (PIL-only)."""
self.draw.rectangle(xy, fill, outline, width)
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False):
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
"""Adds text to an image using PIL or cv2."""
if anchor == 'bottom': # start y from font bottom
if anchor == "bottom": # start y from font bottom
w, h = self.font.getsize(text) # text width, height
xy[1] += 1 - h
if self.pil:
@ -230,8 +304,8 @@ class Annotator:
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
# Using `txt_color` for background and draw fg with white color
txt_color = (255, 255, 255)
if '\n' in text:
lines = text.split('\n')
if "\n" in text:
lines = text.split("\n")
_, h = self.font.getsize(text)
for line in lines:
self.draw.text(xy, line, fill=txt_color, font=self.font)
@ -240,15 +314,13 @@ class Annotator:
self.draw.text(xy, text, fill=txt_color, font=self.font)
else:
if box_style:
tf = max(self.lw - 1, 1) # font thickness
w, h = cv2.getTextSize(text, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
outside = xy[1] - h >= 3
p2 = xy[0] + w, xy[1] - h - 3 if outside else xy[1] + h + 3
cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled
# Using `txt_color` for background and draw fg with white color
txt_color = (255, 255, 255)
tf = max(self.lw - 1, 1) # font thickness
cv2.putText(self.im, text, xy, 0, self.lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)
cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
def fromarray(self, im):
"""Update self.im from a numpy array."""
@ -259,27 +331,289 @@ class Annotator:
"""Return annotated image as array."""
return np.asarray(self.im)
def show(self, title=None):
"""Show the annotated image."""
Image.fromarray(np.asarray(self.im)[..., ::-1]).show(title)
def save(self, filename="image.jpg"):
"""Save the annotated image to 'filename'."""
cv2.imwrite(filename, np.asarray(self.im))
def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
"""
Draw region line.
Args:
reg_pts (list): Region Points (for line 2 points, for region 4 points)
color (tuple): Region Color value
thickness (int): Region area thickness value
"""
cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
"""
Draw centroid point and track trails.
Args:
track (list): object tracking points for trails display
color (tuple): tracks line color
track_thickness (int): track line thickness value
"""
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
def count_labels(self, counts=0, count_txt_size=2, color=(255, 255, 255), txt_color=(0, 0, 0)):
"""
Plot counts for object counter.
Args:
counts (int): objects counts value
count_txt_size (int): text size for counts display
color (tuple): background color of counts display
txt_color (tuple): text color of counts display
"""
self.tf = count_txt_size
tl = self.tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
tf = max(tl - 1, 1)
# Get text size for in_count and out_count
t_size_in = cv2.getTextSize(str(counts), 0, fontScale=tl / 2, thickness=tf)[0]
# Calculate positions for counts label
text_width = t_size_in[0]
text_x = (self.im.shape[1] - text_width) // 2 # Center x-coordinate
text_y = t_size_in[1]
# Create a rounded rectangle for in_count
cv2.rectangle(
self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, -1
)
cv2.putText(
self.im, str(counts), (text_x, text_y + t_size_in[1]), 0, tl / 2, txt_color, self.tf, lineType=cv2.LINE_AA
)
@staticmethod
def estimate_pose_angle(a, b, c):
"""
Calculate the pose angle for object.
Args:
a (float) : The value of pose point a
b (float): The value of pose point b
c (float): The value o pose point c
Returns:
angle (degree): Degree value of angle between three points
"""
a, b, c = np.array(a), np.array(b), np.array(c)
radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
angle = np.abs(radians * 180.0 / np.pi)
if angle > 180.0:
angle = 360 - angle
return angle
def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2):
"""
Draw specific keypoints for gym steps counting.
Args:
keypoints (list): list of keypoints data to be plotted
indices (list): keypoints ids list to be plotted
shape (tuple): imgsz for model inference
radius (int): Keypoint radius value
"""
for i, k in enumerate(keypoints):
if i in indices:
x_coord, y_coord = k[0], k[1]
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
if len(k) == 3:
conf = k[2]
if conf < 0.5:
continue
cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA)
return self.im
def plot_angle_and_count_and_stage(self, angle_text, count_text, stage_text, center_kpt, line_thickness=2):
"""
Plot the pose angle, count value and step stage.
Args:
angle_text (str): angle value for workout monitoring
count_text (str): counts value for workout monitoring
stage_text (str): stage decision for workout monitoring
center_kpt (int): centroid pose index for workout monitoring
line_thickness (int): thickness for text display
"""
angle_text, count_text, stage_text = (f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}")
font_scale = 0.6 + (line_thickness / 10.0)
# Draw angle
(angle_text_width, angle_text_height), _ = cv2.getTextSize(angle_text, 0, font_scale, line_thickness)
angle_text_position = (int(center_kpt[0]), int(center_kpt[1]))
angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5)
angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (line_thickness * 2))
cv2.rectangle(
self.im,
angle_background_position,
(
angle_background_position[0] + angle_background_size[0],
angle_background_position[1] + angle_background_size[1],
),
(255, 255, 255),
-1,
)
cv2.putText(self.im, angle_text, angle_text_position, 0, font_scale, (0, 0, 0), line_thickness)
# Draw Counts
(count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, font_scale, line_thickness)
count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20)
count_background_position = (
angle_background_position[0],
angle_background_position[1] + angle_background_size[1] + 5,
)
count_background_size = (count_text_width + 10, count_text_height + 10 + (line_thickness * 2))
cv2.rectangle(
self.im,
count_background_position,
(
count_background_position[0] + count_background_size[0],
count_background_position[1] + count_background_size[1],
),
(255, 255, 255),
-1,
)
cv2.putText(self.im, count_text, count_text_position, 0, font_scale, (0, 0, 0), line_thickness)
# Draw Stage
(stage_text_width, stage_text_height), _ = cv2.getTextSize(stage_text, 0, font_scale, line_thickness)
stage_text_position = (int(center_kpt[0]), int(center_kpt[1]) + angle_text_height + count_text_height + 40)
stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5)
stage_background_size = (stage_text_width + 10, stage_text_height + 10)
cv2.rectangle(
self.im,
stage_background_position,
(
stage_background_position[0] + stage_background_size[0],
stage_background_position[1] + stage_background_size[1],
),
(255, 255, 255),
-1,
)
cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness)
def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None):
"""
Function for drawing segmented object in bounding box shape.
Args:
mask (list): masks data list for instance segmentation area plotting
mask_color (tuple): mask foreground color
det_label (str): Detection label text
track_label (str): Tracking label text
"""
cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
label = f"Track ID: {track_label}" if track_label else det_label
text_size, _ = cv2.getTextSize(label, 0, 0.7, 1)
cv2.rectangle(
self.im,
(int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
(int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)),
mask_color,
-1,
)
cv2.putText(
self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), 2
)
def plot_distance_and_line(self, distance_m, distance_mm, centroids, line_color, centroid_color):
"""
Plot the distance and line on frame.
Args:
distance_m (float): Distance between two bbox centroids in meters.
distance_mm (float): Distance between two bbox centroids in millimeters.
centroids (list): Bounding box centroids data.
line_color (RGB): Distance line color.
centroid_color (RGB): Bounding box centroid color.
"""
(text_width_m, text_height_m), _ = cv2.getTextSize(
f"Distance M: {distance_m:.2f}m", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
)
cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), (255, 255, 255), -1)
cv2.putText(
self.im,
f"Distance M: {distance_m:.2f}m",
(20, 50),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
(0, 0, 0),
2,
cv2.LINE_AA,
)
(text_width_mm, text_height_mm), _ = cv2.getTextSize(
f"Distance MM: {distance_mm:.2f}mm", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
)
cv2.rectangle(self.im, (15, 75), (15 + text_width_mm + 10, 75 + text_height_mm + 20), (255, 255, 255), -1)
cv2.putText(
self.im,
f"Distance MM: {distance_mm:.2f}mm",
(20, 100),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
(0, 0, 0),
2,
cv2.LINE_AA,
)
cv2.line(self.im, centroids[0], centroids[1], line_color, 3)
cv2.circle(self.im, centroids[0], 6, centroid_color, -1)
cv2.circle(self.im, centroids[1], 6, centroid_color, -1)
def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10):
"""
Function for pinpoint human-vision eye mapping and plotting.
Args:
box (list): Bounding box coordinates
center_point (tuple): center point for vision eye view
color (tuple): object centroid and line color value
pin_color (tuple): visioneye point color value
thickness (int): int value for line thickness
pins_radius (int): visioneye point radius value
"""
center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
cv2.circle(self.im, center_point, pins_radius, pin_color, -1)
cv2.circle(self.im, center_bbox, pins_radius, color, -1)
cv2.line(self.im, center_point, center_bbox, color, thickness)
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()
def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
"""Plot training labels including class histograms and box statistics."""
import pandas as pd
import seaborn as sn
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight')
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
warnings.filterwarnings("ignore", category=FutureWarning)
# Plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
nc = int(cls.max() + 1) # number of classes
boxes = boxes[:1000000] # limit to 1M boxes
x = pd.DataFrame(boxes, columns=['x', 'y', 'width', 'height'])
x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
# Seaborn correlogram
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
plt.close()
# Matplotlib labels
@ -287,14 +621,14 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
for i in range(nc):
y[2].patches[i].set_color([x / 255 for x in colors(i)])
ax[0].set_ylabel('instances')
ax[0].set_ylabel("instances")
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
else:
ax[0].set_xlabel('classes')
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
ax[0].set_xlabel("classes")
sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
# Rectangles
boxes[:, 0:2] = 0.5 # center
@ -303,21 +637,22 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
for cls, box in zip(cls[:500], boxes[:500]):
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
ax[1].imshow(img)
ax[1].axis('off')
ax[1].axis("off")
for a in [0, 1, 2, 3]:
for s in ['top', 'right', 'left', 'bottom']:
for s in ["top", "right", "left", "bottom"]:
ax[a].spines[s].set_visible(False)
fname = save_dir / 'labels.jpg'
fname = save_dir / "labels.jpg"
plt.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
"""
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
This function takes a bounding box and an image, and then saves a cropped portion of the image according
to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
@ -353,27 +688,33 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
xyxy = ops.xywh2xyxy(b).long()
ops.clip_boxes(xyxy, im.shape)
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
xyxy = ops.clip_boxes(xyxy, im.shape)
crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
if save:
file.parent.mkdir(parents=True, exist_ok=True) # make directory
f = str(increment_path(file).with_suffix('.jpg'))
f = str(increment_path(file).with_suffix(".jpg"))
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
return crop
@threaded
def plot_images(images,
batch_idx,
cls,
bboxes=np.zeros(0, dtype=np.float32),
masks=np.zeros(0, dtype=np.uint8),
kpts=np.zeros((0, 51), dtype=np.float32),
paths=None,
fname='images.jpg',
names=None,
on_plot=None):
def plot_images(
images,
batch_idx,
cls,
bboxes=np.zeros(0, dtype=np.float32),
confs=None,
masks=np.zeros(0, dtype=np.uint8),
kpts=np.zeros((0, 51), dtype=np.float32),
paths=None,
fname="images.jpg",
names=None,
on_plot=None,
max_subplots=16,
save=True,
conf_thres=0.25,
):
"""Plot image grid with labels."""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
@ -389,21 +730,17 @@ def plot_images(images,
batch_idx = batch_idx.cpu().numpy()
max_size = 1920 # max image size
max_subplots = 16 # max image subplots, i.e. 4x4
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs ** 0.5) # number of subplots (square)
ns = np.ceil(bs**0.5) # number of subplots (square)
if np.max(images[0]) <= 1:
images *= 255 # de-normalise (optional)
# Build Image
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, im in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect
break
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
im = im.transpose(1, 2, 0)
mosaic[y:y + h, x:x + w, :] = im
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
# Resize (optional)
scale = max_size / ns / max(h, w)
@ -415,40 +752,42 @@ def plot_images(images,
# Annotate
fs = int((h + w) * ns * 0.01) # font size
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
for i in range(i + 1):
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
if paths:
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
if len(cls) > 0:
idx = batch_idx == i
classes = cls[idx].astype('int')
classes = cls[idx].astype("int")
labels = confs is None
if len(bboxes):
boxes = ops.xywh2xyxy(bboxes[idx, :4]).T
labels = bboxes.shape[1] == 4 # labels if no conf column
conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred)
if boxes.shape[1]:
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
boxes[[0, 2]] *= w # scale to pixels
boxes[[1, 3]] *= h
boxes = bboxes[idx]
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
is_obb = boxes.shape[-1] == 5 # xywhr
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
if len(boxes):
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
boxes[..., 0::2] *= w # scale to pixels
boxes[..., 1::2] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes *= scale
boxes[[0, 2]] += x
boxes[[1, 3]] += y
for j, box in enumerate(boxes.T.tolist()):
boxes[..., :4] *= scale
boxes[..., 0::2] += x
boxes[..., 1::2] += y
for j, box in enumerate(boxes.astype(np.int64).tolist()):
c = classes[j]
color = colors(c)
c = names.get(c, c) if names else c
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
annotator.box_label(box, label, color=color)
if labels or conf[j] > conf_thres:
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
annotator.box_label(box, label, color=color, rotated=is_obb)
elif len(classes):
for c in classes:
color = colors(c)
c = names.get(c, c) if names else c
annotator.text((x, y), f'{c}', txt_color=color, box_style=True)
annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
# Plot keypoints
if len(kpts):
@ -462,7 +801,7 @@ def plot_images(images,
kpts_[..., 0] += x
kpts_[..., 1] += y
for j in range(len(kpts_)):
if labels or conf[j] > 0.25: # 0.25 conf thresh
if labels or conf[j] > conf_thres:
annotator.kpts(kpts_[j])
# Plot masks
@ -477,8 +816,8 @@ def plot_images(images,
image_masks = np.where(image_masks == index, 1.0, 0.0)
im = np.asarray(annotator.im).copy()
for j, box in enumerate(boxes.T.tolist()):
if labels or conf[j] > 0.25: # 0.25 conf thresh
for j in range(len(image_masks)):
if labels or conf[j] > conf_thres:
color = colors(classes[j])
mh, mw = image_masks[j].shape
if mh != h or mw != w:
@ -488,27 +827,42 @@ def plot_images(images,
else:
mask = image_masks[j].astype(bool)
with contextlib.suppress(Exception):
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
im[y : y + h, x : x + w, :][mask] = (
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
)
annotator.fromarray(im)
if not save:
return np.asarray(annotator.im)
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
@plt_settings()
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
"""
Plot training results from results CSV file.
Plot training results from a results CSV file. The function supports various types of data including segmentation,
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
Args:
file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
Defaults to None.
Example:
```python
from ultralytics.utils.plotting import plot_results
plot_results('path/to/results.csv')
plot_results('path/to/results.csv', segment=True)
```
"""
import pandas as pd
from scipy.ndimage import gaussian_filter1d
save_dir = Path(file).parent if file else Path(dir)
if classify:
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
@ -523,31 +877,121 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
ax = ax.ravel()
files = list(save_dir.glob('results*.csv'))
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
files = list(save_dir.glob("results*.csv"))
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
for f in files:
try:
data = pd.read_csv(f)
s = [x.strip() for x in data.columns]
x = data.values[:, 0]
for i, j in enumerate(index):
y = data.values[:, j].astype('float')
y = data.values[:, j].astype("float")
# y[y == 0] = np.nan # don't show zero values
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
ax[i].set_title(s[j], fontsize=12)
# if j in [8, 9, 10]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
except Exception as e:
LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
ax[1].legend()
fname = save_dir / 'results.png'
fname = save_dir / "results.png"
fig.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
"""
Plots a scatter plot with points colored based on a 2D histogram.
Args:
v (array-like): Values for the x-axis.
f (array-like): Values for the y-axis.
bins (int, optional): Number of bins for the histogram. Defaults to 20.
cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
Examples:
>>> v = np.random.rand(100)
>>> f = np.random.rand(100)
>>> plt_color_scatter(v, f)
"""
# Calculate 2D histogram and corresponding colors
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
colors = [
hist[
min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
]
for i in range(len(v))
]
# Scatter plot
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
def plot_tune_results(csv_file="tune_results.csv"):
"""
Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
Args:
csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
Examples:
>>> plot_tune_results('path/to/tune_results.csv')
"""
import pandas as pd
from scipy.ndimage import gaussian_filter1d
# Scatter plots for each hyperparameter
csv_file = Path(csv_file)
data = pd.read_csv(csv_file)
num_metrics_columns = 1
keys = [x.strip() for x in data.columns][num_metrics_columns:]
x = data.values
fitness = x[:, 0] # fitness
j = np.argmax(fitness) # max fitness index
n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
plt.figure(figsize=(10, 10), tight_layout=True)
for i, k in enumerate(keys):
v = x[:, i + num_metrics_columns]
mu = v[j] # best single result
plt.subplot(n, n, i + 1)
plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
plt.plot(mu, fitness.max(), "k+", markersize=15)
plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
if i % n != 0:
plt.yticks([])
file = csv_file.with_name("tune_scatter_plots.png") # filename
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f"Saved {file}")
# Fitness vs iteration
x = range(1, len(fitness) + 1)
plt.figure(figsize=(10, 6), tight_layout=True)
plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
plt.title("Fitness vs Iteration")
plt.xlabel("Iteration")
plt.ylabel("Fitness")
plt.grid(True)
plt.legend()
file = csv_file.with_name("tune_fitness.png") # filename
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f"Saved {file}")
def output_to_target(output, max_det=300):
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
targets = []
@ -556,10 +1000,21 @@ def output_to_target(output, max_det=300):
j = torch.full((conf.shape[0], 1), i)
targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
targets = torch.cat(targets, 0).numpy()
return targets[:, 0], targets[:, 1], targets[:, 2:]
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
def output_to_rotated_target(output, max_det=300):
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
targets = []
for i, o in enumerate(output):
box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
j = torch.full((conf.shape[0], 1), i)
targets.append(torch.cat((j, cls, box, angle, conf), 1))
targets = torch.cat(targets, 0).numpy()
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
"""
Visualize feature maps of a given model module during inference.
@ -570,23 +1025,23 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
"""
for m in ['Detect', 'Pose', 'Segment']:
for m in ["Detect", "Pose", "Segment"]:
if m in module_type:
return
batch, channels, height, width = x.shape # batch, channels, height, width
_, channels, height, width = x.shape # batch, channels, height, width
if height > 1 and width > 1:
f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
n = min(n, channels) # number of plots
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
_, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
ax = ax.ravel()
plt.subplots_adjust(wspace=0.05, hspace=0.05)
for i in range(n):
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
ax[i].axis('off')
ax[i].axis("off")
LOGGER.info(f'Saving {f}... ({n}/{channels})')
plt.savefig(f, dpi=300, bbox_inches='tight')
LOGGER.info(f"Saving {f}... ({n}/{channels})")
plt.savefig(f, dpi=300, bbox_inches="tight")
plt.close()
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save

View File

@ -4,65 +4,18 @@ import torch
import torch.nn as nn
from .checks import check_version
from .metrics import bbox_iou
from .metrics import bbox_iou, probiou
from .ops import xywhr2xyxyxyxy
TORCH_1_10 = check_version(torch.__version__, '1.10.0')
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
"""
Select the positive anchor center in gt.
Args:
xy_centers (Tensor): shape(h*w, 2)
gt_bboxes (Tensor): shape(b, n_boxes, 4)
Returns:
(Tensor): shape(b, n_boxes, h*w)
"""
n_anchors = xy_centers.shape[0]
bs, n_boxes, _ = gt_bboxes.shape
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
return bbox_deltas.amin(3).gt_(eps)
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
"""
If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected.
Args:
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
overlaps (Tensor): shape(b, n_max_boxes, h*w)
Returns:
target_gt_idx (Tensor): shape(b, h*w)
fg_mask (Tensor): shape(b, h*w)
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
"""
# (b, n_max_boxes, h*w) -> (b, h*w)
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
fg_mask = mask_pos.sum(-2)
# Find each grid serve which gt(index)
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
return target_gt_idx, fg_mask, mask_pos
TORCH_1_10 = check_version(torch.__version__, "1.10.0")
class TaskAlignedAssigner(nn.Module):
"""
A task-aligned assigner for object detection.
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric,
which combines both classification and localization information.
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
classification and localization information.
Attributes:
topk (int): The number of top candidates to consider.
@ -85,8 +38,8 @@ class TaskAlignedAssigner(nn.Module):
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
"""
Compute the task-aligned assignment.
Reference https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
Compute the task-aligned assignment. Reference code is available at
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
Args:
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
@ -103,19 +56,24 @@ class TaskAlignedAssigner(nn.Module):
fg_mask (Tensor): shape(bs, num_total_anchors)
target_gt_idx (Tensor): shape(bs, num_total_anchors)
"""
self.bs = pd_scores.size(0)
self.n_max_boxes = gt_bboxes.size(1)
self.bs = pd_scores.shape[0]
self.n_max_boxes = gt_bboxes.shape[1]
if self.n_max_boxes == 0:
device = gt_bboxes.device
return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device))
return (
torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
)
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
mask_gt)
mask_pos, align_metric, overlaps = self.get_pos_mask(
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
)
target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
# Assigned target
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
@ -131,7 +89,7 @@ class TaskAlignedAssigner(nn.Module):
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
"""Get in_gts mask, (b, max_num_obj, h*w)."""
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
# Get anchor_align metric, (b, max_num_obj, h*w)
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
# Get topk_metric mask, (b, max_num_obj, h*w)
@ -157,11 +115,15 @@ class TaskAlignedAssigner(nn.Module):
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
return align_metric, overlaps
def iou_calculation(self, gt_bboxes, pd_bboxes):
"""IoU calculation for horizontal bounding boxes."""
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
"""
Select the top-k candidates based on the given metrics.
@ -191,9 +153,9 @@ class TaskAlignedAssigner(nn.Module):
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
for k in range(self.topk):
# Expand topk_idxs for each value of k and add 1 at the specified positions
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
# filter invalid bboxes
# Filter invalid bboxes
count_tensor.masked_fill_(count_tensor > 1, 0)
return count_tensor.to(metrics.dtype)
@ -229,15 +191,17 @@ class TaskAlignedAssigner(nn.Module):
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
# Assigned target scores
target_labels.clamp_(0)
# 10x faster than F.one_hot()
target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
dtype=torch.int64,
device=target_labels.device) # (b, h*w, 80)
target_scores = torch.zeros(
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
dtype=torch.int64,
device=target_labels.device,
) # (b, h*w, 80)
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
@ -245,6 +209,87 @@ class TaskAlignedAssigner(nn.Module):
return target_labels, target_bboxes, target_scores
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
"""
Select the positive anchor center in gt.
Args:
xy_centers (Tensor): shape(h*w, 2)
gt_bboxes (Tensor): shape(b, n_boxes, 4)
Returns:
(Tensor): shape(b, n_boxes, h*w)
"""
n_anchors = xy_centers.shape[0]
bs, n_boxes, _ = gt_bboxes.shape
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
return bbox_deltas.amin(3).gt_(eps)
@staticmethod
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
"""
If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
Args:
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
overlaps (Tensor): shape(b, n_max_boxes, h*w)
Returns:
target_gt_idx (Tensor): shape(b, h*w)
fg_mask (Tensor): shape(b, h*w)
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
"""
# (b, n_max_boxes, h*w) -> (b, h*w)
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
fg_mask = mask_pos.sum(-2)
# Find each grid serve which gt(index)
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
return target_gt_idx, fg_mask, mask_pos
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
def iou_calculation(self, gt_bboxes, pd_bboxes):
"""IoU calculation for rotated bounding boxes."""
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes):
"""
Select the positive anchor center in gt for rotated bounding boxes.
Args:
xy_centers (Tensor): shape(h*w, 2)
gt_bboxes (Tensor): shape(b, n_boxes, 5)
Returns:
(Tensor): shape(b, n_boxes, h*w)
"""
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
corners = xywhr2xyxyxyxy(gt_bboxes)
# (b, n_boxes, 1, 2)
a, b, _, d = corners.split(1, dim=-2)
ab = b - a
ad = d - a
# (b, n_boxes, h*w, 2)
ap = xy_centers - a
norm_ab = (ab * ab).sum(dim=-1)
norm_ad = (ad * ad).sum(dim=-1)
ap_dot_ab = (ap * ab).sum(dim=-1)
ap_dot_ad = (ap * ad).sum(dim=-1)
return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
def make_anchors(feats, strides, grid_cell_offset=0.5):
"""Generate anchors from features."""
@ -255,7 +300,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
_, _, h, w = feats[i].shape
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_points), torch.cat(stride_tensor)
@ -263,7 +308,8 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
"""Transform distance(ltrb) to box(xywh or xyxy)."""
lt, rb = distance.chunk(2, dim)
assert(distance.shape[dim] == 4)
lt, rb = distance.split([2, 2], dim)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
if xywh:
@ -277,3 +323,23 @@ def bbox2dist(anchor_points, bbox, reg_max):
"""Transform bbox(xyxy) to dist(ltrb)."""
x1y1, x2y2 = bbox.chunk(2, -1)
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
"""
Decode predicted object bounding box coordinates from anchor points and distribution.
Args:
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
Returns:
(torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
"""
lt, rb = pred_dist.split(2, dim=dim)
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
# (bs, h*w, 1)
xf, yf = ((rb - lt) / 2).split(1, dim=dim)
x, y = xf * cos - yf * sin, xf * sin + yf * cos
xy = torch.cat([x, y], dim=dim) + anchor_points
return torch.cat([xy, lt + rb], dim=dim)

View File

@ -2,7 +2,6 @@
import math
import os
import platform
import random
import time
from contextlib import contextmanager
@ -15,17 +14,23 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
from ultralytics.utils.checks import check_version
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
from ultralytics.utils.checks import PYTHON_VERSION, check_version
try:
import thop
except ImportError:
thop = None
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
TORCH_2_0 = check_version(torch.__version__, '2.0.0')
# Version checks (all default to version>=min_version)
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
@contextmanager
@ -44,7 +49,10 @@ def smart_inference_mode():
def decorate(fn):
"""Applies appropriate torch decorator for inference mode based on torch version."""
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
if TORCH_1_9 and torch.is_inference_mode_enabled():
return fn # already in inference_mode, act as a pass-through
else:
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
return decorate
@ -53,59 +61,102 @@ def get_cpu_info():
"""Return a string with system CPU information, i.e. 'Apple M2'."""
import cpuinfo # pip install py-cpuinfo
k = 'brand_raw', 'hardware_raw', 'arch_string_raw' # info keys sorted by preference (not all keys always available)
k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available)
info = cpuinfo.get_cpu_info() # info dict
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], 'unknown')
return string.replace('(R)', '').replace('CPU ', '').replace('@ ', '')
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
def select_device(device='', batch=0, newline=False, verbose=True):
"""Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'."""
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
def select_device(device="", batch=0, newline=False, verbose=True):
"""
Selects the appropriate PyTorch device based on the provided arguments.
The function takes a string specifying the device or a torch.device object and returns a torch.device object
representing the selected device. The function also validates the number of available devices and raises an
exception if the requested device(s) are not available.
Args:
device (str | torch.device, optional): Device string or torch.device object.
Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
the first available GPU, or CPU if no GPU is available.
batch (int, optional): Batch size being used in your model. Defaults to 0.
newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
verbose (bool, optional): If True, logs the device information. Defaults to True.
Returns:
(torch.device): Selected device.
Raises:
ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
devices when using multiple GPUs.
Examples:
>>> select_device('cuda:0')
device(type='cuda', index=0)
>>> select_device('cpu')
device(type='cpu')
Note:
Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
"""
if isinstance(device, torch.device):
return device
s = f"Ultralytics YOLOv{__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
device = str(device).lower()
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
cpu = device == 'cpu'
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
cpu = device == "cpu"
mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
if cpu or mps:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
if device == 'cuda':
device = '0'
visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
if device == "cuda":
device = "0"
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
LOGGER.info(s)
install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
raise ValueError(f"Invalid CUDA 'device={device}' requested."
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
f'{install}')
install = (
"See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
"CUDA devices are seen by torch.\n"
if torch.cuda.device_count() == 0
else ""
)
raise ValueError(
f"Invalid CUDA 'device={device}' requested."
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
f"{install}"
)
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
space = ' ' * (len(s) + 1)
raise ValueError(
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
)
space = " " * (len(s) + 1)
for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
arg = 'cuda:0'
elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available() and TORCH_2_0:
arg = "cuda:0"
elif mps and TORCH_2_0 and torch.backends.mps.is_available():
# Prefer MPS if available
s += f'MPS ({get_cpu_info()})\n'
arg = 'mps'
s += f"MPS ({get_cpu_info()})\n"
arg = "mps"
else: # revert to CPU
s += f'CPU ({get_cpu_info()})\n'
arg = 'cpu'
s += f"CPU ({get_cpu_info()})\n"
arg = "cpu"
if verbose and RANK == -1:
if verbose:
LOGGER.info(s if newline else s.rstrip())
return torch.device(arg)
@ -119,14 +170,20 @@ def time_sync():
def fuse_conv_and_bn(conv, bn):
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)
fusedconv = (
nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)
# Prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
@ -134,7 +191,7 @@ def fuse_conv_and_bn(conv, bn):
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
# Prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
@ -143,15 +200,21 @@ def fuse_conv_and_bn(conv, bn):
def fuse_deconv_and_bn(deconv, bn):
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
deconv.out_channels,
kernel_size=deconv.kernel_size,
stride=deconv.stride,
padding=deconv.padding,
output_padding=deconv.output_padding,
dilation=deconv.dilation,
groups=deconv.groups,
bias=True).requires_grad_(False).to(deconv.weight.device)
fuseddconv = (
nn.ConvTranspose2d(
deconv.in_channels,
deconv.out_channels,
kernel_size=deconv.kernel_size,
stride=deconv.stride,
padding=deconv.padding,
output_padding=deconv.output_padding,
dilation=deconv.dilation,
groups=deconv.groups,
bias=True,
)
.requires_grad_(False)
.to(deconv.weight.device)
)
# Prepare filters
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
@ -159,7 +222,7 @@ def fuse_deconv_and_bn(deconv, bn):
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
# Prepare spatial bias
b_conv = torch.zeros(deconv.weight.size(1), device=deconv.weight.device) if deconv.bias is None else deconv.bias
b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
@ -167,7 +230,11 @@ def fuse_deconv_and_bn(deconv, bn):
def model_info(model, detailed=False, verbose=True, imgsz=640):
"""Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""
"""
Model information.
imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
"""
if not verbose:
return
n_p = get_num_params(model) # number of parameters
@ -175,18 +242,21 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
n_l = len(list(model.modules())) # number of layers
if detailed:
LOGGER.info(
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
)
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
name = name.replace("module_list.", "")
LOGGER.info(
"%5g %40s %9s %12g %20s %10.3g %10.3g %10s"
% (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)
)
flops = get_flops(model, imgsz)
fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else ''
fs = f', {flops:.1f} GFLOPs' if flops else ''
yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '')
model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model'
LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}')
fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
fs = f", {flops:.1f} GFLOPs" if flops else ""
yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
return n_l, n_p, n_g, flops
@ -204,37 +274,53 @@ def model_info_for_loggers(trainer):
"""
Return model info dict with useful model information.
Example for YOLOv8n:
{'model/parameters': 3151904,
'model/GFLOPs': 8.746,
'model/speed_ONNX(ms)': 41.244,
'model/speed_TensorRT(ms)': 3.211,
'model/speed_PyTorch(ms)': 18.755}
Example:
YOLOv8n info for loggers
```python
results = {'model/parameters': 3151904,
'model/GFLOPs': 8.746,
'model/speed_ONNX(ms)': 41.244,
'model/speed_TensorRT(ms)': 3.211,
'model/speed_PyTorch(ms)': 18.755}
```
"""
if trainer.args.profile: # profile ONNX and TensorRT times
from ultralytics.utils.benchmarks import ProfileModels
results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
results.pop('model/name')
results.pop("model/name")
else: # only return PyTorch times from most recent validation
results = {
'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3)}
results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3)
"model/parameters": get_num_params(trainer.model),
"model/GFLOPs": round(get_flops(trainer.model), 3),
}
results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
return results
def get_flops(model, imgsz=640):
"""Return a YOLO model's FLOPs."""
if not thop:
return 0.0 # if not installed return 0.0 GFLOPs
try:
model = de_parallel(model)
p = next(model.parameters())
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0 # stride GFLOPs
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
return flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
if not isinstance(imgsz, list):
imgsz = [imgsz, imgsz] # expand if int/float
try:
# Use stride size for input tensor
# stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
# im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
# flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
# return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
raise Exception
except Exception:
# Use actual image size for input tensor (i.e. required for RTDETR models)
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
except Exception:
return 0
return 0.0
def get_flops_with_torch_profiler(model, imgsz=640):
@ -242,11 +328,11 @@ def get_flops_with_torch_profiler(model, imgsz=640):
if TORCH_2_0:
model = de_parallel(model)
p = next(model.parameters())
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
with torch.profiler.profile(with_flops=True) as prof:
model(im)
flops = sum(x.flops for x in prof.key_averages()) / 1E9
flops = sum(x.flops for x in prof.key_averages()) / 1e9
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
return flops
@ -266,13 +352,15 @@ def initialize_weights(model):
m.inplace = True
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
# Scales img(bs,3,y,x) by ratio constrained to gs-multiple
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
"""Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally
retaining the original shape.
"""
if ratio == 1.0:
return img
h, w = img.shape[2:]
s = (int(h * ratio), int(w * ratio)) # new size
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
if not same_shape: # pad/crop img
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
@ -288,7 +376,7 @@ def make_divisible(x, divisor):
def copy_attr(a, b, include=(), exclude=()):
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
if (len(include) and k not in include) or k.startswith("_") or k in exclude:
continue
else:
setattr(a, k, v)
@ -296,7 +384,7 @@ def copy_attr(a, b, include=(), exclude=()):
def get_latest_opset():
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset
def intersect_dicts(da, db, exclude=()):
@ -316,7 +404,7 @@ def de_parallel(model):
def one_cycle(y1=0.0, y2=1.0, steps=100):
"""Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
def init_seeds(seed=0, deterministic=False):
@ -331,10 +419,10 @@ def init_seeds(seed=0, deterministic=False):
if TORCH_2_0:
torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
torch.backends.cudnn.deterministic = True
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = str(seed)
else:
LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
else:
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
@ -369,13 +457,13 @@ class ModelEMA:
v += (1 - d) * msd[k].detach()
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
"""Updates attributes and saves stripped model with optimizer removed."""
if self.enabled:
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
"""
Strip optimizer from 'f' to finalize training, optionally save as 's'.
@ -395,32 +483,26 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
strip_optimizer(f)
```
"""
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
try:
import dill as pickle
except ImportError:
import pickle
x = torch.load(f, map_location=torch.device('cpu'))
if 'model' not in x:
LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.')
x = torch.load(f, map_location=torch.device("cpu"))
if "model" not in x:
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
return
if hasattr(x['model'], 'args'):
x['model'].args = dict(x['model'].args) # convert from IterableSimpleNamespace to dict
args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
if hasattr(x["model"], "args"):
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
if x.get("ema"):
x["model"] = x["ema"] # replace model with ema
for k in "optimizer", "best_fitness", "ema", "updates": # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
x["epoch"] = -1
x["model"].half() # to FP16
for p in x["model"].parameters():
p.requires_grad = False
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
# x['model'].args = x['train_args']
torch.save(x, s or f, pickle_module=pickle)
mb = os.path.getsize(s or f) / 1E6 # filesize
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
@ -441,18 +523,20 @@ def profile(input, ops, n=10, device=None):
results = []
if not isinstance(device, torch.device):
device = select_device(device)
LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
f"{'input':>24s}{'output':>24s}")
LOGGER.info(
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
f"{'input':>24s}{'output':>24s}"
)
for x in input if isinstance(input, list) else [input]:
x = x.to(device)
x.requires_grad = True
for m in ops if isinstance(ops, list) else [ops]:
m = m.to(device) if hasattr(m, 'to') else m # device
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
m = m.to(device) if hasattr(m, "to") else m # device
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
try:
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1E9 * 2 if thop else 0 # GFLOPs
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
except Exception:
flops = 0
@ -466,13 +550,13 @@ def profile(input, ops, n=10, device=None):
t[2] = time_sync()
except Exception: # no backward method
# print(e) # for debug
t[2] = float('nan')
t[2] = float("nan")
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
results.append([p, flops, mem, tf, tb, s_in, s_out])
except Exception as e:
LOGGER.info(e)
@ -482,25 +566,23 @@ def profile(input, ops, n=10, device=None):
class EarlyStopping:
"""
Early stopping class that stops training when a specified number of epochs have passed without improvement.
"""
"""Early stopping class that stops training when a specified number of epochs have passed without improvement."""
def __init__(self, patience=50):
"""
Initialize early stopping object
Initialize early stopping object.
Args:
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
"""
self.best_fitness = 0.0 # i.e. mAP
self.best_epoch = 0
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
self.possible_stop = False # possible stop may occur next epoch
def __call__(self, epoch, fitness):
"""
Check whether to stop training
Check whether to stop training.
Args:
epoch (int): Current epoch of training
@ -519,8 +601,10 @@ class EarlyStopping:
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
stop = delta >= self.patience # stop training if patience exceeded
if stop:
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.')
LOGGER.info(
f"Stopping training early as no improvement observed in last {self.patience} epochs. "
f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
)
return stop

View File

@ -0,0 +1,92 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import List
from urllib.parse import urlsplit
import numpy as np
class TritonRemoteModel:
"""
Client for interacting with a remote Triton Inference Server model.
Attributes:
endpoint (str): The name of the model on the Triton server.
url (str): The URL of the Triton server.
triton_client: The Triton client (either HTTP or gRPC).
InferInput: The input class for the Triton client.
InferRequestedOutput: The output request class for the Triton client.
input_formats (List[str]): The data types of the model inputs.
np_input_formats (List[type]): The numpy data types of the model inputs.
input_names (List[str]): The names of the model inputs.
output_names (List[str]): The names of the model outputs.
"""
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
"""
Initialize the TritonRemoteModel.
Arguments may be provided individually or parsed from a collective 'url' argument of the form
<scheme>://<netloc>/<endpoint>/<task_name>
Args:
url (str): The URL of the Triton server.
endpoint (str): The name of the model on the Triton server.
scheme (str): The communication scheme ('http' or 'grpc').
"""
if not endpoint and not scheme: # Parse all args from URL string
splits = urlsplit(url)
endpoint = splits.path.strip("/").split("/")[0]
scheme = splits.scheme
url = splits.netloc
self.endpoint = endpoint
self.url = url
# Choose the Triton client based on the communication scheme
if scheme == "http":
import tritonclient.http as client # noqa
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint)
else:
import tritonclient.grpc as client # noqa
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
# Define model attributes
type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
self.InferRequestedOutput = client.InferRequestedOutput
self.InferInput = client.InferInput
self.input_formats = [x["data_type"] for x in config["input"]]
self.np_input_formats = [type_map[x] for x in self.input_formats]
self.input_names = [x["name"] for x in config["input"]]
self.output_names = [x["name"] for x in config["output"]]
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
"""
Call the model with the given inputs.
Args:
*inputs (List[np.ndarray]): Input data to the model.
Returns:
(List[np.ndarray]): Model outputs.
"""
infer_inputs = []
input_format = inputs[0].dtype
for i, x in enumerate(inputs):
if x.dtype != self.np_input_formats[i]:
x = x.astype(self.np_input_formats[i])
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
infer_input.set_data_from_numpy(x)
infer_inputs.append(infer_input)
infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]

View File

@ -2,16 +2,13 @@
import subprocess
from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks
def run_ray_tune(model,
space: dict = None,
grace_period: int = 10,
gpu_per_trial: int = None,
max_samples: int = 10,
**train_args):
def run_ray_tune(
model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args
):
"""
Runs hyperparameter tuning using Ray Tune.
@ -37,49 +34,59 @@ def run_ray_tune(model,
result_grid = model.tune(data='coco8.yaml', use_ray=True)
```
"""
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
if train_args is None:
train_args = {}
try:
subprocess.run('pip install ray[tune]'.split(), check=True)
subprocess.run("pip install ray[tune]<=2.9.3".split(), check=True) # do not add single quotes here
import ray
from ray import tune
from ray.air import RunConfig
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.tune.schedulers import ASHAScheduler
except ImportError:
raise ModuleNotFoundError('Tuning hyperparameters requires Ray Tune. Install with: pip install "ray[tune]"')
raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]<=2.9.3"')
try:
import wandb
assert hasattr(wandb, '__version__')
assert hasattr(wandb, "__version__")
except (ImportError, AssertionError):
wandb = False
checks.check_version(ray.__version__, "<=2.9.3", "ray")
default_space = {
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
'lr0': tune.uniform(1e-5, 1e-1),
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
'box': tune.uniform(0.02, 0.2), # box loss gain
'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
"lr0": tune.uniform(1e-5, 1e-1),
"lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
"momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
"weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
"warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
"warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
"box": tune.uniform(0.02, 0.2), # box loss gain
"cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
"hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
"hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
"hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
"degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
"translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
"scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
"shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
"perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
"flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
"fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
"bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability)
"mosaic": tune.uniform(0.0, 1.0), # image mixup (probability)
"mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
"copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
}
# Put the model in ray store
task = model.task
model_in_store = ray.put(model)
def _tune(config):
"""
@ -89,42 +96,50 @@ def run_ray_tune(model,
config (dict): A dictionary of hyperparameters to use for training.
Returns:
None.
None
"""
model._reset_callbacks()
model_to_train = ray.get(model_in_store) # get the model from ray store for tuning
model_to_train.reset_callbacks()
config.update(train_args)
model.train(**config)
results = model_to_train.train(**config)
return results.results_dict
# Get search space
if not space:
space = default_space
LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.")
# Get dataset
data = train_args.get('data', TASK2DATA[model.task])
space['data'] = data
if 'data' not in train_args:
data = train_args.get("data", TASK2DATA[task])
space["data"] = data
if "data" not in train_args:
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
# Define the trainable function with allocated resources
trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
# Define the ASHA scheduler for hyperparameter search
asha_scheduler = ASHAScheduler(time_attr='epoch',
metric=TASK2METRIC[model.task],
mode='max',
max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100,
grace_period=grace_period,
reduction_factor=3)
asha_scheduler = ASHAScheduler(
time_attr="epoch",
metric=TASK2METRIC[task],
mode="max",
max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
grace_period=grace_period,
reduction_factor=3,
)
# Define the callbacks for the hyperparameter search
tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
# Create the Ray Tune hyperparameter search tuner
tuner = tune.Tuner(trainable_with_resources,
param_space=space,
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
run_config=RunConfig(callbacks=tuner_callbacks, storage_path='./runs/tune'))
tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve() # must be absolute dir
tune_dir.mkdir(parents=True, exist_ok=True)
tuner = tune.Tuner(
trainable_with_resources,
param_space=space,
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),
)
# Run the hyperparameter search
tuner.fit()