This commit is contained in:
lee
2025-06-18 14:35:43 +08:00
commit e474ab5f9f
529 changed files with 80523 additions and 0 deletions

View File

@ -0,0 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"

View File

@ -0,0 +1,219 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Base callbacks."""
from collections import defaultdict
from copy import deepcopy
# Trainer callbacks ----------------------------------------------------------------------------------------------------
def on_pretrain_routine_start(trainer):
"""Called before the pretraining routine starts."""
pass
def on_pretrain_routine_end(trainer):
"""Called after the pretraining routine ends."""
pass
def on_train_start(trainer):
"""Called when the training starts."""
pass
def on_train_epoch_start(trainer):
"""Called at the start of each training epoch."""
pass
def on_train_batch_start(trainer):
"""Called at the start of each training batch."""
pass
def optimizer_step(trainer):
"""Called when the optimizer takes a step."""
pass
def on_before_zero_grad(trainer):
"""Called before the gradients are set to zero."""
pass
def on_train_batch_end(trainer):
"""Called at the end of each training batch."""
pass
def on_train_epoch_end(trainer):
"""Called at the end of each training epoch."""
pass
def on_fit_epoch_end(trainer):
"""Called at the end of each fit epoch (train + val)."""
pass
def on_model_save(trainer):
"""Called when the model is saved."""
pass
def on_train_end(trainer):
"""Called when the training ends."""
pass
def on_params_update(trainer):
"""Called when the model parameters are updated."""
pass
def teardown(trainer):
"""Called during the teardown of the training process."""
pass
# Validator callbacks --------------------------------------------------------------------------------------------------
def on_val_start(validator):
"""Called when the validation starts."""
pass
def on_val_batch_start(validator):
"""Called at the start of each validation batch."""
pass
def on_val_batch_end(validator):
"""Called at the end of each validation batch."""
pass
def on_val_end(validator):
"""Called when the validation ends."""
pass
# Predictor callbacks --------------------------------------------------------------------------------------------------
def on_predict_start(predictor):
"""Called when the prediction starts."""
pass
def on_predict_batch_start(predictor):
"""Called at the start of each prediction batch."""
pass
def on_predict_batch_end(predictor):
"""Called at the end of each prediction batch."""
pass
def on_predict_postprocess_end(predictor):
"""Called after the post-processing of the prediction ends."""
pass
def on_predict_end(predictor):
"""Called when the prediction ends."""
pass
# Exporter callbacks ---------------------------------------------------------------------------------------------------
def on_export_start(exporter):
"""Called when the model export starts."""
pass
def on_export_end(exporter):
"""Called when the model export ends."""
pass
default_callbacks = {
# Run in trainer
"on_pretrain_routine_start": [on_pretrain_routine_start],
"on_pretrain_routine_end": [on_pretrain_routine_end],
"on_train_start": [on_train_start],
"on_train_epoch_start": [on_train_epoch_start],
"on_train_batch_start": [on_train_batch_start],
"optimizer_step": [optimizer_step],
"on_before_zero_grad": [on_before_zero_grad],
"on_train_batch_end": [on_train_batch_end],
"on_train_epoch_end": [on_train_epoch_end],
"on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
"on_model_save": [on_model_save],
"on_train_end": [on_train_end],
"on_params_update": [on_params_update],
"teardown": [teardown],
# Run in validator
"on_val_start": [on_val_start],
"on_val_batch_start": [on_val_batch_start],
"on_val_batch_end": [on_val_batch_end],
"on_val_end": [on_val_end],
# Run in predictor
"on_predict_start": [on_predict_start],
"on_predict_batch_start": [on_predict_batch_start],
"on_predict_postprocess_end": [on_predict_postprocess_end],
"on_predict_batch_end": [on_predict_batch_end],
"on_predict_end": [on_predict_end],
# Run in exporter
"on_export_start": [on_export_start],
"on_export_end": [on_export_end],
}
def get_default_callbacks():
"""
Return a copy of the default_callbacks dictionary with lists as default values.
Returns:
(defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values.
"""
return defaultdict(list, deepcopy(default_callbacks))
def add_integration_callbacks(instance):
"""
Add integration callbacks from various sources to the instance's callbacks.
Args:
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
of callback lists.
"""
# Load HUB callbacks
from .hub import callbacks as hub_cb
callbacks_list = [hub_cb]
# Load training callbacks
if "Trainer" in instance.__class__.__name__:
from .clearml import callbacks as clear_cb
from .comet import callbacks as comet_cb
from .dvc import callbacks as dvc_cb
from .mlflow import callbacks as mlflow_cb
from .neptune import callbacks as neptune_cb
from .raytune import callbacks as tune_cb
from .tensorboard import callbacks as tb_cb
from .wb import callbacks as wb_cb
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
# Add the callbacks to the callbacks dictionary
for callbacks in callbacks_list:
for k, v in callbacks.items():
if v not in instance.callbacks[k]:
instance.callbacks[k].append(v)

View File

@ -0,0 +1,152 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["clearml"] is True # verify integration is enabled
import clearml
from clearml import Task
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from clearml.binding.matplotlib_bind import PatchedMatplotlib
assert hasattr(clearml, "__version__") # verify package is not directory
except (ImportError, AssertionError):
clearml = None
def _log_debug_samples(files, title="Debug Samples") -> None:
"""
Log files (images) as debug samples in the ClearML task.
Args:
files (list): A list of file paths in PosixPath format.
title (str): A title that groups together images with the same values.
"""
import re
if task := Task.current_task():
for f in files:
if f.exists():
it = re.search(r"_batch(\d+)", f.name)
iteration = int(it.groups()[0]) if it else 0
task.get_logger().report_image(
title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
)
def _log_plot(title, plot_path) -> None:
"""
Log an image as a plot in the plot section of ClearML.
Args:
title (str): The title of the plot.
plot_path (str): The path to the saved image file.
"""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
Task.current_task().get_logger().report_matplotlib_figure(
title=title, series="", figure=fig, report_interactive=False
)
def on_pretrain_routine_start(trainer):
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
try:
if task := Task.current_task():
# Make sure the automatic pytorch and matplotlib bindings are disabled!
# We are logging these plots and model files manually in the integration
PatchPyTorchModelIO.update_current_task(None)
PatchedMatplotlib.update_current_task(None)
else:
task = Task.init(
project_name=trainer.args.project or "YOLOv8",
task_name=trainer.args.name,
tags=["YOLOv8"],
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={"pytorch": False, "matplotlib": False},
)
LOGGER.warning(
"ClearML Initialized a new task. If you want to run remotely, "
"please add clearml-init and connect your arguments before initializing YOLO."
)
task.connect(vars(trainer.args), name="General")
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
"""Logs debug samples for the first epoch of YOLO training and report current training progress."""
if task := Task.current_task():
# Log debug samples
if trainer.epoch == 1:
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
# Report the current training progress
for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
for k, v in trainer.lr.items():
task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
def on_fit_epoch_end(trainer):
"""Reports model information to logger at the end of an epoch."""
if task := Task.current_task():
# You should have access to the validation bboxes under jdict
task.get_logger().report_scalar(
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
)
for k, v in trainer.metrics.items():
task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for k, v in model_info_for_loggers(trainer).items():
task.get_logger().report_single_value(k, v)
def on_val_end(validator):
"""Logs validation results including labels and predictions."""
if Task.current_task():
# Log val_labels and val_pred
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
def on_train_end(trainer):
"""Logs final model and its name on training completion."""
if task := Task.current_task():
# Log final results, CM matrix + PR plots
files = [
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Report final metrics
for k, v in trainer.validator.metrics.results_dict.items():
task.get_logger().report_single_value(k, v)
# Log the final model
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if clearml
else {}
)

View File

@ -0,0 +1,375 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["comet"] is True # verify integration is enabled
import comet_ml
assert hasattr(comet_ml, "__version__") # verify package is not directory
import os
from pathlib import Path
# Ensures certain logging functions only run for supported tasks
COMET_SUPPORTED_TASKS = ["detect"]
# Names of plots created by YOLOv8 that are logged to Comet
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
LABEL_PLOT_NAMES = "labels", "labels_correlogram"
_comet_image_prediction_count = 0
except (ImportError, AssertionError):
comet_ml = None
def _get_comet_mode():
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
return os.getenv("COMET_MODE", "online")
def _get_comet_model_name():
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
return os.getenv("COMET_MODEL_NAME", "YOLOv8")
def _get_eval_batch_logging_interval():
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
def _get_max_image_predictions_to_log():
"""Get the maximum number of image predictions to log from the environment variables."""
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
def _scale_confidence_score(score):
"""Scales the given confidence score by a factor specified in an environment variable."""
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
return score * scale
def _should_log_confusion_matrix():
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
def _should_log_image_predictions():
"""Determines whether to log image predictions based on a specified environment variable."""
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
def _get_experiment_type(mode, project_name):
"""Return an experiment based on mode and project name."""
if mode == "offline":
return comet_ml.OfflineExperiment(project_name=project_name)
return comet_ml.Experiment(project_name=project_name)
def _create_experiment(args):
"""Ensures that the experiment object is only created in a single process during distributed training."""
if RANK not in (-1, 0):
return
try:
comet_mode = _get_comet_mode()
_project_name = os.getenv("COMET_PROJECT_NAME", args.project)
experiment = _get_experiment_type(comet_mode, _project_name)
experiment.log_parameters(vars(args))
experiment.log_others(
{
"eval_batch_logging_interval": _get_eval_batch_logging_interval(),
"log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
"log_image_predictions": _should_log_image_predictions(),
"max_image_predictions": _get_max_image_predictions_to_log(),
}
)
experiment.log_other("Created from", "yolov8")
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
def _fetch_trainer_metadata(trainer):
"""Returns metadata for YOLO training including epoch and asset saving status."""
curr_epoch = trainer.epoch + 1
train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
curr_step = curr_epoch * train_num_steps_per_epoch
final_epoch = curr_epoch == trainer.epochs
save = trainer.args.save
save_period = trainer.args.save_period
save_interval = curr_epoch % save_period == 0
save_assets = save and save_period > 0 and save_interval and not final_epoch
return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
"""
YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
This function rescales the bounding box labels to the original image shape.
"""
resized_image_height, resized_image_width = resized_image_shape
# Convert normalized xywh format predictions to xyxy in resized scale format
box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
# Scale box predictions from resized image scale back to original image scale
box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
# Convert bounding box format from xyxy to xywh for Comet logging
box = ops.xyxy2xywh(box)
# Adjust xy center to correspond top-left corner
box[:2] -= box[2:] / 2
box = box.tolist()
return box
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
"""Format ground truth annotations for detection."""
indices = batch["batch_idx"] == img_idx
bboxes = batch["bboxes"][indices]
if len(bboxes) == 0:
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
return None
cls_labels = batch["cls"][indices].squeeze(1).tolist()
if class_name_map:
cls_labels = [str(class_name_map[label]) for label in cls_labels]
original_image_shape = batch["ori_shape"][img_idx]
resized_image_shape = batch["resized_shape"][img_idx]
ratio_pad = batch["ratio_pad"][img_idx]
data = []
for box, label in zip(bboxes, cls_labels):
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
data.append(
{
"boxes": [box],
"label": f"gt_{label}",
"score": _scale_confidence_score(1.0),
}
)
return {"name": "ground_truth", "data": data}
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
"""Format YOLO predictions for object detection visualization."""
stem = image_path.stem
image_id = int(stem) if stem.isnumeric() else stem
predictions = metadata.get(image_id)
if not predictions:
LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
return None
data = []
for prediction in predictions:
boxes = prediction["bbox"]
score = _scale_confidence_score(prediction["score"])
cls_label = prediction["category_id"]
if class_label_map:
cls_label = str(class_label_map[cls_label])
data.append({"boxes": [boxes], "label": cls_label, "score": score})
return {"name": "prediction", "data": data}
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
"""Join the ground truth and prediction annotations if they exist."""
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
img_idx, image_path, batch, class_label_map
)
prediction_annotations = _format_prediction_annotations_for_detection(
image_path, prediction_metadata_map, class_label_map
)
annotations = [
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
]
return [annotations] if annotations else None
def _create_prediction_metadata_map(model_predictions):
"""Create metadata map for model predictions by groupings them based on image ID."""
pred_metadata_map = {}
for prediction in model_predictions:
pred_metadata_map.setdefault(prediction["image_id"], [])
pred_metadata_map[prediction["image_id"]].append(prediction)
return pred_metadata_map
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
"""Log the confusion matrix to Comet experiment."""
conf_mat = trainer.validator.confusion_matrix.matrix
names = list(trainer.data["names"].values()) + ["background"]
experiment.log_confusion_matrix(
matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
)
def _log_images(experiment, image_paths, curr_step, annotations=None):
"""Logs images to the experiment with optional annotations."""
if annotations:
for image_path, annotation in zip(image_paths, annotations):
experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
else:
for image_path in image_paths:
experiment.log_image(image_path, name=image_path.stem, step=curr_step)
def _log_image_predictions(experiment, validator, curr_step):
"""Logs predicted boxes for a single image during training."""
global _comet_image_prediction_count
task = validator.args.task
if task not in COMET_SUPPORTED_TASKS:
return
jdict = validator.jdict
if not jdict:
return
predictions_metadata_map = _create_prediction_metadata_map(jdict)
dataloader = validator.dataloader
class_label_map = validator.names
batch_logging_interval = _get_eval_batch_logging_interval()
max_image_predictions = _get_max_image_predictions_to_log()
for batch_idx, batch in enumerate(dataloader):
if (batch_idx + 1) % batch_logging_interval != 0:
continue
image_paths = batch["im_file"]
for img_idx, image_path in enumerate(image_paths):
if _comet_image_prediction_count >= max_image_predictions:
return
image_path = Path(image_path)
annotations = _fetch_annotations(
img_idx,
image_path,
batch,
predictions_metadata_map,
class_label_map,
)
_log_images(
experiment,
[image_path],
curr_step,
annotations=annotations,
)
_comet_image_prediction_count += 1
def _log_plots(experiment, trainer):
"""Logs evaluation plots and label plots for the experiment."""
plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
_log_images(experiment, plot_filenames, None)
label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
_log_images(experiment, label_plot_filenames, None)
def _log_model(experiment, trainer):
"""Log the best-trained model to Comet.ml."""
model_name = _get_comet_model_name()
experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best_gift_v10n.pt", overwrite=True)
def on_pretrain_routine_start(trainer):
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
experiment = comet_ml.get_global_experiment()
is_alive = getattr(experiment, "alive", False)
if not experiment or not is_alive:
_create_experiment(trainer.args)
def on_train_epoch_end(trainer):
"""Log metrics and save batch images at the end of training epochs."""
experiment = comet_ml.get_global_experiment()
if not experiment:
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
def on_fit_epoch_end(trainer):
"""Logs model assets at the end of each epoch."""
experiment = comet_ml.get_global_experiment()
if not experiment:
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
save_assets = metadata["save_assets"]
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
from ultralytics.utils.torch_utils import model_info_for_loggers
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
if not save_assets:
return
_log_model(experiment, trainer)
if _should_log_confusion_matrix():
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
if _should_log_image_predictions():
_log_image_predictions(experiment, trainer.validator, curr_step)
def on_train_end(trainer):
"""Perform operations at the end of training."""
experiment = comet_ml.get_global_experiment()
if not experiment:
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
plots = trainer.args.plots
_log_model(experiment, trainer)
if plots:
_log_plots(experiment, trainer)
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
_log_image_predictions(experiment, trainer.validator, curr_step)
experiment.end()
global _comet_image_prediction_count
_comet_image_prediction_count = 0
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if comet_ml
else {}
)

View File

@ -0,0 +1,145 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["dvc"] is True # verify integration is enabled
import dvclive
assert checks.check_version("dvclive", "2.11.0", verbose=True)
import os
import re
from pathlib import Path
# DVCLive logger instance
live = None
_processed_plots = {}
# `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
# distinguish final evaluation of the best model vs last epoch validation
_training_epoch = False
except (ImportError, AssertionError, TypeError):
dvclive = None
def _log_images(path, prefix=""):
"""Logs images at specified path with an optional prefix using DVCLive."""
if live:
name = path.name
# Group images by batch to enable sliders in UI
if m := re.search(r"_batch(\d+)", name):
ni = m[1]
new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
name = (Path(new_stem) / ni).with_suffix(path.suffix)
live.log_image(os.path.join(prefix, name), path)
def _log_plots(plots, prefix=""):
"""Logs plot images for training progress if they have not been previously processed."""
for name, params in plots.items():
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
_log_images(name, prefix)
_processed_plots[name] = timestamp
def _log_confusion_matrix(validator):
"""Logs the confusion matrix for the given validator using DVCLive."""
targets = []
preds = []
matrix = validator.confusion_matrix.matrix
names = list(validator.names.values())
if validator.confusion_matrix.task == "detect":
names += ["background"]
for ti, pred in enumerate(matrix.T.astype(int)):
for pi, num in enumerate(pred):
targets.extend([names[ti]] * num)
preds.extend([names[pi]] * num)
live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
def on_pretrain_routine_start(trainer):
"""Initializes DVCLive logger for training metadata during pre-training routine."""
try:
global live
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
def on_pretrain_routine_end(trainer):
"""Logs plots related to the training process at the end of the pretraining routine."""
_log_plots(trainer.plots, "train")
def on_train_start(trainer):
"""Logs the training parameters if DVCLive logging is active."""
if live:
live.log_params(trainer.args)
def on_train_epoch_start(trainer):
"""Sets the global variable _training_epoch value to True at the start of training each epoch."""
global _training_epoch
_training_epoch = True
def on_fit_epoch_end(trainer):
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
global _training_epoch
if live and _training_epoch:
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for metric, value in model_info_for_loggers(trainer).items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, "train")
_log_plots(trainer.validator.plots, "val")
live.next_step()
_training_epoch = False
def on_train_end(trainer):
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
if live:
# At the end log the best metrics. It runs validator on the best model internally.
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, "val")
_log_plots(trainer.validator.plots, "val")
_log_confusion_matrix(trainer.validator)
if trainer.best.exists():
live.log_artifact(trainer.best, copy=True, type="model")
live.end()
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_start": on_train_start,
"on_train_epoch_start": on_train_epoch_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if dvclive
else {}
)

View File

@ -0,0 +1,108 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import json
from time import time
from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
from ultralytics.utils import LOGGER, SETTINGS
def on_pretrain_routine_end(trainer):
"""Logs info before starting timer for upload rate limit."""
session = getattr(trainer, "hub_session", None)
if session:
# Start timer for upload rate limit
session.timers = {
"metrics": time(),
"ckpt": time(),
} # start timer on session.rate_limit
def on_fit_epoch_end(trainer):
"""Uploads training progress metrics at the end of each epoch."""
session = getattr(trainer, "hub_session", None)
if session:
# Upload metrics after val end
all_plots = {
**trainer.label_loss_items(trainer.tloss, prefix="train"),
**trainer.metrics,
}
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
# If any metrics fail to upload, add them to the queue to attempt uploading again.
if session.metrics_upload_failed_queue:
session.metrics_queue.update(session.metrics_upload_failed_queue)
if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
session.upload_metrics()
session.timers["metrics"] = time() # reset timer
session.metrics_queue = {} # reset queue
def on_model_save(trainer):
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
session = getattr(trainer, "hub_session", None)
if session:
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}")
session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers["ckpt"] = time() # reset timer
def on_train_end(trainer):
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
session = getattr(trainer, "hub_session", None)
if session:
# Upload final model and metrics with exponential standoff
LOGGER.info(f"{PREFIX}Syncing final model...")
session.upload_model(
trainer.epoch,
trainer.best,
map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
final=True,
)
session.alive = False # stop heartbeats
LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀")
def on_train_start(trainer):
"""Run events on train start."""
events(trainer.args)
def on_val_start(validator):
"""Runs events on validation start."""
events(validator.args)
def on_predict_start(predictor):
"""Run events on predict start."""
events(predictor.args)
def on_export_start(exporter):
"""Run events on export start."""
events(exporter.args)
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_model_save": on_model_save,
"on_train_end": on_train_end,
"on_train_start": on_train_start,
"on_val_start": on_val_start,
"on_predict_start": on_predict_start,
"on_export_start": on_export_start,
}
if SETTINGS["hub"] is True
else {}
) # verify enabled

View File

@ -0,0 +1,133 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
MLflow Logging for Ultralytics YOLO.
This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts.
For setting up, a tracking URI should be specified. The logging can be customized using environment variables.
Commands:
1. To set a project name:
`export MLFLOW_EXPERIMENT_NAME=<your_experiment_name>` or use the project=<project> argument
2. To set a run name:
`export MLFLOW_RUN=<your_run_name>` or use the name=<name> argument
3. To start a local MLflow server:
mlflow server --backend-store-uri runs/mlflow
It will by default start a local server at http://127.0.0.1:5000.
To specify a different URI, set the MLFLOW_TRACKING_URI environment variable.
4. To kill all running MLflow server instances:
ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9
"""
from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr
try:
import os
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
assert SETTINGS["mlflow"] is True # verify integration is enabled
import mlflow
assert hasattr(mlflow, "__version__") # verify package is not directory
from pathlib import Path
PREFIX = colorstr("MLflow: ")
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
except (ImportError, AssertionError):
mlflow = None
def on_pretrain_routine_end(trainer):
"""
Log training parameters to MLflow at the end of the pretraining routine.
This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
from the trainer.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
Global:
mlflow: The imported mlflow module to use for logging.
Environment Variables:
MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of the training phase.
"""
global mlflow
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
mlflow.set_tracking_uri(uri)
# Set experiment and run names
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
mlflow.set_experiment(experiment_name)
mlflow.autolog()
try:
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
if Path(uri).is_dir():
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
mlflow.log_params(dict(trainer.args))
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
def on_train_epoch_end(trainer):
"""Log training metrics at the end of each train epoch to MLflow."""
if mlflow:
mlflow.log_metrics(
metrics={
**SANITIZE(trainer.lr),
**SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")),
},
step=trainer.epoch,
)
def on_fit_epoch_end(trainer):
"""Log training metrics at the end of each fit epoch to MLflow."""
if mlflow:
mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch)
def on_train_end(trainer):
"""Log model artifacts at the end of the training."""
if mlflow:
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best_gift_v10n.pt and last.pt
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
mlflow.log_artifact(str(f))
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true")
if keep_run_active:
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
else:
mlflow.end_run()
LOGGER.debug(f"{PREFIX}mlflow run ended")
LOGGER.info(
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
f"{PREFIX}disable with 'yolo settings mlflow=False'"
)
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if mlflow
else {}
)

View File

@ -0,0 +1,112 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["neptune"] is True # verify integration is enabled
import neptune
from neptune.types import File
assert hasattr(neptune, "__version__")
run = None # NeptuneAI experiment logger instance
except (ImportError, AssertionError):
neptune = None
def _log_scalars(scalars, step=0):
"""Log scalars to the NeptuneAI experiment logger."""
if run:
for k, v in scalars.items():
run[k].append(value=v, step=step)
def _log_images(imgs_dict, group=""):
"""Log scalars to the NeptuneAI experiment logger."""
if run:
for k, v in imgs_dict.items():
run[f"{group}/{k}"].upload(File(v))
def _log_plot(title, plot_path):
"""
Log plots to the NeptuneAI experiment logger.
Args:
title (str): Title of the plot.
plot_path (PosixPath | str): Path to the saved image file.
"""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
run[f"Plots/{title}"].upload(fig)
def on_pretrain_routine_start(trainer):
"""Callback function called before the training routine starts."""
try:
global run
run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
"""Callback function called at end of each training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
if trainer.epoch == 1:
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
def on_fit_epoch_end(trainer):
"""Callback function called at end of each fit (train+val) epoch."""
if run and trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
run["Configuration/Model"] = model_info_for_loggers(trainer)
_log_scalars(trainer.metrics, trainer.epoch + 1)
def on_val_end(validator):
"""Callback function called at end of each validation."""
if run:
# Log val_labels and val_pred
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
def on_train_end(trainer):
"""Callback function called at end of training."""
if run:
# Log final results, CM matrix + PR plots
files = [
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Log the final model
run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if neptune
else {}
)

View File

@ -0,0 +1,29 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import SETTINGS
try:
assert SETTINGS["raytune"] is True # verify integration is enabled
import ray
from ray import tune
from ray.air import session
except (ImportError, AssertionError):
tune = None
def on_fit_epoch_end(trainer):
"""Sends training metrics to Ray Tune at end of each epoch."""
if ray.tune.is_session_enabled():
metrics = trainer.metrics
metrics["epoch"] = trainer.epoch
session.report(metrics)
callbacks = (
{
"on_fit_epoch_end": on_fit_epoch_end,
}
if tune
else {}
)

View File

@ -0,0 +1,106 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
try:
# WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
from torch.utils.tensorboard import SummaryWriter
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["tensorboard"] is True # verify integration is enabled
WRITER = None # TensorBoard SummaryWriter instance
PREFIX = colorstr("TensorBoard: ")
# Imports below only required if TensorBoard enabled
import warnings
from copy import deepcopy
from ultralytics.utils.torch_utils import de_parallel, torch
except (ImportError, AssertionError, TypeError, AttributeError):
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
# AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
SummaryWriter = None
def _log_scalars(scalars, step=0):
"""Logs scalar values to TensorBoard."""
if WRITER:
for k, v in scalars.items():
WRITER.add_scalar(k, v, step)
def _log_tensorboard_graph(trainer):
"""Log model graph to TensorBoard."""
# Input image
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
# Try simple method first (YOLO)
with contextlib.suppress(Exception):
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(de_parallel(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
def on_pretrain_routine_start(trainer):
"""Initialize TensorBoard logging with SummaryWriter."""
if SummaryWriter:
try:
global WRITER
WRITER = SummaryWriter(str(trainer.save_dir))
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
def on_train_start(trainer):
"""Log TensorBoard graph."""
if WRITER:
_log_tensorboard_graph(trainer)
def on_train_epoch_end(trainer):
"""Logs scalar statistics at the end of a training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
def on_fit_epoch_end(trainer):
"""Logs epoch metrics at end of training epoch."""
_log_scalars(trainer.metrics, trainer.epoch + 1)
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_start": on_train_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_epoch_end": on_train_epoch_end,
}
if SummaryWriter
else {}
)

View File

@ -0,0 +1,163 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import SETTINGS, TESTS_RUNNING
from ultralytics.utils.torch_utils import model_info_for_loggers
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["wandb"] is True # verify integration is enabled
import wandb as wb
assert hasattr(wb, "__version__") # verify package is not directory
import numpy as np
import pandas as pd
_processed_plots = {}
except (ImportError, AssertionError):
wb = None
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
"""
Create and log a custom metric visualization to wandb.plot.pr_curve.
This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
different classes.
Args:
x (List): Values for the x-axis; expected to have length N.
y (List): Corresponding values for the y-axis; also expected to have length N.
classes (List): Labels identifying the class of each point; length N.
title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
y_title (str, optional): Label for the y-axis; defaults to 'Precision'.
Returns:
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
"""
df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
fields = {"x": "x", "y": "y", "class": "class"}
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
return wb.plot_table(
"wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
)
def _plot_curve(
x,
y,
names=None,
id="precision-recall",
title="Precision Recall Curve",
x_title="Recall",
y_title="Precision",
num_x=100,
only_mean=False,
):
"""
Log a metric curve visualization.
This function generates a metric curve based on input data and logs the visualization to wandb.
The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
Args:
x (np.ndarray): Data points for the x-axis with length N.
y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.
Note:
The function leverages the '_custom_table' function to generate the actual visualization.
"""
# Create new x
if names is None:
names = []
x_new = np.linspace(x[0], x[-1], num_x).round(5)
# Create arrays for logging
x_log = x_new.tolist()
y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
if only_mean:
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
else:
classes = ["mean"] * len(x_log)
for i, yi in enumerate(y):
x_log.extend(x_new) # add new x
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
classes.extend([names[i]] * len(x_new)) # add class names
wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
def _log_plots(plots, step):
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
for name, params in plots.items():
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
_processed_plots[name] = timestamp
def on_pretrain_routine_start(trainer):
"""Initiate and start project if module is present."""
wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
def on_fit_epoch_end(trainer):
"""Logs training metrics and model information at the end of an epoch."""
wb.run.log(trainer.metrics, step=trainer.epoch + 1)
_log_plots(trainer.plots, step=trainer.epoch + 1)
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
if trainer.epoch == 0:
wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
def on_train_epoch_end(trainer):
"""Log metrics and save images at the end of each training epoch."""
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
wb.run.log(trainer.lr, step=trainer.epoch + 1)
if trainer.epoch == 1:
_log_plots(trainer.plots, step=trainer.epoch + 1)
def on_train_end(trainer):
"""Save the best model as an artifact at end of training."""
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
_log_plots(trainer.plots, step=trainer.epoch + 1)
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
if trainer.best.exists():
art.add_file(trainer.best)
wb.run.log_artifact(art, aliases=["best"])
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
x, y, x_title, y_title = curve_values
_plot_curve(
x,
y,
names=list(trainer.validator.metrics.names.values()),
id=f"curves/{curve_name}",
title=curve_name,
x_title=x_title,
y_title=y_title,
)
wb.run.finish() # required or run continues on dashboard
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if wb
else {}
)