add yolo v10 and modify pipeline
This commit is contained in:
@ -9,6 +9,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@ -25,23 +26,22 @@ from tqdm import tqdm as tqdm_original
|
||||
from ultralytics import __version__
|
||||
|
||||
# PyTorch Multi-GPU DDP Constants
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv("RANK", -1))
|
||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
|
||||
# Other Constants
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[1] # YOLO
|
||||
ASSETS = ROOT / 'assets' # default images
|
||||
DEFAULT_CFG_PATH = ROOT / 'cfg/default.yaml'
|
||||
ASSETS = ROOT / "assets" # default images
|
||||
DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml"
|
||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
|
||||
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' if VERBOSE else None # tqdm bar format
|
||||
LOGGING_NAME = 'ultralytics'
|
||||
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans
|
||||
ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans
|
||||
HELP_MSG = \
|
||||
"""
|
||||
AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode
|
||||
VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode
|
||||
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format
|
||||
LOGGING_NAME = "ultralytics"
|
||||
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
|
||||
ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans
|
||||
HELP_MSG = """
|
||||
Usage examples for running YOLOv8:
|
||||
|
||||
1. Install the ultralytics package:
|
||||
@ -77,7 +77,7 @@ HELP_MSG = \
|
||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
- Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
||||
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320
|
||||
|
||||
- Val a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
@ -99,12 +99,12 @@ HELP_MSG = \
|
||||
"""
|
||||
|
||||
# Settings
|
||||
torch.set_printoptions(linewidth=320, precision=4, profile='default')
|
||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||
torch.set_printoptions(linewidth=320, precision=4, profile="default")
|
||||
np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5
|
||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # for deterministic training
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress verbose TF compiler warnings in Colab
|
||||
|
||||
|
||||
class TQDM(tqdm_original):
|
||||
@ -113,19 +113,22 @@ class TQDM(tqdm_original):
|
||||
|
||||
Args:
|
||||
*args (list): Positional arguments passed to original tqdm.
|
||||
**kwargs (dict): Keyword arguments, with custom defaults applied.
|
||||
**kwargs (any): Keyword arguments, with custom defaults applied.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Set new default values (these can still be overridden when calling TQDM)
|
||||
kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
|
||||
kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
|
||||
"""
|
||||
Initialize custom Ultralytics tqdm class with different default arguments.
|
||||
|
||||
Note these can still be overridden when calling TQDM.
|
||||
"""
|
||||
kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed
|
||||
kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SimpleClass:
|
||||
"""
|
||||
Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
|
||||
"""Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
|
||||
access methods for easier debugging and usage.
|
||||
"""
|
||||
|
||||
@ -134,14 +137,14 @@ class SimpleClass:
|
||||
attr = []
|
||||
for a in dir(self):
|
||||
v = getattr(self, a)
|
||||
if not callable(v) and not a.startswith('_'):
|
||||
if not callable(v) and not a.startswith("_"):
|
||||
if isinstance(v, SimpleClass):
|
||||
# Display only the module and class name for subclasses
|
||||
s = f'{a}: {v.__module__}.{v.__class__.__name__} object'
|
||||
s = f"{a}: {v.__module__}.{v.__class__.__name__} object"
|
||||
else:
|
||||
s = f'{a}: {repr(v)}'
|
||||
s = f"{a}: {repr(v)}"
|
||||
attr.append(s)
|
||||
return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr)
|
||||
return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a machine-readable string representation of the object."""
|
||||
@ -154,8 +157,7 @@ class SimpleClass:
|
||||
|
||||
|
||||
class IterableSimpleNamespace(SimpleNamespace):
|
||||
"""
|
||||
Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
|
||||
"""Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
|
||||
enables usage with dict() and for loops.
|
||||
"""
|
||||
|
||||
@ -165,24 +167,26 @@ class IterableSimpleNamespace(SimpleNamespace):
|
||||
|
||||
def __str__(self):
|
||||
"""Return a human-readable string representation of the object."""
|
||||
return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
|
||||
return "\n".join(f"{k}={v}" for k, v in vars(self).items())
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Custom attribute access error message with helpful information."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"""
|
||||
raise AttributeError(
|
||||
f"""
|
||||
'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
|
||||
'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace
|
||||
{DEFAULT_CFG_PATH} with the latest version from
|
||||
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Return the value of the specified key if it exists; otherwise, return the default value."""
|
||||
return getattr(self, key, default)
|
||||
|
||||
|
||||
def plt_settings(rcparams=None, backend='Agg'):
|
||||
def plt_settings(rcparams=None, backend="Agg"):
|
||||
"""
|
||||
Decorator to temporarily set rc parameters and the backend for a plotting function.
|
||||
|
||||
@ -200,7 +204,7 @@ def plt_settings(rcparams=None, backend='Agg'):
|
||||
"""
|
||||
|
||||
if rcparams is None:
|
||||
rcparams = {'font.size': 11}
|
||||
rcparams = {"font.size": 11}
|
||||
|
||||
def decorator(func):
|
||||
"""Decorator to apply temporary rc parameters and backend to a function."""
|
||||
@ -208,12 +212,16 @@ def plt_settings(rcparams=None, backend='Agg'):
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Sets rc parameters and backend, calls the original function, and restores the settings."""
|
||||
original_backend = plt.get_backend()
|
||||
plt.switch_backend(backend)
|
||||
if backend.lower() != original_backend.lower():
|
||||
plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8
|
||||
plt.switch_backend(backend)
|
||||
|
||||
with plt.rc_context(rcparams):
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
plt.switch_backend(original_backend)
|
||||
if backend != original_backend:
|
||||
plt.close("all")
|
||||
plt.switch_backend(original_backend)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@ -222,58 +230,59 @@ def plt_settings(rcparams=None, backend='Agg'):
|
||||
|
||||
|
||||
def set_logging(name=LOGGING_NAME, verbose=True):
|
||||
"""Sets up logging for the given name."""
|
||||
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
||||
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
||||
logging.config.dictConfig({
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
name: {
|
||||
'format': '%(message)s'}},
|
||||
'handlers': {
|
||||
name: {
|
||||
'class': 'logging.StreamHandler',
|
||||
'formatter': name,
|
||||
'level': level}},
|
||||
'loggers': {
|
||||
name: {
|
||||
'level': level,
|
||||
'handlers': [name],
|
||||
'propagate': False}}})
|
||||
"""Sets up logging for the given name with UTF-8 encoding support."""
|
||||
level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
|
||||
|
||||
# Configure the console (stdout) encoding to UTF-8
|
||||
formatter = logging.Formatter("%(message)s") # Default formatter
|
||||
if WINDOWS and sys.stdout.encoding != "utf-8":
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
elif hasattr(sys.stdout, "buffer"):
|
||||
import io
|
||||
|
||||
def emojis(string=''):
|
||||
"""Return platform-dependent emoji-safe version of string."""
|
||||
return string.encode().decode('ascii', 'ignore') if WINDOWS else string
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
||||
else:
|
||||
sys.stdout.encoding = "utf-8"
|
||||
except Exception as e:
|
||||
print(f"Creating custom formatter for non UTF-8 environments due to {e}")
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
"""Sets up logging with UTF-8 encoding and configurable verbosity."""
|
||||
return emojis(super().format(record))
|
||||
|
||||
class EmojiFilter(logging.Filter):
|
||||
"""
|
||||
A custom logging filter class for removing emojis in log messages.
|
||||
formatter = CustomFormatter("%(message)s") # Use CustomFormatter to eliminate UTF-8 output as last recourse
|
||||
|
||||
This filter is particularly useful for ensuring compatibility with Windows terminals
|
||||
that may not support the display of emojis in log messages.
|
||||
"""
|
||||
# Create and configure the StreamHandler
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
stream_handler.setFormatter(formatter)
|
||||
stream_handler.setLevel(level)
|
||||
|
||||
def filter(self, record):
|
||||
"""Filter logs by emoji unicode characters on windows."""
|
||||
record.msg = emojis(record.msg)
|
||||
return super().filter(record)
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
logger.addHandler(stream_handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
|
||||
|
||||
# Set logger
|
||||
set_logging(LOGGING_NAME, verbose=VERBOSE) # run before defining LOGGER
|
||||
LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
|
||||
if WINDOWS: # emoji-safe logging
|
||||
LOGGER.addFilter(EmojiFilter())
|
||||
LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.)
|
||||
for logger in "sentry_sdk", "urllib3.connectionpool":
|
||||
logging.getLogger(logger).setLevel(logging.CRITICAL + 1)
|
||||
|
||||
|
||||
def emojis(string=""):
|
||||
"""Return platform-dependent emoji-safe version of string."""
|
||||
return string.encode().decode("ascii", "ignore") if WINDOWS else string
|
||||
|
||||
|
||||
class ThreadingLocked:
|
||||
"""
|
||||
A decorator class for ensuring thread-safe execution of a function or method.
|
||||
This class can be used as a decorator to make sure that if the decorated function
|
||||
is called from multiple threads, only one thread at a time will be able to execute the function.
|
||||
A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator
|
||||
to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able
|
||||
to execute the function.
|
||||
|
||||
Attributes:
|
||||
lock (threading.Lock): A lock object used to manage access to the decorated function.
|
||||
@ -290,20 +299,23 @@ class ThreadingLocked:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the decorator class for thread-safe execution of a function or method."""
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def __call__(self, f):
|
||||
"""Run thread-safe execution of function or method."""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
"""Applies thread-safety to the decorated function or method."""
|
||||
with self.lock:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def yaml_save(file='data.yaml', data=None, header=''):
|
||||
def yaml_save(file="data.yaml", data=None, header=""):
|
||||
"""
|
||||
Save YAML data to a file.
|
||||
|
||||
@ -323,18 +335,19 @@ def yaml_save(file='data.yaml', data=None, header=''):
|
||||
file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convert Path objects to strings
|
||||
valid_types = int, float, str, bool, list, tuple, dict, type(None)
|
||||
for k, v in data.items():
|
||||
if isinstance(v, Path):
|
||||
if not isinstance(v, valid_types):
|
||||
data[k] = str(v)
|
||||
|
||||
# Dump data to file in YAML format
|
||||
with open(file, 'w', errors='ignore', encoding='utf-8') as f:
|
||||
with open(file, "w", errors="ignore", encoding="utf-8") as f:
|
||||
if header:
|
||||
f.write(header)
|
||||
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
|
||||
|
||||
|
||||
def yaml_load(file='data.yaml', append_filename=False):
|
||||
def yaml_load(file="data.yaml", append_filename=False):
|
||||
"""
|
||||
Load YAML data from a file.
|
||||
|
||||
@ -345,18 +358,18 @@ def yaml_load(file='data.yaml', append_filename=False):
|
||||
Returns:
|
||||
(dict): YAML data and file name.
|
||||
"""
|
||||
assert Path(file).suffix in ('.yaml', '.yml'), f'Attempting to load non-YAML file {file} with yaml_load()'
|
||||
with open(file, errors='ignore', encoding='utf-8') as f:
|
||||
assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()"
|
||||
with open(file, errors="ignore", encoding="utf-8") as f:
|
||||
s = f.read() # string
|
||||
|
||||
# Remove special characters
|
||||
if not s.isprintable():
|
||||
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
|
||||
s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)
|
||||
|
||||
# Add YAML filename to dict and return
|
||||
data = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files)
|
||||
if append_filename:
|
||||
data['yaml_file'] = str(file)
|
||||
data["yaml_file"] = str(file)
|
||||
return data
|
||||
|
||||
|
||||
@ -368,7 +381,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
||||
yaml_file: The file path of the YAML file or a YAML-formatted dictionary.
|
||||
|
||||
Returns:
|
||||
None
|
||||
(None)
|
||||
"""
|
||||
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
|
||||
dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True)
|
||||
@ -378,7 +391,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
||||
# Default configuration
|
||||
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
|
||||
for k, v in DEFAULT_CFG_DICT.items():
|
||||
if isinstance(v, str) and v.lower() == 'none':
|
||||
if isinstance(v, str) and v.lower() == "none":
|
||||
DEFAULT_CFG_DICT[k] = None
|
||||
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
||||
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
|
||||
@ -392,8 +405,8 @@ def is_ubuntu() -> bool:
|
||||
(bool): True if OS is Ubuntu, False otherwise.
|
||||
"""
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
with open('/etc/os-release') as f:
|
||||
return 'ID=ubuntu' in f.read()
|
||||
with open("/etc/os-release") as f:
|
||||
return "ID=ubuntu" in f.read()
|
||||
return False
|
||||
|
||||
|
||||
@ -404,7 +417,7 @@ def is_colab():
|
||||
Returns:
|
||||
(bool): True if running inside a Colab notebook, False otherwise.
|
||||
"""
|
||||
return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ
|
||||
return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ
|
||||
|
||||
|
||||
def is_kaggle():
|
||||
@ -414,19 +427,19 @@ def is_kaggle():
|
||||
Returns:
|
||||
(bool): True if running inside a Kaggle kernel, False otherwise.
|
||||
"""
|
||||
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
||||
return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
|
||||
|
||||
|
||||
def is_jupyter():
|
||||
"""
|
||||
Check if the current script is running inside a Jupyter Notebook.
|
||||
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
||||
Check if the current script is running inside a Jupyter Notebook. Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
||||
|
||||
Returns:
|
||||
(bool): True if running inside a Jupyter Notebook, False otherwise.
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
from IPython import get_ipython
|
||||
|
||||
return get_ipython() is not None
|
||||
return False
|
||||
|
||||
@ -438,10 +451,10 @@ def is_docker() -> bool:
|
||||
Returns:
|
||||
(bool): True if the script is running inside a Docker container, False otherwise.
|
||||
"""
|
||||
file = Path('/proc/self/cgroup')
|
||||
file = Path("/proc/self/cgroup")
|
||||
if file.exists():
|
||||
with open(file) as f:
|
||||
return 'docker' in f.read()
|
||||
return "docker" in f.read()
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -455,7 +468,7 @@ def is_online() -> bool:
|
||||
"""
|
||||
import socket
|
||||
|
||||
for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
|
||||
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
|
||||
try:
|
||||
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
||||
except (socket.timeout, socket.gaierror, OSError):
|
||||
@ -509,23 +522,23 @@ def is_pytest_running():
|
||||
Returns:
|
||||
(bool): True if pytest is running, False otherwise.
|
||||
"""
|
||||
return ('PYTEST_CURRENT_TEST' in os.environ) or ('pytest' in sys.modules) or ('pytest' in Path(sys.argv[0]).stem)
|
||||
return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(sys.argv[0]).stem)
|
||||
|
||||
|
||||
def is_github_actions_ci() -> bool:
|
||||
def is_github_action_running() -> bool:
|
||||
"""
|
||||
Determine if the current environment is a GitHub Actions CI Python runner.
|
||||
Determine if the current environment is a GitHub Actions runner.
|
||||
|
||||
Returns:
|
||||
(bool): True if the current environment is a GitHub Actions CI Python runner, False otherwise.
|
||||
(bool): True if the current environment is a GitHub Actions runner, False otherwise.
|
||||
"""
|
||||
return 'GITHUB_ACTIONS' in os.environ and 'RUNNER_OS' in os.environ and 'RUNNER_TOOL_CACHE' in os.environ
|
||||
return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ
|
||||
|
||||
|
||||
def is_git_dir():
|
||||
"""
|
||||
Determines whether the current file is part of a git repository.
|
||||
If the current file is not part of a git repository, returns None.
|
||||
Determines whether the current file is part of a git repository. If the current file is not part of a git
|
||||
repository, returns None.
|
||||
|
||||
Returns:
|
||||
(bool): True if current file is part of a git repository.
|
||||
@ -535,14 +548,14 @@ def is_git_dir():
|
||||
|
||||
def get_git_dir():
|
||||
"""
|
||||
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
|
||||
If the current file is not part of a git repository, returns None.
|
||||
Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
|
||||
the current file is not part of a git repository, returns None.
|
||||
|
||||
Returns:
|
||||
(Path | None): Git root directory if found or None if not found.
|
||||
"""
|
||||
for d in Path(__file__).parents:
|
||||
if (d / '.git').is_dir():
|
||||
if (d / ".git").is_dir():
|
||||
return d
|
||||
|
||||
|
||||
@ -555,7 +568,7 @@ def get_git_origin_url():
|
||||
"""
|
||||
if is_git_dir():
|
||||
with contextlib.suppress(subprocess.CalledProcessError):
|
||||
origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
|
||||
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
|
||||
return origin.decode().strip()
|
||||
|
||||
|
||||
@ -568,12 +581,13 @@ def get_git_branch():
|
||||
"""
|
||||
if is_git_dir():
|
||||
with contextlib.suppress(subprocess.CalledProcessError):
|
||||
origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||
return origin.decode().strip()
|
||||
|
||||
|
||||
def get_default_args(func):
|
||||
"""Returns a dictionary of default arguments for a function.
|
||||
"""
|
||||
Returns a dictionary of default arguments for a function.
|
||||
|
||||
Args:
|
||||
func (callable): The function to inspect.
|
||||
@ -594,13 +608,13 @@ def get_ubuntu_version():
|
||||
"""
|
||||
if is_ubuntu():
|
||||
with contextlib.suppress(FileNotFoundError, AttributeError):
|
||||
with open('/etc/os-release') as f:
|
||||
with open("/etc/os-release") as f:
|
||||
return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
|
||||
|
||||
|
||||
def get_user_config_dir(sub_dir='Ultralytics'):
|
||||
def get_user_config_dir(sub_dir="yolov10"):
|
||||
"""
|
||||
Get the user config directory.
|
||||
Return the appropriate config directory based on the environment operating system.
|
||||
|
||||
Args:
|
||||
sub_dir (str): The name of the subdirectory to create.
|
||||
@ -608,21 +622,22 @@ def get_user_config_dir(sub_dir='Ultralytics'):
|
||||
Returns:
|
||||
(Path): The path to the user config directory.
|
||||
"""
|
||||
# Return the appropriate config directory for each operating system
|
||||
if WINDOWS:
|
||||
path = Path.home() / 'AppData' / 'Roaming' / sub_dir
|
||||
path = Path.home() / "AppData" / "Roaming" / sub_dir
|
||||
elif MACOS: # macOS
|
||||
path = Path.home() / 'Library' / 'Application Support' / sub_dir
|
||||
path = Path.home() / "Library" / "Application Support" / sub_dir
|
||||
elif LINUX:
|
||||
path = Path.home() / '.config' / sub_dir
|
||||
path = Path.home() / ".config" / sub_dir
|
||||
else:
|
||||
raise ValueError(f'Unsupported operating system: {platform.system()}')
|
||||
raise ValueError(f"Unsupported operating system: {platform.system()}")
|
||||
|
||||
# GCP and AWS lambda fix, only /tmp is writeable
|
||||
if not is_dir_writeable(path.parent):
|
||||
LOGGER.warning(f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
|
||||
'Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.')
|
||||
path = Path('/tmp') / sub_dir if is_dir_writeable('/tmp') else Path().cwd() / sub_dir
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
|
||||
"Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path."
|
||||
)
|
||||
path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir
|
||||
|
||||
# Create the subdirectory if it does not exist
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
@ -630,40 +645,99 @@ def get_user_config_dir(sub_dir='Ultralytics'):
|
||||
return path
|
||||
|
||||
|
||||
USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR') or get_user_config_dir()) # Ultralytics settings dir
|
||||
SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
|
||||
USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir
|
||||
SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
|
||||
|
||||
|
||||
def colorstr(*input):
|
||||
"""Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')."""
|
||||
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
||||
"""
|
||||
Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
|
||||
See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
|
||||
|
||||
This function can be called in two ways:
|
||||
- colorstr('color', 'style', 'your string')
|
||||
- colorstr('your string')
|
||||
|
||||
In the second form, 'blue' and 'bold' will be applied by default.
|
||||
|
||||
Args:
|
||||
*input (str): A sequence of strings where the first n-1 strings are color and style arguments,
|
||||
and the last string is the one to be colored.
|
||||
|
||||
Supported Colors and Styles:
|
||||
Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
|
||||
Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',
|
||||
'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'
|
||||
Misc: 'end', 'bold', 'underline'
|
||||
|
||||
Returns:
|
||||
(str): The input string wrapped with ANSI escape codes for the specified color and style.
|
||||
|
||||
Examples:
|
||||
>>> colorstr('blue', 'bold', 'hello world')
|
||||
>>> '\033[34m\033[1mhello world\033[0m'
|
||||
"""
|
||||
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
|
||||
colors = {
|
||||
'black': '\033[30m', # basic colors
|
||||
'red': '\033[31m',
|
||||
'green': '\033[32m',
|
||||
'yellow': '\033[33m',
|
||||
'blue': '\033[34m',
|
||||
'magenta': '\033[35m',
|
||||
'cyan': '\033[36m',
|
||||
'white': '\033[37m',
|
||||
'bright_black': '\033[90m', # bright colors
|
||||
'bright_red': '\033[91m',
|
||||
'bright_green': '\033[92m',
|
||||
'bright_yellow': '\033[93m',
|
||||
'bright_blue': '\033[94m',
|
||||
'bright_magenta': '\033[95m',
|
||||
'bright_cyan': '\033[96m',
|
||||
'bright_white': '\033[97m',
|
||||
'end': '\033[0m', # misc
|
||||
'bold': '\033[1m',
|
||||
'underline': '\033[4m'}
|
||||
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
||||
"black": "\033[30m", # basic colors
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
"bright_black": "\033[90m", # bright colors
|
||||
"bright_red": "\033[91m",
|
||||
"bright_green": "\033[92m",
|
||||
"bright_yellow": "\033[93m",
|
||||
"bright_blue": "\033[94m",
|
||||
"bright_magenta": "\033[95m",
|
||||
"bright_cyan": "\033[96m",
|
||||
"bright_white": "\033[97m",
|
||||
"end": "\033[0m", # misc
|
||||
"bold": "\033[1m",
|
||||
"underline": "\033[4m",
|
||||
}
|
||||
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
||||
|
||||
|
||||
def remove_colorstr(input_string):
|
||||
"""
|
||||
Removes ANSI escape codes from a string, effectively un-coloring it.
|
||||
|
||||
Args:
|
||||
input_string (str): The string to remove color and style from.
|
||||
|
||||
Returns:
|
||||
(str): A new string with all ANSI escape codes removed.
|
||||
|
||||
Examples:
|
||||
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
|
||||
>>> 'hello world'
|
||||
"""
|
||||
ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
|
||||
return ansi_escape.sub("", input_string)
|
||||
|
||||
|
||||
class TryExcept(contextlib.ContextDecorator):
|
||||
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
|
||||
"""
|
||||
Ultralytics TryExcept class. Use as @TryExcept() decorator or 'with TryExcept():' context manager.
|
||||
|
||||
def __init__(self, msg='', verbose=True):
|
||||
Examples:
|
||||
As a decorator:
|
||||
>>> @TryExcept(msg="Error occurred in func", verbose=True)
|
||||
>>> def func():
|
||||
>>> # Function logic here
|
||||
>>> pass
|
||||
|
||||
As a context manager:
|
||||
>>> with TryExcept(msg="Error occurred in block", verbose=True):
|
||||
>>> # Code block here
|
||||
>>> pass
|
||||
"""
|
||||
|
||||
def __init__(self, msg="", verbose=True):
|
||||
"""Initialize TryExcept class with optional message and verbosity settings."""
|
||||
self.msg = msg
|
||||
self.verbose = verbose
|
||||
@ -679,14 +753,80 @@ class TryExcept(contextlib.ContextDecorator):
|
||||
return True
|
||||
|
||||
|
||||
class Retry(contextlib.ContextDecorator):
|
||||
"""
|
||||
Retry class for function execution with exponential backoff.
|
||||
|
||||
Can be used as a decorator or a context manager to retry a function or block of code on exceptions, up to a
|
||||
specified number of times with an exponentially increasing delay between retries.
|
||||
|
||||
Examples:
|
||||
Example usage as a decorator:
|
||||
>>> @Retry(times=3, delay=2)
|
||||
>>> def test_func():
|
||||
>>> # Replace with function logic that may raise exceptions
|
||||
>>> return True
|
||||
|
||||
Example usage as a context manager:
|
||||
>>> with Retry(times=3, delay=2):
|
||||
>>> # Replace with code block that may raise exceptions
|
||||
>>> pass
|
||||
"""
|
||||
|
||||
def __init__(self, times=3, delay=2):
|
||||
"""Initialize Retry class with specified number of retries and delay."""
|
||||
self.times = times
|
||||
self.delay = delay
|
||||
self._attempts = 0
|
||||
|
||||
def __call__(self, func):
|
||||
"""Decorator implementation for Retry with exponential backoff."""
|
||||
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""Applies retries to the decorated function or method."""
|
||||
self._attempts = 0
|
||||
while self._attempts < self.times:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
self._attempts += 1
|
||||
print(f"Retry {self._attempts}/{self.times} failed: {e}")
|
||||
if self._attempts >= self.times:
|
||||
raise e
|
||||
time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay
|
||||
|
||||
return wrapped_func
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the runtime context related to this object."""
|
||||
self._attempts = 0
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit the runtime context related to this object with exponential backoff."""
|
||||
if exc_type is not None:
|
||||
self._attempts += 1
|
||||
if self._attempts < self.times:
|
||||
print(f"Retry {self._attempts}/{self.times} failed: {exc_value}")
|
||||
time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay
|
||||
return True # Suppresses the exception and retries
|
||||
return False # Re-raises the exception if retries are exhausted
|
||||
|
||||
|
||||
def threaded(func):
|
||||
"""Multi-threads a target function and returns thread. Usage: @threaded decorator."""
|
||||
"""
|
||||
Multi-threads a target function by default and returns the thread or function result.
|
||||
|
||||
Use as @threaded decorator. The function runs in a separate thread unless 'threaded=False' is passed.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Multi-threads a given function and returns the thread."""
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
"""Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
|
||||
if kwargs.pop("threaded", True): # run in thread
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -723,27 +863,28 @@ def set_sentry():
|
||||
Returns:
|
||||
dict: The modified event or None if the event should not be sent to Sentry.
|
||||
"""
|
||||
if 'exc_info' in hint:
|
||||
exc_type, exc_value, tb = hint['exc_info']
|
||||
if exc_type in (KeyboardInterrupt, FileNotFoundError) \
|
||||
or 'out of memory' in str(exc_value):
|
||||
if "exc_info" in hint:
|
||||
exc_type, exc_value, tb = hint["exc_info"]
|
||||
if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value):
|
||||
return None # do not send event
|
||||
|
||||
event['tags'] = {
|
||||
'sys_argv': sys.argv[0],
|
||||
'sys_argv_name': Path(sys.argv[0]).name,
|
||||
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
||||
'os': ENVIRONMENT}
|
||||
event["tags"] = {
|
||||
"sys_argv": sys.argv[0],
|
||||
"sys_argv_name": Path(sys.argv[0]).name,
|
||||
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
|
||||
"os": ENVIRONMENT,
|
||||
}
|
||||
return event
|
||||
|
||||
if SETTINGS['sync'] and \
|
||||
RANK in (-1, 0) and \
|
||||
Path(sys.argv[0]).name == 'yolo' and \
|
||||
not TESTS_RUNNING and \
|
||||
ONLINE and \
|
||||
is_pip_package() and \
|
||||
not is_git_dir():
|
||||
|
||||
if (
|
||||
SETTINGS["sync"]
|
||||
and RANK in (-1, 0)
|
||||
and Path(sys.argv[0]).name == "yolo"
|
||||
and not TESTS_RUNNING
|
||||
and ONLINE
|
||||
and is_pip_package()
|
||||
and not is_git_dir()
|
||||
):
|
||||
# If sentry_sdk package is not installed then return and do not use Sentry
|
||||
try:
|
||||
import sentry_sdk # noqa
|
||||
@ -751,18 +892,15 @@ def set_sentry():
|
||||
return
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn='https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016',
|
||||
dsn="https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016",
|
||||
debug=False,
|
||||
traces_sample_rate=1.0,
|
||||
release=__version__,
|
||||
environment='production', # 'dev' or 'production'
|
||||
environment="production", # 'dev' or 'production'
|
||||
before_send=before_send,
|
||||
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
|
||||
sentry_sdk.set_user({'id': SETTINGS['uuid']}) # SHA-256 anonymized UUID hash
|
||||
|
||||
# Disable all sentry logging
|
||||
for logger in 'sentry_sdk', 'sentry_sdk.errors':
|
||||
logging.getLogger(logger).setLevel(logging.CRITICAL)
|
||||
ignore_errors=[KeyboardInterrupt, FileNotFoundError],
|
||||
)
|
||||
sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash
|
||||
|
||||
|
||||
class SettingsManager(dict):
|
||||
@ -774,7 +912,10 @@ class SettingsManager(dict):
|
||||
version (str): Settings version. In case of local version mismatch, new default settings will be saved.
|
||||
"""
|
||||
|
||||
def __init__(self, file=SETTINGS_YAML, version='0.0.4'):
|
||||
def __init__(self, file=SETTINGS_YAML, version="0.0.4"):
|
||||
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
|
||||
file.
|
||||
"""
|
||||
import copy
|
||||
import hashlib
|
||||
|
||||
@ -788,22 +929,24 @@ class SettingsManager(dict):
|
||||
self.file = Path(file)
|
||||
self.version = version
|
||||
self.defaults = {
|
||||
'settings_version': version,
|
||||
'datasets_dir': str(datasets_root / 'datasets'),
|
||||
'weights_dir': str(root / 'weights'),
|
||||
'runs_dir': str(root / 'runs'),
|
||||
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
|
||||
'sync': True,
|
||||
'api_key': '',
|
||||
'clearml': True, # integrations
|
||||
'comet': True,
|
||||
'dvc': True,
|
||||
'hub': True,
|
||||
'mlflow': True,
|
||||
'neptune': True,
|
||||
'raytune': True,
|
||||
'tensorboard': True,
|
||||
'wandb': True}
|
||||
"settings_version": version,
|
||||
"datasets_dir": str(datasets_root / "datasets"),
|
||||
"weights_dir": str(root / "weights"),
|
||||
"runs_dir": str(root / "runs"),
|
||||
"uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
|
||||
"sync": True,
|
||||
"api_key": "",
|
||||
"openai_api_key": "",
|
||||
"clearml": True, # integrations
|
||||
"comet": True,
|
||||
"dvc": True,
|
||||
"hub": True,
|
||||
"mlflow": True,
|
||||
"neptune": True,
|
||||
"raytune": True,
|
||||
"tensorboard": True,
|
||||
"wandb": True,
|
||||
}
|
||||
|
||||
super().__init__(copy.deepcopy(self.defaults))
|
||||
|
||||
@ -814,15 +957,26 @@ class SettingsManager(dict):
|
||||
self.load()
|
||||
correct_keys = self.keys() == self.defaults.keys()
|
||||
correct_types = all(type(a) is type(b) for a, b in zip(self.values(), self.defaults.values()))
|
||||
correct_version = check_version(self['settings_version'], self.version)
|
||||
correct_version = check_version(self["settings_version"], self.version)
|
||||
help_msg = (
|
||||
f"\nView settings with 'yolo settings' or at '{self.file}'"
|
||||
"\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. "
|
||||
"For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings."
|
||||
)
|
||||
if not (correct_keys and correct_types and correct_version):
|
||||
LOGGER.warning(
|
||||
'WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem '
|
||||
'with your settings or a recent ultralytics package update. '
|
||||
f"\nView settings with 'yolo settings' or at '{self.file}'"
|
||||
"\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'.")
|
||||
"WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem "
|
||||
f"with your settings or a recent ultralytics package update. {help_msg}"
|
||||
)
|
||||
self.reset()
|
||||
|
||||
if self.get("datasets_dir") == self.get("runs_dir"):
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' "
|
||||
f"must be different than 'runs_dir: {self.get('runs_dir')}'. "
|
||||
f"Please change one to avoid possible issues during training. {help_msg}"
|
||||
)
|
||||
|
||||
def load(self):
|
||||
"""Loads settings from the YAML file."""
|
||||
super().update(yaml_load(self.file))
|
||||
@ -847,14 +1001,16 @@ def deprecation_warn(arg, new_arg, version=None):
|
||||
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
|
||||
if not version:
|
||||
version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release
|
||||
LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
|
||||
f"Please use '{new_arg}' instead.")
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
|
||||
f"Please use '{new_arg}' instead."
|
||||
)
|
||||
|
||||
|
||||
def clean_url(url):
|
||||
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
|
||||
url = Path(url).as_posix().replace(':/', '://') # Pathlib turns :// -> :/, as_posix() for Windows
|
||||
return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows
|
||||
return urllib.parse.unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
|
||||
|
||||
def url2file(url):
|
||||
@ -865,12 +1021,23 @@ def url2file(url):
|
||||
# Run below code on utils init ------------------------------------------------------------------------------------
|
||||
|
||||
# Check first-install steps
|
||||
PREFIX = colorstr('Ultralytics: ')
|
||||
PREFIX = colorstr("Ultralytics: ")
|
||||
SETTINGS = SettingsManager() # initialize settings
|
||||
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
|
||||
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
|
||||
'Docker' if is_docker() else platform.system()
|
||||
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
|
||||
DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory
|
||||
WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory
|
||||
RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory
|
||||
ENVIRONMENT = (
|
||||
"Colab"
|
||||
if is_colab()
|
||||
else "Kaggle"
|
||||
if is_kaggle()
|
||||
else "Jupyter"
|
||||
if is_jupyter()
|
||||
else "Docker"
|
||||
if is_docker()
|
||||
else platform.system()
|
||||
)
|
||||
TESTS_RUNNING = is_pytest_running() or is_github_action_running()
|
||||
set_sentry()
|
||||
|
||||
# Apply monkey patches
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,7 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
|
||||
"""
|
||||
"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
@ -36,7 +34,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
||||
Args:
|
||||
model (torch.nn.module): YOLO model to compute batch size for.
|
||||
imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
|
||||
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
|
||||
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.60.
|
||||
batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
|
||||
|
||||
Returns:
|
||||
@ -44,14 +42,14 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
||||
"""
|
||||
|
||||
# Check device
|
||||
prefix = colorstr('AutoBatch: ')
|
||||
LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
|
||||
prefix = colorstr("AutoBatch: ")
|
||||
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}")
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type == 'cpu':
|
||||
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
|
||||
if device.type == "cpu":
|
||||
LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")
|
||||
return batch_size
|
||||
if torch.backends.cudnn.benchmark:
|
||||
LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
|
||||
LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
|
||||
return batch_size
|
||||
|
||||
# Inspect CUDA memory
|
||||
@ -62,7 +60,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
||||
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
|
||||
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
|
||||
f = t - (r + a) # GiB free
|
||||
LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
|
||||
LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
|
||||
|
||||
# Profile batch sizes
|
||||
batch_sizes = [1, 2, 4, 8, 16]
|
||||
@ -72,7 +70,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
||||
|
||||
# Fit a solution
|
||||
y = [x[2] for x in results if x] # memory [2]
|
||||
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
|
||||
p = np.polyfit(batch_sizes[: len(y)], y, deg=1) # first degree polynomial fit
|
||||
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
|
||||
if None in results: # some sizes failed
|
||||
i = results.index(None) # first fail index
|
||||
@ -80,11 +78,11 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
||||
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
||||
if b < 1 or b > 1024: # b outside of safe range
|
||||
b = batch_size
|
||||
LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
|
||||
LOGGER.info(f"{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.")
|
||||
|
||||
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
|
||||
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
|
||||
LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
|
||||
return b
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
|
||||
LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.")
|
||||
return batch_size
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Benchmark a YOLO model formats for speed and accuracy
|
||||
Benchmark a YOLO model formats for speed and accuracy.
|
||||
|
||||
Usage:
|
||||
from ultralytics.utils.benchmarks import ProfileModels, benchmark
|
||||
@ -21,34 +21,29 @@ TensorFlow Lite | `tflite` | yolov8n.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
ncnn | `ncnn` | yolov8n_ncnn_model/
|
||||
NCNN | `ncnn` | yolov8n_ncnn_model/
|
||||
"""
|
||||
|
||||
import glob
|
||||
import platform
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics import YOLO, YOLOWorld
|
||||
from ultralytics.cfg import TASK2DATA, TASK2METRIC
|
||||
from ultralytics.engine.exporter import export_formats
|
||||
from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, SETTINGS, TQDM
|
||||
from ultralytics.utils.checks import check_requirements, check_yolo
|
||||
from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
|
||||
from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo
|
||||
from ultralytics.utils.files import file_size
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
|
||||
def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
data=None,
|
||||
imgsz=160,
|
||||
half=False,
|
||||
int8=False,
|
||||
device='cpu',
|
||||
verbose=False):
|
||||
def benchmark(
|
||||
model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False
|
||||
):
|
||||
"""
|
||||
Benchmark a YOLO model across different formats for speed and accuracy.
|
||||
|
||||
@ -76,6 +71,7 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
pd.options.display.max_columns = 10
|
||||
pd.options.display.width = 120
|
||||
device = select_device(device, verbose=False)
|
||||
@ -85,67 +81,72 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
y = []
|
||||
t0 = time.time()
|
||||
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
|
||||
emoji, filename = '❌', None # export defaults
|
||||
emoji, filename = "❌", None # export defaults
|
||||
try:
|
||||
assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
|
||||
if i == 10:
|
||||
assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux'
|
||||
elif i == 11:
|
||||
assert sys.version_info < (3, 11), 'PaddlePaddle export only supported on Python<=3.10'
|
||||
if 'cpu' in device.type:
|
||||
assert cpu, 'inference not supported on CPU'
|
||||
if 'cuda' in device.type:
|
||||
assert gpu, 'inference not supported on GPU'
|
||||
# Checks
|
||||
if i == 9: # Edge TPU
|
||||
assert LINUX, "Edge TPU export only supported on Linux"
|
||||
elif i == 7: # TF GraphDef
|
||||
assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
|
||||
elif i in {5, 10}: # CoreML and TF.js
|
||||
assert MACOS or LINUX, "export only supported on macOS and Linux"
|
||||
if i in {3, 5}: # CoreML and OpenVINO
|
||||
assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
|
||||
if i in {6, 7, 8, 9, 10}: # All TF formats
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
|
||||
if i in {11}: # Paddle
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
|
||||
if i in {12}: # NCNN
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
|
||||
if "cpu" in device.type:
|
||||
assert cpu, "inference not supported on CPU"
|
||||
if "cuda" in device.type:
|
||||
assert gpu, "inference not supported on GPU"
|
||||
|
||||
# Export
|
||||
if format == '-':
|
||||
if format == "-":
|
||||
filename = model.ckpt_path or model.cfg
|
||||
export = model # PyTorch format
|
||||
exported_model = model # PyTorch format
|
||||
else:
|
||||
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
|
||||
export = YOLO(filename, task=model.task)
|
||||
assert suffix in str(filename), 'export failed'
|
||||
emoji = '❎' # indicates export succeeded
|
||||
exported_model = YOLO(filename, task=model.task)
|
||||
assert suffix in str(filename), "export failed"
|
||||
emoji = "❎" # indicates export succeeded
|
||||
|
||||
# Predict
|
||||
assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
|
||||
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
|
||||
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
|
||||
export.predict(ASSETS / 'bus.jpg', imgsz=imgsz, device=device, half=half)
|
||||
assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
|
||||
assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported
|
||||
assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
|
||||
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
|
||||
|
||||
# Validate
|
||||
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
|
||||
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
|
||||
results = export.val(data=data,
|
||||
batch=1,
|
||||
imgsz=imgsz,
|
||||
plots=False,
|
||||
device=device,
|
||||
half=half,
|
||||
int8=int8,
|
||||
verbose=False)
|
||||
metric, speed = results.results_dict[key], results.speed['inference']
|
||||
y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
|
||||
results = exported_model.val(
|
||||
data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False
|
||||
)
|
||||
metric, speed = results.results_dict[key], results.speed["inference"]
|
||||
y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
assert type(e) is AssertionError, f'Benchmark failure for {name}: {e}'
|
||||
LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}')
|
||||
assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
|
||||
LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}")
|
||||
y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
|
||||
|
||||
# Print results
|
||||
check_yolo(device=device) # print system info
|
||||
df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'])
|
||||
df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)"])
|
||||
|
||||
name = Path(model.ckpt_path).name
|
||||
s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n'
|
||||
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n"
|
||||
LOGGER.info(s)
|
||||
with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f:
|
||||
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
|
||||
f.write(s)
|
||||
|
||||
if verbose and isinstance(verbose, float):
|
||||
metrics = df[key].array # values to compare to floor
|
||||
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
|
||||
assert all(x > floor for x in metrics if pd.notna(x)), f'Benchmark failure: metric(s) < floor {floor}'
|
||||
assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
|
||||
|
||||
return df
|
||||
|
||||
@ -154,8 +155,7 @@ class ProfileModels:
|
||||
"""
|
||||
ProfileModels class for profiling different models on ONNX and TensorRT.
|
||||
|
||||
This class profiles the performance of different models, provided their paths. The profiling includes parameters such as
|
||||
model speed and FLOPs.
|
||||
This class profiles the performance of different models, returning results such as model speed and FLOPs.
|
||||
|
||||
Attributes:
|
||||
paths (list): Paths of the models to profile.
|
||||
@ -175,15 +175,30 @@ class ProfileModels:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
paths: list,
|
||||
num_timed_runs=100,
|
||||
num_warmup_runs=10,
|
||||
min_time=60,
|
||||
imgsz=640,
|
||||
half=True,
|
||||
trt=True,
|
||||
device=None):
|
||||
def __init__(
|
||||
self,
|
||||
paths: list,
|
||||
num_timed_runs=100,
|
||||
num_warmup_runs=10,
|
||||
min_time=60,
|
||||
imgsz=640,
|
||||
half=True,
|
||||
trt=True,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Initialize the ProfileModels class for profiling models.
|
||||
|
||||
Args:
|
||||
paths (list): List of paths of the models to be profiled.
|
||||
num_timed_runs (int, optional): Number of timed runs for the profiling. Default is 100.
|
||||
num_warmup_runs (int, optional): Number of warmup runs before the actual profiling starts. Default is 10.
|
||||
min_time (float, optional): Minimum time in seconds for profiling a model. Default is 60.
|
||||
imgsz (int, optional): Size of the image used during profiling. Default is 640.
|
||||
half (bool, optional): Flag to indicate whether to use half-precision floating point for profiling.
|
||||
trt (bool, optional): Flag to indicate whether to profile using TensorRT. Default is True.
|
||||
device (torch.device, optional): Device used for profiling. If None, it is determined automatically.
|
||||
"""
|
||||
self.paths = paths
|
||||
self.num_timed_runs = num_timed_runs
|
||||
self.num_warmup_runs = num_warmup_runs
|
||||
@ -191,36 +206,32 @@ class ProfileModels:
|
||||
self.imgsz = imgsz
|
||||
self.half = half
|
||||
self.trt = trt # run TensorRT profiling
|
||||
self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu')
|
||||
self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def profile(self):
|
||||
"""Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
|
||||
files = self.get_files()
|
||||
|
||||
if not files:
|
||||
print('No matching *.pt or *.onnx files found.')
|
||||
print("No matching *.pt or *.onnx files found.")
|
||||
return
|
||||
|
||||
table_rows = []
|
||||
output = []
|
||||
for file in files:
|
||||
engine_file = file.with_suffix('.engine')
|
||||
if file.suffix in ('.pt', '.yaml', '.yml'):
|
||||
engine_file = file.with_suffix(".engine")
|
||||
if file.suffix in (".pt", ".yaml", ".yml"):
|
||||
model = YOLO(str(file))
|
||||
model.fuse() # to report correct params and GFLOPs in model.info()
|
||||
model_info = model.info()
|
||||
if self.trt and self.device.type != 'cpu' and not engine_file.is_file():
|
||||
engine_file = model.export(format='engine',
|
||||
half=self.half,
|
||||
imgsz=self.imgsz,
|
||||
device=self.device,
|
||||
verbose=False)
|
||||
onnx_file = model.export(format='onnx',
|
||||
half=self.half,
|
||||
imgsz=self.imgsz,
|
||||
simplify=True,
|
||||
device=self.device,
|
||||
verbose=False)
|
||||
elif file.suffix == '.onnx':
|
||||
if self.trt and self.device.type != "cpu" and not engine_file.is_file():
|
||||
engine_file = model.export(
|
||||
format="engine", half=self.half, imgsz=self.imgsz, device=self.device, verbose=False
|
||||
)
|
||||
onnx_file = model.export(
|
||||
format="onnx", half=self.half, imgsz=self.imgsz, simplify=True, device=self.device, verbose=False
|
||||
)
|
||||
elif file.suffix == ".onnx":
|
||||
model_info = self.get_onnx_model_info(file)
|
||||
onnx_file = file
|
||||
else:
|
||||
@ -235,25 +246,30 @@ class ProfileModels:
|
||||
return output
|
||||
|
||||
def get_files(self):
|
||||
"""Returns a list of paths for all relevant model files given by the user."""
|
||||
files = []
|
||||
for path in self.paths:
|
||||
path = Path(path)
|
||||
if path.is_dir():
|
||||
extensions = ['*.pt', '*.onnx', '*.yaml']
|
||||
extensions = ["*.pt", "*.onnx", "*.yaml"]
|
||||
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
|
||||
elif path.suffix in {'.pt', '.yaml', '.yml'}: # add non-existing
|
||||
elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing
|
||||
files.append(str(path))
|
||||
else:
|
||||
files.extend(glob.glob(str(path)))
|
||||
|
||||
print(f'Profiling: {sorted(files)}')
|
||||
print(f"Profiling: {sorted(files)}")
|
||||
return [Path(file) for file in sorted(files)]
|
||||
|
||||
def get_onnx_model_info(self, onnx_file: str):
|
||||
# return (num_layers, num_params, num_gradients, num_flops)
|
||||
return 0.0, 0.0, 0.0, 0.0
|
||||
"""Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
|
||||
file.
|
||||
"""
|
||||
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
|
||||
|
||||
def iterative_sigma_clipping(self, data, sigma=2, max_iters=3):
|
||||
@staticmethod
|
||||
def iterative_sigma_clipping(data, sigma=2, max_iters=3):
|
||||
"""Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
|
||||
data = np.array(data)
|
||||
for _ in range(max_iters):
|
||||
mean, std = np.mean(data), np.std(data)
|
||||
@ -264,6 +280,7 @@ class ProfileModels:
|
||||
return data
|
||||
|
||||
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
|
||||
"""Profiles the TensorRT model, measuring average run time and standard deviation among runs."""
|
||||
if not self.trt or not Path(engine_file).is_file():
|
||||
return 0.0, 0.0
|
||||
|
||||
@ -286,39 +303,44 @@ class ProfileModels:
|
||||
run_times = []
|
||||
for _ in TQDM(range(num_runs), desc=engine_file):
|
||||
results = model(input_data, imgsz=self.imgsz, verbose=False)
|
||||
run_times.append(results[0].speed['inference']) # Convert to milliseconds
|
||||
run_times.append(results[0].speed["inference"]) # Convert to milliseconds
|
||||
|
||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
|
||||
check_requirements('onnxruntime')
|
||||
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
|
||||
times.
|
||||
"""
|
||||
check_requirements("onnxruntime")
|
||||
import onnxruntime as ort
|
||||
|
||||
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.intra_op_num_threads = 8 # Limit the number of threads
|
||||
sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider'])
|
||||
sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
|
||||
|
||||
input_tensor = sess.get_inputs()[0]
|
||||
input_type = input_tensor.type
|
||||
dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape
|
||||
input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape
|
||||
|
||||
# Mapping ONNX datatype to numpy datatype
|
||||
if 'float16' in input_type:
|
||||
if "float16" in input_type:
|
||||
input_dtype = np.float16
|
||||
elif 'float' in input_type:
|
||||
elif "float" in input_type:
|
||||
input_dtype = np.float32
|
||||
elif 'double' in input_type:
|
||||
elif "double" in input_type:
|
||||
input_dtype = np.float64
|
||||
elif 'int64' in input_type:
|
||||
elif "int64" in input_type:
|
||||
input_dtype = np.int64
|
||||
elif 'int32' in input_type:
|
||||
elif "int32" in input_type:
|
||||
input_dtype = np.int32
|
||||
else:
|
||||
raise ValueError(f'Unsupported ONNX datatype {input_type}')
|
||||
raise ValueError(f"Unsupported ONNX datatype {input_type}")
|
||||
|
||||
input_data = np.random.rand(*input_tensor.shape).astype(input_dtype)
|
||||
input_data = np.random.rand(*input_shape).astype(input_dtype)
|
||||
input_name = input_tensor.name
|
||||
output_name = sess.get_outputs()[0].name
|
||||
|
||||
@ -344,24 +366,39 @@ class ProfileModels:
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
|
||||
"""Generates a formatted string for a table row that includes model performance and metric details."""
|
||||
layers, params, gradients, flops = model_info
|
||||
return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
|
||||
return (
|
||||
f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± "
|
||||
f"{t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
|
||||
)
|
||||
|
||||
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
|
||||
@staticmethod
|
||||
def generate_results_dict(model_name, t_onnx, t_engine, model_info):
|
||||
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
|
||||
layers, params, gradients, flops = model_info
|
||||
return {
|
||||
'model/name': model_name,
|
||||
'model/parameters': params,
|
||||
'model/GFLOPs': round(flops, 3),
|
||||
'model/speed_ONNX(ms)': round(t_onnx[0], 3),
|
||||
'model/speed_TensorRT(ms)': round(t_engine[0], 3)}
|
||||
"model/name": model_name,
|
||||
"model/parameters": params,
|
||||
"model/GFLOPs": round(flops, 3),
|
||||
"model/speed_ONNX(ms)": round(t_onnx[0], 3),
|
||||
"model/speed_TensorRT(ms)": round(t_engine[0], 3),
|
||||
}
|
||||
|
||||
def print_table(self, table_rows):
|
||||
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
|
||||
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
|
||||
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
|
||||
@staticmethod
|
||||
def print_table(table_rows):
|
||||
"""Formats and prints a comparison table for different models with given statistics and performance data."""
|
||||
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
|
||||
header = (
|
||||
f"| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | "
|
||||
f"Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |"
|
||||
)
|
||||
separator = (
|
||||
"|-------------|---------------------|--------------------|------------------------------|"
|
||||
"-----------------------------------|------------------|-----------------|"
|
||||
)
|
||||
|
||||
print(f'\n\n{header}')
|
||||
print(f"\n\n{header}")
|
||||
print(separator)
|
||||
for row in table_rows:
|
||||
print(row)
|
||||
|
@ -2,4 +2,4 @@
|
||||
|
||||
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
|
||||
|
||||
__all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks'
|
||||
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,11 +1,10 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Base callbacks
|
||||
"""
|
||||
"""Base callbacks."""
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
# Trainer callbacks ----------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@ -145,37 +144,35 @@ def on_export_end(exporter):
|
||||
|
||||
default_callbacks = {
|
||||
# Run in trainer
|
||||
'on_pretrain_routine_start': [on_pretrain_routine_start],
|
||||
'on_pretrain_routine_end': [on_pretrain_routine_end],
|
||||
'on_train_start': [on_train_start],
|
||||
'on_train_epoch_start': [on_train_epoch_start],
|
||||
'on_train_batch_start': [on_train_batch_start],
|
||||
'optimizer_step': [optimizer_step],
|
||||
'on_before_zero_grad': [on_before_zero_grad],
|
||||
'on_train_batch_end': [on_train_batch_end],
|
||||
'on_train_epoch_end': [on_train_epoch_end],
|
||||
'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
|
||||
'on_model_save': [on_model_save],
|
||||
'on_train_end': [on_train_end],
|
||||
'on_params_update': [on_params_update],
|
||||
'teardown': [teardown],
|
||||
|
||||
"on_pretrain_routine_start": [on_pretrain_routine_start],
|
||||
"on_pretrain_routine_end": [on_pretrain_routine_end],
|
||||
"on_train_start": [on_train_start],
|
||||
"on_train_epoch_start": [on_train_epoch_start],
|
||||
"on_train_batch_start": [on_train_batch_start],
|
||||
"optimizer_step": [optimizer_step],
|
||||
"on_before_zero_grad": [on_before_zero_grad],
|
||||
"on_train_batch_end": [on_train_batch_end],
|
||||
"on_train_epoch_end": [on_train_epoch_end],
|
||||
"on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
|
||||
"on_model_save": [on_model_save],
|
||||
"on_train_end": [on_train_end],
|
||||
"on_params_update": [on_params_update],
|
||||
"teardown": [teardown],
|
||||
# Run in validator
|
||||
'on_val_start': [on_val_start],
|
||||
'on_val_batch_start': [on_val_batch_start],
|
||||
'on_val_batch_end': [on_val_batch_end],
|
||||
'on_val_end': [on_val_end],
|
||||
|
||||
"on_val_start": [on_val_start],
|
||||
"on_val_batch_start": [on_val_batch_start],
|
||||
"on_val_batch_end": [on_val_batch_end],
|
||||
"on_val_end": [on_val_end],
|
||||
# Run in predictor
|
||||
'on_predict_start': [on_predict_start],
|
||||
'on_predict_batch_start': [on_predict_batch_start],
|
||||
'on_predict_postprocess_end': [on_predict_postprocess_end],
|
||||
'on_predict_batch_end': [on_predict_batch_end],
|
||||
'on_predict_end': [on_predict_end],
|
||||
|
||||
"on_predict_start": [on_predict_start],
|
||||
"on_predict_batch_start": [on_predict_batch_start],
|
||||
"on_predict_postprocess_end": [on_predict_postprocess_end],
|
||||
"on_predict_batch_end": [on_predict_batch_end],
|
||||
"on_predict_end": [on_predict_end],
|
||||
# Run in exporter
|
||||
'on_export_start': [on_export_start],
|
||||
'on_export_end': [on_export_end]}
|
||||
"on_export_start": [on_export_start],
|
||||
"on_export_end": [on_export_end],
|
||||
}
|
||||
|
||||
|
||||
def get_default_callbacks():
|
||||
@ -199,10 +196,11 @@ def add_integration_callbacks(instance):
|
||||
|
||||
# Load HUB callbacks
|
||||
from .hub import callbacks as hub_cb
|
||||
|
||||
callbacks_list = [hub_cb]
|
||||
|
||||
# Load training callbacks
|
||||
if 'Trainer' in instance.__class__.__name__:
|
||||
if "Trainer" in instance.__class__.__name__:
|
||||
from .clearml import callbacks as clear_cb
|
||||
from .comet import callbacks as comet_cb
|
||||
from .dvc import callbacks as dvc_cb
|
||||
@ -211,12 +209,8 @@ def add_integration_callbacks(instance):
|
||||
from .raytune import callbacks as tune_cb
|
||||
from .tensorboard import callbacks as tb_cb
|
||||
from .wb import callbacks as wb_cb
|
||||
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
|
||||
|
||||
# Load export callbacks (patch to avoid CoreML protobuf error)
|
||||
if 'Exporter' in instance.__class__.__name__:
|
||||
from .tensorboard import callbacks as tb_cb
|
||||
callbacks_list.append(tb_cb)
|
||||
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
|
||||
|
||||
# Add the callbacks to the callbacks dictionary
|
||||
for callbacks in callbacks_list:
|
||||
|
@ -4,19 +4,19 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['clearml'] is True # verify integration is enabled
|
||||
assert SETTINGS["clearml"] is True # verify integration is enabled
|
||||
import clearml
|
||||
from clearml import Task
|
||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||
|
||||
assert hasattr(clearml, '__version__') # verify package is not directory
|
||||
assert hasattr(clearml, "__version__") # verify package is not directory
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
clearml = None
|
||||
|
||||
|
||||
def _log_debug_samples(files, title='Debug Samples') -> None:
|
||||
def _log_debug_samples(files, title="Debug Samples") -> None:
|
||||
"""
|
||||
Log files (images) as debug samples in the ClearML task.
|
||||
|
||||
@ -29,12 +29,11 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
|
||||
if task := Task.current_task():
|
||||
for f in files:
|
||||
if f.exists():
|
||||
it = re.search(r'_batch(\d+)', f.name)
|
||||
it = re.search(r"_batch(\d+)", f.name)
|
||||
iteration = int(it.groups()[0]) if it else 0
|
||||
task.get_logger().report_image(title=title,
|
||||
series=f.name.replace(it.group(), ''),
|
||||
local_path=str(f),
|
||||
iteration=iteration)
|
||||
task.get_logger().report_image(
|
||||
title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
|
||||
)
|
||||
|
||||
|
||||
def _log_plot(title, plot_path) -> None:
|
||||
@ -50,13 +49,12 @@ def _log_plot(title, plot_path) -> None:
|
||||
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
|
||||
Task.current_task().get_logger().report_matplotlib_figure(title=title,
|
||||
series='',
|
||||
figure=fig,
|
||||
report_interactive=False)
|
||||
Task.current_task().get_logger().report_matplotlib_figure(
|
||||
title=title, series="", figure=fig, report_interactive=False
|
||||
)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
@ -68,19 +66,21 @@ def on_pretrain_routine_start(trainer):
|
||||
PatchPyTorchModelIO.update_current_task(None)
|
||||
PatchedMatplotlib.update_current_task(None)
|
||||
else:
|
||||
task = Task.init(project_name=trainer.args.project or 'YOLOv8',
|
||||
task_name=trainer.args.name,
|
||||
tags=['YOLOv8'],
|
||||
output_uri=True,
|
||||
reuse_last_task_id=False,
|
||||
auto_connect_frameworks={
|
||||
'pytorch': False,
|
||||
'matplotlib': False})
|
||||
LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, '
|
||||
'please add clearml-init and connect your arguments before initializing YOLO.')
|
||||
task.connect(vars(trainer.args), name='General')
|
||||
task = Task.init(
|
||||
project_name=trainer.args.project or "YOLOv8",
|
||||
task_name=trainer.args.name,
|
||||
tags=["YOLOv8"],
|
||||
output_uri=True,
|
||||
reuse_last_task_id=False,
|
||||
auto_connect_frameworks={"pytorch": False, "matplotlib": False},
|
||||
)
|
||||
LOGGER.warning(
|
||||
"ClearML Initialized a new task. If you want to run remotely, "
|
||||
"please add clearml-init and connect your arguments before initializing YOLO."
|
||||
)
|
||||
task.connect(vars(trainer.args), name="General")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
@ -88,22 +88,26 @@ def on_train_epoch_end(trainer):
|
||||
if task := Task.current_task():
|
||||
# Log debug samples
|
||||
if trainer.epoch == 1:
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
|
||||
# Report the current training progress
|
||||
for k, v in trainer.validator.metrics.results_dict.items():
|
||||
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
|
||||
for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
|
||||
task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
|
||||
for k, v in trainer.lr.items():
|
||||
task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Reports model information to logger at the end of an epoch."""
|
||||
if task := Task.current_task():
|
||||
# You should have access to the validation bboxes under jdict
|
||||
task.get_logger().report_scalar(title='Epoch Time',
|
||||
series='Epoch Time',
|
||||
value=trainer.epoch_time,
|
||||
iteration=trainer.epoch)
|
||||
task.get_logger().report_scalar(
|
||||
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
|
||||
)
|
||||
for k, v in trainer.metrics.items():
|
||||
task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch)
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
for k, v in model_info_for_loggers(trainer).items():
|
||||
task.get_logger().report_single_value(k, v)
|
||||
|
||||
@ -112,7 +116,7 @@ def on_val_end(validator):
|
||||
"""Logs validation results including labels and predictions."""
|
||||
if Task.current_task():
|
||||
# Log val_labels and val_pred
|
||||
_log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
|
||||
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
@ -120,8 +124,11 @@ def on_train_end(trainer):
|
||||
if task := Task.current_task():
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = [
|
||||
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
|
||||
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
"results.png",
|
||||
"confusion_matrix.png",
|
||||
"confusion_matrix_normalized.png",
|
||||
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
|
||||
]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
@ -132,9 +139,14 @@ def on_train_end(trainer):
|
||||
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_val_end': on_val_end,
|
||||
'on_train_end': on_train_end} if clearml else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_val_end": on_val_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if clearml
|
||||
else {}
|
||||
)
|
||||
|
@ -4,20 +4,20 @@ from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['comet'] is True # verify integration is enabled
|
||||
assert SETTINGS["comet"] is True # verify integration is enabled
|
||||
import comet_ml
|
||||
|
||||
assert hasattr(comet_ml, '__version__') # verify package is not directory
|
||||
assert hasattr(comet_ml, "__version__") # verify package is not directory
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Ensures certain logging functions only run for supported tasks
|
||||
COMET_SUPPORTED_TASKS = ['detect']
|
||||
COMET_SUPPORTED_TASKS = ["detect"]
|
||||
|
||||
# Names of plots created by YOLOv8 that are logged to Comet
|
||||
EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
|
||||
LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
|
||||
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
|
||||
LABEL_PLOT_NAMES = "labels", "labels_correlogram"
|
||||
|
||||
_comet_image_prediction_count = 0
|
||||
|
||||
@ -26,37 +26,44 @@ except (ImportError, AssertionError):
|
||||
|
||||
|
||||
def _get_comet_mode():
|
||||
return os.getenv('COMET_MODE', 'online')
|
||||
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
|
||||
return os.getenv("COMET_MODE", "online")
|
||||
|
||||
|
||||
def _get_comet_model_name():
|
||||
return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
|
||||
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
|
||||
return os.getenv("COMET_MODEL_NAME", "YOLOv8")
|
||||
|
||||
|
||||
def _get_eval_batch_logging_interval():
|
||||
return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
|
||||
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
|
||||
return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
|
||||
|
||||
|
||||
def _get_max_image_predictions_to_log():
|
||||
return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
|
||||
"""Get the maximum number of image predictions to log from the environment variables."""
|
||||
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
|
||||
|
||||
|
||||
def _scale_confidence_score(score):
|
||||
scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
|
||||
"""Scales the given confidence score by a factor specified in an environment variable."""
|
||||
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
|
||||
return score * scale
|
||||
|
||||
|
||||
def _should_log_confusion_matrix():
|
||||
return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
|
||||
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
|
||||
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
|
||||
|
||||
|
||||
def _should_log_image_predictions():
|
||||
return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
|
||||
"""Determines whether to log image predictions based on a specified environment variable."""
|
||||
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
|
||||
|
||||
|
||||
def _get_experiment_type(mode, project_name):
|
||||
"""Return an experiment based on mode and project name."""
|
||||
if mode == 'offline':
|
||||
if mode == "offline":
|
||||
return comet_ml.OfflineExperiment(project_name=project_name)
|
||||
|
||||
return comet_ml.Experiment(project_name=project_name)
|
||||
@ -68,18 +75,21 @@ def _create_experiment(args):
|
||||
return
|
||||
try:
|
||||
comet_mode = _get_comet_mode()
|
||||
_project_name = os.getenv('COMET_PROJECT_NAME', args.project)
|
||||
_project_name = os.getenv("COMET_PROJECT_NAME", args.project)
|
||||
experiment = _get_experiment_type(comet_mode, _project_name)
|
||||
experiment.log_parameters(vars(args))
|
||||
experiment.log_others({
|
||||
'eval_batch_logging_interval': _get_eval_batch_logging_interval(),
|
||||
'log_confusion_matrix_on_eval': _should_log_confusion_matrix(),
|
||||
'log_image_predictions': _should_log_image_predictions(),
|
||||
'max_image_predictions': _get_max_image_predictions_to_log(), })
|
||||
experiment.log_other('Created from', 'yolov8')
|
||||
experiment.log_others(
|
||||
{
|
||||
"eval_batch_logging_interval": _get_eval_batch_logging_interval(),
|
||||
"log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
|
||||
"log_image_predictions": _should_log_image_predictions(),
|
||||
"max_image_predictions": _get_max_image_predictions_to_log(),
|
||||
}
|
||||
)
|
||||
experiment.log_other("Created from", "yolov8")
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def _fetch_trainer_metadata(trainer):
|
||||
@ -95,18 +105,14 @@ def _fetch_trainer_metadata(trainer):
|
||||
save_interval = curr_epoch % save_period == 0
|
||||
save_assets = save and save_period > 0 and save_interval and not final_epoch
|
||||
|
||||
return dict(
|
||||
curr_epoch=curr_epoch,
|
||||
curr_step=curr_step,
|
||||
save_assets=save_assets,
|
||||
final_epoch=final_epoch,
|
||||
)
|
||||
return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
|
||||
|
||||
|
||||
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
|
||||
"""YOLOv8 resizes images during training and the label values
|
||||
are normalized based on this resized shape. This function rescales the
|
||||
bounding box labels to the original image shape.
|
||||
"""
|
||||
YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
|
||||
|
||||
This function rescales the bounding box labels to the original image shape.
|
||||
"""
|
||||
|
||||
resized_image_height, resized_image_width = resized_image_shape
|
||||
@ -126,29 +132,32 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
|
||||
|
||||
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
|
||||
"""Format ground truth annotations for detection."""
|
||||
indices = batch['batch_idx'] == img_idx
|
||||
bboxes = batch['bboxes'][indices]
|
||||
indices = batch["batch_idx"] == img_idx
|
||||
bboxes = batch["bboxes"][indices]
|
||||
if len(bboxes) == 0:
|
||||
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes labels')
|
||||
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
|
||||
return None
|
||||
|
||||
cls_labels = batch['cls'][indices].squeeze(1).tolist()
|
||||
cls_labels = batch["cls"][indices].squeeze(1).tolist()
|
||||
if class_name_map:
|
||||
cls_labels = [str(class_name_map[label]) for label in cls_labels]
|
||||
|
||||
original_image_shape = batch['ori_shape'][img_idx]
|
||||
resized_image_shape = batch['resized_shape'][img_idx]
|
||||
ratio_pad = batch['ratio_pad'][img_idx]
|
||||
original_image_shape = batch["ori_shape"][img_idx]
|
||||
resized_image_shape = batch["resized_shape"][img_idx]
|
||||
ratio_pad = batch["ratio_pad"][img_idx]
|
||||
|
||||
data = []
|
||||
for box, label in zip(bboxes, cls_labels):
|
||||
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
|
||||
data.append({
|
||||
'boxes': [box],
|
||||
'label': f'gt_{label}',
|
||||
'score': _scale_confidence_score(1.0), })
|
||||
data.append(
|
||||
{
|
||||
"boxes": [box],
|
||||
"label": f"gt_{label}",
|
||||
"score": _scale_confidence_score(1.0),
|
||||
}
|
||||
)
|
||||
|
||||
return {'name': 'ground_truth', 'data': data}
|
||||
return {"name": "ground_truth", "data": data}
|
||||
|
||||
|
||||
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
|
||||
@ -158,31 +167,34 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab
|
||||
|
||||
predictions = metadata.get(image_id)
|
||||
if not predictions:
|
||||
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes predictions')
|
||||
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
|
||||
return None
|
||||
|
||||
data = []
|
||||
for prediction in predictions:
|
||||
boxes = prediction['bbox']
|
||||
score = _scale_confidence_score(prediction['score'])
|
||||
cls_label = prediction['category_id']
|
||||
boxes = prediction["bbox"]
|
||||
score = _scale_confidence_score(prediction["score"])
|
||||
cls_label = prediction["category_id"]
|
||||
if class_label_map:
|
||||
cls_label = str(class_label_map[cls_label])
|
||||
|
||||
data.append({'boxes': [boxes], 'label': cls_label, 'score': score})
|
||||
data.append({"boxes": [boxes], "label": cls_label, "score": score})
|
||||
|
||||
return {'name': 'prediction', 'data': data}
|
||||
return {"name": "prediction", "data": data}
|
||||
|
||||
|
||||
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
|
||||
"""Join the ground truth and prediction annotations if they exist."""
|
||||
ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch,
|
||||
class_label_map)
|
||||
prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map,
|
||||
class_label_map)
|
||||
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
|
||||
img_idx, image_path, batch, class_label_map
|
||||
)
|
||||
prediction_annotations = _format_prediction_annotations_for_detection(
|
||||
image_path, prediction_metadata_map, class_label_map
|
||||
)
|
||||
|
||||
annotations = [
|
||||
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None]
|
||||
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
|
||||
]
|
||||
return [annotations] if annotations else None
|
||||
|
||||
|
||||
@ -190,8 +202,8 @@ def _create_prediction_metadata_map(model_predictions):
|
||||
"""Create metadata map for model predictions by groupings them based on image ID."""
|
||||
pred_metadata_map = {}
|
||||
for prediction in model_predictions:
|
||||
pred_metadata_map.setdefault(prediction['image_id'], [])
|
||||
pred_metadata_map[prediction['image_id']].append(prediction)
|
||||
pred_metadata_map.setdefault(prediction["image_id"], [])
|
||||
pred_metadata_map[prediction["image_id"]].append(prediction)
|
||||
|
||||
return pred_metadata_map
|
||||
|
||||
@ -199,13 +211,9 @@ def _create_prediction_metadata_map(model_predictions):
|
||||
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
|
||||
"""Log the confusion matrix to Comet experiment."""
|
||||
conf_mat = trainer.validator.confusion_matrix.matrix
|
||||
names = list(trainer.data['names'].values()) + ['background']
|
||||
names = list(trainer.data["names"].values()) + ["background"]
|
||||
experiment.log_confusion_matrix(
|
||||
matrix=conf_mat,
|
||||
labels=names,
|
||||
max_categories=len(names),
|
||||
epoch=curr_epoch,
|
||||
step=curr_step,
|
||||
matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
|
||||
)
|
||||
|
||||
|
||||
@ -243,7 +251,7 @@ def _log_image_predictions(experiment, validator, curr_step):
|
||||
if (batch_idx + 1) % batch_logging_interval != 0:
|
||||
continue
|
||||
|
||||
image_paths = batch['im_file']
|
||||
image_paths = batch["im_file"]
|
||||
for img_idx, image_path in enumerate(image_paths):
|
||||
if _comet_image_prediction_count >= max_image_predictions:
|
||||
return
|
||||
@ -267,28 +275,23 @@ def _log_image_predictions(experiment, validator, curr_step):
|
||||
|
||||
def _log_plots(experiment, trainer):
|
||||
"""Logs evaluation plots and label plots for the experiment."""
|
||||
plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES]
|
||||
plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
|
||||
_log_images(experiment, plot_filenames, None)
|
||||
|
||||
label_plot_filenames = [trainer.save_dir / f'{labels}.jpg' for labels in LABEL_PLOT_NAMES]
|
||||
label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
|
||||
_log_images(experiment, label_plot_filenames, None)
|
||||
|
||||
|
||||
def _log_model(experiment, trainer):
|
||||
"""Log the best-trained model to Comet.ml."""
|
||||
model_name = _get_comet_model_name()
|
||||
experiment.log_model(
|
||||
model_name,
|
||||
file_or_folder=str(trainer.best),
|
||||
file_name='best.pt',
|
||||
overwrite=True,
|
||||
)
|
||||
experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
is_alive = getattr(experiment, 'alive', False)
|
||||
is_alive = getattr(experiment, "alive", False)
|
||||
if not experiment or not is_alive:
|
||||
_create_experiment(trainer.args)
|
||||
|
||||
@ -300,17 +303,13 @@ def on_train_epoch_end(trainer):
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata['curr_epoch']
|
||||
curr_step = metadata['curr_step']
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
|
||||
experiment.log_metrics(
|
||||
trainer.label_loss_items(trainer.tloss, prefix='train'),
|
||||
step=curr_step,
|
||||
epoch=curr_epoch,
|
||||
)
|
||||
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
|
||||
|
||||
if curr_epoch == 1:
|
||||
_log_images(experiment, trainer.save_dir.glob('train_batch*.jpg'), curr_step)
|
||||
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
@ -320,14 +319,15 @@ def on_fit_epoch_end(trainer):
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata['curr_epoch']
|
||||
curr_step = metadata['curr_step']
|
||||
save_assets = metadata['save_assets']
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
save_assets = metadata["save_assets"]
|
||||
|
||||
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
|
||||
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
|
||||
if curr_epoch == 1:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
|
||||
|
||||
if not save_assets:
|
||||
@ -347,8 +347,8 @@ def on_train_end(trainer):
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata['curr_epoch']
|
||||
curr_step = metadata['curr_step']
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
plots = trainer.args.plots
|
||||
|
||||
_log_model(experiment, trainer)
|
||||
@ -363,8 +363,13 @@ def on_train_end(trainer):
|
||||
_comet_image_prediction_count = 0
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if comet_ml else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if comet_ml
|
||||
else {}
|
||||
)
|
||||
|
@ -1,26 +1,18 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['dvc'] is True # verify integration is enabled
|
||||
assert SETTINGS["dvc"] is True # verify integration is enabled
|
||||
import dvclive
|
||||
|
||||
assert hasattr(dvclive, '__version__') # verify package is not directory
|
||||
assert checks.check_version("dvclive", "2.11.0", verbose=True)
|
||||
|
||||
import os
|
||||
import re
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources as pkg
|
||||
|
||||
ver = version('dvclive')
|
||||
if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
|
||||
LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
|
||||
dvclive = None # noqa: F811
|
||||
|
||||
# DVCLive logger instance
|
||||
live = None
|
||||
_processed_plots = {}
|
||||
@ -33,108 +25,121 @@ except (ImportError, AssertionError, TypeError):
|
||||
dvclive = None
|
||||
|
||||
|
||||
def _log_images(path, prefix=''):
|
||||
def _log_images(path, prefix=""):
|
||||
"""Logs images at specified path with an optional prefix using DVCLive."""
|
||||
if live:
|
||||
name = path.name
|
||||
|
||||
# Group images by batch to enable sliders in UI
|
||||
if m := re.search(r'_batch(\d+)', name):
|
||||
if m := re.search(r"_batch(\d+)", name):
|
||||
ni = m[1]
|
||||
new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
|
||||
new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
|
||||
name = (Path(new_stem) / ni).with_suffix(path.suffix)
|
||||
|
||||
live.log_image(os.path.join(prefix, name), path)
|
||||
|
||||
|
||||
def _log_plots(plots, prefix=''):
|
||||
def _log_plots(plots, prefix=""):
|
||||
"""Logs plot images for training progress if they have not been previously processed."""
|
||||
for name, params in plots.items():
|
||||
timestamp = params['timestamp']
|
||||
timestamp = params["timestamp"]
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
_log_images(name, prefix)
|
||||
_processed_plots[name] = timestamp
|
||||
|
||||
|
||||
def _log_confusion_matrix(validator):
|
||||
"""Logs the confusion matrix for the given validator using DVCLive."""
|
||||
targets = []
|
||||
preds = []
|
||||
matrix = validator.confusion_matrix.matrix
|
||||
names = list(validator.names.values())
|
||||
if validator.confusion_matrix.task == 'detect':
|
||||
names += ['background']
|
||||
if validator.confusion_matrix.task == "detect":
|
||||
names += ["background"]
|
||||
|
||||
for ti, pred in enumerate(matrix.T.astype(int)):
|
||||
for pi, num in enumerate(pred):
|
||||
targets.extend([names[ti]] * num)
|
||||
preds.extend([names[pi]] * num)
|
||||
|
||||
live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)
|
||||
live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initializes DVCLive logger for training metadata during pre-training routine."""
|
||||
try:
|
||||
global live
|
||||
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
|
||||
LOGGER.info(
|
||||
f'DVCLive is detected and auto logging is enabled (can be disabled in the {SETTINGS.file} with `dvc: false`).'
|
||||
)
|
||||
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
_log_plots(trainer.plots, 'train')
|
||||
"""Logs plots related to the training process at the end of the pretraining routine."""
|
||||
_log_plots(trainer.plots, "train")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Logs the training parameters if DVCLive logging is active."""
|
||||
if live:
|
||||
live.log_params(trainer.args)
|
||||
|
||||
|
||||
def on_train_epoch_start(trainer):
|
||||
"""Sets the global variable _training_epoch value to True at the start of training each epoch."""
|
||||
global _training_epoch
|
||||
_training_epoch = True
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
|
||||
global _training_epoch
|
||||
if live and _training_epoch:
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
|
||||
for metric, value in all_metrics.items():
|
||||
live.log_metric(metric, value)
|
||||
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
for metric, value in model_info_for_loggers(trainer).items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
_log_plots(trainer.plots, 'train')
|
||||
_log_plots(trainer.validator.plots, 'val')
|
||||
_log_plots(trainer.plots, "train")
|
||||
_log_plots(trainer.validator.plots, "val")
|
||||
|
||||
live.next_step()
|
||||
_training_epoch = False
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
|
||||
if live:
|
||||
# At the end log the best metrics. It runs validator on the best model internally.
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
|
||||
for metric, value in all_metrics.items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
_log_plots(trainer.plots, 'val')
|
||||
_log_plots(trainer.validator.plots, 'val')
|
||||
_log_plots(trainer.plots, "val")
|
||||
_log_plots(trainer.validator.plots, "val")
|
||||
_log_confusion_matrix(trainer.validator)
|
||||
|
||||
if trainer.best.exists():
|
||||
live.log_artifact(trainer.best, copy=True, type='model')
|
||||
live.log_artifact(trainer.best, copy=True, type="model")
|
||||
|
||||
live.end()
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||
'on_train_start': on_train_start,
|
||||
'on_train_epoch_start': on_train_epoch_start,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if dvclive else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_train_start": on_train_start,
|
||||
"on_train_epoch_start": on_train_epoch_start,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if dvclive
|
||||
else {}
|
||||
)
|
||||
|
@ -9,51 +9,67 @@ from ultralytics.utils import LOGGER, SETTINGS
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs info before starting timer for upload rate limit."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Start timer for upload rate limit
|
||||
LOGGER.info(f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
|
||||
session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
|
||||
session.timers = {
|
||||
"metrics": time(),
|
||||
"ckpt": time(),
|
||||
} # start timer on session.rate_limit
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Uploads training progress metrics at the end of each epoch."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Upload metrics after val end
|
||||
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
|
||||
all_plots = {
|
||||
**trainer.label_loss_items(trainer.tloss, prefix="train"),
|
||||
**trainer.metrics,
|
||||
}
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
|
||||
|
||||
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
|
||||
|
||||
# If any metrics fail to upload, add them to the queue to attempt uploading again.
|
||||
if session.metrics_upload_failed_queue:
|
||||
session.metrics_queue.update(session.metrics_upload_failed_queue)
|
||||
|
||||
if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
|
||||
session.upload_metrics()
|
||||
session.timers['metrics'] = time() # reset timer
|
||||
session.timers["metrics"] = time() # reset timer
|
||||
session.metrics_queue = {} # reset queue
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Upload checkpoints with rate limiting
|
||||
is_best = trainer.best_fitness == trainer.fitness
|
||||
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
|
||||
LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}')
|
||||
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
|
||||
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}")
|
||||
session.upload_model(trainer.epoch, trainer.last, is_best)
|
||||
session.timers['ckpt'] = time() # reset timer
|
||||
session.timers["ckpt"] = time() # reset timer
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
session = getattr(trainer, "hub_session", None)
|
||||
if session:
|
||||
# Upload final model and metrics with exponential standoff
|
||||
LOGGER.info(f'{PREFIX}Syncing final model...')
|
||||
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
|
||||
LOGGER.info(f"{PREFIX}Syncing final model...")
|
||||
session.upload_model(
|
||||
trainer.epoch,
|
||||
trainer.best,
|
||||
map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
|
||||
final=True,
|
||||
)
|
||||
session.alive = False # stop heartbeats
|
||||
LOGGER.info(f'{PREFIX}Done ✅\n'
|
||||
f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
|
||||
LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
@ -76,12 +92,17 @@ def on_export_start(exporter):
|
||||
events(exporter.args)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_model_save': on_model_save,
|
||||
'on_train_end': on_train_end,
|
||||
'on_train_start': on_train_start,
|
||||
'on_val_start': on_val_start,
|
||||
'on_predict_start': on_predict_start,
|
||||
'on_export_start': on_export_start} if SETTINGS['hub'] is True else {} # verify enabled
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_model_save": on_model_save,
|
||||
"on_train_end": on_train_end,
|
||||
"on_train_start": on_train_start,
|
||||
"on_val_start": on_val_start,
|
||||
"on_predict_start": on_predict_start,
|
||||
"on_export_start": on_export_start,
|
||||
}
|
||||
if SETTINGS["hub"] is True
|
||||
else {}
|
||||
) # verify enabled
|
||||
|
@ -1,70 +1,133 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
MLflow Logging for Ultralytics YOLO.
|
||||
|
||||
from ultralytics.utils import LOGGER, ROOT, SETTINGS, TESTS_RUNNING, colorstr
|
||||
This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts.
|
||||
For setting up, a tracking URI should be specified. The logging can be customized using environment variables.
|
||||
|
||||
Commands:
|
||||
1. To set a project name:
|
||||
`export MLFLOW_EXPERIMENT_NAME=<your_experiment_name>` or use the project=<project> argument
|
||||
|
||||
2. To set a run name:
|
||||
`export MLFLOW_RUN=<your_run_name>` or use the name=<name> argument
|
||||
|
||||
3. To start a local MLflow server:
|
||||
mlflow server --backend-store-uri runs/mlflow
|
||||
It will by default start a local server at http://127.0.0.1:5000.
|
||||
To specify a different URI, set the MLFLOW_TRACKING_URI environment variable.
|
||||
|
||||
4. To kill all running MLflow server instances:
|
||||
ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9
|
||||
"""
|
||||
|
||||
from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['mlflow'] is True # verify integration is enabled
|
||||
import os
|
||||
|
||||
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
|
||||
assert SETTINGS["mlflow"] is True # verify integration is enabled
|
||||
import mlflow
|
||||
|
||||
assert hasattr(mlflow, '__version__') # verify package is not directory
|
||||
assert hasattr(mlflow, "__version__") # verify package is not directory
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import re
|
||||
PREFIX = colorstr("MLflow: ")
|
||||
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
mlflow = None
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs training parameters to MLflow."""
|
||||
global mlflow, run, experiment_name
|
||||
"""
|
||||
Log training parameters to MLflow at the end of the pretraining routine.
|
||||
|
||||
if os.environ.get('MLFLOW_TRACKING_URI') is None:
|
||||
mlflow = None
|
||||
This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
|
||||
experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
|
||||
from the trainer.
|
||||
|
||||
Args:
|
||||
trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
|
||||
|
||||
Global:
|
||||
mlflow: The imported mlflow module to use for logging.
|
||||
|
||||
Environment Variables:
|
||||
MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
|
||||
MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
|
||||
MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
|
||||
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of the training phase.
|
||||
"""
|
||||
global mlflow
|
||||
|
||||
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
|
||||
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
|
||||
mlflow.set_tracking_uri(uri)
|
||||
|
||||
# Set experiment and run names
|
||||
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
|
||||
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
mlflow.autolog()
|
||||
try:
|
||||
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
|
||||
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
|
||||
if Path(uri).is_dir():
|
||||
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
|
||||
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
|
||||
mlflow.log_params(dict(trainer.args))
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log training metrics at the end of each train epoch to MLflow."""
|
||||
if mlflow:
|
||||
mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
|
||||
mlflow.set_tracking_uri(mlflow_location)
|
||||
|
||||
experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
|
||||
run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
|
||||
experiment = mlflow.get_experiment_by_name(experiment_name)
|
||||
if experiment is None:
|
||||
mlflow.create_experiment(experiment_name)
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
prefix = colorstr('MLFlow: ')
|
||||
try:
|
||||
run, active_run = mlflow, mlflow.active_run()
|
||||
if not active_run:
|
||||
active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name)
|
||||
LOGGER.info(f'{prefix}Using run_id({active_run.info.run_id}) at {mlflow_location}')
|
||||
run.log_params(vars(trainer.model.args))
|
||||
except Exception as err:
|
||||
LOGGER.error(f'{prefix}Failing init - {repr(err)}')
|
||||
LOGGER.warning(f'{prefix}Continuing without Mlflow')
|
||||
mlflow.log_metrics(
|
||||
metrics={
|
||||
**SANITIZE(trainer.lr),
|
||||
**SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")),
|
||||
},
|
||||
step=trainer.epoch,
|
||||
)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs training metrics to Mlflow."""
|
||||
"""Log training metrics at the end of each fit epoch to MLflow."""
|
||||
if mlflow:
|
||||
metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
|
||||
run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
|
||||
mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Called at end of train loop to log model artifact info."""
|
||||
"""Log model artifacts at the end of the training."""
|
||||
if mlflow:
|
||||
run.log_artifact(trainer.last)
|
||||
run.log_artifact(trainer.best)
|
||||
run.pyfunc.log_model(artifact_path=experiment_name,
|
||||
code_path=[str(ROOT.parent)],
|
||||
artifacts={'model_path': str(trainer.save_dir)},
|
||||
python_model=run.pyfunc.PythonModel())
|
||||
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
|
||||
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
||||
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
||||
mlflow.log_artifact(str(f))
|
||||
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true")
|
||||
if keep_run_active:
|
||||
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
|
||||
else:
|
||||
mlflow.end_run()
|
||||
LOGGER.debug(f"{PREFIX}mlflow run ended")
|
||||
|
||||
LOGGER.info(
|
||||
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
|
||||
f"{PREFIX}disable with 'yolo settings mlflow=False'"
|
||||
)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if mlflow else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if mlflow
|
||||
else {}
|
||||
)
|
||||
|
@ -4,11 +4,11 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['neptune'] is True # verify integration is enabled
|
||||
assert SETTINGS["neptune"] is True # verify integration is enabled
|
||||
import neptune
|
||||
from neptune.types import File
|
||||
|
||||
assert hasattr(neptune, '__version__')
|
||||
assert hasattr(neptune, "__version__")
|
||||
|
||||
run = None # NeptuneAI experiment logger instance
|
||||
|
||||
@ -23,55 +23,55 @@ def _log_scalars(scalars, step=0):
|
||||
run[k].append(value=v, step=step)
|
||||
|
||||
|
||||
def _log_images(imgs_dict, group=''):
|
||||
def _log_images(imgs_dict, group=""):
|
||||
"""Log scalars to the NeptuneAI experiment logger."""
|
||||
if run:
|
||||
for k, v in imgs_dict.items():
|
||||
run[f'{group}/{k}'].upload(File(v))
|
||||
run[f"{group}/{k}"].upload(File(v))
|
||||
|
||||
|
||||
def _log_plot(title, plot_path):
|
||||
"""Log plots to the NeptuneAI experiment logger."""
|
||||
"""
|
||||
Log image as plot in the plot section of NeptuneAI
|
||||
Log plots to the NeptuneAI experiment logger.
|
||||
|
||||
arguments:
|
||||
title (str) Title of the plot
|
||||
plot_path (PosixPath or str) Path to the saved image file
|
||||
"""
|
||||
Args:
|
||||
title (str): Title of the plot.
|
||||
plot_path (PosixPath | str): Path to the saved image file.
|
||||
"""
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
run[f'Plots/{title}'].upload(fig)
|
||||
run[f"Plots/{title}"].upload(fig)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Callback function called before the training routine starts."""
|
||||
try:
|
||||
global run
|
||||
run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
|
||||
run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
|
||||
run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
|
||||
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Callback function called at end of each training epoch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
||||
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
|
||||
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Callback function called at end of each fit (train+val) epoch."""
|
||||
if run and trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
run['Configuration/Model'] = model_info_for_loggers(trainer)
|
||||
|
||||
run["Configuration/Model"] = model_info_for_loggers(trainer)
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ def on_val_end(validator):
|
||||
"""Callback function called at end of each validation."""
|
||||
if run:
|
||||
# Log val_labels and val_pred
|
||||
_log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
|
||||
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
@ -87,19 +87,26 @@ def on_train_end(trainer):
|
||||
if run:
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = [
|
||||
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
|
||||
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
"results.png",
|
||||
"confusion_matrix.png",
|
||||
"confusion_matrix_normalized.png",
|
||||
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
|
||||
]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
# Log the final model
|
||||
run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
|
||||
trainer.best)))
|
||||
run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_val_end': on_val_end,
|
||||
'on_train_end': on_train_end} if neptune else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_val_end": on_val_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if neptune
|
||||
else {}
|
||||
)
|
||||
|
@ -3,7 +3,7 @@
|
||||
from ultralytics.utils import SETTINGS
|
||||
|
||||
try:
|
||||
assert SETTINGS['raytune'] is True # verify integration is enabled
|
||||
assert SETTINGS["raytune"] is True # verify integration is enabled
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air import session
|
||||
@ -16,9 +16,14 @@ def on_fit_epoch_end(trainer):
|
||||
"""Sends training metrics to Ray Tune at end of each epoch."""
|
||||
if ray.tune.is_session_enabled():
|
||||
metrics = trainer.metrics
|
||||
metrics['epoch'] = trainer.epoch
|
||||
metrics["epoch"] = trainer.epoch
|
||||
session.report(metrics)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
}
|
||||
if tune
|
||||
else {}
|
||||
)
|
||||
|
@ -1,17 +1,25 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import contextlib
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
|
||||
|
||||
try:
|
||||
# WARNING: do not move import due to protobuf issue in https://github.com/ultralytics/ultralytics/pull/4674
|
||||
# WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['tensorboard'] is True # verify integration is enabled
|
||||
assert SETTINGS["tensorboard"] is True # verify integration is enabled
|
||||
WRITER = None # TensorBoard SummaryWriter instance
|
||||
PREFIX = colorstr("TensorBoard: ")
|
||||
|
||||
except (ImportError, AssertionError, TypeError):
|
||||
# Imports below only required if TensorBoard enabled
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from ultralytics.utils.torch_utils import de_parallel, torch
|
||||
|
||||
except (ImportError, AssertionError, TypeError, AttributeError):
|
||||
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
|
||||
# AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
|
||||
SummaryWriter = None
|
||||
|
||||
|
||||
@ -24,20 +32,38 @@ def _log_scalars(scalars, step=0):
|
||||
|
||||
def _log_tensorboard_graph(trainer):
|
||||
"""Log model graph to TensorBoard."""
|
||||
try:
|
||||
import warnings
|
||||
|
||||
from ultralytics.utils.torch_utils import de_parallel, torch
|
||||
# Input image
|
||||
imgsz = trainer.args.imgsz
|
||||
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
||||
p = next(trainer.model.parameters()) # for device, type
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
|
||||
|
||||
imgsz = trainer.args.imgsz
|
||||
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
||||
p = next(trainer.model.parameters()) # for device, type
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
|
||||
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
|
||||
|
||||
# Try simple method first (YOLO)
|
||||
with contextlib.suppress(Exception):
|
||||
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
|
||||
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
return
|
||||
|
||||
# Fallback to TorchScript export steps (RTDETR)
|
||||
try:
|
||||
model = deepcopy(de_parallel(trainer.model))
|
||||
model.eval()
|
||||
model = model.fuse(verbose=False)
|
||||
for m in model.modules():
|
||||
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
|
||||
m.export = True
|
||||
m.format = "torchscript"
|
||||
model(im) # dry run
|
||||
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
@ -46,10 +72,9 @@ def on_pretrain_routine_start(trainer):
|
||||
try:
|
||||
global WRITER
|
||||
WRITER = SummaryWriter(str(trainer.save_dir))
|
||||
prefix = colorstr('TensorBoard: ')
|
||||
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
||||
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
@ -58,9 +83,10 @@ def on_train_start(trainer):
|
||||
_log_tensorboard_graph(trainer)
|
||||
|
||||
|
||||
def on_batch_end(trainer):
|
||||
"""Logs scalar statistics at the end of a training batch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Logs scalar statistics at the end of a training epoch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
||||
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
@ -68,8 +94,13 @@ def on_fit_epoch_end(trainer):
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_start': on_train_start,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_batch_end': on_batch_end} if SummaryWriter else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_start": on_train_start,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
}
|
||||
if SummaryWriter
|
||||
else {}
|
||||
)
|
||||
|
@ -5,10 +5,13 @@ from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['wandb'] is True # verify integration is enabled
|
||||
assert SETTINGS["wandb"] is True # verify integration is enabled
|
||||
import wandb as wb
|
||||
|
||||
assert hasattr(wb, '__version__')
|
||||
assert hasattr(wb, "__version__") # verify package is not directory
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
_processed_plots = {}
|
||||
|
||||
@ -16,9 +19,89 @@ except (ImportError, AssertionError):
|
||||
wb = None
|
||||
|
||||
|
||||
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
|
||||
"""
|
||||
Create and log a custom metric visualization to wandb.plot.pr_curve.
|
||||
|
||||
This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall
|
||||
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
|
||||
different classes.
|
||||
|
||||
Args:
|
||||
x (List): Values for the x-axis; expected to have length N.
|
||||
y (List): Corresponding values for the y-axis; also expected to have length N.
|
||||
classes (List): Labels identifying the class of each point; length N.
|
||||
title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
|
||||
x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
|
||||
y_title (str, optional): Label for the y-axis; defaults to 'Precision'.
|
||||
|
||||
Returns:
|
||||
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
||||
"""
|
||||
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
||||
fields = {"x": "x", "y": "y", "class": "class"}
|
||||
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
||||
return wb.plot_table(
|
||||
"wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
|
||||
)
|
||||
|
||||
|
||||
def _plot_curve(
|
||||
x,
|
||||
y,
|
||||
names=None,
|
||||
id="precision-recall",
|
||||
title="Precision Recall Curve",
|
||||
x_title="Recall",
|
||||
y_title="Precision",
|
||||
num_x=100,
|
||||
only_mean=False,
|
||||
):
|
||||
"""
|
||||
Log a metric curve visualization.
|
||||
|
||||
This function generates a metric curve based on input data and logs the visualization to wandb.
|
||||
The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Data points for the x-axis with length N.
|
||||
y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
|
||||
names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
|
||||
id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
|
||||
title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
|
||||
x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
|
||||
y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
|
||||
num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
|
||||
only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.
|
||||
|
||||
Note:
|
||||
The function leverages the '_custom_table' function to generate the actual visualization.
|
||||
"""
|
||||
# Create new x
|
||||
if names is None:
|
||||
names = []
|
||||
x_new = np.linspace(x[0], x[-1], num_x).round(5)
|
||||
|
||||
# Create arrays for logging
|
||||
x_log = x_new.tolist()
|
||||
y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
|
||||
|
||||
if only_mean:
|
||||
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
|
||||
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
|
||||
else:
|
||||
classes = ["mean"] * len(x_log)
|
||||
for i, yi in enumerate(y):
|
||||
x_log.extend(x_new) # add new x
|
||||
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
|
||||
classes.extend([names[i]] * len(x_new)) # add class names
|
||||
wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
|
||||
|
||||
|
||||
def _log_plots(plots, step):
|
||||
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
|
||||
for name, params in plots.items():
|
||||
timestamp = params['timestamp']
|
||||
timestamp = params["timestamp"]
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
|
||||
_processed_plots[name] = timestamp
|
||||
@ -26,7 +109,7 @@ def _log_plots(plots, step):
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initiate and start project if module is present."""
|
||||
wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args))
|
||||
wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
@ -40,7 +123,7 @@ def on_fit_epoch_end(trainer):
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log metrics and save images at the end of each training epoch."""
|
||||
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
|
||||
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
|
||||
wb.run.log(trainer.lr, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
@ -50,14 +133,31 @@ def on_train_end(trainer):
|
||||
"""Save the best model as an artifact at end of training."""
|
||||
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
|
||||
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
|
||||
if trainer.best.exists():
|
||||
art.add_file(trainer.best)
|
||||
wb.run.log_artifact(art, aliases=['best'])
|
||||
wb.run.log_artifact(art, aliases=["best"])
|
||||
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
|
||||
x, y, x_title, y_title = curve_values
|
||||
_plot_curve(
|
||||
x,
|
||||
y,
|
||||
names=list(trainer.validator.metrics.names.values()),
|
||||
id=f"curves/{curve_name}",
|
||||
title=curve_name,
|
||||
x_title=x_title,
|
||||
y_title=y_title,
|
||||
)
|
||||
wb.run.finish() # required or run continues on dashboard
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if wb else {}
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if wb
|
||||
else {}
|
||||
)
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import glob
|
||||
import inspect
|
||||
@ -9,20 +10,96 @@ import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from importlib import metadata
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pkg_resources as pkg
|
||||
import psutil
|
||||
import requests
|
||||
import torch
|
||||
from matplotlib import font_manager
|
||||
|
||||
from ultralytics.utils import (ASSETS, AUTOINSTALL, LINUX, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, ThreadingLocked,
|
||||
TryExcept, clean_url, colorstr, downloads, emojis, is_colab, is_docker, is_jupyter,
|
||||
is_kaggle, is_online, is_pip_package, url2file)
|
||||
from ultralytics.utils import (
|
||||
ASSETS,
|
||||
AUTOINSTALL,
|
||||
LINUX,
|
||||
LOGGER,
|
||||
ONLINE,
|
||||
ROOT,
|
||||
USER_CONFIG_DIR,
|
||||
SimpleNamespace,
|
||||
ThreadingLocked,
|
||||
TryExcept,
|
||||
clean_url,
|
||||
colorstr,
|
||||
downloads,
|
||||
emojis,
|
||||
is_colab,
|
||||
is_docker,
|
||||
is_github_action_running,
|
||||
is_jupyter,
|
||||
is_kaggle,
|
||||
is_online,
|
||||
is_pip_package,
|
||||
url2file,
|
||||
)
|
||||
|
||||
PYTHON_VERSION = platform.python_version()
|
||||
|
||||
|
||||
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
||||
"""
|
||||
Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the requirements.txt file.
|
||||
package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'.
|
||||
|
||||
Returns:
|
||||
(List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.checks import parse_requirements
|
||||
|
||||
parse_requirements(package='ultralytics')
|
||||
```
|
||||
"""
|
||||
|
||||
if package:
|
||||
requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
|
||||
else:
|
||||
requires = Path(file_path).read_text().splitlines()
|
||||
|
||||
requirements = []
|
||||
for line in requires:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
line = line.split("#")[0].strip() # ignore inline comments
|
||||
match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line)
|
||||
if match:
|
||||
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
|
||||
|
||||
return requirements
|
||||
|
||||
|
||||
def parse_version(version="0.0.0") -> tuple:
|
||||
"""
|
||||
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
|
||||
function replaces deprecated 'pkg_resources.parse_version(v)'.
|
||||
|
||||
Args:
|
||||
version (str): Version string, i.e. '2.0.1+cpu'
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
|
||||
"""
|
||||
try:
|
||||
return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
|
||||
return 0, 0, 0
|
||||
|
||||
|
||||
def is_ascii(s) -> bool:
|
||||
@ -33,7 +110,7 @@ def is_ascii(s) -> bool:
|
||||
s (str): String to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the string is composed only of ASCII characters, False otherwise.
|
||||
(bool): True if the string is composed only of ASCII characters, False otherwise.
|
||||
"""
|
||||
# Convert list, tuple, None, etc. to string
|
||||
s = str(s)
|
||||
@ -65,16 +142,22 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
||||
imgsz = [imgsz]
|
||||
elif isinstance(imgsz, (list, tuple)):
|
||||
imgsz = list(imgsz)
|
||||
elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'
|
||||
imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
|
||||
else:
|
||||
raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
|
||||
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")
|
||||
raise TypeError(
|
||||
f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
|
||||
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
|
||||
)
|
||||
|
||||
# Apply max_dim
|
||||
if len(imgsz) > max_dim:
|
||||
msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
|
||||
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
|
||||
msg = (
|
||||
"'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
|
||||
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
|
||||
)
|
||||
if max_dim != 1:
|
||||
raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
|
||||
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
|
||||
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
|
||||
imgsz = [max(imgsz)]
|
||||
# Make image size a multiple of the stride
|
||||
@ -82,7 +165,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
||||
|
||||
# Print warning message if image size was updated
|
||||
if sz != imgsz:
|
||||
LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
|
||||
LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
|
||||
|
||||
# Add missing dimensions if necessary
|
||||
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
|
||||
@ -90,66 +173,88 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
||||
return sz
|
||||
|
||||
|
||||
def check_version(current: str = '0.0.0',
|
||||
required: str = '0.0.0',
|
||||
name: str = 'version ',
|
||||
hard: bool = False,
|
||||
verbose: bool = False) -> bool:
|
||||
def check_version(
|
||||
current: str = "0.0.0",
|
||||
required: str = "0.0.0",
|
||||
name: str = "version",
|
||||
hard: bool = False,
|
||||
verbose: bool = False,
|
||||
msg: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
Check current version against the required version or range.
|
||||
|
||||
Args:
|
||||
current (str): Current version.
|
||||
current (str): Current version or package name to get version from.
|
||||
required (str): Required version or range (in pip-style format).
|
||||
name (str): Name to be used in warning message.
|
||||
hard (bool): If True, raise an AssertionError if the requirement is not met.
|
||||
verbose (bool): If True, print warning message if requirement is not met.
|
||||
name (str, optional): Name to be used in warning message.
|
||||
hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
|
||||
verbose (bool, optional): If True, print warning message if requirement is not met.
|
||||
msg (str, optional): Extra message to display if verbose.
|
||||
|
||||
Returns:
|
||||
(bool): True if requirement is met, False otherwise.
|
||||
|
||||
Example:
|
||||
# check if current version is exactly 22.04
|
||||
```python
|
||||
# Check if current version is exactly 22.04
|
||||
check_version(current='22.04', required='==22.04')
|
||||
|
||||
# check if current version is greater than or equal to 22.04
|
||||
# Check if current version is greater than or equal to 22.04
|
||||
check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed
|
||||
|
||||
# check if current version is less than or equal to 22.04
|
||||
# Check if current version is less than or equal to 22.04
|
||||
check_version(current='22.04', required='<=22.04')
|
||||
|
||||
# check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
|
||||
# Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
|
||||
check_version(current='21.10', required='>20.04,<22.04')
|
||||
```
|
||||
"""
|
||||
current = pkg.parse_version(current)
|
||||
constraints = re.findall(r'([<>!=]{1,2}\s*\d+\.\d+)', required) or [f'>={required}']
|
||||
if not current: # if current is '' or None
|
||||
LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
|
||||
return True
|
||||
elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
|
||||
try:
|
||||
name = current # assigned package name to 'name' arg
|
||||
current = metadata.version(current) # get version string from package name
|
||||
except metadata.PackageNotFoundError as e:
|
||||
if hard:
|
||||
raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e
|
||||
else:
|
||||
return False
|
||||
|
||||
if not required: # if required is '' or None
|
||||
return True
|
||||
|
||||
op = ""
|
||||
version = ""
|
||||
result = True
|
||||
for constraint in constraints:
|
||||
op, version = re.match(r'([<>!=]{1,2})\s*(\d+\.\d+)', constraint).groups()
|
||||
version = pkg.parse_version(version)
|
||||
if op == '==' and current != version:
|
||||
c = parse_version(current) # '1.2.3' -> (1, 2, 3)
|
||||
for r in required.strip(",").split(","):
|
||||
op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
|
||||
v = parse_version(version) # '1.2.3' -> (1, 2, 3)
|
||||
if op == "==" and c != v:
|
||||
result = False
|
||||
elif op == '!=' and current == version:
|
||||
elif op == "!=" and c == v:
|
||||
result = False
|
||||
elif op == '>=' and not (current >= version):
|
||||
elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
|
||||
result = False
|
||||
elif op == '<=' and not (current <= version):
|
||||
elif op == "<=" and not (c <= v):
|
||||
result = False
|
||||
elif op == '>' and not (current > version):
|
||||
elif op == ">" and not (c > v):
|
||||
result = False
|
||||
elif op == '<' and not (current < version):
|
||||
elif op == "<" and not (c < v):
|
||||
result = False
|
||||
if not result:
|
||||
warning_message = f'WARNING ⚠️ {name}{required} is required, but {name}{current} is currently installed'
|
||||
warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
|
||||
if hard:
|
||||
raise ModuleNotFoundError(emojis(warning_message)) # assert version requirements met
|
||||
raise ModuleNotFoundError(emojis(warning)) # assert version requirements met
|
||||
if verbose:
|
||||
LOGGER.warning(warning_message)
|
||||
LOGGER.warning(warning)
|
||||
return result
|
||||
|
||||
|
||||
def check_latest_pypi_version(package_name='ultralytics'):
|
||||
def check_latest_pypi_version(package_name="ultralytics"):
|
||||
"""
|
||||
Returns the latest version of a PyPI package without downloading or installing it.
|
||||
|
||||
@ -161,9 +266,9 @@ def check_latest_pypi_version(package_name='ultralytics'):
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
|
||||
response = requests.get(f'https://pypi.org/pypi/{package_name}/json', timeout=3)
|
||||
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
|
||||
if response.status_code == 200:
|
||||
return response.json()['info']['version']
|
||||
return response.json()["info"]["version"]
|
||||
|
||||
|
||||
def check_pip_update_available():
|
||||
@ -176,16 +281,19 @@ def check_pip_update_available():
|
||||
if ONLINE and is_pip_package():
|
||||
with contextlib.suppress(Exception):
|
||||
from ultralytics import __version__
|
||||
|
||||
latest = check_latest_pypi_version()
|
||||
if pkg.parse_version(__version__) < pkg.parse_version(latest): # update is available
|
||||
LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 '
|
||||
f"Update with 'pip install -U ultralytics'")
|
||||
if check_version(__version__, f"<{latest}"): # check if current version is < latest version
|
||||
LOGGER.info(
|
||||
f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
|
||||
f"Update with 'pip install -U ultralytics'"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ThreadingLocked()
|
||||
def check_font(font='Arial.ttf'):
|
||||
def check_font(font="Arial.ttf"):
|
||||
"""
|
||||
Find font locally or download to user's configuration directory if it does not already exist.
|
||||
|
||||
@ -208,13 +316,13 @@ def check_font(font='Arial.ttf'):
|
||||
return matches[0]
|
||||
|
||||
# Download to USER_CONFIG_DIR if missing
|
||||
url = f'https://ultralytics.com/assets/{name}'
|
||||
if downloads.is_url(url):
|
||||
url = f"https://ultralytics.com/assets/{name}"
|
||||
if downloads.is_url(url, check=True):
|
||||
downloads.safe_download(url=url, file=file)
|
||||
return file
|
||||
|
||||
|
||||
def check_python(minimum: str = '3.8.0') -> bool:
|
||||
def check_python(minimum: str = "3.8.0") -> bool:
|
||||
"""
|
||||
Check current python version against the required minimum version.
|
||||
|
||||
@ -222,13 +330,13 @@ def check_python(minimum: str = '3.8.0') -> bool:
|
||||
minimum (str): Required minimum version of python.
|
||||
|
||||
Returns:
|
||||
None
|
||||
(bool): Whether the installed Python version meets the minimum constraints.
|
||||
"""
|
||||
return check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
||||
return check_version(PYTHON_VERSION, minimum, name="Python ", hard=True)
|
||||
|
||||
|
||||
@TryExcept()
|
||||
def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''):
|
||||
def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
|
||||
"""
|
||||
Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
|
||||
|
||||
@ -253,46 +361,43 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
check_requirements(['numpy', 'ultralytics>=8.0.0'])
|
||||
```
|
||||
"""
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
|
||||
prefix = colorstr("red", "bold", "requirements:")
|
||||
check_python() # check python version
|
||||
check_torchvision() # check torch-torchvision compatibility
|
||||
if isinstance(requirements, Path): # requirements.txt file
|
||||
file = requirements.resolve()
|
||||
assert file.exists(), f'{prefix} {file} not found, check failed.'
|
||||
with file.open() as f:
|
||||
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
|
||||
assert file.exists(), f"{prefix} {file} not found, check failed."
|
||||
requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
|
||||
elif isinstance(requirements, str):
|
||||
requirements = [requirements]
|
||||
|
||||
pkgs = []
|
||||
for r in requirements:
|
||||
r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
|
||||
r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
|
||||
match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
|
||||
name, required = match[1], match[2].strip() if match[2] else ""
|
||||
try:
|
||||
pkg.require(r_stripped) # exception if requirements not met
|
||||
except pkg.DistributionNotFound:
|
||||
try: # attempt to import (slower but more accurate)
|
||||
import importlib
|
||||
importlib.import_module(next(pkg.parse_requirements(r_stripped)).name)
|
||||
except ImportError:
|
||||
pkgs.append(r)
|
||||
except pkg.VersionConflict:
|
||||
assert check_version(metadata.version(name), required) # exception if requirements not met
|
||||
except (AssertionError, metadata.PackageNotFoundError):
|
||||
pkgs.append(r)
|
||||
|
||||
s = ' '.join(f'"{x}"' for x in pkgs) # console string
|
||||
s = " ".join(f'"{x}"' for x in pkgs) # console string
|
||||
if s:
|
||||
if install and AUTOINSTALL: # check environment variable
|
||||
n = len(pkgs) # number of packages updates
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
||||
try:
|
||||
t = time.time()
|
||||
assert is_online(), 'AutoUpdate skipped (offline)'
|
||||
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
|
||||
assert is_online(), "AutoUpdate skipped (offline)"
|
||||
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
|
||||
dt = time.time() - t
|
||||
LOGGER.info(
|
||||
f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix} ❌ {e}')
|
||||
LOGGER.warning(f"{prefix} ❌ {e}")
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
@ -305,134 +410,211 @@ def check_torchvision():
|
||||
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
||||
|
||||
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
|
||||
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
|
||||
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
||||
to the provided compatibility table based on:
|
||||
https://github.com/pytorch/vision#installation.
|
||||
|
||||
The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
||||
Torchvision versions.
|
||||
"""
|
||||
|
||||
import torchvision
|
||||
|
||||
# Compatibility table
|
||||
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
|
||||
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
|
||||
|
||||
# Extract only the major and minor versions
|
||||
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
|
||||
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
|
||||
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
|
||||
v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
|
||||
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
if all(pkg.parse_version(v_torchvision) != pkg.parse_version(v) for v in compatible_versions):
|
||||
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
|
||||
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
|
||||
"'pip install -U torch torchvision' to update both.\n"
|
||||
'For a full compatibility table see https://github.com/pytorch/vision#installation')
|
||||
if all(v_torchvision != v for v in compatible_versions):
|
||||
print(
|
||||
f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
|
||||
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
|
||||
"'pip install -U torch torchvision' to update both.\n"
|
||||
"For a full compatibility table see https://github.com/pytorch/vision#installation"
|
||||
)
|
||||
|
||||
|
||||
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
|
||||
def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):
|
||||
"""Check file(s) for acceptable suffix."""
|
||||
if file and suffix:
|
||||
if isinstance(suffix, str):
|
||||
suffix = (suffix, )
|
||||
suffix = (suffix,)
|
||||
for f in file if isinstance(file, (list, tuple)) else [file]:
|
||||
s = Path(f).suffix.lower().strip() # file suffix
|
||||
if len(s):
|
||||
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}'
|
||||
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"
|
||||
|
||||
|
||||
def check_yolov5u_filename(file: str, verbose: bool = True):
|
||||
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
|
||||
if 'yolov3' in file or 'yolov5' in file:
|
||||
if 'u.yaml' in file:
|
||||
file = file.replace('u.yaml', '.yaml') # i.e. yolov5nu.yaml -> yolov5n.yaml
|
||||
elif '.pt' in file and 'u' not in file:
|
||||
if "yolov3" in file or "yolov5" in file:
|
||||
if "u.yaml" in file:
|
||||
file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
|
||||
elif ".pt" in file and "u" not in file:
|
||||
original_file = file
|
||||
file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt
|
||||
file = re.sub(r'(.*yolov5([nsmlx])6)\.pt', '\\1u.pt', file) # i.e. yolov5n6.pt -> yolov5n6u.pt
|
||||
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.pt', '\\1u.pt', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
|
||||
file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
|
||||
file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
|
||||
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
|
||||
if file != original_file and verbose:
|
||||
LOGGER.info(
|
||||
f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
|
||||
f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
|
||||
f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
|
||||
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
|
||||
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
def check_file(file, suffix='', download=True, hard=True):
|
||||
def check_model_file_from_stem(model="yolov8n"):
|
||||
"""Return a model filename from a valid model stem."""
|
||||
if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
|
||||
return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def check_file(file, suffix="", download=True, hard=True):
|
||||
"""Search/download file (if necessary) and return path."""
|
||||
check_suffix(file, suffix) # optional
|
||||
file = str(file).strip() # convert to string and strip spaces
|
||||
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
|
||||
if not file or ('://' not in file and Path(file).exists()): # exists ('://' check required in Windows Python<3.10)
|
||||
if (
|
||||
not file
|
||||
or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
|
||||
or file.lower().startswith("grpc://")
|
||||
): # file exists or gRPC Triton images
|
||||
return file
|
||||
elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
|
||||
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
|
||||
url = file # warning: Pathlib turns :// -> :/
|
||||
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
if Path(file).exists():
|
||||
LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
|
||||
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
|
||||
else:
|
||||
downloads.safe_download(url=url, file=file, unzip=False)
|
||||
return file
|
||||
else: # search
|
||||
files = glob.glob(str(ROOT / 'cfg' / '**' / file), recursive=True) # find file
|
||||
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
|
||||
if not files and hard:
|
||||
raise FileNotFoundError(f"'{file}' does not exist")
|
||||
elif len(files) > 1 and hard:
|
||||
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
||||
return files[0] if len(files) else [] # return file
|
||||
return files[0] if len(files) else [] if hard else file # return file
|
||||
|
||||
|
||||
def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
|
||||
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
|
||||
"""Search/download YAML file (if necessary) and return path, checking suffix."""
|
||||
return check_file(file, suffix, hard=hard)
|
||||
|
||||
|
||||
def check_is_path_safe(basedir, path):
|
||||
"""
|
||||
Check if the resolved path is under the intended directory to prevent path traversal.
|
||||
|
||||
Args:
|
||||
basedir (Path | str): The intended directory.
|
||||
path (Path | str): The path to check.
|
||||
|
||||
Returns:
|
||||
(bool): True if the path is safe, False otherwise.
|
||||
"""
|
||||
base_dir_resolved = Path(basedir).resolve()
|
||||
path_resolved = Path(path).resolve()
|
||||
|
||||
return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
|
||||
|
||||
|
||||
def check_imshow(warn=False):
|
||||
"""Check if environment supports image displays."""
|
||||
try:
|
||||
if LINUX:
|
||||
assert 'DISPLAY' in os.environ and not is_docker() and not is_colab() and not is_kaggle()
|
||||
cv2.imshow('test', np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
|
||||
assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle()
|
||||
cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
|
||||
cv2.waitKey(1)
|
||||
cv2.destroyAllWindows()
|
||||
cv2.waitKey(1)
|
||||
return True
|
||||
except Exception as e:
|
||||
if warn:
|
||||
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
|
||||
LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_yolo(verbose=True, device=''):
|
||||
def check_yolo(verbose=True, device=""):
|
||||
"""Return a human-readable YOLO software and hardware summary."""
|
||||
import psutil
|
||||
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
if is_jupyter():
|
||||
if check_requirements('wandb', install=False):
|
||||
os.system('pip uninstall -y wandb') # uninstall wandb: unwanted account creation prompt with infinite hang
|
||||
if check_requirements("wandb", install=False):
|
||||
os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang
|
||||
if is_colab():
|
||||
shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory
|
||||
shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
|
||||
|
||||
if verbose:
|
||||
# System info
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
ram = psutil.virtual_memory().total
|
||||
total, used, free = shutil.disk_usage('/')
|
||||
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
|
||||
total, used, free = shutil.disk_usage("/")
|
||||
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
|
||||
with contextlib.suppress(Exception): # clear display if ipython is installed
|
||||
from IPython import display
|
||||
|
||||
display.clear_output()
|
||||
else:
|
||||
s = ''
|
||||
s = ""
|
||||
|
||||
select_device(device=device, newline=False)
|
||||
LOGGER.info(f'Setup complete ✅ {s}')
|
||||
LOGGER.info(f"Setup complete ✅ {s}")
|
||||
|
||||
|
||||
def collect_system_info():
|
||||
"""Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
|
||||
|
||||
import psutil
|
||||
|
||||
from ultralytics.utils import ENVIRONMENT, is_git_dir
|
||||
from ultralytics.utils.torch_utils import get_cpu_info
|
||||
|
||||
ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
|
||||
check_yolo()
|
||||
LOGGER.info(
|
||||
f"\n{'OS':<20}{platform.platform()}\n"
|
||||
f"{'Environment':<20}{ENVIRONMENT}\n"
|
||||
f"{'Python':<20}{PYTHON_VERSION}\n"
|
||||
f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
|
||||
f"{'RAM':<20}{ram_info:.2f} GB\n"
|
||||
f"{'CPU':<20}{get_cpu_info()}\n"
|
||||
f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
|
||||
)
|
||||
|
||||
for r in parse_requirements(package="ultralytics"):
|
||||
try:
|
||||
current = metadata.version(r.name)
|
||||
is_met = "✅ " if check_version(current, str(r.specifier), hard=True) else "❌ "
|
||||
except metadata.PackageNotFoundError:
|
||||
current = "(not installed)"
|
||||
is_met = "❌ "
|
||||
LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")
|
||||
|
||||
if is_github_action_running():
|
||||
LOGGER.info(
|
||||
f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
|
||||
f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
|
||||
f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
|
||||
f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
|
||||
f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
|
||||
f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
|
||||
)
|
||||
|
||||
|
||||
def check_amp(model):
|
||||
"""
|
||||
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model.
|
||||
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
|
||||
results, so AMP will be disabled during training.
|
||||
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
|
||||
fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
|
||||
be disabled during training.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A YOLOv8 model instance.
|
||||
@ -450,7 +632,7 @@ def check_amp(model):
|
||||
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||
"""
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in ('cpu', 'mps'):
|
||||
if device.type in ("cpu", "mps"):
|
||||
return False # AMP only used on CUDA devices
|
||||
|
||||
def amp_allclose(m, im):
|
||||
@ -461,23 +643,27 @@ def check_amp(model):
|
||||
del m
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||
|
||||
im = ASSETS / 'bus.jpg' # image to check
|
||||
prefix = colorstr('AMP: ')
|
||||
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
|
||||
im = ASSETS / "bus.jpg" # image to check
|
||||
prefix = colorstr("AMP: ")
|
||||
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")
|
||||
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
assert amp_allclose(YOLO('yolov8n.pt'), im)
|
||||
LOGGER.info(f'{prefix}checks passed ✅')
|
||||
|
||||
assert amp_allclose(YOLO("yolov8n.pt"), im)
|
||||
LOGGER.info(f"{prefix}checks passed ✅")
|
||||
except ConnectionError:
|
||||
LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
|
||||
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")
|
||||
except (AttributeError, ModuleNotFoundError):
|
||||
LOGGER.warning(
|
||||
f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}'
|
||||
f"{prefix}checks skipped ⚠️. "
|
||||
f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
|
||||
)
|
||||
except AssertionError:
|
||||
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
|
||||
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
|
||||
LOGGER.warning(
|
||||
f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
|
||||
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -485,8 +671,8 @@ def check_amp(model):
|
||||
def git_describe(path=ROOT): # path must be a directory
|
||||
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
|
||||
with contextlib.suppress(Exception):
|
||||
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
||||
return ''
|
||||
return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
|
||||
return ""
|
||||
|
||||
|
||||
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
@ -494,7 +680,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
|
||||
def strip_auth(v):
|
||||
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
|
||||
return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v
|
||||
return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
|
||||
|
||||
x = inspect.currentframe().f_back # previous frame
|
||||
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||
@ -502,26 +688,28 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
args, _, _, frm = inspect.getargvalues(x)
|
||||
args = {k: v for k, v in frm.items() if k in args}
|
||||
try:
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
|
||||
except ValueError:
|
||||
file = Path(file).stem
|
||||
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
|
||||
LOGGER.info(colorstr(s) + ', '.join(f'{k}={strip_auth(v)}' for k, v in args.items()))
|
||||
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
|
||||
LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))
|
||||
|
||||
|
||||
def cuda_device_count() -> int:
|
||||
"""Get the number of NVIDIA GPUs available in the environment.
|
||||
"""
|
||||
Get the number of NVIDIA GPUs available in the environment.
|
||||
|
||||
Returns:
|
||||
(int): The number of NVIDIA GPUs available.
|
||||
"""
|
||||
try:
|
||||
# Run the nvidia-smi command and capture its output
|
||||
output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
|
||||
encoding='utf-8')
|
||||
output = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
|
||||
)
|
||||
|
||||
# Take the first line and strip any leading/trailing white space
|
||||
first_line = output.strip().split('\n')[0]
|
||||
first_line = output.strip().split("\n")[0]
|
||||
|
||||
return int(first_line)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
|
||||
@ -530,9 +718,14 @@ def cuda_device_count() -> int:
|
||||
|
||||
|
||||
def cuda_is_available() -> bool:
|
||||
"""Check if CUDA is available in the environment.
|
||||
"""
|
||||
Check if CUDA is available in the environment.
|
||||
|
||||
Returns:
|
||||
(bool): True if one or more NVIDIA GPUs are available, False otherwise.
|
||||
"""
|
||||
return cuda_device_count() > 0
|
||||
|
||||
|
||||
# Define constants
|
||||
IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")
|
||||
|
@ -1,47 +1,53 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from . import USER_CONFIG_DIR
|
||||
from .torch_utils import TORCH_1_9
|
||||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
"""Finds a free port on localhost.
|
||||
"""
|
||||
Finds a free port on localhost.
|
||||
|
||||
It is useful in single-node training when we don't want to connect to a real main node but have to set the
|
||||
`MASTER_PORT` environment variable.
|
||||
"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('127.0.0.1', 0))
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1] # port
|
||||
|
||||
|
||||
def generate_ddp_file(trainer):
|
||||
"""Generates a DDP file and returns its file name."""
|
||||
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
|
||||
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
|
||||
|
||||
content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__":
|
||||
content = f"""
|
||||
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
|
||||
overrides = {vars(trainer.args)}
|
||||
|
||||
if __name__ == "__main__":
|
||||
from {module} import {name}
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT
|
||||
|
||||
cfg = DEFAULT_CFG_DICT.copy()
|
||||
cfg.update(save_dir='') # handle the extra key 'save_dir'
|
||||
trainer = {name}(cfg=cfg, overrides=overrides)
|
||||
trainer.train()'''
|
||||
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(prefix='_temp_',
|
||||
suffix=f'{id(trainer)}.py',
|
||||
mode='w+',
|
||||
encoding='utf-8',
|
||||
dir=USER_CONFIG_DIR / 'DDP',
|
||||
delete=False) as file:
|
||||
results = trainer.train()
|
||||
"""
|
||||
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
prefix="_temp_",
|
||||
suffix=f"{id(trainer)}.py",
|
||||
mode="w+",
|
||||
encoding="utf-8",
|
||||
dir=USER_CONFIG_DIR / "DDP",
|
||||
delete=False,
|
||||
) as file:
|
||||
file.write(content)
|
||||
return file.name
|
||||
|
||||
@ -49,19 +55,17 @@ def generate_ddp_file(trainer):
|
||||
def generate_ddp_command(world_size, trainer):
|
||||
"""Generates and returns command for distributed training."""
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
|
||||
if not trainer.resume:
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
file = str(Path(sys.argv[0]).resolve())
|
||||
safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters
|
||||
if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI
|
||||
file = generate_ddp_file(trainer)
|
||||
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
||||
file = generate_ddp_file(trainer)
|
||||
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
|
||||
port = find_free_network_port()
|
||||
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
|
||||
cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
|
||||
return cmd, file
|
||||
|
||||
|
||||
def ddp_cleanup(trainer, file):
|
||||
"""Delete temp file if created."""
|
||||
if f'{id(trainer)}.py' in file: # if temp_file suffix in file
|
||||
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
|
||||
os.remove(file)
|
||||
|
@ -15,20 +15,42 @@ import torch
|
||||
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
|
||||
|
||||
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
|
||||
GITHUB_ASSETS_REPO = 'ultralytics/assets'
|
||||
GITHUB_ASSETS_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '6', '-cls', '-seg', '-pose')] + \
|
||||
[f'yolov5{k}u.pt' for k in 'nsmlx'] + \
|
||||
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
|
||||
[f'yolo_nas_{k}.pt' for k in 'sml'] + \
|
||||
[f'sam_{k}.pt' for k in 'bl'] + \
|
||||
[f'FastSAM-{k}.pt' for k in 'sx'] + \
|
||||
[f'rtdetr-{k}.pt' for k in 'lx'] + \
|
||||
['mobile_sam.pt']
|
||||
GITHUB_ASSETS_REPO = "ultralytics/assets"
|
||||
GITHUB_ASSETS_NAMES = (
|
||||
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
|
||||
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
|
||||
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
|
||||
+ [f"yolov8{k}-world.pt" for k in "smlx"]
|
||||
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
|
||||
+ [f"yolov9{k}.pt" for k in "ce"]
|
||||
+ [f"yolo_nas_{k}.pt" for k in "sml"]
|
||||
+ [f"sam_{k}.pt" for k in "bl"]
|
||||
+ [f"FastSAM-{k}.pt" for k in "sx"]
|
||||
+ [f"rtdetr-{k}.pt" for k in "lx"]
|
||||
+ ["mobile_sam.pt"]
|
||||
+ ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"]
|
||||
)
|
||||
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
|
||||
|
||||
|
||||
def is_url(url, check=True):
|
||||
"""Check if string is URL and check if URL exists."""
|
||||
def is_url(url, check=False):
|
||||
"""
|
||||
Validates if the given string is a URL and optionally checks if the URL exists online.
|
||||
|
||||
Args:
|
||||
url (str): The string to be validated as a URL.
|
||||
check (bool, optional): If True, performs an additional check to see if the URL exists online.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
(bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online.
|
||||
Returns False otherwise.
|
||||
|
||||
Example:
|
||||
```python
|
||||
valid = is_url("https://www.example.com")
|
||||
```
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
url = str(url)
|
||||
result = parse.urlparse(url)
|
||||
@ -40,7 +62,7 @@ def is_url(url, check=True):
|
||||
return False
|
||||
|
||||
|
||||
def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
|
||||
def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
|
||||
"""
|
||||
Deletes all ".DS_store" files under a specified directory.
|
||||
|
||||
@ -59,18 +81,17 @@ def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
|
||||
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
|
||||
are hidden system files and can cause issues when transferring files between different operating systems.
|
||||
"""
|
||||
# Delete Apple .DS_store files
|
||||
for file in files_to_delete:
|
||||
matches = list(Path(path).rglob(file))
|
||||
LOGGER.info(f'Deleting {file} files: {matches}')
|
||||
LOGGER.info(f"Deleting {file} files: {matches}")
|
||||
for f in matches:
|
||||
f.unlink()
|
||||
|
||||
|
||||
def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True):
|
||||
def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
|
||||
"""
|
||||
Zips the contents of a directory, excluding files containing strings in the exclude list.
|
||||
The resulting zip file is named after the directory and placed alongside it.
|
||||
Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
|
||||
named after the directory and placed alongside it.
|
||||
|
||||
Args:
|
||||
directory (str | Path): The path to the directory to be zipped.
|
||||
@ -96,17 +117,17 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
|
||||
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
|
||||
|
||||
# Unzip with progress bar
|
||||
files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)]
|
||||
zip_file = directory.with_suffix('.zip')
|
||||
files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
|
||||
zip_file = directory.with_suffix(".zip")
|
||||
compression = ZIP_DEFLATED if compress else ZIP_STORED
|
||||
with ZipFile(zip_file, 'w', compression) as f:
|
||||
for file in TQDM(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress):
|
||||
with ZipFile(zip_file, "w", compression) as f:
|
||||
for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress):
|
||||
f.write(file, file.relative_to(directory))
|
||||
|
||||
return zip_file # return path to zip file
|
||||
|
||||
|
||||
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=False, progress=True):
|
||||
def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
|
||||
"""
|
||||
Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
|
||||
|
||||
@ -146,51 +167,62 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
|
||||
files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
|
||||
top_level_dirs = {Path(f).parts[0] for f in files}
|
||||
|
||||
if len(top_level_dirs) > 1 or not files[0].endswith('/'): # zip has multiple files at top level
|
||||
if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")):
|
||||
# Zip has multiple files at top level
|
||||
path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8
|
||||
else: # zip has 1 top-level directory
|
||||
else:
|
||||
# Zip has 1 top-level directory
|
||||
extract_path = path # i.e. ../datasets
|
||||
path = Path(path) / list(top_level_dirs)[0] # i.e. ../datasets/coco8
|
||||
|
||||
# Check if destination directory already exists and contains files
|
||||
if path.exists() and any(path.iterdir()) and not exist_ok:
|
||||
# If it exists and is not empty, return the path without unzipping
|
||||
LOGGER.warning(f'WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.')
|
||||
LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.")
|
||||
return path
|
||||
|
||||
for f in TQDM(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress):
|
||||
zipObj.extract(f, path=extract_path)
|
||||
for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress):
|
||||
# Ensure the file is within the extract_path to avoid path traversal security vulnerability
|
||||
if ".." in Path(f).parts:
|
||||
LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
|
||||
continue
|
||||
zipObj.extract(f, extract_path)
|
||||
|
||||
return path # return unzip dir
|
||||
|
||||
|
||||
def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, hard=True):
|
||||
def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", path=Path.cwd(), sf=1.5, hard=True):
|
||||
"""
|
||||
Check if there is sufficient disk space to download and store a file.
|
||||
|
||||
Args:
|
||||
url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco128.zip'.
|
||||
path (str | Path, optional): The path or drive to check the available free space on.
|
||||
sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0.
|
||||
hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(bool): True if there is sufficient disk space, False otherwise.
|
||||
"""
|
||||
r = requests.head(url) # response
|
||||
|
||||
# Check response
|
||||
assert r.status_code < 400, f'URL error for {url}: {r.status_code} {r.reason}'
|
||||
try:
|
||||
r = requests.head(url) # response
|
||||
assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response
|
||||
except Exception:
|
||||
return True # requests issue, default to True
|
||||
|
||||
# Check file size
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
data = int(r.headers.get('Content-Length', 0)) / gib # file size (GB)
|
||||
total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
|
||||
data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB)
|
||||
total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes
|
||||
|
||||
if data * sf < free:
|
||||
return True # sufficient space
|
||||
|
||||
# Insufficient space
|
||||
text = (f'WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, '
|
||||
f'Please free {data * sf - free:.1f} GB additional disk space and try again.')
|
||||
text = (
|
||||
f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
|
||||
f"Please free {data * sf - free:.1f} GB additional disk space and try again."
|
||||
)
|
||||
if hard:
|
||||
raise MemoryError(text)
|
||||
LOGGER.warning(text)
|
||||
@ -216,35 +248,41 @@ def get_google_drive_file_info(link):
|
||||
url, filename = get_google_drive_file_info(link)
|
||||
```
|
||||
"""
|
||||
file_id = link.split('/d/')[1].split('/view')[0]
|
||||
drive_url = f'https://drive.google.com/uc?export=download&id={file_id}'
|
||||
file_id = link.split("/d/")[1].split("/view")[0]
|
||||
drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
||||
filename = None
|
||||
|
||||
# Start session
|
||||
with requests.Session() as session:
|
||||
response = session.get(drive_url, stream=True)
|
||||
if 'quota exceeded' in str(response.content.lower()):
|
||||
if "quota exceeded" in str(response.content.lower()):
|
||||
raise ConnectionError(
|
||||
emojis(f'❌ Google Drive file download quota exceeded. '
|
||||
f'Please try again later or download this file manually at {link}.'))
|
||||
emojis(
|
||||
f"❌ Google Drive file download quota exceeded. "
|
||||
f"Please try again later or download this file manually at {link}."
|
||||
)
|
||||
)
|
||||
for k, v in response.cookies.items():
|
||||
if k.startswith('download_warning'):
|
||||
drive_url += f'&confirm={v}' # v is token
|
||||
cd = response.headers.get('content-disposition')
|
||||
if k.startswith("download_warning"):
|
||||
drive_url += f"&confirm={v}" # v is token
|
||||
cd = response.headers.get("content-disposition")
|
||||
if cd:
|
||||
filename = re.findall('filename="(.+)"', cd)[0]
|
||||
return drive_url, filename
|
||||
|
||||
|
||||
def safe_download(url,
|
||||
file=None,
|
||||
dir=None,
|
||||
unzip=True,
|
||||
delete=False,
|
||||
curl=False,
|
||||
retry=3,
|
||||
min_bytes=1E0,
|
||||
progress=True):
|
||||
def safe_download(
|
||||
url,
|
||||
file=None,
|
||||
dir=None,
|
||||
unzip=True,
|
||||
delete=False,
|
||||
curl=False,
|
||||
retry=3,
|
||||
min_bytes=1e0,
|
||||
exist_ok=False,
|
||||
progress=True,
|
||||
):
|
||||
"""
|
||||
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
|
||||
|
||||
@ -260,41 +298,49 @@ def safe_download(url,
|
||||
retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
|
||||
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
|
||||
a successful download. Default: 1E0.
|
||||
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
|
||||
progress (bool, optional): Whether to display a progress bar during the download. Default: True.
|
||||
"""
|
||||
|
||||
# Check if the URL is a Google Drive link
|
||||
gdrive = url.startswith('https://drive.google.com/')
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.downloads import safe_download
|
||||
|
||||
link = "https://ultralytics.com/assets/bus.jpg"
|
||||
path = safe_download(link)
|
||||
```
|
||||
"""
|
||||
gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
|
||||
if gdrive:
|
||||
url, file = get_google_drive_file_info(url)
|
||||
|
||||
f = dir / (file if gdrive else url2file(url)) if dir else Path(file) # URL converted to filename
|
||||
if '://' not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
|
||||
f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
|
||||
if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
|
||||
f = Path(url) # filename
|
||||
elif not f.is_file(): # URL and file do not exist
|
||||
assert dir or file, 'dir or file required for download'
|
||||
desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'"
|
||||
LOGGER.info(f'{desc}...')
|
||||
LOGGER.info(f"{desc}...")
|
||||
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
|
||||
check_disk_space(url)
|
||||
check_disk_space(url, path=f.parent)
|
||||
for i in range(retry + 1):
|
||||
try:
|
||||
if curl or i > 0: # curl download with retry, continue
|
||||
s = 'sS' * (not progress) # silent
|
||||
r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode
|
||||
assert r == 0, f'Curl return value {r}'
|
||||
s = "sS" * (not progress) # silent
|
||||
r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
|
||||
assert r == 0, f"Curl return value {r}"
|
||||
else: # urllib download
|
||||
method = 'torch'
|
||||
if method == 'torch':
|
||||
method = "torch"
|
||||
if method == "torch":
|
||||
torch.hub.download_url_to_file(url, f, progress=progress)
|
||||
else:
|
||||
with request.urlopen(url) as response, TQDM(total=int(response.getheader('Content-Length', 0)),
|
||||
desc=desc,
|
||||
disable=not progress,
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024) as pbar:
|
||||
with open(f, 'wb') as f_opened:
|
||||
with request.urlopen(url) as response, TQDM(
|
||||
total=int(response.getheader("Content-Length", 0)),
|
||||
desc=desc,
|
||||
disable=not progress,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as pbar:
|
||||
with open(f, "wb") as f_opened:
|
||||
for data in response:
|
||||
f_opened.write(data)
|
||||
pbar.update(len(data))
|
||||
@ -305,88 +351,150 @@ def safe_download(url,
|
||||
f.unlink() # remove partial downloads
|
||||
except Exception as e:
|
||||
if i == 0 and not is_online():
|
||||
raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e
|
||||
raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e
|
||||
elif i >= retry:
|
||||
raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e
|
||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||
raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
|
||||
LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
|
||||
|
||||
if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'):
|
||||
if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
|
||||
from zipfile import is_zipfile
|
||||
|
||||
unzip_dir = dir or f.parent # unzip to dir if provided else unzip in place
|
||||
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
|
||||
if is_zipfile(f):
|
||||
unzip_dir = unzip_file(file=f, path=unzip_dir, progress=progress) # unzip
|
||||
elif f.suffix in ('.tar', '.gz'):
|
||||
LOGGER.info(f'Unzipping {f} to {unzip_dir.resolve()}...')
|
||||
subprocess.run(['tar', 'xf' if f.suffix == '.tar' else 'xfz', f, '--directory', unzip_dir], check=True)
|
||||
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
|
||||
elif f.suffix in (".tar", ".gz"):
|
||||
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
|
||||
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
|
||||
if delete:
|
||||
f.unlink() # remove zip
|
||||
return unzip_dir
|
||||
|
||||
|
||||
def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
|
||||
"""Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
|
||||
if version != 'latest':
|
||||
version = f'tags/{version}' # i.e. tags/v6.2
|
||||
url = f'https://api.github.com/repos/{repo}/releases/{version}'
|
||||
def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
|
||||
"""
|
||||
Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
|
||||
function fetches the latest release assets.
|
||||
|
||||
Args:
|
||||
repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
|
||||
version (str, optional): The release version to fetch assets from. Defaults to 'latest'.
|
||||
retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the release tag and a list of asset names.
|
||||
|
||||
Example:
|
||||
```python
|
||||
tag, assets = get_github_assets(repo='ultralytics/assets', version='latest')
|
||||
```
|
||||
"""
|
||||
|
||||
if version != "latest":
|
||||
version = f"tags/{version}" # i.e. tags/v6.2
|
||||
url = f"https://api.github.com/repos/{repo}/releases/{version}"
|
||||
r = requests.get(url) # github api
|
||||
if r.status_code != 200 and r.reason != 'rate limit exceeded' and retry: # failed and not 403 rate limit exceeded
|
||||
if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
|
||||
r = requests.get(url) # try again
|
||||
if r.status_code != 200:
|
||||
LOGGER.warning(f'⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}')
|
||||
return '', []
|
||||
LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}")
|
||||
return "", []
|
||||
data = r.json()
|
||||
return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets
|
||||
return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
|
||||
|
||||
|
||||
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
"""Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc."""
|
||||
def attempt_download_asset(file, repo="ultralytics/assets", release="v8.1.0", **kwargs):
|
||||
"""
|
||||
Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
|
||||
locally first, then tries to download it from the specified GitHub repository release.
|
||||
|
||||
Args:
|
||||
file (str | Path): The filename or file path to be downloaded.
|
||||
repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
|
||||
release (str, optional): The specific release version to be downloaded. Defaults to 'v8.1.0'.
|
||||
**kwargs (any): Additional keyword arguments for the download process.
|
||||
|
||||
Returns:
|
||||
(str): The path to the downloaded file.
|
||||
|
||||
Example:
|
||||
```python
|
||||
file_path = attempt_download_asset('yolov5s.pt', repo='ultralytics/assets', release='latest')
|
||||
```
|
||||
"""
|
||||
from ultralytics.utils import SETTINGS # scoped for circular import
|
||||
|
||||
# YOLOv3/5u updates
|
||||
file = str(file)
|
||||
file = checks.check_yolov5u_filename(file)
|
||||
file = Path(file.strip().replace("'", ''))
|
||||
file = Path(file.strip().replace("'", ""))
|
||||
if file.exists():
|
||||
return str(file)
|
||||
elif (SETTINGS['weights_dir'] / file).exists():
|
||||
return str(SETTINGS['weights_dir'] / file)
|
||||
elif (SETTINGS["weights_dir"] / file).exists():
|
||||
return str(SETTINGS["weights_dir"] / file)
|
||||
else:
|
||||
# URL specified
|
||||
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
||||
if str(file).startswith(('http:/', 'https:/')): # download
|
||||
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
|
||||
download_url = f"https://github.com/{repo}/releases/download"
|
||||
if str(file).startswith(("http:/", "https:/")): # download
|
||||
url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
|
||||
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
|
||||
if Path(file).is_file():
|
||||
LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
|
||||
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
|
||||
else:
|
||||
safe_download(url=url, file=file, min_bytes=1E5)
|
||||
safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
|
||||
safe_download(url=f'https://github.com/{repo}/releases/download/{release}/{name}', file=file, min_bytes=1E5)
|
||||
safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
else:
|
||||
tag, assets = get_github_assets(repo, release)
|
||||
if not assets:
|
||||
tag, assets = get_github_assets(repo) # latest release
|
||||
if name in assets:
|
||||
safe_download(url=f'https://github.com/{repo}/releases/download/{tag}/{name}', file=file, min_bytes=1E5)
|
||||
safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
return str(file)
|
||||
|
||||
|
||||
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3):
|
||||
"""Downloads and unzips files concurrently if threads > 1, else sequentially."""
|
||||
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):
|
||||
"""
|
||||
Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are
|
||||
specified.
|
||||
|
||||
Args:
|
||||
url (str | list): The URL or list of URLs of the files to be downloaded.
|
||||
dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory.
|
||||
unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True.
|
||||
delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False.
|
||||
curl (bool, optional): Flag to use curl for downloading. Defaults to False.
|
||||
threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1.
|
||||
retry (int, optional): Number of retries in case of download failure. Defaults to 3.
|
||||
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
|
||||
|
||||
Example:
|
||||
```python
|
||||
download('https://ultralytics.com/assets/example.zip', dir='path/to/dir', unzip=True)
|
||||
```
|
||||
"""
|
||||
dir = Path(dir)
|
||||
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||||
if threads > 1:
|
||||
with ThreadPool(threads) as pool:
|
||||
pool.map(
|
||||
lambda x: safe_download(
|
||||
url=x[0], dir=x[1], unzip=unzip, delete=delete, curl=curl, retry=retry, progress=threads <= 1),
|
||||
zip(url, repeat(dir)))
|
||||
url=x[0],
|
||||
dir=x[1],
|
||||
unzip=unzip,
|
||||
delete=delete,
|
||||
curl=curl,
|
||||
retry=retry,
|
||||
exist_ok=exist_ok,
|
||||
progress=threads <= 1,
|
||||
),
|
||||
zip(url, repeat(dir)),
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
for u in [url] if isinstance(url, (str, Path)) else url:
|
||||
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry)
|
||||
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)
|
||||
|
@ -4,7 +4,19 @@ from ultralytics.utils import emojis
|
||||
|
||||
|
||||
class HUBModelError(Exception):
|
||||
"""
|
||||
Custom exception class for handling errors related to model fetching in Ultralytics YOLO.
|
||||
|
||||
def __init__(self, message='Model not found. Please check model URL and try again.'):
|
||||
This exception is raised when a requested model is not found or cannot be retrieved.
|
||||
The message is also processed to include emojis for better user experience.
|
||||
|
||||
Attributes:
|
||||
message (str): The error message displayed when the exception is raised.
|
||||
|
||||
Note:
|
||||
The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Model not found. Please check model URL and try again."):
|
||||
"""Create an exception for when a model is not found."""
|
||||
super().__init__(emojis(message))
|
||||
|
@ -30,9 +30,9 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
||||
@contextmanager
|
||||
def spaces_in_path(path):
|
||||
"""
|
||||
Context manager to handle paths with spaces in their names.
|
||||
If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path,
|
||||
executes the context code block, then copies the file/directory back to its original location.
|
||||
Context manager to handle paths with spaces in their names. If a path contains spaces, it replaces them with
|
||||
underscores, copies the file/directory to the new path, executes the context code block, then copies the
|
||||
file/directory back to its original location.
|
||||
|
||||
Args:
|
||||
path (str | Path): The original path.
|
||||
@ -45,18 +45,18 @@ def spaces_in_path(path):
|
||||
with ultralytics.utils.files import spaces_in_path
|
||||
|
||||
with spaces_in_path('/path/with spaces') as new_path:
|
||||
# your code here
|
||||
# Your code here
|
||||
```
|
||||
"""
|
||||
|
||||
# If path has spaces, replace them with underscores
|
||||
if ' ' in str(path):
|
||||
if " " in str(path):
|
||||
string = isinstance(path, str) # input type
|
||||
path = Path(path)
|
||||
|
||||
# Create a temporary directory and construct the new path
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir) / path.name.replace(' ', '_')
|
||||
tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
|
||||
|
||||
# Copy file/directory
|
||||
if path.is_dir():
|
||||
@ -82,7 +82,7 @@ def spaces_in_path(path):
|
||||
yield path
|
||||
|
||||
|
||||
def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
||||
"""
|
||||
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
|
||||
@ -102,12 +102,12 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
"""
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
|
||||
path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
|
||||
|
||||
# Method 1
|
||||
for n in range(2, 9999):
|
||||
p = f'{path}{sep}{n}{suffix}' # increment path
|
||||
if not os.path.exists(p): #
|
||||
p = f"{path}{sep}{n}{suffix}" # increment path
|
||||
if not os.path.exists(p):
|
||||
break
|
||||
path = Path(p)
|
||||
|
||||
@ -119,14 +119,14 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
|
||||
def file_age(path=__file__):
|
||||
"""Return days since last file update."""
|
||||
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
|
||||
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
|
||||
return dt.days # + dt.seconds / 86400 # fractional days
|
||||
|
||||
|
||||
def file_date(path=__file__):
|
||||
"""Return human-readable file modification date, i.e. '2021-3-26'."""
|
||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||
return f'{t.year}-{t.month}-{t.day}'
|
||||
return f"{t.year}-{t.month}-{t.day}"
|
||||
|
||||
|
||||
def file_size(path):
|
||||
@ -137,11 +137,52 @@ def file_size(path):
|
||||
if path.is_file():
|
||||
return path.stat().st_size / mb
|
||||
elif path.is_dir():
|
||||
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
|
||||
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_latest_run(search_dir='.'):
|
||||
def get_latest_run(search_dir="."):
|
||||
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
|
||||
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ''
|
||||
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ""
|
||||
|
||||
|
||||
def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False):
|
||||
"""
|
||||
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
|
||||
|
||||
Args:
|
||||
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
|
||||
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
|
||||
update_names (bool, optional): Update model names from a data YAML.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.files import update_models
|
||||
|
||||
model_names = (f"rtdetr-{size}.pt" for size in "lx")
|
||||
update_models(model_names)
|
||||
```
|
||||
"""
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.nn.autobackend import default_class_names
|
||||
|
||||
target_dir = source_dir / "updated_models"
|
||||
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
|
||||
|
||||
for model_name in model_names:
|
||||
model_path = source_dir / model_name
|
||||
print(f"Loading model from {model_path}")
|
||||
|
||||
# Load model
|
||||
model = YOLO(model_path)
|
||||
model.half()
|
||||
if update_names: # update model names from a dataset YAML
|
||||
model.model.names = default_class_names("coco8.yaml")
|
||||
|
||||
# Define new save path
|
||||
save_path = target_dir / model_name
|
||||
|
||||
# Save model using model.save()
|
||||
print(f"Re-saving {model_name} model to {save_path}")
|
||||
model.save(save_path, use_dill=False)
|
||||
|
@ -7,7 +7,7 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
|
||||
from .ops import ltwh2xywh, ltwh2xyxy, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
@ -26,16 +26,29 @@ to_4tuple = _ntuple(4)
|
||||
# `xyxy` means left top and right bottom
|
||||
# `xywh` means center x, center y and width, height(YOLO format)
|
||||
# `ltwh` means left top and width, height(COCO format)
|
||||
_formats = ['xyxy', 'xywh', 'ltwh']
|
||||
_formats = ["xyxy", "xywh", "ltwh"]
|
||||
|
||||
__all__ = 'Bboxes', # tuple or list
|
||||
__all__ = ("Bboxes",) # tuple or list
|
||||
|
||||
|
||||
class Bboxes:
|
||||
"""Bounding Boxes class. Only numpy variables are supported."""
|
||||
"""
|
||||
A class for handling bounding boxes.
|
||||
|
||||
def __init__(self, bboxes, format='xyxy') -> None:
|
||||
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
||||
The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'.
|
||||
Bounding box data should be provided in numpy arrays.
|
||||
|
||||
Attributes:
|
||||
bboxes (numpy.ndarray): The bounding boxes stored in a 2D numpy array.
|
||||
format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').
|
||||
|
||||
Note:
|
||||
This class does not handle normalization or denormalization of bounding boxes.
|
||||
"""
|
||||
|
||||
def __init__(self, bboxes, format="xyxy") -> None:
|
||||
"""Initializes the Bboxes class with bounding box data in a specified format."""
|
||||
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
|
||||
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
|
||||
assert bboxes.ndim == 2
|
||||
assert bboxes.shape[1] == 4
|
||||
@ -45,21 +58,21 @@ class Bboxes:
|
||||
|
||||
def convert(self, format):
|
||||
"""Converts bounding box format from one type to another."""
|
||||
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
||||
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
|
||||
if self.format == format:
|
||||
return
|
||||
elif self.format == 'xyxy':
|
||||
func = xyxy2xywh if format == 'xywh' else xyxy2ltwh
|
||||
elif self.format == 'xywh':
|
||||
func = xywh2xyxy if format == 'xyxy' else xywh2ltwh
|
||||
elif self.format == "xyxy":
|
||||
func = xyxy2xywh if format == "xywh" else xyxy2ltwh
|
||||
elif self.format == "xywh":
|
||||
func = xywh2xyxy if format == "xyxy" else xywh2ltwh
|
||||
else:
|
||||
func = ltwh2xyxy if format == 'xyxy' else ltwh2xywh
|
||||
func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
|
||||
self.bboxes = func(self.bboxes)
|
||||
self.format = format
|
||||
|
||||
def areas(self):
|
||||
"""Return box areas."""
|
||||
self.convert('xyxy')
|
||||
self.convert("xyxy")
|
||||
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
|
||||
|
||||
# def denormalize(self, w, h):
|
||||
@ -111,7 +124,7 @@ class Bboxes:
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
|
||||
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
|
||||
"""
|
||||
Concatenate a list of Bboxes objects into a single Bboxes object.
|
||||
|
||||
@ -135,7 +148,7 @@ class Bboxes:
|
||||
return boxes_list[0]
|
||||
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
||||
|
||||
def __getitem__(self, index) -> 'Bboxes':
|
||||
def __getitem__(self, index) -> "Bboxes":
|
||||
"""
|
||||
Retrieve a specific bounding box or a set of bounding boxes using indexing.
|
||||
|
||||
@ -156,32 +169,52 @@ class Bboxes:
|
||||
if isinstance(index, int):
|
||||
return Bboxes(self.bboxes[index].view(1, -1))
|
||||
b = self.bboxes[index]
|
||||
assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
|
||||
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
|
||||
return Bboxes(b)
|
||||
|
||||
|
||||
class Instances:
|
||||
"""
|
||||
Container for bounding boxes, segments, and keypoints of detected objects in an image.
|
||||
|
||||
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
|
||||
Attributes:
|
||||
_bboxes (Bboxes): Internal object for handling bounding box operations.
|
||||
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3]. Default is None.
|
||||
normalized (bool): Flag indicating whether the bounding box coordinates are normalized.
|
||||
segments (ndarray): Segments array with shape [N, 1000, 2] after resampling.
|
||||
|
||||
Args:
|
||||
bboxes (ndarray): An array of bounding boxes with shape [N, 4].
|
||||
segments (list | ndarray, optional): A list or array of object segments. Default is None.
|
||||
keypoints (ndarray, optional): An array of keypoints with shape [N, 17, 3]. Default is None.
|
||||
bbox_format (str, optional): The format of bounding boxes ('xywh' or 'xyxy'). Default is 'xywh'.
|
||||
normalized (bool, optional): Whether the bounding box coordinates are normalized. Default is True.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Create an Instances object
|
||||
instances = Instances(
|
||||
bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
|
||||
segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
|
||||
keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]])
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
The bounding box format is either 'xywh' or 'xyxy', and is determined by the `bbox_format` argument.
|
||||
This class does not perform input validation, and it assumes the inputs are well-formed.
|
||||
"""
|
||||
|
||||
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
|
||||
"""
|
||||
Args:
|
||||
bboxes (ndarray): bboxes with shape [N, 4].
|
||||
segments (list | ndarray): segments.
|
||||
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
|
||||
"""
|
||||
if segments is None:
|
||||
segments = []
|
||||
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
|
||||
self.keypoints = keypoints
|
||||
self.normalized = normalized
|
||||
|
||||
if len(segments) > 0:
|
||||
# list[np.array(1000, 2)] * num_samples
|
||||
segments = resample_segments(segments)
|
||||
# (N, 1000, 2)
|
||||
segments = np.stack(segments, axis=0)
|
||||
else:
|
||||
segments = np.zeros((0, 1000, 2), dtype=np.float32)
|
||||
self.segments = segments
|
||||
|
||||
def convert_bbox(self, format):
|
||||
@ -194,7 +227,7 @@ class Instances:
|
||||
return self._bboxes.areas()
|
||||
|
||||
def scale(self, scale_w, scale_h, bbox_only=False):
|
||||
"""this might be similar with denormalize func but without normalized sign."""
|
||||
"""This might be similar with denormalize func but without normalized sign."""
|
||||
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
|
||||
if bbox_only:
|
||||
return
|
||||
@ -230,7 +263,7 @@ class Instances:
|
||||
|
||||
def add_padding(self, padw, padh):
|
||||
"""Handle rect and mosaic situation."""
|
||||
assert not self.normalized, 'you should add padding with absolute coordinates.'
|
||||
assert not self.normalized, "you should add padding with absolute coordinates."
|
||||
self._bboxes.add(offset=(padw, padh, padw, padh))
|
||||
self.segments[..., 0] += padw
|
||||
self.segments[..., 1] += padh
|
||||
@ -238,7 +271,7 @@ class Instances:
|
||||
self.keypoints[..., 0] += padw
|
||||
self.keypoints[..., 1] += padh
|
||||
|
||||
def __getitem__(self, index) -> 'Instances':
|
||||
def __getitem__(self, index) -> "Instances":
|
||||
"""
|
||||
Retrieve a specific instance or a set of instances using indexing.
|
||||
|
||||
@ -268,7 +301,7 @@ class Instances:
|
||||
|
||||
def flipud(self, h):
|
||||
"""Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
|
||||
if self._bboxes.format == 'xyxy':
|
||||
if self._bboxes.format == "xyxy":
|
||||
y1 = self.bboxes[:, 1].copy()
|
||||
y2 = self.bboxes[:, 3].copy()
|
||||
self.bboxes[:, 1] = h - y2
|
||||
@ -281,7 +314,7 @@ class Instances:
|
||||
|
||||
def fliplr(self, w):
|
||||
"""Reverses the order of the bounding boxes and segments horizontally."""
|
||||
if self._bboxes.format == 'xyxy':
|
||||
if self._bboxes.format == "xyxy":
|
||||
x1 = self.bboxes[:, 0].copy()
|
||||
x2 = self.bboxes[:, 2].copy()
|
||||
self.bboxes[:, 0] = w - x2
|
||||
@ -295,10 +328,10 @@ class Instances:
|
||||
def clip(self, w, h):
|
||||
"""Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
|
||||
ori_format = self._bboxes.format
|
||||
self.convert_bbox(format='xyxy')
|
||||
self.convert_bbox(format="xyxy")
|
||||
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
||||
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
|
||||
if ori_format != 'xyxy':
|
||||
if ori_format != "xyxy":
|
||||
self.convert_bbox(format=ori_format)
|
||||
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
|
||||
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
|
||||
@ -307,7 +340,11 @@ class Instances:
|
||||
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
||||
|
||||
def remove_zero_area_boxes(self):
|
||||
"""Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. This removes them."""
|
||||
"""
|
||||
Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
|
||||
|
||||
This removes them.
|
||||
"""
|
||||
good = self.bbox_areas > 0
|
||||
if not all(good):
|
||||
self._bboxes = self._bboxes[good]
|
||||
@ -330,7 +367,7 @@ class Instances:
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
|
||||
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
|
||||
"""
|
||||
Concatenates a list of Instances objects into a single Instances object.
|
||||
|
||||
|
@ -6,14 +6,17 @@ import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils.metrics import OKS_SIGMA
|
||||
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
||||
from ultralytics.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
|
||||
|
||||
from .metrics import bbox_iou
|
||||
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
||||
from .metrics import bbox_iou, probiou
|
||||
from .tal import bbox2dist
|
||||
|
||||
|
||||
class VarifocalLoss(nn.Module):
|
||||
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
|
||||
"""
|
||||
Varifocal loss by Zhang et al.
|
||||
|
||||
https://arxiv.org/abs/2008.13367.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the VarifocalLoss class."""
|
||||
@ -24,21 +27,25 @@ class VarifocalLoss(nn.Module):
|
||||
"""Computes varfocal loss."""
|
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
|
||||
weight).mean(1).sum()
|
||||
loss = (
|
||||
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
|
||||
.mean(1)
|
||||
.sum()
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
|
||||
|
||||
def __init__(self, ):
|
||||
def __init__(self):
|
||||
"""Initializer for FocalLoss class with no parameters."""
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def forward(pred, label, gamma=1.5, alpha=0.25):
|
||||
"""Calculates and updates confusion matrix for object detection/classification tasks."""
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
|
||||
# p_t = torch.exp(-loss)
|
||||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||
|
||||
@ -54,6 +61,7 @@ class FocalLoss(nn.Module):
|
||||
|
||||
|
||||
class BboxLoss(nn.Module):
|
||||
"""Criterion class for computing training losses during training."""
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
@ -79,42 +87,73 @@ class BboxLoss(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _df_loss(pred_dist, target):
|
||||
"""Return sum of left and right DFL losses."""
|
||||
# Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
||||
"""
|
||||
Return sum of left and right DFL losses.
|
||||
|
||||
Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
|
||||
https://ieeexplore.ieee.org/document/9792391
|
||||
"""
|
||||
tl = target.long() # target left
|
||||
tr = tl + 1 # target right
|
||||
wl = tr - target # weight left
|
||||
wr = 1 - wl # weight right
|
||||
return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
|
||||
F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
|
||||
return (
|
||||
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
|
||||
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
|
||||
).mean(-1, keepdim=True)
|
||||
|
||||
|
||||
class RotatedBboxLoss(BboxLoss):
|
||||
"""Criterion class for computing training losses during training."""
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__(reg_max, use_dfl)
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
"""IoU loss."""
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
if self.use_dfl:
|
||||
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
|
||||
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
|
||||
class KeypointLoss(nn.Module):
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, sigmas) -> None:
|
||||
"""Initialize the KeypointLoss class."""
|
||||
super().__init__()
|
||||
self.sigmas = sigmas
|
||||
|
||||
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
|
||||
"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
|
||||
d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2
|
||||
kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9)
|
||||
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
|
||||
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
|
||||
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
|
||||
e = d / (2 * self.sigmas) ** 2 / (area + 1e-9) / 2 # from cocoeval
|
||||
return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean()
|
||||
e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
|
||||
return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
|
||||
|
||||
|
||||
class v8DetectionLoss:
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
|
||||
def __init__(self, model, tal_topk=10): # model must be de-paralleled
|
||||
"""Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
|
||||
device = next(model.parameters()).device # get model device
|
||||
h = model.args # hyperparameters
|
||||
|
||||
m = model.model[-1] # Detect() module
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction='none')
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction="none")
|
||||
self.hyp = h
|
||||
self.stride = m.stride # model strides
|
||||
self.nc = m.nc # number of classes
|
||||
@ -124,7 +163,7 @@ class v8DetectionLoss:
|
||||
|
||||
self.use_dfl = m.reg_max > 1
|
||||
|
||||
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
|
||||
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
||||
|
||||
@ -159,7 +198,8 @@ class v8DetectionLoss:
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
feats = preds[1] if isinstance(preds, tuple) else preds
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
@ -169,30 +209,36 @@ class v8DetectionLoss:
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# targets
|
||||
targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||
# Targets
|
||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
||||
# pboxes
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# cls loss
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# bbox loss
|
||||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
|
||||
target_scores_sum, fg_mask)
|
||||
loss[0], loss[2] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.cls # cls gain
|
||||
@ -205,8 +251,8 @@ class v8SegmentationLoss(v8DetectionLoss):
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
"""Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument."""
|
||||
super().__init__(model)
|
||||
self.nm = model.model[-1].nm # number of masks
|
||||
self.overlap = model.args.overlap_mask
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
@ -215,9 +261,10 @@ class v8SegmentationLoss(v8DetectionLoss):
|
||||
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
||||
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# b, grids, ..
|
||||
# B, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
|
||||
@ -226,80 +273,168 @@ class v8SegmentationLoss(v8DetectionLoss):
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# targets
|
||||
# Targets
|
||||
try:
|
||||
batch_idx = batch['batch_idx'].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
except RuntimeError as e:
|
||||
raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
|
||||
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
|
||||
'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
|
||||
raise TypeError(
|
||||
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
|
||||
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
||||
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
|
||||
) from e
|
||||
|
||||
# pboxes
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
||||
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# cls loss
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
if fg_mask.sum():
|
||||
# bbox loss
|
||||
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
|
||||
target_scores, target_scores_sum, fg_mask)
|
||||
# masks loss
|
||||
masks = batch['masks'].to(self.device).float()
|
||||
# Bbox loss
|
||||
loss[0], loss[3] = self.bbox_loss(
|
||||
pred_distri,
|
||||
pred_bboxes,
|
||||
anchor_points,
|
||||
target_bboxes / stride_tensor,
|
||||
target_scores,
|
||||
target_scores_sum,
|
||||
fg_mask,
|
||||
)
|
||||
# Masks loss
|
||||
masks = batch["masks"].to(self.device).float()
|
||||
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
|
||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
if fg_mask[i].sum():
|
||||
mask_idx = target_gt_idx[i][fg_mask[i]]
|
||||
if self.overlap:
|
||||
gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
|
||||
else:
|
||||
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
|
||||
xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
|
||||
marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
|
||||
mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
|
||||
loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg
|
||||
|
||||
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
loss[1] = self.calculate_segmentation_loss(
|
||||
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
|
||||
)
|
||||
|
||||
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.box / batch_size # seg gain
|
||||
loss[1] *= self.hyp.box # seg gain
|
||||
loss[2] *= self.hyp.cls # cls gain
|
||||
loss[3] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
|
||||
"""Mask loss for one image."""
|
||||
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
|
||||
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
|
||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
||||
@staticmethod
|
||||
def single_mask_loss(
|
||||
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the instance segmentation loss for a single image.
|
||||
|
||||
Args:
|
||||
gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
|
||||
pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
|
||||
proto (torch.Tensor): Prototype masks of shape (32, H, W).
|
||||
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
|
||||
area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The calculated mask loss for a single image.
|
||||
|
||||
Notes:
|
||||
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
|
||||
predicted masks from the prototype masks and predicted mask coefficients.
|
||||
"""
|
||||
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
|
||||
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
|
||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
|
||||
|
||||
def calculate_segmentation_loss(
|
||||
self,
|
||||
fg_mask: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
target_gt_idx: torch.Tensor,
|
||||
target_bboxes: torch.Tensor,
|
||||
batch_idx: torch.Tensor,
|
||||
proto: torch.Tensor,
|
||||
pred_masks: torch.Tensor,
|
||||
imgsz: torch.Tensor,
|
||||
overlap: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the loss for instance segmentation.
|
||||
|
||||
Args:
|
||||
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
|
||||
masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
|
||||
target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
|
||||
target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
|
||||
batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
|
||||
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
|
||||
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
|
||||
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
|
||||
overlap (bool): Whether the masks in `masks` tensor overlap.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The calculated loss for instance segmentation.
|
||||
|
||||
Notes:
|
||||
The batch loss can be computed for improved speed at higher memory usage.
|
||||
For example, pred_mask can be computed as follows:
|
||||
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
|
||||
"""
|
||||
_, _, mask_h, mask_w = proto.shape
|
||||
loss = 0
|
||||
|
||||
# Normalize to 0-1
|
||||
target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
|
||||
|
||||
# Areas of target bboxes
|
||||
marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
|
||||
|
||||
# Normalize to mask size
|
||||
mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
|
||||
|
||||
for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
|
||||
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
|
||||
if fg_mask_i.any():
|
||||
mask_idx = target_gt_idx_i[fg_mask_i]
|
||||
if overlap:
|
||||
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
|
||||
gt_mask = gt_mask.float()
|
||||
else:
|
||||
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
|
||||
|
||||
loss += self.single_mask_loss(
|
||||
gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
|
||||
)
|
||||
|
||||
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
|
||||
return loss / fg_mask.sum()
|
||||
|
||||
|
||||
class v8PoseLoss(v8DetectionLoss):
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
"""Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance."""
|
||||
super().__init__(model)
|
||||
self.kpt_shape = model.model[-1].kpt_shape
|
||||
self.bce_pose = nn.BCEWithLogitsLoss()
|
||||
@ -313,9 +448,10 @@ class v8PoseLoss(v8DetectionLoss):
|
||||
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
||||
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# b, grids, ..
|
||||
# B, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
|
||||
@ -324,53 +460,50 @@ class v8PoseLoss(v8DetectionLoss):
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# targets
|
||||
# Targets
|
||||
batch_size = pred_scores.shape[0]
|
||||
batch_idx = batch['batch_idx'].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
||||
# pboxes
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
||||
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# cls loss
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# bbox loss
|
||||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
|
||||
target_scores_sum, fg_mask)
|
||||
keypoints = batch['keypoints'].to(self.device).float().clone()
|
||||
loss[0], loss[4] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
keypoints = batch["keypoints"].to(self.device).float().clone()
|
||||
keypoints[..., 0] *= imgsz[1]
|
||||
keypoints[..., 1] *= imgsz[0]
|
||||
for i in range(batch_size):
|
||||
if fg_mask[i].sum():
|
||||
idx = target_gt_idx[i][fg_mask[i]]
|
||||
gt_kpt = keypoints[batch_idx.view(-1) == i][idx] # (n, 51)
|
||||
gt_kpt[..., 0] /= stride_tensor[fg_mask[i]]
|
||||
gt_kpt[..., 1] /= stride_tensor[fg_mask[i]]
|
||||
area = xyxy2xywh(target_bboxes[i][fg_mask[i]])[:, 2:].prod(1, keepdim=True)
|
||||
pred_kpt = pred_kpts[i][fg_mask[i]]
|
||||
kpt_mask = gt_kpt[..., 2] != 0
|
||||
loss[1] += self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
|
||||
# kpt_score loss
|
||||
if pred_kpt.shape[-1] == 3:
|
||||
loss[2] += self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
|
||||
|
||||
loss[1], loss[2] = self.calculate_keypoints_loss(
|
||||
fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
|
||||
)
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.pose / batch_size # pose gain
|
||||
loss[2] *= self.hyp.kobj / batch_size # kobj gain
|
||||
loss[1] *= self.hyp.pose # pose gain
|
||||
loss[2] *= self.hyp.kobj # kobj gain
|
||||
loss[3] *= self.hyp.cls # cls gain
|
||||
loss[4] *= self.hyp.dfl # dfl gain
|
||||
|
||||
@ -385,12 +518,210 @@ class v8PoseLoss(v8DetectionLoss):
|
||||
y[..., 1] += anchor_points[:, [1]] - 0.5
|
||||
return y
|
||||
|
||||
def calculate_keypoints_loss(
|
||||
self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
|
||||
):
|
||||
"""
|
||||
Calculate the keypoints loss for the model.
|
||||
|
||||
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
||||
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
||||
a binary classification loss that classifies whether a keypoint is present or not.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
||||
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
||||
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
||||
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
||||
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
||||
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
||||
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
||||
|
||||
Returns:
|
||||
(tuple): Returns a tuple containing:
|
||||
- kpts_loss (torch.Tensor): The keypoints loss.
|
||||
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
||||
"""
|
||||
batch_idx = batch_idx.flatten()
|
||||
batch_size = len(masks)
|
||||
|
||||
# Find the maximum number of keypoints in a single image
|
||||
max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
|
||||
|
||||
# Create a tensor to hold batched keypoints
|
||||
batched_keypoints = torch.zeros(
|
||||
(batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
|
||||
)
|
||||
|
||||
# TODO: any idea how to vectorize this?
|
||||
# Fill batched_keypoints with keypoints based on batch_idx
|
||||
for i in range(batch_size):
|
||||
keypoints_i = keypoints[batch_idx == i]
|
||||
batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
|
||||
|
||||
# Expand dimensions of target_gt_idx to match the shape of batched_keypoints
|
||||
target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# Use target_gt_idx_expanded to select keypoints from batched_keypoints
|
||||
selected_keypoints = batched_keypoints.gather(
|
||||
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
|
||||
)
|
||||
|
||||
# Divide coordinates by stride
|
||||
selected_keypoints /= stride_tensor.view(1, -1, 1, 1)
|
||||
|
||||
kpts_loss = 0
|
||||
kpts_obj_loss = 0
|
||||
|
||||
if masks.any():
|
||||
gt_kpt = selected_keypoints[masks]
|
||||
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
||||
pred_kpt = pred_kpts[masks]
|
||||
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
|
||||
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
|
||||
|
||||
if pred_kpt.shape[-1] == 3:
|
||||
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
|
||||
|
||||
return kpts_loss, kpts_obj_loss
|
||||
|
||||
|
||||
class v8ClassificationLoss:
|
||||
"""Criterion class for computing training losses."""
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / 64
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
|
||||
|
||||
class v8OBBLoss(v8DetectionLoss):
|
||||
def __init__(self, model):
|
||||
"""
|
||||
Initializes v8OBBLoss with model, assigner, and rotated bbox loss.
|
||||
|
||||
Note model must be de-paralleled.
|
||||
"""
|
||||
super().__init__(model)
|
||||
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)
|
||||
|
||||
def preprocess(self, targets, batch_size, scale_tensor):
|
||||
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
|
||||
if targets.shape[0] == 0:
|
||||
out = torch.zeros(batch_size, 0, 6, device=self.device)
|
||||
else:
|
||||
i = targets[:, 0] # image index
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
n = matches.sum()
|
||||
if n:
|
||||
bboxes = targets[matches, 2:]
|
||||
bboxes[..., :4].mul_(scale_tensor)
|
||||
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
||||
return out
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate and return the loss for the YOLO model."""
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
|
||||
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# b, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# targets
|
||||
try:
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
||||
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
|
||||
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
except RuntimeError as e:
|
||||
raise TypeError(
|
||||
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
|
||||
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
|
||||
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
|
||||
) from e
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
|
||||
|
||||
bboxes_for_assigner = pred_bboxes.clone().detach()
|
||||
# Only the first four elements need to be scaled
|
||||
bboxes_for_assigner[..., :4] *= stride_tensor
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
bboxes_for_assigner.type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes[..., :4] /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
else:
|
||||
loss[0] += (pred_angle * 0).sum()
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.cls # cls gain
|
||||
loss[2] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def bbox_decode(self, anchor_points, pred_dist, pred_angle):
|
||||
"""
|
||||
Decode predicted object bounding box coordinates from anchor points and distribution.
|
||||
|
||||
Args:
|
||||
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
||||
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
|
||||
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
|
||||
"""
|
||||
if self.use_dfl:
|
||||
b, a, c = pred_dist.shape # batch, anchors, channels
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
|
||||
|
||||
class v10DetectLoss:
|
||||
def __init__(self, model):
|
||||
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
||||
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
one2many = preds["one2many"]
|
||||
loss_one2many = self.one2many(one2many, batch)
|
||||
one2one = preds["one2one"]
|
||||
loss_one2one = self.one2one(one2one, batch)
|
||||
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -12,6 +12,7 @@ import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import batch_probiou
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
@ -22,22 +23,24 @@ class Profile(contextlib.ContextDecorator):
|
||||
```python
|
||||
from ultralytics.utils.ops import Profile
|
||||
|
||||
with Profile() as dt:
|
||||
with Profile(device=device) as dt:
|
||||
pass # slow operation here
|
||||
|
||||
print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, t=0.0):
|
||||
def __init__(self, t=0.0, device: torch.device = None):
|
||||
"""
|
||||
Initialize the Profile class.
|
||||
|
||||
Args:
|
||||
t (float): Initial time. Defaults to 0.0.
|
||||
device (torch.device): Devices used for model inference. Defaults to None (cpu).
|
||||
"""
|
||||
self.t = t
|
||||
self.cuda = torch.cuda.is_available()
|
||||
self.device = device
|
||||
self.cuda = bool(device and str(device).startswith("cuda"))
|
||||
|
||||
def __enter__(self):
|
||||
"""Start timing."""
|
||||
@ -50,12 +53,13 @@ class Profile(contextlib.ContextDecorator):
|
||||
self.t += self.dt # accumulate dt
|
||||
|
||||
def __str__(self):
|
||||
return f'Elapsed time is {self.t} s'
|
||||
"""Returns a human-readable string representing the accumulated elapsed time in the profiler."""
|
||||
return f"Elapsed time is {self.t} s"
|
||||
|
||||
def time(self):
|
||||
"""Get current time."""
|
||||
if self.cuda:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize(self.device)
|
||||
return time.time()
|
||||
|
||||
|
||||
@ -71,18 +75,21 @@ def segment2box(segment, width=640, height=640):
|
||||
Returns:
|
||||
(np.ndarray): the minimum and maximum x and y values of the segment.
|
||||
"""
|
||||
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||||
x, y = segment.T # segment xy
|
||||
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
||||
x, y, = x[inside], y[inside]
|
||||
return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
|
||||
4, dtype=segment.dtype) # xyxy
|
||||
x = x[inside]
|
||||
y = y[inside]
|
||||
return (
|
||||
np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
|
||||
if any(x)
|
||||
else np.zeros(4, dtype=segment.dtype)
|
||||
) # xyxy
|
||||
|
||||
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
|
||||
"""
|
||||
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
|
||||
(img1_shape) to the shape of a different image (img0_shape).
|
||||
Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
|
||||
specified in (img1_shape) to the shape of a different image (img0_shape).
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
|
||||
@ -92,24 +99,29 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
|
||||
calculated based on the size difference between the two images.
|
||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||
rescaling.
|
||||
xywh (bool): The box format is xywh or not, default=False.
|
||||
|
||||
Returns:
|
||||
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
|
||||
(img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding
|
||||
pad = (
|
||||
round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
|
||||
round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
|
||||
) # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
||||
if padding:
|
||||
boxes[..., [0, 2]] -= pad[0] # x padding
|
||||
boxes[..., [1, 3]] -= pad[1] # y padding
|
||||
boxes[..., 0] -= pad[0] # x padding
|
||||
boxes[..., 1] -= pad[1] # y padding
|
||||
if not xywh:
|
||||
boxes[..., 2] -= pad[0] # x padding
|
||||
boxes[..., 3] -= pad[1] # y padding
|
||||
boxes[..., :4] /= gain
|
||||
clip_boxes(boxes, img0_shape)
|
||||
return boxes
|
||||
return clip_boxes(boxes, img0_shape)
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
@ -128,19 +140,41 @@ def make_divisible(x, divisor):
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def nms_rotated(boxes, scores, threshold=0.45):
|
||||
"""
|
||||
NMS for obbs, powered by probiou and fast-nms.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): (N, 5), xywhr.
|
||||
scores (torch.Tensor): (N, ).
|
||||
threshold (float): IoU threshold.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return np.empty((0,), dtype=np.int8)
|
||||
sorted_idx = torch.argsort(scores, descending=True)
|
||||
boxes = boxes[sorted_idx]
|
||||
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
|
||||
pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
|
||||
return sorted_idx[pick]
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
classes=None,
|
||||
agnostic=False,
|
||||
multi_label=False,
|
||||
labels=(),
|
||||
max_det=300,
|
||||
nc=0, # number of classes (optional)
|
||||
max_time_img=0.05,
|
||||
max_nms=30000,
|
||||
max_wh=7680,
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
classes=None,
|
||||
agnostic=False,
|
||||
multi_label=False,
|
||||
labels=(),
|
||||
max_det=300,
|
||||
nc=0, # number of classes (optional)
|
||||
max_time_img=0.05,
|
||||
max_nms=30000,
|
||||
max_wh=7680,
|
||||
in_place=True,
|
||||
rotated=False,
|
||||
):
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
||||
@ -164,7 +198,8 @@ def non_max_suppression(
|
||||
nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
|
||||
max_time_img (float): The maximum time (seconds) for processing one image.
|
||||
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
|
||||
max_wh (int): The maximum box width and height in pixels
|
||||
max_wh (int): The maximum box width and height in pixels.
|
||||
in_place (bool): If True, the input prediction tensor will be modified in place.
|
||||
|
||||
Returns:
|
||||
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
|
||||
@ -173,15 +208,11 @@ def non_max_suppression(
|
||||
"""
|
||||
|
||||
# Checks
|
||||
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
||||
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
||||
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
||||
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
|
||||
device = prediction.device
|
||||
mps = 'mps' in device.type # Apple MPS
|
||||
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
||||
prediction = prediction.cpu()
|
||||
bs = prediction.shape[0] # batch size
|
||||
nc = nc or (prediction.shape[1] - 4) # number of classes
|
||||
nm = prediction.shape[1] - nc - 4
|
||||
@ -190,11 +221,15 @@ def non_max_suppression(
|
||||
|
||||
# Settings
|
||||
# min_wh = 2 # (pixels) minimum box width and height
|
||||
time_limit = 0.5 + max_time_img * bs # seconds to quit after
|
||||
time_limit = 2.0 + max_time_img * bs # seconds to quit after
|
||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
|
||||
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
if not rotated:
|
||||
if in_place:
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
else:
|
||||
prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy
|
||||
|
||||
t = time.time()
|
||||
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
||||
@ -204,7 +239,7 @@ def non_max_suppression(
|
||||
x = x[xc[xi]] # confidence
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]):
|
||||
if labels and len(labels[xi]) and not rotated:
|
||||
lb = labels[xi]
|
||||
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
|
||||
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
|
||||
@ -238,8 +273,13 @@ def non_max_suppression(
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
scores = x[:, 4] # scores
|
||||
if rotated:
|
||||
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
|
||||
i = nms_rotated(boxes, scores, iou_thres)
|
||||
else:
|
||||
boxes = x[:, :4] + c # boxes (offset by class)
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
i = i[:max_det] # limit detections
|
||||
|
||||
# # Experimental
|
||||
@ -247,7 +287,7 @@ def non_max_suppression(
|
||||
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||
# # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||
# from .metrics import box_iou
|
||||
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
||||
# iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix
|
||||
# weights = iou * scores[None] # box weights
|
||||
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||
# redundant = True # require redundant detections
|
||||
@ -255,10 +295,8 @@ def non_max_suppression(
|
||||
# i = i[iou.sum(1) > 1] # require redundancy
|
||||
|
||||
output[xi] = x[i]
|
||||
if mps:
|
||||
output[xi] = output[xi].to(device)
|
||||
if (time.time() - t) > time_limit:
|
||||
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
||||
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
|
||||
break # time limit exceeded
|
||||
|
||||
return output
|
||||
@ -269,17 +307,21 @@ def clip_boxes(boxes, shape):
|
||||
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): the bounding boxes to clip
|
||||
shape (tuple): the shape of the image
|
||||
boxes (torch.Tensor): the bounding boxes to clip
|
||||
shape (tuple): the shape of the image
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | numpy.ndarray): Clipped boxes
|
||||
"""
|
||||
if isinstance(boxes, torch.Tensor): # faster individually
|
||||
boxes[..., 0].clamp_(0, shape[1]) # x1
|
||||
boxes[..., 1].clamp_(0, shape[0]) # y1
|
||||
boxes[..., 2].clamp_(0, shape[1]) # x2
|
||||
boxes[..., 3].clamp_(0, shape[0]) # y2
|
||||
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
||||
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
|
||||
boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1
|
||||
boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2
|
||||
boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2
|
||||
else: # np.array (faster grouped)
|
||||
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
|
||||
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
|
||||
return boxes
|
||||
|
||||
|
||||
def clip_coords(coords, shape):
|
||||
@ -291,19 +333,20 @@ def clip_coords(coords, shape):
|
||||
shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
|
||||
|
||||
Returns:
|
||||
(None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries.
|
||||
(torch.Tensor | numpy.ndarray): Clipped coordinates
|
||||
"""
|
||||
if isinstance(coords, torch.Tensor): # faster individually
|
||||
coords[..., 0].clamp_(0, shape[1]) # x
|
||||
coords[..., 1].clamp_(0, shape[0]) # y
|
||||
if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
||||
coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
|
||||
coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y
|
||||
else: # np.array (faster grouped)
|
||||
coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
|
||||
coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
|
||||
return coords
|
||||
|
||||
|
||||
def scale_image(masks, im0_shape, ratio_pad=None):
|
||||
"""
|
||||
Takes a mask, and resizes it to the original image size
|
||||
Takes a mask, and resizes it to the original image size.
|
||||
|
||||
Args:
|
||||
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||||
@ -321,7 +364,7 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
||||
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
|
||||
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
# gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
top, left = int(pad[1]), int(pad[0]) # y, x
|
||||
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
|
||||
@ -347,7 +390,7 @@ def xyxy2xywh(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
||||
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
||||
@ -367,7 +410,7 @@ def xywh2xyxy(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
dw = x[..., 2] / 2 # half-width
|
||||
dh = x[..., 3] / 2 # half-height
|
||||
@ -392,7 +435,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
||||
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
||||
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
||||
@ -403,8 +446,8 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
|
||||
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
|
||||
x, y, width and height are normalized to image dimensions
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
|
||||
width and height are normalized to image dimensions.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
@ -417,8 +460,8 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
|
||||
"""
|
||||
if clip:
|
||||
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
x = clip_boxes(x, (h - eps, w - eps))
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
||||
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
||||
@ -445,7 +488,7 @@ def xywh2ltwh(x):
|
||||
|
||||
def xyxy2ltwh(x):
|
||||
"""
|
||||
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
|
||||
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
|
||||
@ -461,7 +504,7 @@ def xyxy2ltwh(x):
|
||||
|
||||
def ltwh2xywh(x):
|
||||
"""
|
||||
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
|
||||
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): the input tensor
|
||||
@ -477,7 +520,8 @@ def ltwh2xywh(x):
|
||||
|
||||
def xyxyxyxy2xywhr(corners):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation].
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
|
||||
expected in degrees from 0 to 90.
|
||||
|
||||
Args:
|
||||
corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8).
|
||||
@ -485,66 +529,53 @@ def xyxyxyxy2xywhr(corners):
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
"""
|
||||
is_numpy = isinstance(corners, np.ndarray)
|
||||
atan2, sqrt = (np.arctan2, np.sqrt) if is_numpy else (torch.atan2, torch.sqrt)
|
||||
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
|
||||
cx = (x1 + x3) / 2
|
||||
cy = (y1 + y3) / 2
|
||||
dx21 = x2 - x1
|
||||
dy21 = y2 - y1
|
||||
|
||||
w = sqrt(dx21 ** 2 + dy21 ** 2)
|
||||
h = sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
|
||||
|
||||
rotation = atan2(-dy21, dx21)
|
||||
rotation *= 180.0 / math.pi # radians to degrees
|
||||
|
||||
return np.vstack((cx, cy, w, h, rotation)).T if is_numpy else torch.stack((cx, cy, w, h, rotation), dim=1)
|
||||
is_torch = isinstance(corners, torch.Tensor)
|
||||
points = corners.cpu().numpy() if is_torch else corners
|
||||
points = points.reshape(len(corners), -1, 2)
|
||||
rboxes = []
|
||||
for pts in points:
|
||||
# NOTE: Use cv2.minAreaRect to get accurate xywhr,
|
||||
# especially some objects are cut off by augmentations in dataloader.
|
||||
(x, y), (w, h), angle = cv2.minAreaRect(pts)
|
||||
rboxes.append([x, y, w, h, angle / 180 * np.pi])
|
||||
return (
|
||||
torch.tensor(rboxes, device=corners.device, dtype=corners.dtype)
|
||||
if is_torch
|
||||
else np.asarray(rboxes, dtype=points.dtype)
|
||||
) # rboxes
|
||||
|
||||
|
||||
def xywhr2xyxyxyxy(center):
|
||||
def xywhr2xyxyxyxy(rboxes):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4].
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
|
||||
be in degrees from 0 to 90.
|
||||
|
||||
Args:
|
||||
center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
rboxes (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
|
||||
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
|
||||
"""
|
||||
is_numpy = isinstance(center, np.ndarray)
|
||||
is_numpy = isinstance(rboxes, np.ndarray)
|
||||
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
|
||||
|
||||
cx, cy, w, h, rotation = center.T
|
||||
rotation *= math.pi / 180.0 # degrees to radians
|
||||
|
||||
dx = w / 2
|
||||
dy = h / 2
|
||||
|
||||
cos_rot = cos(rotation)
|
||||
sin_rot = sin(rotation)
|
||||
dx_cos_rot = dx * cos_rot
|
||||
dx_sin_rot = dx * sin_rot
|
||||
dy_cos_rot = dy * cos_rot
|
||||
dy_sin_rot = dy * sin_rot
|
||||
|
||||
x1 = cx - dx_cos_rot - dy_sin_rot
|
||||
y1 = cy + dx_sin_rot - dy_cos_rot
|
||||
x2 = cx + dx_cos_rot - dy_sin_rot
|
||||
y2 = cy - dx_sin_rot - dy_cos_rot
|
||||
x3 = cx + dx_cos_rot + dy_sin_rot
|
||||
y3 = cy - dx_sin_rot + dy_cos_rot
|
||||
x4 = cx - dx_cos_rot + dy_sin_rot
|
||||
y4 = cy + dx_sin_rot + dy_cos_rot
|
||||
|
||||
return np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).T if is_numpy else torch.stack(
|
||||
(x1, y1, x2, y2, x3, y3, x4, y4), dim=1)
|
||||
ctr = rboxes[..., :2]
|
||||
w, h, angle = (rboxes[..., i : i + 1] for i in range(2, 5))
|
||||
cos_value, sin_value = cos(angle), sin(angle)
|
||||
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
|
||||
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
|
||||
vec1 = np.concatenate(vec1, axis=-1) if is_numpy else torch.cat(vec1, dim=-1)
|
||||
vec2 = np.concatenate(vec2, axis=-1) if is_numpy else torch.cat(vec2, dim=-1)
|
||||
pt1 = ctr + vec1 + vec2
|
||||
pt2 = ctr + vec1 - vec2
|
||||
pt3 = ctr - vec1 - vec2
|
||||
pt4 = ctr - vec1 + vec2
|
||||
return np.stack([pt1, pt2, pt3, pt4], axis=-2) if is_numpy else torch.stack([pt1, pt2, pt3, pt4], dim=-2)
|
||||
|
||||
|
||||
def ltwh2xyxy(x):
|
||||
"""
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): the input image
|
||||
@ -590,8 +621,9 @@ def resample_segments(segments, n=1000):
|
||||
s = np.concatenate((s, s[0:1, :]), axis=0)
|
||||
x = np.linspace(0, len(s) - 1, n)
|
||||
xp = np.arange(len(s))
|
||||
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
|
||||
dtype=np.float32).reshape(2, -1).T # segment xy
|
||||
segments[i] = (
|
||||
np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
|
||||
) # segment xy
|
||||
return segments
|
||||
|
||||
|
||||
@ -606,7 +638,7 @@ def crop_mask(masks, boxes):
|
||||
Returns:
|
||||
(torch.Tensor): The masks are being cropped to the bounding box.
|
||||
"""
|
||||
n, h, w = masks.shape
|
||||
_, h, w = masks.shape
|
||||
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
|
||||
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
|
||||
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
|
||||
@ -616,8 +648,8 @@ def crop_mask(masks, boxes):
|
||||
|
||||
def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
|
||||
quality but is slower.
|
||||
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
|
||||
but is slower.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
@ -630,7 +662,7 @@ def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
@ -654,16 +686,18 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
||||
c, mh, mw = protos.shape # CHW
|
||||
ih, iw = shape
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
|
||||
width_ratio = mw / iw
|
||||
height_ratio = mh / ih
|
||||
|
||||
downsampled_bboxes = bboxes.clone()
|
||||
downsampled_bboxes[:, 0] *= mw / iw
|
||||
downsampled_bboxes[:, 2] *= mw / iw
|
||||
downsampled_bboxes[:, 3] *= mh / ih
|
||||
downsampled_bboxes[:, 1] *= mh / ih
|
||||
downsampled_bboxes[:, 0] *= width_ratio
|
||||
downsampled_bboxes[:, 2] *= width_ratio
|
||||
downsampled_bboxes[:, 3] *= height_ratio
|
||||
downsampled_bboxes[:, 1] *= height_ratio
|
||||
|
||||
masks = crop_mask(masks, downsampled_bboxes) # CHW
|
||||
if upsample:
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
@ -707,13 +741,13 @@ def scale_masks(masks, shape, padding=True):
|
||||
bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
|
||||
masks = masks[..., top:bottom, left:right]
|
||||
|
||||
masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
|
||||
masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
|
||||
return masks
|
||||
|
||||
|
||||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
|
||||
"""
|
||||
Rescale segment coordinates (xy) from img1_shape to img0_shape
|
||||
Rescale segment coordinates (xy) from img1_shape to img0_shape.
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the coords are from.
|
||||
@ -739,14 +773,32 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
|
||||
coords[..., 1] -= pad[1] # y padding
|
||||
coords[..., 0] /= gain
|
||||
coords[..., 1] /= gain
|
||||
clip_coords(coords, img0_shape)
|
||||
coords = clip_coords(coords, img0_shape)
|
||||
if normalize:
|
||||
coords[..., 0] /= img0_shape[1] # width
|
||||
coords[..., 1] /= img0_shape[0] # height
|
||||
return coords
|
||||
|
||||
|
||||
def masks2segments(masks, strategy='largest'):
|
||||
def regularize_rboxes(rboxes):
|
||||
"""
|
||||
Regularize rotated boxes in range [0, pi/2].
|
||||
|
||||
Args:
|
||||
rboxes (torch.Tensor): (N, 5), xywhr.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The regularized boxes.
|
||||
"""
|
||||
x, y, w, h, t = rboxes.unbind(dim=-1)
|
||||
# Swap edge and angle if h >= w
|
||||
w_ = torch.where(w > h, w, h)
|
||||
h_ = torch.where(w > h, h, w)
|
||||
t = torch.where(w > h, t, t + math.pi / 2) % math.pi
|
||||
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
|
||||
|
||||
|
||||
def masks2segments(masks, strategy="largest"):
|
||||
"""
|
||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
|
||||
|
||||
@ -758,16 +810,16 @@ def masks2segments(masks, strategy='largest'):
|
||||
segments (List): list of segment masks
|
||||
"""
|
||||
segments = []
|
||||
for x in masks.int().cpu().numpy().astype('uint8'):
|
||||
for x in masks.int().cpu().numpy().astype("uint8"):
|
||||
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
||||
if c:
|
||||
if strategy == 'concat': # concatenate all segments
|
||||
if strategy == "concat": # concatenate all segments
|
||||
c = np.concatenate([x.reshape(-1, 2) for x in c])
|
||||
elif strategy == 'largest': # select largest segment
|
||||
elif strategy == "largest": # select largest segment
|
||||
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
||||
else:
|
||||
c = np.zeros((0, 2)) # no segments found
|
||||
segments.append(c.astype('float32'))
|
||||
segments.append(c.astype("float32"))
|
||||
return segments
|
||||
|
||||
|
||||
@ -794,4 +846,19 @@ def clean_str(s):
|
||||
Returns:
|
||||
(str): a string with special characters replaced by an underscore _
|
||||
"""
|
||||
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
|
||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||
|
||||
def v10postprocess(preds, max_det, nc=80):
|
||||
assert(4 + nc == preds.shape[-1])
|
||||
boxes, scores = preds.split([4, nc], dim=-1)
|
||||
max_scores = scores.amax(dim=-1)
|
||||
max_scores, index = torch.topk(max_scores, max_det, dim=-1)
|
||||
index = index.unsqueeze(-1)
|
||||
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
|
||||
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
|
||||
|
||||
scores, index = torch.topk(scores.flatten(1), max_det, dim=-1)
|
||||
labels = index % nc
|
||||
index = index // nc
|
||||
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
||||
return boxes, scores, labels
|
@ -1,8 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Monkey patches to update/extend functionality of existing functions
|
||||
"""
|
||||
"""Monkey patches to update/extend functionality of existing functions."""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
@ -14,7 +13,8 @@ _imshow = cv2.imshow # copy to avoid recursion errors
|
||||
|
||||
|
||||
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
||||
"""Read an image from a file.
|
||||
"""
|
||||
Read an image from a file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file to read.
|
||||
@ -27,7 +27,8 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
||||
|
||||
|
||||
def imwrite(filename: str, img: np.ndarray, params=None):
|
||||
"""Write an image to a file.
|
||||
"""
|
||||
Write an image to a file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file to write.
|
||||
@ -45,31 +46,43 @@ def imwrite(filename: str, img: np.ndarray, params=None):
|
||||
|
||||
|
||||
def imshow(winname: str, mat: np.ndarray):
|
||||
"""Displays an image in the specified window.
|
||||
"""
|
||||
Displays an image in the specified window.
|
||||
|
||||
Args:
|
||||
winname (str): Name of the window.
|
||||
mat (np.ndarray): Image to be shown.
|
||||
"""
|
||||
_imshow(winname.encode('unicode_escape').decode(), mat)
|
||||
_imshow(winname.encode("unicode_escape").decode(), mat)
|
||||
|
||||
|
||||
# PyTorch functions ----------------------------------------------------------------------------------------------------
|
||||
_torch_save = torch.save # copy to avoid recursion errors
|
||||
|
||||
|
||||
def torch_save(*args, **kwargs):
|
||||
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this.
|
||||
def torch_save(*args, use_dill=True, **kwargs):
|
||||
"""
|
||||
Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and
|
||||
exponential standoff in case of save failure.
|
||||
|
||||
Args:
|
||||
*args (tuple): Positional arguments to pass to torch.save.
|
||||
**kwargs (dict): Keyword arguments to pass to torch.save.
|
||||
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
|
||||
**kwargs (any): Keyword arguments to pass to torch.save.
|
||||
"""
|
||||
try:
|
||||
import dill as pickle # noqa
|
||||
except ImportError:
|
||||
assert use_dill
|
||||
import dill as pickle
|
||||
except (AssertionError, ImportError):
|
||||
import pickle
|
||||
|
||||
if 'pickle_module' not in kwargs:
|
||||
kwargs['pickle_module'] = pickle # noqa
|
||||
return _torch_save(*args, **kwargs)
|
||||
if "pickle_module" not in kwargs:
|
||||
kwargs["pickle_module"] = pickle
|
||||
|
||||
for i in range(4): # 3 retries
|
||||
try:
|
||||
return _torch_save(*args, **kwargs)
|
||||
except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan
|
||||
if i == 3:
|
||||
raise e
|
||||
time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s
|
||||
|
@ -13,7 +13,6 @@ from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL import __version__ as pil_version
|
||||
|
||||
from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded
|
||||
|
||||
from .checks import check_font, check_version, is_ascii
|
||||
from .files import increment_path
|
||||
|
||||
@ -28,20 +27,60 @@ class Colors:
|
||||
Attributes:
|
||||
palette (list of tuple): List of RGB color values.
|
||||
n (int): The number of colors in the palette.
|
||||
pose_palette (np.array): A specific color palette array with dtype np.uint8.
|
||||
pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
|
||||
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
||||
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
||||
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
||||
hexs = (
|
||||
"FF3838",
|
||||
"FF9D97",
|
||||
"FF701F",
|
||||
"FFB21D",
|
||||
"CFD231",
|
||||
"48F90A",
|
||||
"92CC17",
|
||||
"3DDB86",
|
||||
"1A9334",
|
||||
"00D4BB",
|
||||
"2C99A8",
|
||||
"00C2FF",
|
||||
"344593",
|
||||
"6473FF",
|
||||
"0018EC",
|
||||
"8438FF",
|
||||
"520085",
|
||||
"CB38FF",
|
||||
"FF95C8",
|
||||
"FF37C7",
|
||||
)
|
||||
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
|
||||
self.n = len(self.palette)
|
||||
self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255],
|
||||
[153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255],
|
||||
[255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102],
|
||||
[51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]],
|
||||
dtype=np.uint8)
|
||||
self.pose_palette = np.array(
|
||||
[
|
||||
[255, 128, 0],
|
||||
[255, 153, 51],
|
||||
[255, 178, 102],
|
||||
[230, 230, 0],
|
||||
[255, 153, 255],
|
||||
[153, 204, 255],
|
||||
[255, 102, 255],
|
||||
[255, 51, 255],
|
||||
[102, 178, 255],
|
||||
[51, 153, 255],
|
||||
[255, 153, 153],
|
||||
[255, 102, 102],
|
||||
[255, 51, 51],
|
||||
[153, 255, 153],
|
||||
[102, 255, 102],
|
||||
[51, 255, 51],
|
||||
[0, 255, 0],
|
||||
[0, 0, 255],
|
||||
[255, 0, 0],
|
||||
[255, 255, 255],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
def __call__(self, i, bgr=False):
|
||||
"""Converts hex color codes to RGB values."""
|
||||
@ -51,7 +90,7 @@ class Colors:
|
||||
@staticmethod
|
||||
def hex2rgb(h):
|
||||
"""Converts hex color codes to RGB values (i.e. default PIL order)."""
|
||||
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
||||
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
|
||||
|
||||
|
||||
colors = Colors() # create instance for 'from utils.plots import colors'
|
||||
@ -71,65 +110,99 @@ class Annotator:
|
||||
kpt_color (List[int]): Color palette for keypoints.
|
||||
"""
|
||||
|
||||
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
|
||||
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
||||
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
||||
self.pil = pil or non_ascii
|
||||
input_is_pil = isinstance(im, Image.Image)
|
||||
self.pil = pil or non_ascii or input_is_pil
|
||||
self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
|
||||
if self.pil: # use PIL
|
||||
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
||||
self.im = im if input_is_pil else Image.fromarray(im)
|
||||
self.draw = ImageDraw.Draw(self.im)
|
||||
try:
|
||||
font = check_font('Arial.Unicode.ttf' if non_ascii else font)
|
||||
font = check_font("Arial.Unicode.ttf" if non_ascii else font)
|
||||
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
|
||||
self.font = ImageFont.truetype(str(font), size)
|
||||
except Exception:
|
||||
self.font = ImageFont.load_default()
|
||||
# Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
|
||||
if check_version(pil_version, '9.2.0'):
|
||||
if check_version(pil_version, "9.2.0"):
|
||||
self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
|
||||
else: # use cv2
|
||||
self.im = im
|
||||
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
||||
assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
|
||||
self.im = im if im.flags.writeable else im.copy()
|
||||
self.tf = max(self.lw - 1, 1) # font thickness
|
||||
self.sf = self.lw / 3 # font scale
|
||||
# Pose
|
||||
self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
|
||||
[8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
|
||||
self.skeleton = [
|
||||
[16, 14],
|
||||
[14, 12],
|
||||
[17, 15],
|
||||
[15, 13],
|
||||
[12, 13],
|
||||
[6, 12],
|
||||
[7, 13],
|
||||
[6, 7],
|
||||
[6, 8],
|
||||
[7, 9],
|
||||
[8, 10],
|
||||
[9, 11],
|
||||
[2, 3],
|
||||
[1, 2],
|
||||
[1, 3],
|
||||
[2, 4],
|
||||
[3, 5],
|
||||
[4, 6],
|
||||
[5, 7],
|
||||
]
|
||||
|
||||
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
|
||||
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
|
||||
|
||||
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
||||
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
|
||||
"""Add one xyxy box to image with label."""
|
||||
if isinstance(box, torch.Tensor):
|
||||
box = box.tolist()
|
||||
if self.pil or not is_ascii(label):
|
||||
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
||||
if rotated:
|
||||
p1 = box[0]
|
||||
# NOTE: PIL-version polygon needs tuple type.
|
||||
self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color)
|
||||
else:
|
||||
p1 = (box[0], box[1])
|
||||
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
||||
if label:
|
||||
w, h = self.font.getsize(label) # text width, height
|
||||
outside = box[1] - h >= 0 # label fits outside box
|
||||
outside = p1[1] - h >= 0 # label fits outside box
|
||||
self.draw.rectangle(
|
||||
(box[0], box[1] - h if outside else box[1], box[0] + w + 1,
|
||||
box[1] + 1 if outside else box[1] + h + 1),
|
||||
(p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
|
||||
fill=color,
|
||||
)
|
||||
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
||||
self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
|
||||
self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
|
||||
else: # cv2
|
||||
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
||||
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
||||
if rotated:
|
||||
p1 = [int(b) for b in box[0]]
|
||||
# NOTE: cv2-version polylines needs np.asarray type.
|
||||
cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw)
|
||||
else:
|
||||
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
||||
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
||||
if label:
|
||||
tf = max(self.lw - 1, 1) # font thickness
|
||||
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
||||
w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
|
||||
outside = p1[1] - h >= 3
|
||||
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
||||
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
||||
cv2.putText(self.im,
|
||||
label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
||||
0,
|
||||
self.lw / 3,
|
||||
txt_color,
|
||||
thickness=tf,
|
||||
lineType=cv2.LINE_AA)
|
||||
cv2.putText(
|
||||
self.im,
|
||||
label,
|
||||
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
||||
0,
|
||||
self.sf,
|
||||
txt_color,
|
||||
thickness=self.tf,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
|
||||
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
|
||||
"""
|
||||
@ -154,13 +227,13 @@ class Annotator:
|
||||
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
||||
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
||||
|
||||
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
||||
inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
||||
mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
|
||||
|
||||
im_gpu = im_gpu.flip(dims=[0]) # flip channel
|
||||
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
||||
im_gpu = im_gpu * inv_alph_masks[-1] + mcs
|
||||
im_mask = (im_gpu * 255)
|
||||
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
|
||||
im_mask = im_gpu * 255
|
||||
im_mask_np = im_mask.byte().cpu().numpy()
|
||||
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
|
||||
if self.pil:
|
||||
@ -178,13 +251,14 @@ class Annotator:
|
||||
kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
|
||||
for human pose. Default is True.
|
||||
|
||||
Note: `kpt_line=True` currently only supports human pose plotting.
|
||||
Note:
|
||||
`kpt_line=True` currently only supports human pose plotting.
|
||||
"""
|
||||
if self.pil:
|
||||
# Convert to numpy first
|
||||
self.im = np.asarray(self.im).copy()
|
||||
nkpt, ndim = kpts.shape
|
||||
is_pose = nkpt == 17 and ndim == 3
|
||||
is_pose = nkpt == 17 and ndim in {2, 3}
|
||||
kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
|
||||
for i, k in enumerate(kpts):
|
||||
color_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i)
|
||||
@ -219,9 +293,9 @@ class Annotator:
|
||||
"""Add rectangle to image (PIL-only)."""
|
||||
self.draw.rectangle(xy, fill, outline, width)
|
||||
|
||||
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False):
|
||||
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
|
||||
"""Adds text to an image using PIL or cv2."""
|
||||
if anchor == 'bottom': # start y from font bottom
|
||||
if anchor == "bottom": # start y from font bottom
|
||||
w, h = self.font.getsize(text) # text width, height
|
||||
xy[1] += 1 - h
|
||||
if self.pil:
|
||||
@ -230,8 +304,8 @@ class Annotator:
|
||||
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
|
||||
# Using `txt_color` for background and draw fg with white color
|
||||
txt_color = (255, 255, 255)
|
||||
if '\n' in text:
|
||||
lines = text.split('\n')
|
||||
if "\n" in text:
|
||||
lines = text.split("\n")
|
||||
_, h = self.font.getsize(text)
|
||||
for line in lines:
|
||||
self.draw.text(xy, line, fill=txt_color, font=self.font)
|
||||
@ -240,15 +314,13 @@ class Annotator:
|
||||
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
||||
else:
|
||||
if box_style:
|
||||
tf = max(self.lw - 1, 1) # font thickness
|
||||
w, h = cv2.getTextSize(text, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
||||
w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
|
||||
outside = xy[1] - h >= 3
|
||||
p2 = xy[0] + w, xy[1] - h - 3 if outside else xy[1] + h + 3
|
||||
cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled
|
||||
# Using `txt_color` for background and draw fg with white color
|
||||
txt_color = (255, 255, 255)
|
||||
tf = max(self.lw - 1, 1) # font thickness
|
||||
cv2.putText(self.im, text, xy, 0, self.lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)
|
||||
cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
|
||||
|
||||
def fromarray(self, im):
|
||||
"""Update self.im from a numpy array."""
|
||||
@ -259,27 +331,289 @@ class Annotator:
|
||||
"""Return annotated image as array."""
|
||||
return np.asarray(self.im)
|
||||
|
||||
def show(self, title=None):
|
||||
"""Show the annotated image."""
|
||||
Image.fromarray(np.asarray(self.im)[..., ::-1]).show(title)
|
||||
|
||||
def save(self, filename="image.jpg"):
|
||||
"""Save the annotated image to 'filename'."""
|
||||
cv2.imwrite(filename, np.asarray(self.im))
|
||||
|
||||
def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
|
||||
"""
|
||||
Draw region line.
|
||||
|
||||
Args:
|
||||
reg_pts (list): Region Points (for line 2 points, for region 4 points)
|
||||
color (tuple): Region Color value
|
||||
thickness (int): Region area thickness value
|
||||
"""
|
||||
cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
|
||||
|
||||
def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
|
||||
"""
|
||||
Draw centroid point and track trails.
|
||||
|
||||
Args:
|
||||
track (list): object tracking points for trails display
|
||||
color (tuple): tracks line color
|
||||
track_thickness (int): track line thickness value
|
||||
"""
|
||||
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
|
||||
cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
|
||||
|
||||
def count_labels(self, counts=0, count_txt_size=2, color=(255, 255, 255), txt_color=(0, 0, 0)):
|
||||
"""
|
||||
Plot counts for object counter.
|
||||
|
||||
Args:
|
||||
counts (int): objects counts value
|
||||
count_txt_size (int): text size for counts display
|
||||
color (tuple): background color of counts display
|
||||
txt_color (tuple): text color of counts display
|
||||
"""
|
||||
self.tf = count_txt_size
|
||||
tl = self.tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
|
||||
tf = max(tl - 1, 1)
|
||||
|
||||
# Get text size for in_count and out_count
|
||||
t_size_in = cv2.getTextSize(str(counts), 0, fontScale=tl / 2, thickness=tf)[0]
|
||||
|
||||
# Calculate positions for counts label
|
||||
text_width = t_size_in[0]
|
||||
text_x = (self.im.shape[1] - text_width) // 2 # Center x-coordinate
|
||||
text_y = t_size_in[1]
|
||||
|
||||
# Create a rounded rectangle for in_count
|
||||
cv2.rectangle(
|
||||
self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, -1
|
||||
)
|
||||
cv2.putText(
|
||||
self.im, str(counts), (text_x, text_y + t_size_in[1]), 0, tl / 2, txt_color, self.tf, lineType=cv2.LINE_AA
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def estimate_pose_angle(a, b, c):
|
||||
"""
|
||||
Calculate the pose angle for object.
|
||||
|
||||
Args:
|
||||
a (float) : The value of pose point a
|
||||
b (float): The value of pose point b
|
||||
c (float): The value o pose point c
|
||||
|
||||
Returns:
|
||||
angle (degree): Degree value of angle between three points
|
||||
"""
|
||||
a, b, c = np.array(a), np.array(b), np.array(c)
|
||||
radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
|
||||
angle = np.abs(radians * 180.0 / np.pi)
|
||||
if angle > 180.0:
|
||||
angle = 360 - angle
|
||||
return angle
|
||||
|
||||
def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2):
|
||||
"""
|
||||
Draw specific keypoints for gym steps counting.
|
||||
|
||||
Args:
|
||||
keypoints (list): list of keypoints data to be plotted
|
||||
indices (list): keypoints ids list to be plotted
|
||||
shape (tuple): imgsz for model inference
|
||||
radius (int): Keypoint radius value
|
||||
"""
|
||||
for i, k in enumerate(keypoints):
|
||||
if i in indices:
|
||||
x_coord, y_coord = k[0], k[1]
|
||||
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
|
||||
if len(k) == 3:
|
||||
conf = k[2]
|
||||
if conf < 0.5:
|
||||
continue
|
||||
cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA)
|
||||
return self.im
|
||||
|
||||
def plot_angle_and_count_and_stage(self, angle_text, count_text, stage_text, center_kpt, line_thickness=2):
|
||||
"""
|
||||
Plot the pose angle, count value and step stage.
|
||||
|
||||
Args:
|
||||
angle_text (str): angle value for workout monitoring
|
||||
count_text (str): counts value for workout monitoring
|
||||
stage_text (str): stage decision for workout monitoring
|
||||
center_kpt (int): centroid pose index for workout monitoring
|
||||
line_thickness (int): thickness for text display
|
||||
"""
|
||||
angle_text, count_text, stage_text = (f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}")
|
||||
font_scale = 0.6 + (line_thickness / 10.0)
|
||||
|
||||
# Draw angle
|
||||
(angle_text_width, angle_text_height), _ = cv2.getTextSize(angle_text, 0, font_scale, line_thickness)
|
||||
angle_text_position = (int(center_kpt[0]), int(center_kpt[1]))
|
||||
angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5)
|
||||
angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (line_thickness * 2))
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
angle_background_position,
|
||||
(
|
||||
angle_background_position[0] + angle_background_size[0],
|
||||
angle_background_position[1] + angle_background_size[1],
|
||||
),
|
||||
(255, 255, 255),
|
||||
-1,
|
||||
)
|
||||
cv2.putText(self.im, angle_text, angle_text_position, 0, font_scale, (0, 0, 0), line_thickness)
|
||||
|
||||
# Draw Counts
|
||||
(count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, font_scale, line_thickness)
|
||||
count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20)
|
||||
count_background_position = (
|
||||
angle_background_position[0],
|
||||
angle_background_position[1] + angle_background_size[1] + 5,
|
||||
)
|
||||
count_background_size = (count_text_width + 10, count_text_height + 10 + (line_thickness * 2))
|
||||
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
count_background_position,
|
||||
(
|
||||
count_background_position[0] + count_background_size[0],
|
||||
count_background_position[1] + count_background_size[1],
|
||||
),
|
||||
(255, 255, 255),
|
||||
-1,
|
||||
)
|
||||
cv2.putText(self.im, count_text, count_text_position, 0, font_scale, (0, 0, 0), line_thickness)
|
||||
|
||||
# Draw Stage
|
||||
(stage_text_width, stage_text_height), _ = cv2.getTextSize(stage_text, 0, font_scale, line_thickness)
|
||||
stage_text_position = (int(center_kpt[0]), int(center_kpt[1]) + angle_text_height + count_text_height + 40)
|
||||
stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5)
|
||||
stage_background_size = (stage_text_width + 10, stage_text_height + 10)
|
||||
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
stage_background_position,
|
||||
(
|
||||
stage_background_position[0] + stage_background_size[0],
|
||||
stage_background_position[1] + stage_background_size[1],
|
||||
),
|
||||
(255, 255, 255),
|
||||
-1,
|
||||
)
|
||||
cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness)
|
||||
|
||||
def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None):
|
||||
"""
|
||||
Function for drawing segmented object in bounding box shape.
|
||||
|
||||
Args:
|
||||
mask (list): masks data list for instance segmentation area plotting
|
||||
mask_color (tuple): mask foreground color
|
||||
det_label (str): Detection label text
|
||||
track_label (str): Tracking label text
|
||||
"""
|
||||
cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
|
||||
|
||||
label = f"Track ID: {track_label}" if track_label else det_label
|
||||
text_size, _ = cv2.getTextSize(label, 0, 0.7, 1)
|
||||
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
(int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
|
||||
(int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)),
|
||||
mask_color,
|
||||
-1,
|
||||
)
|
||||
|
||||
cv2.putText(
|
||||
self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), 2
|
||||
)
|
||||
|
||||
def plot_distance_and_line(self, distance_m, distance_mm, centroids, line_color, centroid_color):
|
||||
"""
|
||||
Plot the distance and line on frame.
|
||||
|
||||
Args:
|
||||
distance_m (float): Distance between two bbox centroids in meters.
|
||||
distance_mm (float): Distance between two bbox centroids in millimeters.
|
||||
centroids (list): Bounding box centroids data.
|
||||
line_color (RGB): Distance line color.
|
||||
centroid_color (RGB): Bounding box centroid color.
|
||||
"""
|
||||
(text_width_m, text_height_m), _ = cv2.getTextSize(
|
||||
f"Distance M: {distance_m:.2f}m", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
|
||||
)
|
||||
cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), (255, 255, 255), -1)
|
||||
cv2.putText(
|
||||
self.im,
|
||||
f"Distance M: {distance_m:.2f}m",
|
||||
(20, 50),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.8,
|
||||
(0, 0, 0),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
(text_width_mm, text_height_mm), _ = cv2.getTextSize(
|
||||
f"Distance MM: {distance_mm:.2f}mm", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
|
||||
)
|
||||
cv2.rectangle(self.im, (15, 75), (15 + text_width_mm + 10, 75 + text_height_mm + 20), (255, 255, 255), -1)
|
||||
cv2.putText(
|
||||
self.im,
|
||||
f"Distance MM: {distance_mm:.2f}mm",
|
||||
(20, 100),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.8,
|
||||
(0, 0, 0),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
cv2.line(self.im, centroids[0], centroids[1], line_color, 3)
|
||||
cv2.circle(self.im, centroids[0], 6, centroid_color, -1)
|
||||
cv2.circle(self.im, centroids[1], 6, centroid_color, -1)
|
||||
|
||||
def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10):
|
||||
"""
|
||||
Function for pinpoint human-vision eye mapping and plotting.
|
||||
|
||||
Args:
|
||||
box (list): Bounding box coordinates
|
||||
center_point (tuple): center point for vision eye view
|
||||
color (tuple): object centroid and line color value
|
||||
pin_color (tuple): visioneye point color value
|
||||
thickness (int): int value for line thickness
|
||||
pins_radius (int): visioneye point radius value
|
||||
"""
|
||||
center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
|
||||
cv2.circle(self.im, center_point, pins_radius, pin_color, -1)
|
||||
cv2.circle(self.im, center_bbox, pins_radius, color, -1)
|
||||
cv2.line(self.im, center_point, center_bbox, color, thickness)
|
||||
|
||||
|
||||
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||
@plt_settings()
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
||||
"""Plot training labels including class histograms and box statistics."""
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
|
||||
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight')
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
# Plot dataset labels
|
||||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
||||
nc = int(cls.max() + 1) # number of classes
|
||||
boxes = boxes[:1000000] # limit to 1M boxes
|
||||
x = pd.DataFrame(boxes, columns=['x', 'y', 'width', 'height'])
|
||||
x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
|
||||
|
||||
# Seaborn correlogram
|
||||
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
||||
sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
|
||||
plt.close()
|
||||
|
||||
# Matplotlib labels
|
||||
@ -287,14 +621,14 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
|
||||
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
||||
for i in range(nc):
|
||||
y[2].patches[i].set_color([x / 255 for x in colors(i)])
|
||||
ax[0].set_ylabel('instances')
|
||||
ax[0].set_ylabel("instances")
|
||||
if 0 < len(names) < 30:
|
||||
ax[0].set_xticks(range(len(names)))
|
||||
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
||||
else:
|
||||
ax[0].set_xlabel('classes')
|
||||
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
||||
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
||||
ax[0].set_xlabel("classes")
|
||||
sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
|
||||
sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
|
||||
|
||||
# Rectangles
|
||||
boxes[:, 0:2] = 0.5 # center
|
||||
@ -303,21 +637,22 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
|
||||
for cls, box in zip(cls[:500], boxes[:500]):
|
||||
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
|
||||
ax[1].imshow(img)
|
||||
ax[1].axis('off')
|
||||
ax[1].axis("off")
|
||||
|
||||
for a in [0, 1, 2, 3]:
|
||||
for s in ['top', 'right', 'left', 'bottom']:
|
||||
for s in ["top", "right", "left", "bottom"]:
|
||||
ax[a].spines[s].set_visible(False)
|
||||
|
||||
fname = save_dir / 'labels.jpg'
|
||||
fname = save_dir / "labels.jpg"
|
||||
plt.savefig(fname, dpi=200)
|
||||
plt.close()
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
||||
def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||
"""
|
||||
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
||||
|
||||
This function takes a bounding box and an image, and then saves a cropped portion of the image according
|
||||
to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
|
||||
@ -353,27 +688,33 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
|
||||
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
||||
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
||||
xyxy = ops.xywh2xyxy(b).long()
|
||||
ops.clip_boxes(xyxy, im.shape)
|
||||
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
||||
xyxy = ops.clip_boxes(xyxy, im.shape)
|
||||
crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
|
||||
if save:
|
||||
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
||||
f = str(increment_path(file).with_suffix('.jpg'))
|
||||
f = str(increment_path(file).with_suffix(".jpg"))
|
||||
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
||||
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
|
||||
return crop
|
||||
|
||||
|
||||
@threaded
|
||||
def plot_images(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes=np.zeros(0, dtype=np.float32),
|
||||
masks=np.zeros(0, dtype=np.uint8),
|
||||
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||
paths=None,
|
||||
fname='images.jpg',
|
||||
names=None,
|
||||
on_plot=None):
|
||||
def plot_images(
|
||||
images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes=np.zeros(0, dtype=np.float32),
|
||||
confs=None,
|
||||
masks=np.zeros(0, dtype=np.uint8),
|
||||
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||
paths=None,
|
||||
fname="images.jpg",
|
||||
names=None,
|
||||
on_plot=None,
|
||||
max_subplots=16,
|
||||
save=True,
|
||||
conf_thres=0.25,
|
||||
):
|
||||
"""Plot image grid with labels."""
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
@ -389,21 +730,17 @@ def plot_images(images,
|
||||
batch_idx = batch_idx.cpu().numpy()
|
||||
|
||||
max_size = 1920 # max image size
|
||||
max_subplots = 16 # max image subplots, i.e. 4x4
|
||||
bs, _, h, w = images.shape # batch size, _, height, width
|
||||
bs = min(bs, max_subplots) # limit plot images
|
||||
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
||||
ns = np.ceil(bs**0.5) # number of subplots (square)
|
||||
if np.max(images[0]) <= 1:
|
||||
images *= 255 # de-normalise (optional)
|
||||
|
||||
# Build Image
|
||||
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
||||
for i, im in enumerate(images):
|
||||
if i == max_subplots: # if last batch has fewer images than we expect
|
||||
break
|
||||
for i in range(bs):
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
im = im.transpose(1, 2, 0)
|
||||
mosaic[y:y + h, x:x + w, :] = im
|
||||
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
|
||||
|
||||
# Resize (optional)
|
||||
scale = max_size / ns / max(h, w)
|
||||
@ -415,40 +752,42 @@ def plot_images(images,
|
||||
# Annotate
|
||||
fs = int((h + w) * ns * 0.01) # font size
|
||||
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
||||
for i in range(i + 1):
|
||||
for i in range(bs):
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
||||
if paths:
|
||||
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
||||
if len(cls) > 0:
|
||||
idx = batch_idx == i
|
||||
classes = cls[idx].astype('int')
|
||||
classes = cls[idx].astype("int")
|
||||
labels = confs is None
|
||||
|
||||
if len(bboxes):
|
||||
boxes = ops.xywh2xyxy(bboxes[idx, :4]).T
|
||||
labels = bboxes.shape[1] == 4 # labels if no conf column
|
||||
conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred)
|
||||
|
||||
if boxes.shape[1]:
|
||||
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
||||
boxes[[0, 2]] *= w # scale to pixels
|
||||
boxes[[1, 3]] *= h
|
||||
boxes = bboxes[idx]
|
||||
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
||||
is_obb = boxes.shape[-1] == 5 # xywhr
|
||||
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
||||
if len(boxes):
|
||||
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
|
||||
boxes[..., 0::2] *= w # scale to pixels
|
||||
boxes[..., 1::2] *= h
|
||||
elif scale < 1: # absolute coords need scale if image scales
|
||||
boxes *= scale
|
||||
boxes[[0, 2]] += x
|
||||
boxes[[1, 3]] += y
|
||||
for j, box in enumerate(boxes.T.tolist()):
|
||||
boxes[..., :4] *= scale
|
||||
boxes[..., 0::2] += x
|
||||
boxes[..., 1::2] += y
|
||||
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
||||
c = classes[j]
|
||||
color = colors(c)
|
||||
c = names.get(c, c) if names else c
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
|
||||
annotator.box_label(box, label, color=color)
|
||||
if labels or conf[j] > conf_thres:
|
||||
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
|
||||
annotator.box_label(box, label, color=color, rotated=is_obb)
|
||||
|
||||
elif len(classes):
|
||||
for c in classes:
|
||||
color = colors(c)
|
||||
c = names.get(c, c) if names else c
|
||||
annotator.text((x, y), f'{c}', txt_color=color, box_style=True)
|
||||
annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
|
||||
|
||||
# Plot keypoints
|
||||
if len(kpts):
|
||||
@ -462,7 +801,7 @@ def plot_images(images,
|
||||
kpts_[..., 0] += x
|
||||
kpts_[..., 1] += y
|
||||
for j in range(len(kpts_)):
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
if labels or conf[j] > conf_thres:
|
||||
annotator.kpts(kpts_[j])
|
||||
|
||||
# Plot masks
|
||||
@ -477,8 +816,8 @@ def plot_images(images,
|
||||
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
||||
|
||||
im = np.asarray(annotator.im).copy()
|
||||
for j, box in enumerate(boxes.T.tolist()):
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
for j in range(len(image_masks)):
|
||||
if labels or conf[j] > conf_thres:
|
||||
color = colors(classes[j])
|
||||
mh, mw = image_masks[j].shape
|
||||
if mh != h or mw != w:
|
||||
@ -488,27 +827,42 @@ def plot_images(images,
|
||||
else:
|
||||
mask = image_masks[j].astype(bool)
|
||||
with contextlib.suppress(Exception):
|
||||
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||
im[y : y + h, x : x + w, :][mask] = (
|
||||
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||
)
|
||||
annotator.fromarray(im)
|
||||
if not save:
|
||||
return np.asarray(annotator.im)
|
||||
annotator.im.save(fname) # save
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
@plt_settings()
|
||||
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
|
||||
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
|
||||
"""
|
||||
Plot training results from results CSV file.
|
||||
Plot training results from a results CSV file. The function supports various types of data including segmentation,
|
||||
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
|
||||
|
||||
Args:
|
||||
file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
|
||||
dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
|
||||
segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
|
||||
pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
|
||||
classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
|
||||
on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
|
||||
Defaults to None.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.plotting import plot_results
|
||||
|
||||
plot_results('path/to/results.csv')
|
||||
plot_results('path/to/results.csv', segment=True)
|
||||
```
|
||||
"""
|
||||
import pandas as pd
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
save_dir = Path(file).parent if file else Path(dir)
|
||||
if classify:
|
||||
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
|
||||
@ -523,31 +877,121 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
|
||||
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
||||
index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
|
||||
ax = ax.ravel()
|
||||
files = list(save_dir.glob('results*.csv'))
|
||||
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
|
||||
files = list(save_dir.glob("results*.csv"))
|
||||
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
|
||||
for f in files:
|
||||
try:
|
||||
data = pd.read_csv(f)
|
||||
s = [x.strip() for x in data.columns]
|
||||
x = data.values[:, 0]
|
||||
for i, j in enumerate(index):
|
||||
y = data.values[:, j].astype('float')
|
||||
y = data.values[:, j].astype("float")
|
||||
# y[y == 0] = np.nan # don't show zero values
|
||||
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
|
||||
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
|
||||
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
|
||||
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
|
||||
ax[i].set_title(s[j], fontsize=12)
|
||||
# if j in [8, 9, 10]: # share train and val loss y axes
|
||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
|
||||
LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
|
||||
ax[1].legend()
|
||||
fname = save_dir / 'results.png'
|
||||
fname = save_dir / "results.png"
|
||||
fig.savefig(fname, dpi=200)
|
||||
plt.close()
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
|
||||
"""
|
||||
Plots a scatter plot with points colored based on a 2D histogram.
|
||||
|
||||
Args:
|
||||
v (array-like): Values for the x-axis.
|
||||
f (array-like): Values for the y-axis.
|
||||
bins (int, optional): Number of bins for the histogram. Defaults to 20.
|
||||
cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
|
||||
alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
|
||||
edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
|
||||
|
||||
Examples:
|
||||
>>> v = np.random.rand(100)
|
||||
>>> f = np.random.rand(100)
|
||||
>>> plt_color_scatter(v, f)
|
||||
"""
|
||||
|
||||
# Calculate 2D histogram and corresponding colors
|
||||
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
|
||||
colors = [
|
||||
hist[
|
||||
min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
|
||||
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
|
||||
]
|
||||
for i in range(len(v))
|
||||
]
|
||||
|
||||
# Scatter plot
|
||||
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
|
||||
|
||||
|
||||
def plot_tune_results(csv_file="tune_results.csv"):
|
||||
"""
|
||||
Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
|
||||
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
|
||||
|
||||
Args:
|
||||
csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
|
||||
|
||||
Examples:
|
||||
>>> plot_tune_results('path/to/tune_results.csv')
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
# Scatter plots for each hyperparameter
|
||||
csv_file = Path(csv_file)
|
||||
data = pd.read_csv(csv_file)
|
||||
num_metrics_columns = 1
|
||||
keys = [x.strip() for x in data.columns][num_metrics_columns:]
|
||||
x = data.values
|
||||
fitness = x[:, 0] # fitness
|
||||
j = np.argmax(fitness) # max fitness index
|
||||
n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
|
||||
plt.figure(figsize=(10, 10), tight_layout=True)
|
||||
for i, k in enumerate(keys):
|
||||
v = x[:, i + num_metrics_columns]
|
||||
mu = v[j] # best single result
|
||||
plt.subplot(n, n, i + 1)
|
||||
plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
|
||||
plt.plot(mu, fitness.max(), "k+", markersize=15)
|
||||
plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
|
||||
plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
|
||||
if i % n != 0:
|
||||
plt.yticks([])
|
||||
|
||||
file = csv_file.with_name("tune_scatter_plots.png") # filename
|
||||
plt.savefig(file, dpi=200)
|
||||
plt.close()
|
||||
LOGGER.info(f"Saved {file}")
|
||||
|
||||
# Fitness vs iteration
|
||||
x = range(1, len(fitness) + 1)
|
||||
plt.figure(figsize=(10, 6), tight_layout=True)
|
||||
plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
|
||||
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
|
||||
plt.title("Fitness vs Iteration")
|
||||
plt.xlabel("Iteration")
|
||||
plt.ylabel("Fitness")
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
file = csv_file.with_name("tune_fitness.png") # filename
|
||||
plt.savefig(file, dpi=200)
|
||||
plt.close()
|
||||
LOGGER.info(f"Saved {file}")
|
||||
|
||||
|
||||
def output_to_target(output, max_det=300):
|
||||
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
||||
targets = []
|
||||
@ -556,10 +1000,21 @@ def output_to_target(output, max_det=300):
|
||||
j = torch.full((conf.shape[0], 1), i)
|
||||
targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
|
||||
targets = torch.cat(targets, 0).numpy()
|
||||
return targets[:, 0], targets[:, 1], targets[:, 2:]
|
||||
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
||||
|
||||
|
||||
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
|
||||
def output_to_rotated_target(output, max_det=300):
|
||||
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
||||
targets = []
|
||||
for i, o in enumerate(output):
|
||||
box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
|
||||
j = torch.full((conf.shape[0], 1), i)
|
||||
targets.append(torch.cat((j, cls, box, angle, conf), 1))
|
||||
targets = torch.cat(targets, 0).numpy()
|
||||
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
||||
|
||||
|
||||
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
|
||||
"""
|
||||
Visualize feature maps of a given model module during inference.
|
||||
|
||||
@ -570,23 +1025,23 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
|
||||
n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
|
||||
save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
|
||||
"""
|
||||
for m in ['Detect', 'Pose', 'Segment']:
|
||||
for m in ["Detect", "Pose", "Segment"]:
|
||||
if m in module_type:
|
||||
return
|
||||
batch, channels, height, width = x.shape # batch, channels, height, width
|
||||
_, channels, height, width = x.shape # batch, channels, height, width
|
||||
if height > 1 and width > 1:
|
||||
f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
||||
|
||||
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
||||
n = min(n, channels) # number of plots
|
||||
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
|
||||
_, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
|
||||
ax = ax.ravel()
|
||||
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
||||
for i in range(n):
|
||||
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
||||
ax[i].axis('off')
|
||||
ax[i].axis("off")
|
||||
|
||||
LOGGER.info(f'Saving {f}... ({n}/{channels})')
|
||||
plt.savefig(f, dpi=300, bbox_inches='tight')
|
||||
LOGGER.info(f"Saving {f}... ({n}/{channels})")
|
||||
plt.savefig(f, dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
|
||||
np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save
|
||||
|
@ -4,65 +4,18 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .checks import check_version
|
||||
from .metrics import bbox_iou
|
||||
from .metrics import bbox_iou, probiou
|
||||
from .ops import xywhr2xyxyxyxy
|
||||
|
||||
TORCH_1_10 = check_version(torch.__version__, '1.10.0')
|
||||
|
||||
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
"""
|
||||
Select the positive anchor center in gt.
|
||||
|
||||
Args:
|
||||
xy_centers (Tensor): shape(h*w, 2)
|
||||
gt_bboxes (Tensor): shape(b, n_boxes, 4)
|
||||
|
||||
Returns:
|
||||
(Tensor): shape(b, n_boxes, h*w)
|
||||
"""
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
||||
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
|
||||
return bbox_deltas.amin(3).gt_(eps)
|
||||
|
||||
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
"""
|
||||
If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected.
|
||||
|
||||
Args:
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
overlaps (Tensor): shape(b, n_max_boxes, h*w)
|
||||
|
||||
Returns:
|
||||
target_gt_idx (Tensor): shape(b, h*w)
|
||||
fg_mask (Tensor): shape(b, h*w)
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
"""
|
||||
# (b, n_max_boxes, h*w) -> (b, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
||||
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
# Find each grid serve which gt(index)
|
||||
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
||||
return target_gt_idx, fg_mask, mask_pos
|
||||
TORCH_1_10 = check_version(torch.__version__, "1.10.0")
|
||||
|
||||
|
||||
class TaskAlignedAssigner(nn.Module):
|
||||
"""
|
||||
A task-aligned assigner for object detection.
|
||||
|
||||
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric,
|
||||
which combines both classification and localization information.
|
||||
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
|
||||
classification and localization information.
|
||||
|
||||
Attributes:
|
||||
topk (int): The number of top candidates to consider.
|
||||
@ -85,8 +38,8 @@ class TaskAlignedAssigner(nn.Module):
|
||||
@torch.no_grad()
|
||||
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
"""
|
||||
Compute the task-aligned assignment.
|
||||
Reference https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
|
||||
Compute the task-aligned assignment. Reference code is available at
|
||||
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
|
||||
|
||||
Args:
|
||||
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
||||
@ -103,19 +56,24 @@ class TaskAlignedAssigner(nn.Module):
|
||||
fg_mask (Tensor): shape(bs, num_total_anchors)
|
||||
target_gt_idx (Tensor): shape(bs, num_total_anchors)
|
||||
"""
|
||||
self.bs = pd_scores.size(0)
|
||||
self.n_max_boxes = gt_bboxes.size(1)
|
||||
self.bs = pd_scores.shape[0]
|
||||
self.n_max_boxes = gt_bboxes.shape[1]
|
||||
|
||||
if self.n_max_boxes == 0:
|
||||
device = gt_bboxes.device
|
||||
return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
|
||||
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device))
|
||||
return (
|
||||
torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
|
||||
torch.zeros_like(pd_bboxes).to(device),
|
||||
torch.zeros_like(pd_scores).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
)
|
||||
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
|
||||
mask_gt)
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
||||
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
||||
)
|
||||
|
||||
target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
|
||||
# Assigned target
|
||||
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
||||
@ -131,7 +89,7 @@ class TaskAlignedAssigner(nn.Module):
|
||||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
||||
"""Get in_gts mask, (b, max_num_obj, h*w)."""
|
||||
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
# Get anchor_align metric, (b, max_num_obj, h*w)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
||||
# Get topk_metric mask, (b, max_num_obj, h*w)
|
||||
@ -157,11 +115,15 @@ class TaskAlignedAssigner(nn.Module):
|
||||
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
|
||||
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
||||
overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
|
||||
|
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
||||
return align_metric, overlaps
|
||||
|
||||
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
||||
"""IoU calculation for horizontal bounding boxes."""
|
||||
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
||||
|
||||
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
|
||||
"""
|
||||
Select the top-k candidates based on the given metrics.
|
||||
@ -191,9 +153,9 @@ class TaskAlignedAssigner(nn.Module):
|
||||
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
|
||||
for k in range(self.topk):
|
||||
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
|
||||
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
|
||||
# filter invalid bboxes
|
||||
# Filter invalid bboxes
|
||||
count_tensor.masked_fill_(count_tensor > 1, 0)
|
||||
|
||||
return count_tensor.to(metrics.dtype)
|
||||
@ -229,15 +191,17 @@ class TaskAlignedAssigner(nn.Module):
|
||||
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
|
||||
|
||||
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
|
||||
target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
|
||||
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
|
||||
|
||||
# Assigned target scores
|
||||
target_labels.clamp_(0)
|
||||
|
||||
# 10x faster than F.one_hot()
|
||||
target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device) # (b, h*w, 80)
|
||||
target_scores = torch.zeros(
|
||||
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device,
|
||||
) # (b, h*w, 80)
|
||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||
|
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
|
||||
@ -245,6 +209,87 @@ class TaskAlignedAssigner(nn.Module):
|
||||
|
||||
return target_labels, target_bboxes, target_scores
|
||||
|
||||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
"""
|
||||
Select the positive anchor center in gt.
|
||||
|
||||
Args:
|
||||
xy_centers (Tensor): shape(h*w, 2)
|
||||
gt_bboxes (Tensor): shape(b, n_boxes, 4)
|
||||
|
||||
Returns:
|
||||
(Tensor): shape(b, n_boxes, h*w)
|
||||
"""
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
||||
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
|
||||
return bbox_deltas.amin(3).gt_(eps)
|
||||
|
||||
@staticmethod
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
"""
|
||||
If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
|
||||
|
||||
Args:
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
overlaps (Tensor): shape(b, n_max_boxes, h*w)
|
||||
|
||||
Returns:
|
||||
target_gt_idx (Tensor): shape(b, h*w)
|
||||
fg_mask (Tensor): shape(b, h*w)
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
"""
|
||||
# (b, n_max_boxes, h*w) -> (b, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
||||
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
# Find each grid serve which gt(index)
|
||||
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
||||
return target_gt_idx, fg_mask, mask_pos
|
||||
|
||||
|
||||
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
||||
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
||||
"""IoU calculation for rotated bounding boxes."""
|
||||
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
|
||||
|
||||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes):
|
||||
"""
|
||||
Select the positive anchor center in gt for rotated bounding boxes.
|
||||
|
||||
Args:
|
||||
xy_centers (Tensor): shape(h*w, 2)
|
||||
gt_bboxes (Tensor): shape(b, n_boxes, 5)
|
||||
|
||||
Returns:
|
||||
(Tensor): shape(b, n_boxes, h*w)
|
||||
"""
|
||||
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
|
||||
corners = xywhr2xyxyxyxy(gt_bboxes)
|
||||
# (b, n_boxes, 1, 2)
|
||||
a, b, _, d = corners.split(1, dim=-2)
|
||||
ab = b - a
|
||||
ad = d - a
|
||||
|
||||
# (b, n_boxes, h*w, 2)
|
||||
ap = xy_centers - a
|
||||
norm_ab = (ab * ab).sum(dim=-1)
|
||||
norm_ad = (ad * ad).sum(dim=-1)
|
||||
ap_dot_ab = (ap * ab).sum(dim=-1)
|
||||
ap_dot_ad = (ap * ad).sum(dim=-1)
|
||||
return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
|
||||
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
"""Generate anchors from features."""
|
||||
@ -255,7 +300,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
_, _, h, w = feats[i].shape
|
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
|
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||||
@ -263,7 +308,8 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
"""Transform distance(ltrb) to box(xywh or xyxy)."""
|
||||
lt, rb = distance.chunk(2, dim)
|
||||
assert(distance.shape[dim] == 4)
|
||||
lt, rb = distance.split([2, 2], dim)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
if xywh:
|
||||
@ -277,3 +323,23 @@ def bbox2dist(anchor_points, bbox, reg_max):
|
||||
"""Transform bbox(xyxy) to dist(ltrb)."""
|
||||
x1y1, x2y2 = bbox.chunk(2, -1)
|
||||
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
|
||||
|
||||
|
||||
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
||||
"""
|
||||
Decode predicted object bounding box coordinates from anchor points and distribution.
|
||||
|
||||
Args:
|
||||
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
|
||||
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
|
||||
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
||||
Returns:
|
||||
(torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
|
||||
"""
|
||||
lt, rb = pred_dist.split(2, dim=dim)
|
||||
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
|
||||
# (bs, h*w, 1)
|
||||
xf, yf = ((rb - lt) / 2).split(1, dim=dim)
|
||||
x, y = xf * cos - yf * sin, xf * sin + yf * cos
|
||||
xy = torch.cat([x, y], dim=dim) + anchor_points
|
||||
return torch.cat([xy, lt + rb], dim=dim)
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
@ -15,17 +14,23 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
|
||||
from ultralytics.utils.checks import check_version
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
|
||||
from ultralytics.utils.checks import PYTHON_VERSION, check_version
|
||||
|
||||
try:
|
||||
import thop
|
||||
except ImportError:
|
||||
thop = None
|
||||
|
||||
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
|
||||
TORCH_2_0 = check_version(torch.__version__, '2.0.0')
|
||||
# Version checks (all default to version>=min_version)
|
||||
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
|
||||
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
|
||||
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
|
||||
TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
|
||||
TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
|
||||
TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -44,7 +49,10 @@ def smart_inference_mode():
|
||||
|
||||
def decorate(fn):
|
||||
"""Applies appropriate torch decorator for inference mode based on torch version."""
|
||||
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
|
||||
if TORCH_1_9 and torch.is_inference_mode_enabled():
|
||||
return fn # already in inference_mode, act as a pass-through
|
||||
else:
|
||||
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
|
||||
|
||||
return decorate
|
||||
|
||||
@ -53,59 +61,102 @@ def get_cpu_info():
|
||||
"""Return a string with system CPU information, i.e. 'Apple M2'."""
|
||||
import cpuinfo # pip install py-cpuinfo
|
||||
|
||||
k = 'brand_raw', 'hardware_raw', 'arch_string_raw' # info keys sorted by preference (not all keys always available)
|
||||
k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available)
|
||||
info = cpuinfo.get_cpu_info() # info dict
|
||||
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], 'unknown')
|
||||
return string.replace('(R)', '').replace('CPU ', '').replace('@ ', '')
|
||||
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
|
||||
return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
|
||||
|
||||
|
||||
def select_device(device='', batch=0, newline=False, verbose=True):
|
||||
"""Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'."""
|
||||
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
def select_device(device="", batch=0, newline=False, verbose=True):
|
||||
"""
|
||||
Selects the appropriate PyTorch device based on the provided arguments.
|
||||
|
||||
The function takes a string specifying the device or a torch.device object and returns a torch.device object
|
||||
representing the selected device. The function also validates the number of available devices and raises an
|
||||
exception if the requested device(s) are not available.
|
||||
|
||||
Args:
|
||||
device (str | torch.device, optional): Device string or torch.device object.
|
||||
Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
|
||||
the first available GPU, or CPU if no GPU is available.
|
||||
batch (int, optional): Batch size being used in your model. Defaults to 0.
|
||||
newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
|
||||
verbose (bool, optional): If True, logs the device information. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(torch.device): Selected device.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
|
||||
devices when using multiple GPUs.
|
||||
|
||||
Examples:
|
||||
>>> select_device('cuda:0')
|
||||
device(type='cuda', index=0)
|
||||
|
||||
>>> select_device('cpu')
|
||||
device(type='cpu')
|
||||
|
||||
Note:
|
||||
Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
|
||||
"""
|
||||
|
||||
if isinstance(device, torch.device):
|
||||
return device
|
||||
|
||||
s = f"Ultralytics YOLOv{__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
|
||||
device = str(device).lower()
|
||||
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
||||
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||
cpu = device == 'cpu'
|
||||
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
||||
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
|
||||
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||
cpu = device == "cpu"
|
||||
mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
|
||||
if cpu or mps:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
|
||||
elif device: # non-cpu device requested
|
||||
if device == 'cuda':
|
||||
device = '0'
|
||||
visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
||||
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
|
||||
if device == "cuda":
|
||||
device = "0"
|
||||
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
|
||||
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
|
||||
LOGGER.info(s)
|
||||
install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
|
||||
'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
|
||||
raise ValueError(f"Invalid CUDA 'device={device}' requested."
|
||||
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
|
||||
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
|
||||
f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
|
||||
f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
|
||||
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
|
||||
f'{install}')
|
||||
install = (
|
||||
"See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
|
||||
"CUDA devices are seen by torch.\n"
|
||||
if torch.cuda.device_count() == 0
|
||||
else ""
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid CUDA 'device={device}' requested."
|
||||
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
|
||||
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
|
||||
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
|
||||
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
|
||||
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
|
||||
f"{install}"
|
||||
)
|
||||
|
||||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
n = len(devices) # device count
|
||||
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||||
raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||||
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
|
||||
space = ' ' * (len(s) + 1)
|
||||
raise ValueError(
|
||||
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||||
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
|
||||
)
|
||||
space = " " * (len(s) + 1)
|
||||
for i, d in enumerate(devices):
|
||||
p = torch.cuda.get_device_properties(i)
|
||||
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
||||
arg = 'cuda:0'
|
||||
elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available() and TORCH_2_0:
|
||||
arg = "cuda:0"
|
||||
elif mps and TORCH_2_0 and torch.backends.mps.is_available():
|
||||
# Prefer MPS if available
|
||||
s += f'MPS ({get_cpu_info()})\n'
|
||||
arg = 'mps'
|
||||
s += f"MPS ({get_cpu_info()})\n"
|
||||
arg = "mps"
|
||||
else: # revert to CPU
|
||||
s += f'CPU ({get_cpu_info()})\n'
|
||||
arg = 'cpu'
|
||||
s += f"CPU ({get_cpu_info()})\n"
|
||||
arg = "cpu"
|
||||
|
||||
if verbose and RANK == -1:
|
||||
if verbose:
|
||||
LOGGER.info(s if newline else s.rstrip())
|
||||
return torch.device(arg)
|
||||
|
||||
@ -119,14 +170,20 @@ def time_sync():
|
||||
|
||||
def fuse_conv_and_bn(conv, bn):
|
||||
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
|
||||
fusedconv = nn.Conv2d(conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
dilation=conv.dilation,
|
||||
groups=conv.groups,
|
||||
bias=True).requires_grad_(False).to(conv.weight.device)
|
||||
fusedconv = (
|
||||
nn.Conv2d(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
dilation=conv.dilation,
|
||||
groups=conv.groups,
|
||||
bias=True,
|
||||
)
|
||||
.requires_grad_(False)
|
||||
.to(conv.weight.device)
|
||||
)
|
||||
|
||||
# Prepare filters
|
||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||
@ -134,7 +191,7 @@ def fuse_conv_and_bn(conv, bn):
|
||||
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
||||
|
||||
# Prepare spatial bias
|
||||
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
||||
b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
|
||||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||||
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||||
|
||||
@ -143,15 +200,21 @@ def fuse_conv_and_bn(conv, bn):
|
||||
|
||||
def fuse_deconv_and_bn(deconv, bn):
|
||||
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
|
||||
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
|
||||
deconv.out_channels,
|
||||
kernel_size=deconv.kernel_size,
|
||||
stride=deconv.stride,
|
||||
padding=deconv.padding,
|
||||
output_padding=deconv.output_padding,
|
||||
dilation=deconv.dilation,
|
||||
groups=deconv.groups,
|
||||
bias=True).requires_grad_(False).to(deconv.weight.device)
|
||||
fuseddconv = (
|
||||
nn.ConvTranspose2d(
|
||||
deconv.in_channels,
|
||||
deconv.out_channels,
|
||||
kernel_size=deconv.kernel_size,
|
||||
stride=deconv.stride,
|
||||
padding=deconv.padding,
|
||||
output_padding=deconv.output_padding,
|
||||
dilation=deconv.dilation,
|
||||
groups=deconv.groups,
|
||||
bias=True,
|
||||
)
|
||||
.requires_grad_(False)
|
||||
.to(deconv.weight.device)
|
||||
)
|
||||
|
||||
# Prepare filters
|
||||
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
|
||||
@ -159,7 +222,7 @@ def fuse_deconv_and_bn(deconv, bn):
|
||||
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
|
||||
|
||||
# Prepare spatial bias
|
||||
b_conv = torch.zeros(deconv.weight.size(1), device=deconv.weight.device) if deconv.bias is None else deconv.bias
|
||||
b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
|
||||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||||
fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||||
|
||||
@ -167,7 +230,11 @@ def fuse_deconv_and_bn(deconv, bn):
|
||||
|
||||
|
||||
def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
"""Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""
|
||||
"""
|
||||
Model information.
|
||||
|
||||
imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
|
||||
"""
|
||||
if not verbose:
|
||||
return
|
||||
n_p = get_num_params(model) # number of parameters
|
||||
@ -175,18 +242,21 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
n_l = len(list(model.modules())) # number of layers
|
||||
if detailed:
|
||||
LOGGER.info(
|
||||
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
||||
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
|
||||
)
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
name = name.replace('module_list.', '')
|
||||
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
|
||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
|
||||
name = name.replace("module_list.", "")
|
||||
LOGGER.info(
|
||||
"%5g %40s %9s %12g %20s %10.3g %10.3g %10s"
|
||||
% (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)
|
||||
)
|
||||
|
||||
flops = get_flops(model, imgsz)
|
||||
fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else ''
|
||||
fs = f', {flops:.1f} GFLOPs' if flops else ''
|
||||
yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '')
|
||||
model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model'
|
||||
LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}')
|
||||
fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
|
||||
fs = f", {flops:.1f} GFLOPs" if flops else ""
|
||||
yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
|
||||
model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
|
||||
LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
|
||||
return n_l, n_p, n_g, flops
|
||||
|
||||
|
||||
@ -204,37 +274,53 @@ def model_info_for_loggers(trainer):
|
||||
"""
|
||||
Return model info dict with useful model information.
|
||||
|
||||
Example for YOLOv8n:
|
||||
{'model/parameters': 3151904,
|
||||
'model/GFLOPs': 8.746,
|
||||
'model/speed_ONNX(ms)': 41.244,
|
||||
'model/speed_TensorRT(ms)': 3.211,
|
||||
'model/speed_PyTorch(ms)': 18.755}
|
||||
Example:
|
||||
YOLOv8n info for loggers
|
||||
```python
|
||||
results = {'model/parameters': 3151904,
|
||||
'model/GFLOPs': 8.746,
|
||||
'model/speed_ONNX(ms)': 41.244,
|
||||
'model/speed_TensorRT(ms)': 3.211,
|
||||
'model/speed_PyTorch(ms)': 18.755}
|
||||
```
|
||||
"""
|
||||
if trainer.args.profile: # profile ONNX and TensorRT times
|
||||
from ultralytics.utils.benchmarks import ProfileModels
|
||||
|
||||
results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
|
||||
results.pop('model/name')
|
||||
results.pop("model/name")
|
||||
else: # only return PyTorch times from most recent validation
|
||||
results = {
|
||||
'model/parameters': get_num_params(trainer.model),
|
||||
'model/GFLOPs': round(get_flops(trainer.model), 3)}
|
||||
results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3)
|
||||
"model/parameters": get_num_params(trainer.model),
|
||||
"model/GFLOPs": round(get_flops(trainer.model), 3),
|
||||
}
|
||||
results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
|
||||
return results
|
||||
|
||||
|
||||
def get_flops(model, imgsz=640):
|
||||
"""Return a YOLO model's FLOPs."""
|
||||
if not thop:
|
||||
return 0.0 # if not installed return 0.0 GFLOPs
|
||||
|
||||
try:
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
||||
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0 # stride GFLOPs
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
return flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||
if not isinstance(imgsz, list):
|
||||
imgsz = [imgsz, imgsz] # expand if int/float
|
||||
try:
|
||||
# Use stride size for input tensor
|
||||
# stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
|
||||
# im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
# flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
|
||||
# return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
|
||||
raise Exception
|
||||
except Exception:
|
||||
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
||||
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
|
||||
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
|
||||
except Exception:
|
||||
return 0
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_flops_with_torch_profiler(model, imgsz=640):
|
||||
@ -242,11 +328,11 @@ def get_flops_with_torch_profiler(model, imgsz=640):
|
||||
if TORCH_2_0:
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
|
||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
|
||||
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
with torch.profiler.profile(with_flops=True) as prof:
|
||||
model(im)
|
||||
flops = sum(x.flops for x in prof.key_averages()) / 1E9
|
||||
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||
return flops
|
||||
@ -266,13 +352,15 @@ def initialize_weights(model):
|
||||
m.inplace = True
|
||||
|
||||
|
||||
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
||||
# Scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
||||
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
|
||||
"""Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally
|
||||
retaining the original shape.
|
||||
"""
|
||||
if ratio == 1.0:
|
||||
return img
|
||||
h, w = img.shape[2:]
|
||||
s = (int(h * ratio), int(w * ratio)) # new size
|
||||
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
||||
img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
|
||||
if not same_shape: # pad/crop img
|
||||
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
|
||||
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
||||
@ -288,7 +376,7 @@ def make_divisible(x, divisor):
|
||||
def copy_attr(a, b, include=(), exclude=()):
|
||||
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
|
||||
for k, v in b.__dict__.items():
|
||||
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
||||
if (len(include) and k not in include) or k.startswith("_") or k in exclude:
|
||||
continue
|
||||
else:
|
||||
setattr(a, k, v)
|
||||
@ -296,7 +384,7 @@ def copy_attr(a, b, include=(), exclude=()):
|
||||
|
||||
def get_latest_opset():
|
||||
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
|
||||
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
|
||||
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset
|
||||
|
||||
|
||||
def intersect_dicts(da, db, exclude=()):
|
||||
@ -316,7 +404,7 @@ def de_parallel(model):
|
||||
|
||||
def one_cycle(y1=0.0, y2=1.0, steps=100):
|
||||
"""Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
|
||||
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
||||
return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
|
||||
|
||||
|
||||
def init_seeds(seed=0, deterministic=False):
|
||||
@ -331,10 +419,10 @@ def init_seeds(seed=0, deterministic=False):
|
||||
if TORCH_2_0:
|
||||
torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
|
||||
torch.backends.cudnn.deterministic = True
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
else:
|
||||
LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
|
||||
LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
|
||||
else:
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = False
|
||||
@ -369,13 +457,13 @@ class ModelEMA:
|
||||
v += (1 - d) * msd[k].detach()
|
||||
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
|
||||
|
||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||
def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
|
||||
"""Updates attributes and saves stripped model with optimizer removed."""
|
||||
if self.enabled:
|
||||
copy_attr(self.ema, model, include, exclude)
|
||||
|
||||
|
||||
def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
|
||||
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
||||
"""
|
||||
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
||||
|
||||
@ -395,32 +483,26 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
|
||||
strip_optimizer(f)
|
||||
```
|
||||
"""
|
||||
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
||||
try:
|
||||
import dill as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
if 'model' not in x:
|
||||
LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.')
|
||||
x = torch.load(f, map_location=torch.device("cpu"))
|
||||
if "model" not in x:
|
||||
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
|
||||
return
|
||||
|
||||
if hasattr(x['model'], 'args'):
|
||||
x['model'].args = dict(x['model'].args) # convert from IterableSimpleNamespace to dict
|
||||
args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args
|
||||
if x.get('ema'):
|
||||
x['model'] = x['ema'] # replace model with ema
|
||||
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
||||
if hasattr(x["model"], "args"):
|
||||
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
|
||||
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
|
||||
if x.get("ema"):
|
||||
x["model"] = x["ema"] # replace model with ema
|
||||
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
||||
x[k] = None
|
||||
x['epoch'] = -1
|
||||
x['model'].half() # to FP16
|
||||
for p in x['model'].parameters():
|
||||
x["epoch"] = -1
|
||||
x["model"].half() # to FP16
|
||||
for p in x["model"].parameters():
|
||||
p.requires_grad = False
|
||||
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||
# x['model'].args = x['train_args']
|
||||
torch.save(x, s or f, pickle_module=pickle)
|
||||
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||||
torch.save(x, s or f)
|
||||
mb = os.path.getsize(s or f) / 1e6 # file size
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
||||
|
||||
@ -441,18 +523,20 @@ def profile(input, ops, n=10, device=None):
|
||||
results = []
|
||||
if not isinstance(device, torch.device):
|
||||
device = select_device(device)
|
||||
LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||
f"{'input':>24s}{'output':>24s}")
|
||||
LOGGER.info(
|
||||
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||
f"{'input':>24s}{'output':>24s}"
|
||||
)
|
||||
|
||||
for x in input if isinstance(input, list) else [input]:
|
||||
x = x.to(device)
|
||||
x.requires_grad = True
|
||||
for m in ops if isinstance(ops, list) else [ops]:
|
||||
m = m.to(device) if hasattr(m, 'to') else m # device
|
||||
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
||||
m = m.to(device) if hasattr(m, "to") else m # device
|
||||
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
||||
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
||||
try:
|
||||
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1E9 * 2 if thop else 0 # GFLOPs
|
||||
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
|
||||
except Exception:
|
||||
flops = 0
|
||||
|
||||
@ -466,13 +550,13 @@ def profile(input, ops, n=10, device=None):
|
||||
t[2] = time_sync()
|
||||
except Exception: # no backward method
|
||||
# print(e) # for debug
|
||||
t[2] = float('nan')
|
||||
t[2] = float("nan")
|
||||
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
||||
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
|
||||
mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
|
||||
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
|
||||
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
||||
LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||||
LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
|
||||
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
||||
except Exception as e:
|
||||
LOGGER.info(e)
|
||||
@ -482,25 +566,23 @@ def profile(input, ops, n=10, device=None):
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
"""
|
||||
Early stopping class that stops training when a specified number of epochs have passed without improvement.
|
||||
"""
|
||||
"""Early stopping class that stops training when a specified number of epochs have passed without improvement."""
|
||||
|
||||
def __init__(self, patience=50):
|
||||
"""
|
||||
Initialize early stopping object
|
||||
Initialize early stopping object.
|
||||
|
||||
Args:
|
||||
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
|
||||
"""
|
||||
self.best_fitness = 0.0 # i.e. mAP
|
||||
self.best_epoch = 0
|
||||
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
||||
self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
|
||||
self.possible_stop = False # possible stop may occur next epoch
|
||||
|
||||
def __call__(self, epoch, fitness):
|
||||
"""
|
||||
Check whether to stop training
|
||||
Check whether to stop training.
|
||||
|
||||
Args:
|
||||
epoch (int): Current epoch of training
|
||||
@ -519,8 +601,10 @@ class EarlyStopping:
|
||||
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
||||
stop = delta >= self.patience # stop training if patience exceeded
|
||||
if stop:
|
||||
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
|
||||
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
|
||||
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
|
||||
f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.')
|
||||
LOGGER.info(
|
||||
f"Stopping training early as no improvement observed in last {self.patience} epochs. "
|
||||
f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
|
||||
f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
|
||||
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
|
||||
)
|
||||
return stop
|
||||
|
92
ultralytics/utils/triton.py
Normal file
92
ultralytics/utils/triton.py
Normal file
@ -0,0 +1,92 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from typing import List
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TritonRemoteModel:
|
||||
"""
|
||||
Client for interacting with a remote Triton Inference Server model.
|
||||
|
||||
Attributes:
|
||||
endpoint (str): The name of the model on the Triton server.
|
||||
url (str): The URL of the Triton server.
|
||||
triton_client: The Triton client (either HTTP or gRPC).
|
||||
InferInput: The input class for the Triton client.
|
||||
InferRequestedOutput: The output request class for the Triton client.
|
||||
input_formats (List[str]): The data types of the model inputs.
|
||||
np_input_formats (List[type]): The numpy data types of the model inputs.
|
||||
input_names (List[str]): The names of the model inputs.
|
||||
output_names (List[str]): The names of the model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
|
||||
"""
|
||||
Initialize the TritonRemoteModel.
|
||||
|
||||
Arguments may be provided individually or parsed from a collective 'url' argument of the form
|
||||
<scheme>://<netloc>/<endpoint>/<task_name>
|
||||
|
||||
Args:
|
||||
url (str): The URL of the Triton server.
|
||||
endpoint (str): The name of the model on the Triton server.
|
||||
scheme (str): The communication scheme ('http' or 'grpc').
|
||||
"""
|
||||
if not endpoint and not scheme: # Parse all args from URL string
|
||||
splits = urlsplit(url)
|
||||
endpoint = splits.path.strip("/").split("/")[0]
|
||||
scheme = splits.scheme
|
||||
url = splits.netloc
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.url = url
|
||||
|
||||
# Choose the Triton client based on the communication scheme
|
||||
if scheme == "http":
|
||||
import tritonclient.http as client # noqa
|
||||
|
||||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint)
|
||||
else:
|
||||
import tritonclient.grpc as client # noqa
|
||||
|
||||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
|
||||
|
||||
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
|
||||
config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
|
||||
|
||||
# Define model attributes
|
||||
type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
|
||||
self.InferRequestedOutput = client.InferRequestedOutput
|
||||
self.InferInput = client.InferInput
|
||||
self.input_formats = [x["data_type"] for x in config["input"]]
|
||||
self.np_input_formats = [type_map[x] for x in self.input_formats]
|
||||
self.input_names = [x["name"] for x in config["input"]]
|
||||
self.output_names = [x["name"] for x in config["output"]]
|
||||
|
||||
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
Call the model with the given inputs.
|
||||
|
||||
Args:
|
||||
*inputs (List[np.ndarray]): Input data to the model.
|
||||
|
||||
Returns:
|
||||
(List[np.ndarray]): Model outputs.
|
||||
"""
|
||||
infer_inputs = []
|
||||
input_format = inputs[0].dtype
|
||||
for i, x in enumerate(inputs):
|
||||
if x.dtype != self.np_input_formats[i]:
|
||||
x = x.astype(self.np_input_formats[i])
|
||||
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
|
||||
infer_input.set_data_from_numpy(x)
|
||||
infer_inputs.append(infer_input)
|
||||
|
||||
infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
|
||||
outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
|
||||
|
||||
return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]
|
@ -2,16 +2,13 @@
|
||||
|
||||
import subprocess
|
||||
|
||||
from ultralytics.cfg import TASK2DATA, TASK2METRIC
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
|
||||
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir
|
||||
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks
|
||||
|
||||
|
||||
def run_ray_tune(model,
|
||||
space: dict = None,
|
||||
grace_period: int = 10,
|
||||
gpu_per_trial: int = None,
|
||||
max_samples: int = 10,
|
||||
**train_args):
|
||||
def run_ray_tune(
|
||||
model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args
|
||||
):
|
||||
"""
|
||||
Runs hyperparameter tuning using Ray Tune.
|
||||
|
||||
@ -37,49 +34,59 @@ def run_ray_tune(model,
|
||||
result_grid = model.tune(data='coco8.yaml', use_ray=True)
|
||||
```
|
||||
"""
|
||||
|
||||
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
|
||||
if train_args is None:
|
||||
train_args = {}
|
||||
|
||||
try:
|
||||
subprocess.run('pip install ray[tune]'.split(), check=True)
|
||||
subprocess.run("pip install ray[tune]<=2.9.3".split(), check=True) # do not add single quotes here
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air import RunConfig
|
||||
from ray.air.integrations.wandb import WandbLoggerCallback
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError('Tuning hyperparameters requires Ray Tune. Install with: pip install "ray[tune]"')
|
||||
raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]<=2.9.3"')
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
assert hasattr(wandb, '__version__')
|
||||
assert hasattr(wandb, "__version__")
|
||||
except (ImportError, AssertionError):
|
||||
wandb = False
|
||||
|
||||
checks.check_version(ray.__version__, "<=2.9.3", "ray")
|
||||
default_space = {
|
||||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
||||
'lr0': tune.uniform(1e-5, 1e-1),
|
||||
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
|
||||
'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
|
||||
'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
|
||||
'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
|
||||
'box': tune.uniform(0.02, 0.2), # box loss gain
|
||||
'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
|
||||
'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||
'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
|
||||
'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
|
||||
'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
|
||||
'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
|
||||
'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
|
||||
'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
|
||||
'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
|
||||
"lr0": tune.uniform(1e-5, 1e-1),
|
||||
"lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
"momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
|
||||
"weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
|
||||
"warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
|
||||
"warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
|
||||
"box": tune.uniform(0.02, 0.2), # box loss gain
|
||||
"cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
|
||||
"hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
"hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
"hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||
"degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
|
||||
"translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
|
||||
"scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
|
||||
"shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
|
||||
"perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
"flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
|
||||
"fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
|
||||
"bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability)
|
||||
"mosaic": tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
"mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
"copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
|
||||
}
|
||||
|
||||
# Put the model in ray store
|
||||
task = model.task
|
||||
model_in_store = ray.put(model)
|
||||
|
||||
def _tune(config):
|
||||
"""
|
||||
@ -89,42 +96,50 @@ def run_ray_tune(model,
|
||||
config (dict): A dictionary of hyperparameters to use for training.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
None
|
||||
"""
|
||||
model._reset_callbacks()
|
||||
model_to_train = ray.get(model_in_store) # get the model from ray store for tuning
|
||||
model_to_train.reset_callbacks()
|
||||
config.update(train_args)
|
||||
model.train(**config)
|
||||
results = model_to_train.train(**config)
|
||||
return results.results_dict
|
||||
|
||||
# Get search space
|
||||
if not space:
|
||||
space = default_space
|
||||
LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
|
||||
LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.")
|
||||
|
||||
# Get dataset
|
||||
data = train_args.get('data', TASK2DATA[model.task])
|
||||
space['data'] = data
|
||||
if 'data' not in train_args:
|
||||
data = train_args.get("data", TASK2DATA[task])
|
||||
space["data"] = data
|
||||
if "data" not in train_args:
|
||||
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
|
||||
|
||||
# Define the trainable function with allocated resources
|
||||
trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
|
||||
trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
|
||||
|
||||
# Define the ASHA scheduler for hyperparameter search
|
||||
asha_scheduler = ASHAScheduler(time_attr='epoch',
|
||||
metric=TASK2METRIC[model.task],
|
||||
mode='max',
|
||||
max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100,
|
||||
grace_period=grace_period,
|
||||
reduction_factor=3)
|
||||
asha_scheduler = ASHAScheduler(
|
||||
time_attr="epoch",
|
||||
metric=TASK2METRIC[task],
|
||||
mode="max",
|
||||
max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
|
||||
grace_period=grace_period,
|
||||
reduction_factor=3,
|
||||
)
|
||||
|
||||
# Define the callbacks for the hyperparameter search
|
||||
tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
|
||||
tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
|
||||
|
||||
# Create the Ray Tune hyperparameter search tuner
|
||||
tuner = tune.Tuner(trainable_with_resources,
|
||||
param_space=space,
|
||||
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
|
||||
run_config=RunConfig(callbacks=tuner_callbacks, storage_path='./runs/tune'))
|
||||
tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve() # must be absolute dir
|
||||
tune_dir.mkdir(parents=True, exist_ok=True)
|
||||
tuner = tune.Tuner(
|
||||
trainable_with_resources,
|
||||
param_space=space,
|
||||
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
|
||||
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),
|
||||
)
|
||||
|
||||
# Run the hyperparameter search
|
||||
tuner.fit()
|
||||
|
Reference in New Issue
Block a user