add yolo v10 and modify pipeline

This commit is contained in:
王庆刚
2025-03-28 13:19:54 +08:00
parent 183299c06b
commit 798c596acc
471 changed files with 19109 additions and 7342 deletions

View File

@ -2,4 +2,4 @@
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
__all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks'
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"

View File

@ -1,11 +1,10 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Base callbacks
"""
"""Base callbacks."""
from collections import defaultdict
from copy import deepcopy
# Trainer callbacks ----------------------------------------------------------------------------------------------------
@ -145,37 +144,35 @@ def on_export_end(exporter):
default_callbacks = {
# Run in trainer
'on_pretrain_routine_start': [on_pretrain_routine_start],
'on_pretrain_routine_end': [on_pretrain_routine_end],
'on_train_start': [on_train_start],
'on_train_epoch_start': [on_train_epoch_start],
'on_train_batch_start': [on_train_batch_start],
'optimizer_step': [optimizer_step],
'on_before_zero_grad': [on_before_zero_grad],
'on_train_batch_end': [on_train_batch_end],
'on_train_epoch_end': [on_train_epoch_end],
'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
'on_model_save': [on_model_save],
'on_train_end': [on_train_end],
'on_params_update': [on_params_update],
'teardown': [teardown],
"on_pretrain_routine_start": [on_pretrain_routine_start],
"on_pretrain_routine_end": [on_pretrain_routine_end],
"on_train_start": [on_train_start],
"on_train_epoch_start": [on_train_epoch_start],
"on_train_batch_start": [on_train_batch_start],
"optimizer_step": [optimizer_step],
"on_before_zero_grad": [on_before_zero_grad],
"on_train_batch_end": [on_train_batch_end],
"on_train_epoch_end": [on_train_epoch_end],
"on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
"on_model_save": [on_model_save],
"on_train_end": [on_train_end],
"on_params_update": [on_params_update],
"teardown": [teardown],
# Run in validator
'on_val_start': [on_val_start],
'on_val_batch_start': [on_val_batch_start],
'on_val_batch_end': [on_val_batch_end],
'on_val_end': [on_val_end],
"on_val_start": [on_val_start],
"on_val_batch_start": [on_val_batch_start],
"on_val_batch_end": [on_val_batch_end],
"on_val_end": [on_val_end],
# Run in predictor
'on_predict_start': [on_predict_start],
'on_predict_batch_start': [on_predict_batch_start],
'on_predict_postprocess_end': [on_predict_postprocess_end],
'on_predict_batch_end': [on_predict_batch_end],
'on_predict_end': [on_predict_end],
"on_predict_start": [on_predict_start],
"on_predict_batch_start": [on_predict_batch_start],
"on_predict_postprocess_end": [on_predict_postprocess_end],
"on_predict_batch_end": [on_predict_batch_end],
"on_predict_end": [on_predict_end],
# Run in exporter
'on_export_start': [on_export_start],
'on_export_end': [on_export_end]}
"on_export_start": [on_export_start],
"on_export_end": [on_export_end],
}
def get_default_callbacks():
@ -199,10 +196,11 @@ def add_integration_callbacks(instance):
# Load HUB callbacks
from .hub import callbacks as hub_cb
callbacks_list = [hub_cb]
# Load training callbacks
if 'Trainer' in instance.__class__.__name__:
if "Trainer" in instance.__class__.__name__:
from .clearml import callbacks as clear_cb
from .comet import callbacks as comet_cb
from .dvc import callbacks as dvc_cb
@ -211,12 +209,8 @@ def add_integration_callbacks(instance):
from .raytune import callbacks as tune_cb
from .tensorboard import callbacks as tb_cb
from .wb import callbacks as wb_cb
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
# Load export callbacks (patch to avoid CoreML protobuf error)
if 'Exporter' in instance.__class__.__name__:
from .tensorboard import callbacks as tb_cb
callbacks_list.append(tb_cb)
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
# Add the callbacks to the callbacks dictionary
for callbacks in callbacks_list:

View File

@ -4,19 +4,19 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['clearml'] is True # verify integration is enabled
assert SETTINGS["clearml"] is True # verify integration is enabled
import clearml
from clearml import Task
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from clearml.binding.matplotlib_bind import PatchedMatplotlib
assert hasattr(clearml, '__version__') # verify package is not directory
assert hasattr(clearml, "__version__") # verify package is not directory
except (ImportError, AssertionError):
clearml = None
def _log_debug_samples(files, title='Debug Samples') -> None:
def _log_debug_samples(files, title="Debug Samples") -> None:
"""
Log files (images) as debug samples in the ClearML task.
@ -29,12 +29,11 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
if task := Task.current_task():
for f in files:
if f.exists():
it = re.search(r'_batch(\d+)', f.name)
it = re.search(r"_batch(\d+)", f.name)
iteration = int(it.groups()[0]) if it else 0
task.get_logger().report_image(title=title,
series=f.name.replace(it.group(), ''),
local_path=str(f),
iteration=iteration)
task.get_logger().report_image(
title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
)
def _log_plot(title, plot_path) -> None:
@ -50,13 +49,12 @@ def _log_plot(title, plot_path) -> None:
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
Task.current_task().get_logger().report_matplotlib_figure(title=title,
series='',
figure=fig,
report_interactive=False)
Task.current_task().get_logger().report_matplotlib_figure(
title=title, series="", figure=fig, report_interactive=False
)
def on_pretrain_routine_start(trainer):
@ -68,19 +66,21 @@ def on_pretrain_routine_start(trainer):
PatchPyTorchModelIO.update_current_task(None)
PatchedMatplotlib.update_current_task(None)
else:
task = Task.init(project_name=trainer.args.project or 'YOLOv8',
task_name=trainer.args.name,
tags=['YOLOv8'],
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={
'pytorch': False,
'matplotlib': False})
LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, '
'please add clearml-init and connect your arguments before initializing YOLO.')
task.connect(vars(trainer.args), name='General')
task = Task.init(
project_name=trainer.args.project or "YOLOv8",
task_name=trainer.args.name,
tags=["YOLOv8"],
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={"pytorch": False, "matplotlib": False},
)
LOGGER.warning(
"ClearML Initialized a new task. If you want to run remotely, "
"please add clearml-init and connect your arguments before initializing YOLO."
)
task.connect(vars(trainer.args), name="General")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
@ -88,22 +88,26 @@ def on_train_epoch_end(trainer):
if task := Task.current_task():
# Log debug samples
if trainer.epoch == 1:
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
# Report the current training progress
for k, v in trainer.validator.metrics.results_dict.items():
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
for k, v in trainer.lr.items():
task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
def on_fit_epoch_end(trainer):
"""Reports model information to logger at the end of an epoch."""
if task := Task.current_task():
# You should have access to the validation bboxes under jdict
task.get_logger().report_scalar(title='Epoch Time',
series='Epoch Time',
value=trainer.epoch_time,
iteration=trainer.epoch)
task.get_logger().report_scalar(
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
)
for k, v in trainer.metrics.items():
task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for k, v in model_info_for_loggers(trainer).items():
task.get_logger().report_single_value(k, v)
@ -112,7 +116,7 @@ def on_val_end(validator):
"""Logs validation results including labels and predictions."""
if Task.current_task():
# Log val_labels and val_pred
_log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
def on_train_end(trainer):
@ -120,8 +124,11 @@ def on_train_end(trainer):
if task := Task.current_task():
# Log final results, CM matrix + PR plots
files = [
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
@ -132,9 +139,14 @@ def on_train_end(trainer):
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_val_end': on_val_end,
'on_train_end': on_train_end} if clearml else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if clearml
else {}
)

View File

@ -4,20 +4,20 @@ from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['comet'] is True # verify integration is enabled
assert SETTINGS["comet"] is True # verify integration is enabled
import comet_ml
assert hasattr(comet_ml, '__version__') # verify package is not directory
assert hasattr(comet_ml, "__version__") # verify package is not directory
import os
from pathlib import Path
# Ensures certain logging functions only run for supported tasks
COMET_SUPPORTED_TASKS = ['detect']
COMET_SUPPORTED_TASKS = ["detect"]
# Names of plots created by YOLOv8 that are logged to Comet
EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
LABEL_PLOT_NAMES = "labels", "labels_correlogram"
_comet_image_prediction_count = 0
@ -26,37 +26,44 @@ except (ImportError, AssertionError):
def _get_comet_mode():
return os.getenv('COMET_MODE', 'online')
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
return os.getenv("COMET_MODE", "online")
def _get_comet_model_name():
return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
return os.getenv("COMET_MODEL_NAME", "YOLOv8")
def _get_eval_batch_logging_interval():
return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
def _get_max_image_predictions_to_log():
return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
"""Get the maximum number of image predictions to log from the environment variables."""
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
def _scale_confidence_score(score):
scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
"""Scales the given confidence score by a factor specified in an environment variable."""
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
return score * scale
def _should_log_confusion_matrix():
return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
def _should_log_image_predictions():
return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
"""Determines whether to log image predictions based on a specified environment variable."""
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
def _get_experiment_type(mode, project_name):
"""Return an experiment based on mode and project name."""
if mode == 'offline':
if mode == "offline":
return comet_ml.OfflineExperiment(project_name=project_name)
return comet_ml.Experiment(project_name=project_name)
@ -68,18 +75,21 @@ def _create_experiment(args):
return
try:
comet_mode = _get_comet_mode()
_project_name = os.getenv('COMET_PROJECT_NAME', args.project)
_project_name = os.getenv("COMET_PROJECT_NAME", args.project)
experiment = _get_experiment_type(comet_mode, _project_name)
experiment.log_parameters(vars(args))
experiment.log_others({
'eval_batch_logging_interval': _get_eval_batch_logging_interval(),
'log_confusion_matrix_on_eval': _should_log_confusion_matrix(),
'log_image_predictions': _should_log_image_predictions(),
'max_image_predictions': _get_max_image_predictions_to_log(), })
experiment.log_other('Created from', 'yolov8')
experiment.log_others(
{
"eval_batch_logging_interval": _get_eval_batch_logging_interval(),
"log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
"log_image_predictions": _should_log_image_predictions(),
"max_image_predictions": _get_max_image_predictions_to_log(),
}
)
experiment.log_other("Created from", "yolov8")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
def _fetch_trainer_metadata(trainer):
@ -95,18 +105,14 @@ def _fetch_trainer_metadata(trainer):
save_interval = curr_epoch % save_period == 0
save_assets = save and save_period > 0 and save_interval and not final_epoch
return dict(
curr_epoch=curr_epoch,
curr_step=curr_step,
save_assets=save_assets,
final_epoch=final_epoch,
)
return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
"""YOLOv8 resizes images during training and the label values
are normalized based on this resized shape. This function rescales the
bounding box labels to the original image shape.
"""
YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
This function rescales the bounding box labels to the original image shape.
"""
resized_image_height, resized_image_width = resized_image_shape
@ -126,29 +132,32 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
"""Format ground truth annotations for detection."""
indices = batch['batch_idx'] == img_idx
bboxes = batch['bboxes'][indices]
indices = batch["batch_idx"] == img_idx
bboxes = batch["bboxes"][indices]
if len(bboxes) == 0:
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes labels')
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
return None
cls_labels = batch['cls'][indices].squeeze(1).tolist()
cls_labels = batch["cls"][indices].squeeze(1).tolist()
if class_name_map:
cls_labels = [str(class_name_map[label]) for label in cls_labels]
original_image_shape = batch['ori_shape'][img_idx]
resized_image_shape = batch['resized_shape'][img_idx]
ratio_pad = batch['ratio_pad'][img_idx]
original_image_shape = batch["ori_shape"][img_idx]
resized_image_shape = batch["resized_shape"][img_idx]
ratio_pad = batch["ratio_pad"][img_idx]
data = []
for box, label in zip(bboxes, cls_labels):
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
data.append({
'boxes': [box],
'label': f'gt_{label}',
'score': _scale_confidence_score(1.0), })
data.append(
{
"boxes": [box],
"label": f"gt_{label}",
"score": _scale_confidence_score(1.0),
}
)
return {'name': 'ground_truth', 'data': data}
return {"name": "ground_truth", "data": data}
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
@ -158,31 +167,34 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab
predictions = metadata.get(image_id)
if not predictions:
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes predictions')
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
return None
data = []
for prediction in predictions:
boxes = prediction['bbox']
score = _scale_confidence_score(prediction['score'])
cls_label = prediction['category_id']
boxes = prediction["bbox"]
score = _scale_confidence_score(prediction["score"])
cls_label = prediction["category_id"]
if class_label_map:
cls_label = str(class_label_map[cls_label])
data.append({'boxes': [boxes], 'label': cls_label, 'score': score})
data.append({"boxes": [boxes], "label": cls_label, "score": score})
return {'name': 'prediction', 'data': data}
return {"name": "prediction", "data": data}
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
"""Join the ground truth and prediction annotations if they exist."""
ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch,
class_label_map)
prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map,
class_label_map)
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
img_idx, image_path, batch, class_label_map
)
prediction_annotations = _format_prediction_annotations_for_detection(
image_path, prediction_metadata_map, class_label_map
)
annotations = [
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None]
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
]
return [annotations] if annotations else None
@ -190,8 +202,8 @@ def _create_prediction_metadata_map(model_predictions):
"""Create metadata map for model predictions by groupings them based on image ID."""
pred_metadata_map = {}
for prediction in model_predictions:
pred_metadata_map.setdefault(prediction['image_id'], [])
pred_metadata_map[prediction['image_id']].append(prediction)
pred_metadata_map.setdefault(prediction["image_id"], [])
pred_metadata_map[prediction["image_id"]].append(prediction)
return pred_metadata_map
@ -199,13 +211,9 @@ def _create_prediction_metadata_map(model_predictions):
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
"""Log the confusion matrix to Comet experiment."""
conf_mat = trainer.validator.confusion_matrix.matrix
names = list(trainer.data['names'].values()) + ['background']
names = list(trainer.data["names"].values()) + ["background"]
experiment.log_confusion_matrix(
matrix=conf_mat,
labels=names,
max_categories=len(names),
epoch=curr_epoch,
step=curr_step,
matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
)
@ -243,7 +251,7 @@ def _log_image_predictions(experiment, validator, curr_step):
if (batch_idx + 1) % batch_logging_interval != 0:
continue
image_paths = batch['im_file']
image_paths = batch["im_file"]
for img_idx, image_path in enumerate(image_paths):
if _comet_image_prediction_count >= max_image_predictions:
return
@ -267,28 +275,23 @@ def _log_image_predictions(experiment, validator, curr_step):
def _log_plots(experiment, trainer):
"""Logs evaluation plots and label plots for the experiment."""
plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES]
plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
_log_images(experiment, plot_filenames, None)
label_plot_filenames = [trainer.save_dir / f'{labels}.jpg' for labels in LABEL_PLOT_NAMES]
label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
_log_images(experiment, label_plot_filenames, None)
def _log_model(experiment, trainer):
"""Log the best-trained model to Comet.ml."""
model_name = _get_comet_model_name()
experiment.log_model(
model_name,
file_or_folder=str(trainer.best),
file_name='best.pt',
overwrite=True,
)
experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
def on_pretrain_routine_start(trainer):
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
experiment = comet_ml.get_global_experiment()
is_alive = getattr(experiment, 'alive', False)
is_alive = getattr(experiment, "alive", False)
if not experiment or not is_alive:
_create_experiment(trainer.args)
@ -300,17 +303,13 @@ def on_train_epoch_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata['curr_epoch']
curr_step = metadata['curr_step']
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
experiment.log_metrics(
trainer.label_loss_items(trainer.tloss, prefix='train'),
step=curr_step,
epoch=curr_epoch,
)
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
_log_images(experiment, trainer.save_dir.glob('train_batch*.jpg'), curr_step)
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
def on_fit_epoch_end(trainer):
@ -320,14 +319,15 @@ def on_fit_epoch_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata['curr_epoch']
curr_step = metadata['curr_step']
save_assets = metadata['save_assets']
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
save_assets = metadata["save_assets"]
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
from ultralytics.utils.torch_utils import model_info_for_loggers
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
if not save_assets:
@ -347,8 +347,8 @@ def on_train_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata['curr_epoch']
curr_step = metadata['curr_step']
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
plots = trainer.args.plots
_log_model(experiment, trainer)
@ -363,8 +363,13 @@ def on_train_end(trainer):
_comet_image_prediction_count = 0
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if comet_ml else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if comet_ml
else {}
)

View File

@ -1,26 +1,18 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['dvc'] is True # verify integration is enabled
assert SETTINGS["dvc"] is True # verify integration is enabled
import dvclive
assert hasattr(dvclive, '__version__') # verify package is not directory
assert checks.check_version("dvclive", "2.11.0", verbose=True)
import os
import re
from importlib.metadata import version
from pathlib import Path
import pkg_resources as pkg
ver = version('dvclive')
if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
dvclive = None # noqa: F811
# DVCLive logger instance
live = None
_processed_plots = {}
@ -33,108 +25,121 @@ except (ImportError, AssertionError, TypeError):
dvclive = None
def _log_images(path, prefix=''):
def _log_images(path, prefix=""):
"""Logs images at specified path with an optional prefix using DVCLive."""
if live:
name = path.name
# Group images by batch to enable sliders in UI
if m := re.search(r'_batch(\d+)', name):
if m := re.search(r"_batch(\d+)", name):
ni = m[1]
new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
name = (Path(new_stem) / ni).with_suffix(path.suffix)
live.log_image(os.path.join(prefix, name), path)
def _log_plots(plots, prefix=''):
def _log_plots(plots, prefix=""):
"""Logs plot images for training progress if they have not been previously processed."""
for name, params in plots.items():
timestamp = params['timestamp']
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
_log_images(name, prefix)
_processed_plots[name] = timestamp
def _log_confusion_matrix(validator):
"""Logs the confusion matrix for the given validator using DVCLive."""
targets = []
preds = []
matrix = validator.confusion_matrix.matrix
names = list(validator.names.values())
if validator.confusion_matrix.task == 'detect':
names += ['background']
if validator.confusion_matrix.task == "detect":
names += ["background"]
for ti, pred in enumerate(matrix.T.astype(int)):
for pi, num in enumerate(pred):
targets.extend([names[ti]] * num)
preds.extend([names[pi]] * num)
live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)
live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
def on_pretrain_routine_start(trainer):
"""Initializes DVCLive logger for training metadata during pre-training routine."""
try:
global live
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
LOGGER.info(
f'DVCLive is detected and auto logging is enabled (can be disabled in the {SETTINGS.file} with `dvc: false`).'
)
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
def on_pretrain_routine_end(trainer):
_log_plots(trainer.plots, 'train')
"""Logs plots related to the training process at the end of the pretraining routine."""
_log_plots(trainer.plots, "train")
def on_train_start(trainer):
"""Logs the training parameters if DVCLive logging is active."""
if live:
live.log_params(trainer.args)
def on_train_epoch_start(trainer):
"""Sets the global variable _training_epoch value to True at the start of training each epoch."""
global _training_epoch
_training_epoch = True
def on_fit_epoch_end(trainer):
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
global _training_epoch
if live and _training_epoch:
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for metric, value in model_info_for_loggers(trainer).items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, 'train')
_log_plots(trainer.validator.plots, 'val')
_log_plots(trainer.plots, "train")
_log_plots(trainer.validator.plots, "val")
live.next_step()
_training_epoch = False
def on_train_end(trainer):
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
if live:
# At the end log the best metrics. It runs validator on the best model internally.
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, 'val')
_log_plots(trainer.validator.plots, 'val')
_log_plots(trainer.plots, "val")
_log_plots(trainer.validator.plots, "val")
_log_confusion_matrix(trainer.validator)
if trainer.best.exists():
live.log_artifact(trainer.best, copy=True, type='model')
live.log_artifact(trainer.best, copy=True, type="model")
live.end()
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_train_start': on_train_start,
'on_train_epoch_start': on_train_epoch_start,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if dvclive else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_start": on_train_start,
"on_train_epoch_start": on_train_epoch_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if dvclive
else {}
)

View File

@ -9,51 +9,67 @@ from ultralytics.utils import LOGGER, SETTINGS
def on_pretrain_routine_end(trainer):
"""Logs info before starting timer for upload rate limit."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Start timer for upload rate limit
LOGGER.info(f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
session.timers = {
"metrics": time(),
"ckpt": time(),
} # start timer on session.rate_limit
def on_fit_epoch_end(trainer):
"""Uploads training progress metrics at the end of each epoch."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Upload metrics after val end
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
all_plots = {
**trainer.label_loss_items(trainer.tloss, prefix="train"),
**trainer.metrics,
}
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
# If any metrics fail to upload, add them to the queue to attempt uploading again.
if session.metrics_upload_failed_queue:
session.metrics_queue.update(session.metrics_upload_failed_queue)
if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
session.upload_metrics()
session.timers['metrics'] = time() # reset timer
session.timers["metrics"] = time() # reset timer
session.metrics_queue = {} # reset queue
def on_model_save(trainer):
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}')
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}")
session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers['ckpt'] = time() # reset timer
session.timers["ckpt"] = time() # reset timer
def on_train_end(trainer):
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
session = getattr(trainer, 'hub_session', None)
session = getattr(trainer, "hub_session", None)
if session:
# Upload final model and metrics with exponential standoff
LOGGER.info(f'{PREFIX}Syncing final model...')
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
LOGGER.info(f"{PREFIX}Syncing final model...")
session.upload_model(
trainer.epoch,
trainer.best,
map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
final=True,
)
session.alive = False # stop heartbeats
LOGGER.info(f'{PREFIX}Done ✅\n'
f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀")
def on_train_start(trainer):
@ -76,12 +92,17 @@ def on_export_start(exporter):
events(exporter.args)
callbacks = {
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_model_save': on_model_save,
'on_train_end': on_train_end,
'on_train_start': on_train_start,
'on_val_start': on_val_start,
'on_predict_start': on_predict_start,
'on_export_start': on_export_start} if SETTINGS['hub'] is True else {} # verify enabled
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_model_save": on_model_save,
"on_train_end": on_train_end,
"on_train_start": on_train_start,
"on_val_start": on_val_start,
"on_predict_start": on_predict_start,
"on_export_start": on_export_start,
}
if SETTINGS["hub"] is True
else {}
) # verify enabled

View File

@ -1,70 +1,133 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
MLflow Logging for Ultralytics YOLO.
from ultralytics.utils import LOGGER, ROOT, SETTINGS, TESTS_RUNNING, colorstr
This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts.
For setting up, a tracking URI should be specified. The logging can be customized using environment variables.
Commands:
1. To set a project name:
`export MLFLOW_EXPERIMENT_NAME=<your_experiment_name>` or use the project=<project> argument
2. To set a run name:
`export MLFLOW_RUN=<your_run_name>` or use the name=<name> argument
3. To start a local MLflow server:
mlflow server --backend-store-uri runs/mlflow
It will by default start a local server at http://127.0.0.1:5000.
To specify a different URI, set the MLFLOW_TRACKING_URI environment variable.
4. To kill all running MLflow server instances:
ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9
"""
from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['mlflow'] is True # verify integration is enabled
import os
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
assert SETTINGS["mlflow"] is True # verify integration is enabled
import mlflow
assert hasattr(mlflow, '__version__') # verify package is not directory
assert hasattr(mlflow, "__version__") # verify package is not directory
from pathlib import Path
import os
import re
PREFIX = colorstr("MLflow: ")
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
except (ImportError, AssertionError):
mlflow = None
def on_pretrain_routine_end(trainer):
"""Logs training parameters to MLflow."""
global mlflow, run, experiment_name
"""
Log training parameters to MLflow at the end of the pretraining routine.
if os.environ.get('MLFLOW_TRACKING_URI') is None:
mlflow = None
This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
from the trainer.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
Global:
mlflow: The imported mlflow module to use for logging.
Environment Variables:
MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of the training phase.
"""
global mlflow
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
mlflow.set_tracking_uri(uri)
# Set experiment and run names
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
mlflow.set_experiment(experiment_name)
mlflow.autolog()
try:
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
if Path(uri).is_dir():
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
mlflow.log_params(dict(trainer.args))
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
def on_train_epoch_end(trainer):
"""Log training metrics at the end of each train epoch to MLflow."""
if mlflow:
mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
mlflow.set_tracking_uri(mlflow_location)
experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment is None:
mlflow.create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)
prefix = colorstr('MLFlow: ')
try:
run, active_run = mlflow, mlflow.active_run()
if not active_run:
active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name)
LOGGER.info(f'{prefix}Using run_id({active_run.info.run_id}) at {mlflow_location}')
run.log_params(vars(trainer.model.args))
except Exception as err:
LOGGER.error(f'{prefix}Failing init - {repr(err)}')
LOGGER.warning(f'{prefix}Continuing without Mlflow')
mlflow.log_metrics(
metrics={
**SANITIZE(trainer.lr),
**SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")),
},
step=trainer.epoch,
)
def on_fit_epoch_end(trainer):
"""Logs training metrics to Mlflow."""
"""Log training metrics at the end of each fit epoch to MLflow."""
if mlflow:
metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch)
def on_train_end(trainer):
"""Called at end of train loop to log model artifact info."""
"""Log model artifacts at the end of the training."""
if mlflow:
run.log_artifact(trainer.last)
run.log_artifact(trainer.best)
run.pyfunc.log_model(artifact_path=experiment_name,
code_path=[str(ROOT.parent)],
artifacts={'model_path': str(trainer.save_dir)},
python_model=run.pyfunc.PythonModel())
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
mlflow.log_artifact(str(f))
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true")
if keep_run_active:
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
else:
mlflow.end_run()
LOGGER.debug(f"{PREFIX}mlflow run ended")
LOGGER.info(
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
f"{PREFIX}disable with 'yolo settings mlflow=False'"
)
callbacks = {
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if mlflow else {}
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if mlflow
else {}
)

View File

@ -4,11 +4,11 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['neptune'] is True # verify integration is enabled
assert SETTINGS["neptune"] is True # verify integration is enabled
import neptune
from neptune.types import File
assert hasattr(neptune, '__version__')
assert hasattr(neptune, "__version__")
run = None # NeptuneAI experiment logger instance
@ -23,55 +23,55 @@ def _log_scalars(scalars, step=0):
run[k].append(value=v, step=step)
def _log_images(imgs_dict, group=''):
def _log_images(imgs_dict, group=""):
"""Log scalars to the NeptuneAI experiment logger."""
if run:
for k, v in imgs_dict.items():
run[f'{group}/{k}'].upload(File(v))
run[f"{group}/{k}"].upload(File(v))
def _log_plot(title, plot_path):
"""Log plots to the NeptuneAI experiment logger."""
"""
Log image as plot in the plot section of NeptuneAI
Log plots to the NeptuneAI experiment logger.
arguments:
title (str) Title of the plot
plot_path (PosixPath or str) Path to the saved image file
"""
Args:
title (str): Title of the plot.
plot_path (PosixPath | str): Path to the saved image file.
"""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
run[f'Plots/{title}'].upload(fig)
run[f"Plots/{title}"].upload(fig)
def on_pretrain_routine_start(trainer):
"""Callback function called before the training routine starts."""
try:
global run
run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
"""Callback function called at end of each training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
if trainer.epoch == 1:
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
def on_fit_epoch_end(trainer):
"""Callback function called at end of each fit (train+val) epoch."""
if run and trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
run['Configuration/Model'] = model_info_for_loggers(trainer)
run["Configuration/Model"] = model_info_for_loggers(trainer)
_log_scalars(trainer.metrics, trainer.epoch + 1)
@ -79,7 +79,7 @@ def on_val_end(validator):
"""Callback function called at end of each validation."""
if run:
# Log val_labels and val_pred
_log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
def on_train_end(trainer):
@ -87,19 +87,26 @@ def on_train_end(trainer):
if run:
# Log final results, CM matrix + PR plots
files = [
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Log the final model
run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
trainer.best)))
run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_val_end': on_val_end,
'on_train_end': on_train_end} if neptune else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if neptune
else {}
)

View File

@ -3,7 +3,7 @@
from ultralytics.utils import SETTINGS
try:
assert SETTINGS['raytune'] is True # verify integration is enabled
assert SETTINGS["raytune"] is True # verify integration is enabled
import ray
from ray import tune
from ray.air import session
@ -16,9 +16,14 @@ def on_fit_epoch_end(trainer):
"""Sends training metrics to Ray Tune at end of each epoch."""
if ray.tune.is_session_enabled():
metrics = trainer.metrics
metrics['epoch'] = trainer.epoch
metrics["epoch"] = trainer.epoch
session.report(metrics)
callbacks = {
'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}
callbacks = (
{
"on_fit_epoch_end": on_fit_epoch_end,
}
if tune
else {}
)

View File

@ -1,17 +1,25 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
try:
# WARNING: do not move import due to protobuf issue in https://github.com/ultralytics/ultralytics/pull/4674
# WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
from torch.utils.tensorboard import SummaryWriter
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['tensorboard'] is True # verify integration is enabled
assert SETTINGS["tensorboard"] is True # verify integration is enabled
WRITER = None # TensorBoard SummaryWriter instance
PREFIX = colorstr("TensorBoard: ")
except (ImportError, AssertionError, TypeError):
# Imports below only required if TensorBoard enabled
import warnings
from copy import deepcopy
from ultralytics.utils.torch_utils import de_parallel, torch
except (ImportError, AssertionError, TypeError, AttributeError):
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
# AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
SummaryWriter = None
@ -24,20 +32,38 @@ def _log_scalars(scalars, step=0):
def _log_tensorboard_graph(trainer):
"""Log model graph to TensorBoard."""
try:
import warnings
from ultralytics.utils.torch_utils import de_parallel, torch
# Input image
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
# Try simple method first (YOLO)
with contextlib.suppress(Exception):
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(de_parallel(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
def on_pretrain_routine_start(trainer):
@ -46,10 +72,9 @@ def on_pretrain_routine_start(trainer):
try:
global WRITER
WRITER = SummaryWriter(str(trainer.save_dir))
prefix = colorstr('TensorBoard: ')
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
def on_train_start(trainer):
@ -58,9 +83,10 @@ def on_train_start(trainer):
_log_tensorboard_graph(trainer)
def on_batch_end(trainer):
"""Logs scalar statistics at the end of a training batch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
def on_train_epoch_end(trainer):
"""Logs scalar statistics at the end of a training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
def on_fit_epoch_end(trainer):
@ -68,8 +94,13 @@ def on_fit_epoch_end(trainer):
_log_scalars(trainer.metrics, trainer.epoch + 1)
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_start': on_train_start,
'on_fit_epoch_end': on_fit_epoch_end,
'on_batch_end': on_batch_end} if SummaryWriter else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_start": on_train_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_epoch_end": on_train_epoch_end,
}
if SummaryWriter
else {}
)

View File

@ -5,10 +5,13 @@ from ultralytics.utils.torch_utils import model_info_for_loggers
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['wandb'] is True # verify integration is enabled
assert SETTINGS["wandb"] is True # verify integration is enabled
import wandb as wb
assert hasattr(wb, '__version__')
assert hasattr(wb, "__version__") # verify package is not directory
import numpy as np
import pandas as pd
_processed_plots = {}
@ -16,9 +19,89 @@ except (ImportError, AssertionError):
wb = None
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
"""
Create and log a custom metric visualization to wandb.plot.pr_curve.
This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
different classes.
Args:
x (List): Values for the x-axis; expected to have length N.
y (List): Corresponding values for the y-axis; also expected to have length N.
classes (List): Labels identifying the class of each point; length N.
title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
y_title (str, optional): Label for the y-axis; defaults to 'Precision'.
Returns:
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
"""
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
fields = {"x": "x", "y": "y", "class": "class"}
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
return wb.plot_table(
"wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
)
def _plot_curve(
x,
y,
names=None,
id="precision-recall",
title="Precision Recall Curve",
x_title="Recall",
y_title="Precision",
num_x=100,
only_mean=False,
):
"""
Log a metric curve visualization.
This function generates a metric curve based on input data and logs the visualization to wandb.
The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
Args:
x (np.ndarray): Data points for the x-axis with length N.
y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.
Note:
The function leverages the '_custom_table' function to generate the actual visualization.
"""
# Create new x
if names is None:
names = []
x_new = np.linspace(x[0], x[-1], num_x).round(5)
# Create arrays for logging
x_log = x_new.tolist()
y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
if only_mean:
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
else:
classes = ["mean"] * len(x_log)
for i, yi in enumerate(y):
x_log.extend(x_new) # add new x
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
classes.extend([names[i]] * len(x_new)) # add class names
wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
def _log_plots(plots, step):
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
for name, params in plots.items():
timestamp = params['timestamp']
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
_processed_plots[name] = timestamp
@ -26,7 +109,7 @@ def _log_plots(plots, step):
def on_pretrain_routine_start(trainer):
"""Initiate and start project if module is present."""
wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args))
wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
def on_fit_epoch_end(trainer):
@ -40,7 +123,7 @@ def on_fit_epoch_end(trainer):
def on_train_epoch_end(trainer):
"""Log metrics and save images at the end of each training epoch."""
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
wb.run.log(trainer.lr, step=trainer.epoch + 1)
if trainer.epoch == 1:
_log_plots(trainer.plots, step=trainer.epoch + 1)
@ -50,14 +133,31 @@ def on_train_end(trainer):
"""Save the best model as an artifact at end of training."""
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
_log_plots(trainer.plots, step=trainer.epoch + 1)
art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
if trainer.best.exists():
art.add_file(trainer.best)
wb.run.log_artifact(art, aliases=['best'])
wb.run.log_artifact(art, aliases=["best"])
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
x, y, x_title, y_title = curve_values
_plot_curve(
x,
y,
names=list(trainer.validator.metrics.names.values()),
id=f"curves/{curve_name}",
title=curve_name,
x_title=x_title,
y_title=y_title,
)
wb.run.finish() # required or run continues on dashboard
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if wb else {}
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if wb
else {}
)