initial project version!

This commit is contained in:
王庆刚
2024-05-20 20:01:06 +08:00
commit d6f3693d3f
483 changed files with 60345 additions and 0 deletions

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

436
ultralytics/engine/model.py Normal file
View File

@ -0,0 +1,436 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import inspect
import sys
from pathlib import Path
from typing import Union
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, emojis, yaml_load
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
from ultralytics.utils.torch_utils import smart_inference_mode
class Model:
"""
A base model class to unify apis for all the models.
Args:
model (str, Path): Path to the model file to load or create.
task (Any, optional): Task type for the YOLO model. Defaults to None.
Attributes:
predictor (Any): The predictor object.
model (Any): The model object.
trainer (Any): The trainer object.
task (str): The type of model task.
ckpt (Any): The checkpoint object if the model loaded from *.pt file.
cfg (str): The model configuration if loaded from *.yaml file.
ckpt_path (str): The checkpoint file path.
overrides (dict): Overrides for the trainer object.
metrics (Any): The data for metrics.
Methods:
__call__(source=None, stream=False, **kwargs):
Alias for the predict method.
_new(cfg:str, verbose:bool=True) -> None:
Initializes a new model and infers the task type from the model definitions.
_load(weights:str, task:str='') -> None:
Initializes a new model and infers the task type from the model head.
_check_is_pytorch_model() -> None:
Raises TypeError if the model is not a PyTorch model.
reset() -> None:
Resets the model modules.
info(verbose:bool=False) -> None:
Logs the model info.
fuse() -> None:
Fuses the model for faster inference.
predict(source=None, stream=False, **kwargs) -> List[ultralytics.engine.results.Results]:
Performs prediction using the YOLO model.
Returns:
list(ultralytics.engine.results.Results): The prediction results.
"""
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
"""
Initializes the YOLO model.
Args:
model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
task (Any, optional): Task type for the YOLO model. Defaults to None.
"""
self.callbacks = callbacks.get_default_callbacks()
self.predictor = None # reuse predictor
self.model = None # model object
self.trainer = None # trainer object
self.ckpt = None # if loaded from *.pt
self.cfg = None # if loaded from *.yaml
self.ckpt_path = None
self.overrides = {} # overrides for trainer object
self.metrics = None # validation/training metrics
self.session = None # HUB session
self.task = task # task type
model = str(model).strip() # strip spaces
# Check if Ultralytics HUB model from https://hub.ultralytics.com
if self.is_hub_model(model):
from ultralytics.hub.session import HUBTrainingSession
self.session = HUBTrainingSession(model)
model = self.session.model_file
# Load or create new YOLO model
suffix = Path(model).suffix
if not suffix and Path(model).stem in GITHUB_ASSETS_STEMS:
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
if suffix in ('.yaml', '.yml'):
self._new(model, task)
else:
self._load(model, task)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
@staticmethod
def is_hub_model(model):
"""Check if the provided model is a HUB model."""
return any((
model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
def _new(self, cfg: str, task=None, model=None, verbose=True):
"""
Initializes a new model and infers the task type from the model definitions.
Args:
cfg (str): model configuration file
task (str | None): model task
model (BaseModel): Customized model.
verbose (bool): display model info on load
"""
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = task or guess_model_task(cfg_dict)
self.model = (model or self.smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model
self.overrides['model'] = self.cfg
self.overrides['task'] = self.task
# Below added to allow export from YAMLs
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
self.model.task = self.task
def _load(self, weights: str, task=None):
"""
Initializes a new model and infers the task type from the model head.
Args:
weights (str): model checkpoint to be loaded
task (str | None): model task
"""
suffix = Path(weights).suffix
if suffix == '.pt':
self.model, self.ckpt = attempt_load_one_weight(weights)
self.task = self.model.args['task']
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
self.ckpt_path = self.model.pt_path
else:
weights = check_file(weights)
self.model, self.ckpt = weights, None
self.task = task or guess_model_task(weights)
self.ckpt_path = weights
self.overrides['model'] = weights
self.overrides['task'] = self.task
def _check_is_pytorch_model(self):
"""
Raises TypeError is model is not a PyTorch model
"""
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
pt_module = isinstance(self.model, nn.Module)
if not (pt_module or pt_str):
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
f'PyTorch models can be used to train, val, predict and export, i.e. '
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
@smart_inference_mode()
def reset_weights(self):
"""
Resets the model modules parameters to randomly initialized values, losing all training information.
"""
self._check_is_pytorch_model()
for m in self.model.modules():
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
for p in self.model.parameters():
p.requires_grad = True
return self
@smart_inference_mode()
def load(self, weights='yolov8n.pt'):
"""
Transfers parameters with matching names and shapes from 'weights' to model.
"""
self._check_is_pytorch_model()
if isinstance(weights, (str, Path)):
weights, self.ckpt = attempt_load_one_weight(weights)
self.model.load(weights)
return self
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
self._check_is_pytorch_model()
return self.model.info(detailed=detailed, verbose=verbose)
def fuse(self):
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
self._check_is_pytorch_model()
self.model.fuse()
@smart_inference_mode()
def predict(self, source=None, stream=False, predictor=None, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
predictor (BasePredictor): Customized predictor.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ASSETS
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
custom = {'conf': 0.25, 'save': is_cli} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'} # highest priority args on the right
prompts = args.pop('prompts', None) # for SAM-type models
if not self.predictor:
self.predictor = (predictor or self.smart_load('predictor'))(overrides=args, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, args)
if 'project' in args or 'name' in args:
self.predictor.save_dir = get_save_dir(self.predictor.args)
if prompts and hasattr(self.predictor, 'set_prompts'): # for SAM-type models
self.predictor.set_prompts(prompts)
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, persist=False, **kwargs):
"""
Perform object tracking on the input source using the registered trackers.
Args:
source (str, optional): The input source for object tracking. Can be a file path or a video stream.
stream (bool, optional): Whether the input source is a video stream. Defaults to False.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
**kwargs (optional): Additional keyword arguments for the tracking process.
Returns:
(List[ultralytics.engine.results.Results]): The tracking results.
"""
if not hasattr(self.predictor, 'trackers'):
from ultralytics.trackers import register_tracker
register_tracker(self, persist)
# ByteTrack-based method needs low confidence predictions as input
kwargs['conf'] = kwargs.get('conf') or 0.1
kwargs['mode'] = 'track'
return self.predict(source=source, stream=stream, **kwargs)
@smart_inference_mode()
def val(self, validator=None, **kwargs):
"""
Validate a model on a given dataset.
Args:
validator (BaseValidator): Customized validator.
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
custom = {'rect': True} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1)
validator = (validator or self.smart_load('validator'))(args=args, _callbacks=self.callbacks)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def benchmark(self, **kwargs):
"""
Benchmark a model on all export formats.
Args:
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
self._check_is_pytorch_model()
from ultralytics.utils.benchmarks import benchmark
custom = {'verbose': False} # method defaults
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, 'mode': 'benchmark'}
return benchmark(
model=self,
data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
imgsz=args['imgsz'],
half=args['half'],
int8=args['int8'],
device=args['device'],
verbose=kwargs.get('verbose'))
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the Exporter. To see all args check 'configuration' section in docs.
"""
self._check_is_pytorch_model()
from .exporter import Exporter
custom = {'imgsz': self.model.args['imgsz'], 'batch': 1, 'data': None, 'verbose': False} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
def train(self, trainer=None, **kwargs):
"""
Trains the model on a given dataset.
Args:
trainer (BaseTrainer, optional): Customized trainer.
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
if self.session: # Ultralytics HUB session
if any(kwargs):
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
kwargs = self.session.train_args
check_pip_update_available()
overrides = yaml_load(check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
custom = {'data': TASK2DATA[self.task]} # method defaults
args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
if args.get('resume'):
args['resume'] = self.ckpt_path
self.trainer = (trainer or self.smart_load('trainer'))(overrides=args, _callbacks=self.callbacks)
if not args.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train()
# Update model and cfg after training
if RANK in (-1, 0):
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
self.model, _ = attempt_load_one_weight(ckpt)
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
return self.metrics
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
"""
Runs hyperparameter tuning, optionally using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
Returns:
(dict): A dictionary containing the results of the hyperparameter search.
"""
self._check_is_pytorch_model()
if use_ray:
from ultralytics.utils.tuner import run_ray_tune
return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
else:
from .tuner import Tuner
custom = {'plots': False, 'save': False} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
def to(self, device):
"""
Sends the model to the given device.
Args:
device (str): device
"""
self._check_is_pytorch_model()
self.model.to(device)
return self
@property
def names(self):
"""Returns class names of the loaded model."""
return self.model.names if hasattr(self.model, 'names') else None
@property
def device(self):
"""Returns device if PyTorch model."""
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
@property
def transforms(self):
"""Returns transform of the loaded model."""
return self.model.transforms if hasattr(self.model, 'transforms') else None
def add_callback(self, event: str, func):
"""Add a callback."""
self.callbacks[event].append(func)
def clear_callback(self, event: str):
"""Clear all event callbacks."""
self.callbacks[event] = []
@staticmethod
def _reset_ckpt_args(args):
"""Reset arguments when loading a PyTorch model."""
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
return {k: v for k, v in args.items() if k in include}
def _reset_callbacks(self):
"""Reset all registered callbacks."""
for event in callbacks.default_callbacks.keys():
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def smart_load(self, key):
"""Load model/trainer/validator/predictor."""
try:
return self.task_map[self.task][key]
except Exception as e:
name = self.__class__.__name__
mode = inspect.stack()[1][3] # get the function name.
raise NotImplementedError(
emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")) from e
@property
def task_map(self):
"""
Map head to model, trainer, validator, and predictor classes.
Returns:
task_map (dict): The map of model task to mode classes.
"""
raise NotImplementedError('Please provide task map for your model!')

View File

@ -0,0 +1,358 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
Usage - sources:
$ yolo mode=predict model=yolov8n.pt source=0 # webcam
img.jpg # image
vid.mp4 # video
screen # screenshot
path/ # directory
list.txt # list of images
list.streams # list of streams
'path/*.jpg' # glob
'https://youtu.be/Zgi9g1ksQHc' # YouTube
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
Usage - formats:
$ yolo mode=predict model=yolov8n.pt # PyTorch
yolov8n.torchscript # TorchScript
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
yolov8n_openvino_model # OpenVINO
yolov8n.engine # TensorRT
yolov8n.mlpackage # CoreML (macOS-only)
yolov8n_saved_model # TensorFlow SavedModel
yolov8n.pb # TensorFlow GraphDef
yolov8n.tflite # TensorFlow Lite
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
yolov8n_paddle_model # PaddlePaddle
"""
import platform
from pathlib import Path
import cv2
import numpy as np
import torch
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data import load_inference_source
from ultralytics.data.augment import LetterBox, classify_transforms
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
from ultralytics.utils.checks import check_imgsz, check_imshow
from ultralytics.utils.files import increment_path
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
STREAM_WARNING = """
WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
Example:
results = model(source=..., stream=True) # generator of Results objects
for r in results:
boxes = r.boxes # Boxes object for bbox outputs
masks = r.masks # Masks object for segment masks outputs
probs = r.probs # Class probabilities for classification outputs
"""
class BasePredictor:
"""
BasePredictor
A base class for creating predictors.
Attributes:
args (SimpleNamespace): Configuration for the predictor.
save_dir (Path): Directory to save results.
done_warmup (bool): Whether the predictor has finished setup.
model (nn.Module): Model used for prediction.
data (dict): Data configuration.
device (torch.device): Device used for prediction.
dataset (Dataset): Dataset used for prediction.
vid_path (str): Path to video file.
vid_writer (cv2.VideoWriter): Video writer for saving video output.
data_path (str): Path to data.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the BasePredictor class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.save_dir = get_save_dir(self.args)
if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25
self.done_warmup = False
if self.args.show:
self.args.show = check_imshow(warn=True)
# Usable if setup is done
self.model = None
self.data = self.args.data # data_dict
self.imgsz = None
self.device = None
self.dataset = None
self.vid_path, self.vid_writer = None, None
self.plotted_img = None
self.data_path = None
self.source_type = None
self.batch = None
self.results = None
self.transforms = None
self.callbacks = _callbacks or callbacks.get_default_callbacks()
self.txt_path = None
callbacks.add_integration_callbacks(self)
def preprocess(self, im):
"""Prepares input image before inference.
Args:
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
"""
not_tensor = not isinstance(im, torch.Tensor)
if not_tensor:
im = np.stack(self.pre_transform(im))
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im)
im = im.to(self.device)
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
if not_tensor:
im /= 255 # 0 - 255 to 0.0 - 1.0
return im
def inference(self, im, *args, **kwargs):
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
return self.model(im, augment=self.args.augment, visualize=visualize)
def pre_transform(self, im):
"""
Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Returns:
(list): A list of transformed images.
"""
same_shapes = all(x.shape == im[0].shape for x in im)
letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
return [letterbox(image=x) for x in im]
def write_results(self, idx, results, batch):
"""Write inference results to a file or directory."""
p, im, _ = batch
log_string = ''
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
else:
frame = getattr(self.dataset, 'frame', 0)
self.data_path = p
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
log_string += '%gx%g ' % im.shape[2:] # print string
result = results[idx]
log_string += result.verbose()
if self.args.save or self.args.show: # Add bbox to image
plot_args = {
'line_width': self.args.line_width,
'boxes': self.args.boxes,
'conf': self.args.show_conf,
'labels': self.args.show_labels}
if not self.args.retina_masks:
plot_args['im_gpu'] = im[idx]
self.plotted_img = result.plot(**plot_args)
# Write
if self.args.save_txt:
result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
if self.args.save_crop:
result.save_crop(save_dir=self.save_dir / 'crops',
file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}'))
return log_string
def postprocess(self, preds, img, orig_imgs):
"""Post-processes predictions for an image and returns them."""
return preds
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
"""Performs inference on an image or stream."""
self.stream = stream
if stream:
return self.stream_inference(source, model, *args, **kwargs)
else:
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
def predict_cli(self, source=None, model=None):
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
gen = self.stream_inference(source, model)
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass
def setup_source(self, source):
"""Sets up source and inference mode."""
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
self.transforms = getattr(self.model.model, 'transforms', classify_transforms(
self.imgsz[0])) if self.args.task == 'classify' else None
self.dataset = load_inference_source(source=source,
imgsz=self.imgsz,
vid_stride=self.args.vid_stride,
stream_buffer=self.args.stream_buffer)
self.source_type = self.dataset.source_type
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
len(self.dataset) > 1000 or # images
any(getattr(self.dataset, 'video_flag', [False]))): # videos
LOGGER.warning(STREAM_WARNING)
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
@smart_inference_mode()
def stream_inference(self, source=None, model=None, *args, **kwargs):
"""Streams real-time inference on camera feed and saves results to file."""
if self.args.verbose:
LOGGER.info('')
# Setup model
if not self.model:
self.setup_model(model)
# Setup source every time predict is called
self.setup_source(source if source is not None else self.args.source)
# Check if save_dir/ label file exists
if self.args.save or self.args.save_txt:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# Warmup model
if not self.done_warmup:
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
self.done_warmup = True
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
self.run_callbacks('on_predict_start')
for batch in self.dataset:
self.run_callbacks('on_predict_batch_start')
self.batch = batch
path, im0s, vid_cap, s = batch
# Preprocess
with profilers[0]:
im = self.preprocess(im0s)
# Inference
with profilers[1]:
preds = self.inference(im, *args, **kwargs)
# Postprocess
with profilers[2]:
self.results = self.postprocess(preds, im, im0s)
self.run_callbacks('on_predict_postprocess_end')
# Visualize, save, write results
n = len(im0s)
for i in range(n):
self.seen += 1
self.results[i].speed = {
'preprocess': profilers[0].dt * 1E3 / n,
'inference': profilers[1].dt * 1E3 / n,
'postprocess': profilers[2].dt * 1E3 / n}
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
p = Path(p)
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
s += self.write_results(i, self.results, (p, im, im0))
if self.args.save or self.args.save_txt:
self.results[i].save_dir = self.save_dir.__str__()
if self.args.show and self.plotted_img is not None:
self.show(p)
if self.args.save and self.plotted_img is not None:
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
self.run_callbacks('on_predict_batch_end')
yield from self.results
# Print time (inference-only)
if self.args.verbose:
LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
# Release assets
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
self.vid_writer[-1].release() # release final video writer
# Print results
if self.args.verbose and self.seen:
t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
f'{(1, 3, *im.shape[2:])}' % t)
if self.args.save or self.args.save_txt or self.args.save_crop:
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks('on_predict_end')
def setup_model(self, model, verbose=True):
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
self.model = AutoBackend(model or self.args.model,
device=select_device(self.args.device, verbose=verbose),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
fuse=True,
verbose=verbose)
self.device = self.model.device # update device
self.args.half = self.model.fp16 # update half
self.model.eval()
def show(self, p):
"""Display an image in a window using OpenCV imshow()."""
im0 = self.plotted_img
if platform.system() == 'Linux' and p not in self.windows:
self.windows.append(p)
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
cv2.imshow(str(p), im0)
cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond
def save_preds(self, vid_cap, idx, save_path):
"""Save video predictions as mp4 at specified path."""
im0 = self.plotted_img
# Save imgs
if self.dataset.mode == 'image':
cv2.imwrite(save_path, im0)
else: # 'video' or 'stream'
if self.vid_path[idx] != save_path: # new video
self.vid_path[idx] = save_path
if isinstance(self.vid_writer[idx], cv2.VideoWriter):
self.vid_writer[idx].release() # release previous video writer
if vid_cap: # video
fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG')
save_path = str(Path(save_path).with_suffix(suffix))
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
self.vid_writer[idx].write(im0)
def run_callbacks(self, event: str):
"""Runs all registered callbacks for a specific event."""
for callback in self.callbacks.get(event, []):
callback(self)
def add_callback(self, event: str, func):
"""
Add callback
"""
self.callbacks[event].append(func)

View File

@ -0,0 +1,593 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Ultralytics Results, Boxes and Masks classes for handling inference results
Usage: See https://docs.ultralytics.com/modes/predict/
"""
from copy import deepcopy
from functools import lru_cache
from pathlib import Path
import numpy as np
import torch
from ultralytics.data.augment import LetterBox
from ultralytics.utils import LOGGER, SimpleClass, deprecation_warn, ops
from ultralytics.utils.plotting import Annotator, colors, save_one_box
class BaseTensor(SimpleClass):
"""
Base tensor class with additional methods for easy manipulation and device handling.
"""
def __init__(self, data, orig_shape) -> None:
"""Initialize BaseTensor with data and original shape.
Args:
data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
orig_shape (tuple): Original shape of image.
"""
assert isinstance(data, (torch.Tensor, np.ndarray))
self.data = data
self.orig_shape = orig_shape
@property
def shape(self):
"""Return the shape of the data tensor."""
return self.data.shape
def cpu(self):
"""Return a copy of the tensor on CPU memory."""
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
def numpy(self):
"""Return a copy of the tensor as a numpy array."""
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
def cuda(self):
"""Return a copy of the tensor on GPU memory."""
return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
def to(self, *args, **kwargs):
"""Return a copy of the tensor with the specified device and dtype."""
return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
def __len__(self): # override len(results)
"""Return the length of the data tensor."""
return len(self.data)
def __getitem__(self, idx):
"""Return a BaseTensor with the specified index of the data tensor."""
return self.__class__(self.data[idx], self.orig_shape)
class Results(SimpleClass):
"""
A class for storing and manipulating inference results.
Args:
orig_img (numpy.ndarray): The original image as a numpy array.
path (str): The path to the image file.
names (dict): A dictionary of class names.
boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection.
masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image.
probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task.
keypoints (List[List[float]], optional): A list of detected keypoints for each object.
Attributes:
orig_img (numpy.ndarray): The original image as a numpy array.
orig_shape (tuple): The original image shape in (height, width) format.
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
masks (Masks, optional): A Masks object containing the detection masks.
probs (Probs, optional): A Probs object containing probabilities of each class for classification task.
keypoints (Keypoints, optional): A Keypoints object containing detected keypoints for each object.
speed (dict): A dictionary of preprocess, inference, and postprocess speeds in milliseconds per image.
names (dict): A dictionary of class names.
path (str): The path to the image file.
_keys (tuple): A tuple of attribute names for non-empty attributes.
"""
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
"""Initialize the Results class."""
self.orig_img = orig_img
self.orig_shape = orig_img.shape[:2]
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
self.probs = Probs(probs) if probs is not None else None
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
self.names = names
self.path = path
self.save_dir = None
self._keys = 'boxes', 'masks', 'probs', 'keypoints'
def __getitem__(self, idx):
"""Return a Results object for the specified index."""
return self._apply('__getitem__', idx)
def __len__(self):
"""Return the number of detections in the Results object."""
for k in self._keys:
v = getattr(self, k)
if v is not None:
return len(v)
def update(self, boxes=None, masks=None, probs=None):
"""Update the boxes, masks, and probs attributes of the Results object."""
if boxes is not None:
ops.clip_boxes(boxes, self.orig_shape) # clip boxes
self.boxes = Boxes(boxes, self.orig_shape)
if masks is not None:
self.masks = Masks(masks, self.orig_shape)
if probs is not None:
self.probs = probs
def _apply(self, fn, *args, **kwargs):
r = self.new()
for k in self._keys:
v = getattr(self, k)
if v is not None:
setattr(r, k, getattr(v, fn)(*args, **kwargs))
return r
def cpu(self):
"""Return a copy of the Results object with all tensors on CPU memory."""
return self._apply('cpu')
def numpy(self):
"""Return a copy of the Results object with all tensors as numpy arrays."""
return self._apply('numpy')
def cuda(self):
"""Return a copy of the Results object with all tensors on GPU memory."""
return self._apply('cuda')
def to(self, *args, **kwargs):
"""Return a copy of the Results object with tensors on the specified device and dtype."""
return self._apply('to', *args, **kwargs)
def new(self):
"""Return a new Results object with the same image, path, and names."""
return Results(orig_img=self.orig_img, path=self.path, names=self.names)
def plot(
self,
conf=True,
line_width=None,
font_size=None,
font='Arial.ttf',
pil=False,
img=None,
im_gpu=None,
kpt_radius=5,
kpt_line=True,
labels=True,
boxes=True,
masks=True,
probs=True,
**kwargs # deprecated args TODO: remove support in 8.2
):
"""
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
Args:
conf (bool): Whether to plot the detection confidence score.
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
font (str): The font to use for the text.
pil (bool): Whether to return the image as a PIL Image.
img (numpy.ndarray): Plot to another image. if not, plot to original image.
im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
kpt_radius (int, optional): Radius of the drawn keypoints. Default is 5.
kpt_line (bool): Whether to draw lines connecting keypoints.
labels (bool): Whether to plot the label of bounding boxes.
boxes (bool): Whether to plot the bounding boxes.
masks (bool): Whether to plot the masks.
probs (bool): Whether to plot classification probability
Returns:
(numpy.ndarray): A numpy array of the annotated image.
Example:
```python
from PIL import Image
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
results = model('bus.jpg') # results list
for r in results:
im_array = r.plot() # plot a BGR numpy array of predictions
im = Image.fromarray(im_array[..., ::-1]) # RGB PIL image
im.show() # show image
im.save('results.jpg') # save image
```
"""
if img is None and isinstance(self.orig_img, torch.Tensor):
img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy()
# Deprecation warn TODO: remove in 8.2
if 'show_conf' in kwargs:
deprecation_warn('show_conf', 'conf')
conf = kwargs['show_conf']
assert isinstance(conf, bool), '`show_conf` should be of boolean type, i.e, show_conf=True/False'
if 'line_thickness' in kwargs:
deprecation_warn('line_thickness', 'line_width')
line_width = kwargs['line_thickness']
assert isinstance(line_width, int), '`line_width` should be of int type, i.e, line_width=3'
names = self.names
pred_boxes, show_boxes = self.boxes, boxes
pred_masks, show_masks = self.masks, masks
pred_probs, show_probs = self.probs, probs
annotator = Annotator(
deepcopy(self.orig_img if img is None else img),
line_width,
font_size,
font,
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
example=names)
# Plot Segment results
if pred_masks and show_masks:
if im_gpu is None:
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
2, 0, 1).flip(0).contiguous() / 255
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
# Plot Detect results
if pred_boxes and show_boxes:
for d in reversed(pred_boxes):
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
name = ('' if id is None else f'id:{id} ') + names[c]
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
# Plot Classify results
if pred_probs is not None and show_probs:
text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
x = round(self.orig_shape[0] * 0.03)
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
# Plot Pose results
if self.keypoints is not None:
for k in reversed(self.keypoints.data):
annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
return annotator.result()
def verbose(self):
"""
Return log string for each task.
"""
log_string = ''
probs = self.probs
boxes = self.boxes
if len(self) == 0:
return log_string if probs is not None else f'{log_string}(no detections), '
if probs is not None:
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
if boxes:
for c in boxes.cls.unique():
n = (boxes.cls == c).sum() # detections per class
log_string += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "
return log_string
def save_txt(self, txt_file, save_conf=False):
"""
Save predictions into txt file.
Args:
txt_file (str): txt file path.
save_conf (bool): save confidence score or not.
"""
boxes = self.boxes
masks = self.masks
probs = self.probs
kpts = self.keypoints
texts = []
if probs is not None:
# Classify
[texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
elif boxes:
# Detect/segment/pose
for j, d in enumerate(boxes):
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
line = (c, *d.xywhn.view(-1))
if masks:
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
line = (c, *seg)
if kpts is not None:
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
line += (*kpt.reshape(-1).tolist(), )
line += (conf, ) * save_conf + (() if id is None else (id, ))
texts.append(('%g ' * len(line)).rstrip() % line)
if texts:
Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory
with open(txt_file, 'a') as f:
f.writelines(text + '\n' for text in texts)
def save_crop(self, save_dir, file_name=Path('im.jpg')):
"""
Save cropped predictions to `save_dir/cls/file_name.jpg`.
Args:
save_dir (str | pathlib.Path): Save path.
file_name (str | pathlib.Path): File name.
"""
if self.probs is not None:
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
return
for d in self.boxes:
save_one_box(d.xyxy,
self.orig_img.copy(),
file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name).stem}.jpg',
BGR=True)
def tojson(self, normalize=False):
"""Convert the object to JSON format."""
if self.probs is not None:
LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
return
import json
# Create list of detection dictionaries
results = []
data = self.boxes.data.cpu().tolist()
h, w = self.orig_shape if normalize else (1, 1)
for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h}
conf = row[-2]
class_id = int(row[-1])
name = self.names[class_id]
result = {'name': name, 'class': class_id, 'confidence': conf, 'box': box}
if self.boxes.is_track:
result['track_id'] = int(row[-3]) # track ID
if self.masks:
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
if self.keypoints is not None:
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
results.append(result)
# Convert detections to JSON
return json.dumps(results, indent=2)
class Boxes(BaseTensor):
"""
A class for storing and manipulating detection boxes.
Args:
boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
with shape (num_boxes, 6) or (num_boxes, 7). The last two columns contain confidence and class values.
If present, the third last column contains track IDs.
orig_shape (tuple): Original image size, in the format (height, width).
Attributes:
xyxy (torch.Tensor | numpy.ndarray): The boxes in xyxy format.
conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
xywh (torch.Tensor | numpy.ndarray): The boxes in xywh format.
xyxyn (torch.Tensor | numpy.ndarray): The boxes in xyxy format normalized by original image size.
xywhn (torch.Tensor | numpy.ndarray): The boxes in xywh format normalized by original image size.
data (torch.Tensor): The raw bboxes tensor (alias for `boxes`).
Methods:
cpu(): Move the object to CPU memory.
numpy(): Convert the object to a numpy array.
cuda(): Move the object to CUDA memory.
to(*args, **kwargs): Move the object to the specified device.
"""
def __init__(self, boxes, orig_shape) -> None:
"""Initialize the Boxes class."""
if boxes.ndim == 1:
boxes = boxes[None, :]
n = boxes.shape[-1]
assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, track_id, conf, cls
super().__init__(boxes, orig_shape)
self.is_track = n == 7
self.orig_shape = orig_shape
@property
def xyxy(self):
"""Return the boxes in xyxy format."""
return self.data[:, :4]
@property
def conf(self):
"""Return the confidence values of the boxes."""
return self.data[:, -2]
@property
def cls(self):
"""Return the class values of the boxes."""
return self.data[:, -1]
@property
def id(self):
"""Return the track IDs of the boxes (if available)."""
return self.data[:, -3] if self.is_track else None
@property
@lru_cache(maxsize=2) # maxsize 1 should suffice
def xywh(self):
"""Return the boxes in xywh format."""
return ops.xyxy2xywh(self.xyxy)
@property
@lru_cache(maxsize=2)
def xyxyn(self):
"""Return the boxes in xyxy format normalized by original image size."""
xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy)
xyxy[..., [0, 2]] /= self.orig_shape[1]
xyxy[..., [1, 3]] /= self.orig_shape[0]
return xyxy
@property
@lru_cache(maxsize=2)
def xywhn(self):
"""Return the boxes in xywh format normalized by original image size."""
xywh = ops.xyxy2xywh(self.xyxy)
xywh[..., [0, 2]] /= self.orig_shape[1]
xywh[..., [1, 3]] /= self.orig_shape[0]
return xywh
@property
def boxes(self):
"""Return the raw bboxes tensor (deprecated)."""
LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.")
return self.data
class Masks(BaseTensor):
"""
A class for storing and manipulating detection masks.
Attributes:
segments (list): Deprecated property for segments (normalized).
xy (list): A list of segments in pixel coordinates.
xyn (list): A list of normalized segments.
Methods:
cpu(): Returns the masks tensor on CPU memory.
numpy(): Returns the masks tensor as a numpy array.
cuda(): Returns the masks tensor on GPU memory.
to(device, dtype): Returns the masks tensor with the specified device and dtype.
"""
def __init__(self, masks, orig_shape) -> None:
"""Initialize the Masks class with the given masks tensor and original image shape."""
if masks.ndim == 2:
masks = masks[None, :]
super().__init__(masks, orig_shape)
@property
@lru_cache(maxsize=1)
def segments(self):
"""Return segments (normalized). Deprecated; use xyn property instead."""
LOGGER.warning(
"WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and 'Masks.xy' for segments (pixels) instead."
)
return self.xyn
@property
@lru_cache(maxsize=1)
def xyn(self):
"""Return normalized segments."""
return [
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
for x in ops.masks2segments(self.data)]
@property
@lru_cache(maxsize=1)
def xy(self):
"""Return segments in pixel coordinates."""
return [
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
for x in ops.masks2segments(self.data)]
@property
def masks(self):
"""Return the raw masks tensor. Deprecated; use data attribute instead."""
LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.")
return self.data
class Keypoints(BaseTensor):
"""
A class for storing and manipulating detection keypoints.
Attributes:
xy (torch.Tensor): A collection of keypoints containing x, y coordinates for each detection.
xyn (torch.Tensor): A normalized version of xy with coordinates in the range [0, 1].
conf (torch.Tensor): Confidence values associated with keypoints if available, otherwise None.
Methods:
cpu(): Returns a copy of the keypoints tensor on CPU memory.
numpy(): Returns a copy of the keypoints tensor as a numpy array.
cuda(): Returns a copy of the keypoints tensor on GPU memory.
to(device, dtype): Returns a copy of the keypoints tensor with the specified device and dtype.
"""
def __init__(self, keypoints, orig_shape) -> None:
"""Initializes the Keypoints object with detection keypoints and original image size."""
if keypoints.ndim == 2:
keypoints = keypoints[None, :]
super().__init__(keypoints, orig_shape)
self.has_visible = self.data.shape[-1] == 3
@property
@lru_cache(maxsize=1)
def xy(self):
"""Returns x, y coordinates of keypoints."""
return self.data[..., :2]
@property
@lru_cache(maxsize=1)
def xyn(self):
"""Returns normalized x, y coordinates of keypoints."""
xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
xy[..., 0] /= self.orig_shape[1]
xy[..., 1] /= self.orig_shape[0]
return xy
@property
@lru_cache(maxsize=1)
def conf(self):
"""Returns confidence values of keypoints if available, else None."""
return self.data[..., 2] if self.has_visible else None
class Probs(BaseTensor):
"""
A class for storing and manipulating classification predictions.
Attributes:
top1 (int): Index of the top 1 class.
top5 (list[int]): Indices of the top 5 classes.
top1conf (torch.Tensor): Confidence of the top 1 class.
top5conf (torch.Tensor): Confidences of the top 5 classes.
Methods:
cpu(): Returns a copy of the probs tensor on CPU memory.
numpy(): Returns a copy of the probs tensor as a numpy array.
cuda(): Returns a copy of the probs tensor on GPU memory.
to(): Returns a copy of the probs tensor with the specified device and dtype.
"""
def __init__(self, probs, orig_shape=None) -> None:
super().__init__(probs, orig_shape)
@property
@lru_cache(maxsize=1)
def top1(self):
"""Return the index of top 1."""
return int(self.data.argmax())
@property
@lru_cache(maxsize=1)
def top5(self):
"""Return the indices of top 5."""
return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy.
@property
@lru_cache(maxsize=1)
def top1conf(self):
"""Return the confidence of top 1."""
return self.data[self.top1]
@property
@lru_cache(maxsize=1)
def top5conf(self):
"""Return the confidences of top 5."""
return self.data[self.top5]

View File

@ -0,0 +1,689 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Train a model on a dataset
Usage:
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
"""
import math
import os
import subprocess
import time
import warnings
from copy import deepcopy
from datetime import datetime, timedelta
from pathlib import Path
import numpy as np
import torch
from torch import distributed as dist
from torch import nn, optim
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis,
yaml_save)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
strip_optimizer)
class BaseTrainer:
"""
BaseTrainer
A base class for creating trainers.
Attributes:
args (SimpleNamespace): Configuration for the trainer.
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
validator (BaseValidator): Validator instance.
model (nn.Module): Model instance.
callbacks (defaultdict): Dictionary of callbacks.
save_dir (Path): Directory to save results.
wdir (Path): Directory to save weights.
last (Path): Path to the last checkpoint.
best (Path): Path to the best checkpoint.
save_period (int): Save checkpoint every x epochs (disabled if < 1).
batch_size (int): Batch size for training.
epochs (int): Number of epochs to train for.
start_epoch (int): Starting epoch for training.
device (torch.device): Device to use for training.
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
scaler (amp.GradScaler): Gradient scaler for AMP.
data (str): Path to data.
trainset (torch.utils.data.Dataset): Training dataset.
testset (torch.utils.data.Dataset): Testing dataset.
ema (nn.Module): EMA (Exponential Moving Average) of the model.
lf (nn.Module): Loss function.
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
best_fitness (float): The best fitness value achieved.
fitness (float): Current fitness value.
loss (float): Current loss value.
tloss (float): Total loss value.
loss_names (list): List of loss names.
csv (Path): Path to results CSV file.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the BaseTrainer class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.check_resume(overrides)
self.device = select_device(self.args.device, self.args.batch)
self.validator = None
self.model = None
self.metrics = None
self.plots = {}
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
self.save_dir = get_save_dir(self.args)
self.wdir = self.save_dir / 'weights' # weights dir
if RANK in (-1, 0):
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.args.save_dir = str(self.save_dir)
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.save_period = self.args.save_period
self.batch_size = self.args.batch
self.epochs = self.args.epochs
self.start_epoch = 0
if RANK == -1:
print_args(vars(self.args))
# Device
if self.device.type in ('cpu', 'mps'):
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataset
self.model = self.args.model
try:
if self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'):
self.data = check_det_dataset(self.args.data)
if 'yaml_file' in self.data:
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
self.trainset, self.testset = self.get_dataset(self.data)
self.ema = None
# Optimization utils init
self.lf = None
self.scheduler = None
# Epoch level metrics
self.best_fitness = None
self.fitness = None
self.loss = None
self.tloss = None
self.loss_names = ['Loss']
self.csv = self.save_dir / 'results.csv'
self.plot_idx = [0, 1, 2]
# Callbacks
self.callbacks = _callbacks or callbacks.get_default_callbacks()
if RANK in (-1, 0):
callbacks.add_integration_callbacks(self)
def add_callback(self, event: str, callback):
"""
Appends the given callback.
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
Overrides the existing callbacks with the given callback.
"""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
"""Run all existing callbacks associated with a particular event."""
for callback in self.callbacks.get(event, []):
callback(self)
def train(self):
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
world_size = len(self.args.device.split(','))
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
world_size = len(self.args.device)
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
world_size = 1 # default to device 0
else: # i.e. device='cpu' or 'mps'
world_size = 0
# Run subprocess if DDP training, else train normally
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
# Argument checks
if self.args.rect:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
self.args.rect = False
if self.args.batch == -1:
LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
"default 'batch=16'")
self.args.batch = 16
# Command
cmd, file = generate_ddp_command(world_size, self)
try:
LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
subprocess.run(cmd, check=True)
except Exception as e:
raise e
finally:
ddp_cleanup(self, str(file))
else:
self._do_train(world_size)
def _setup_ddp(self, world_size):
"""Initializes and sets the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK)
self.device = torch.device('cuda', RANK)
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
dist.init_process_group(
'nccl' if dist.is_nccl_available() else 'gloo',
timeout=timedelta(seconds=10800), # 3 hours
rank=RANK,
world_size=world_size)
def _setup_train(self, world_size):
"""
Builds dataloaders and optimizer on correct rank process.
"""
# Model
self.run_callbacks('on_pretrain_routine_start')
ckpt = self.setup_model()
self.model = self.model.to(self.device)
self.set_model_attributes()
# Freeze layers
freeze_list = self.args.freeze if isinstance(
self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
always_freeze_names = ['.dfl'] # always freeze these layers
freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
for k, v in self.model.named_parameters():
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
if any(x in k for x in freeze_layer_names):
LOGGER.info(f"Freezing layer '{k}'")
v.requires_grad = False
elif not v.requires_grad:
LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
'See ultralytics.engine.trainer for customization of frozen layers.')
v.requires_grad = True
# Check AMP
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
self.amp = torch.tensor(check_amp(self.model), device=self.device)
callbacks.default_callbacks = callbacks_backup # restore callbacks
if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean
self.scaler = amp.GradScaler(enabled=self.amp)
if world_size > 1:
self.model = DDP(self.model, device_ids=[RANK])
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
# Batch size
if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
# Dataloaders
batch_size = self.batch_size // max(world_size, 1)
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
if RANK in (-1, 0):
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
self.ema = ModelEMA(self.model)
if self.args.plots:
self.plot_training_labels()
# Optimizer
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
self.optimizer = self.build_optimizer(model=self.model,
name=self.args.optimizer,
lr=self.args.lr0,
momentum=self.args.momentum,
decay=weight_decay,
iterations=iterations)
# Scheduler
if self.args.cos_lr:
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
else:
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks('on_pretrain_routine_end')
def _do_train(self, world_size=1):
"""Train completed, evaluate and plot if specified by arguments."""
if world_size > 1:
self._setup_ddp(world_size)
self._setup_train(world_size)
self.epoch_time = None
self.epoch_time_start = time.time()
self.train_time_start = time.time()
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
last_opt_step = -1
self.run_callbacks('on_train_start')
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f'Starting training for {self.epochs} epochs...')
if self.args.close_mosaic:
base_idx = (self.epochs - self.args.close_mosaic) * nb
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
epoch = self.epochs # predefine for resume fully trained model edge cases
for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch
self.run_callbacks('on_train_epoch_start')
self.model.train()
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic):
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
self.train_loader.reset()
if RANK in (-1, 0):
LOGGER.info(self.progress_string())
pbar = TQDM(enumerate(self.train_loader), total=nb)
self.tloss = None
self.optimizer.zero_grad()
for i, batch in pbar:
self.run_callbacks('on_train_batch_start')
# Warmup
ni = i + nb * epoch
if ni <= nw:
xi = [0, nw] # x interp
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
for j, x in enumerate(self.optimizer.param_groups):
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward
with torch.cuda.amp.autocast(self.amp):
batch = self.preprocess_batch(batch)
self.loss, self.loss_items = self.model(batch)
if RANK != -1:
self.loss *= world_size
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
else self.loss_items
# Backward
self.scaler.scale(self.loss).backward()
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
self.optimizer_step()
last_opt_step = ni
# Log
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
if RANK in (-1, 0):
pbar.set_description(
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
self.run_callbacks('on_batch_end')
if self.args.plots and ni in self.plot_idx:
self.plot_training_samples(batch, ni)
self.run_callbacks('on_train_batch_end')
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
self.scheduler.step()
self.run_callbacks('on_train_epoch_end')
if RANK in (-1, 0):
# Validation
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
if self.args.val or final_epoch:
self.metrics, self.fitness = self.validate()
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
self.stop = self.stopper(epoch + 1, self.fitness)
# Save model
if self.args.save or (epoch + 1 == self.epochs):
self.save_model()
self.run_callbacks('on_model_save')
tnow = time.time()
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
self.run_callbacks('on_fit_epoch_end')
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
# Early Stopping
if RANK != -1: # if DDP training
broadcast_list = [self.stop if RANK == 0 else None]
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
if RANK != 0:
self.stop = broadcast_list[0]
if self.stop:
break # must break all DDP ranks
if RANK in (-1, 0):
# Do final val with best.pt
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
self.final_eval()
if self.args.plots:
self.plot_metrics()
self.run_callbacks('on_train_end')
torch.cuda.empty_cache()
self.run_callbacks('teardown')
def save_model(self):
"""Save model checkpoints based on various conditions."""
ckpt = {
'epoch': self.epoch,
'best_fitness': self.best_fitness,
'model': deepcopy(de_parallel(self.model)).half(),
'ema': deepcopy(self.ema.ema).half(),
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'train_args': vars(self.args), # save as dict
'date': datetime.now().isoformat(),
'version': __version__}
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
try:
import dill as pickle
except ImportError:
import pickle
# Save last, best and delete
torch.save(ckpt, self.last, pickle_module=pickle)
if self.best_fitness == self.fitness:
torch.save(ckpt, self.best, pickle_module=pickle)
if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
del ckpt
@staticmethod
def get_dataset(data):
"""
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
"""
return data['train'], data.get('val') or data.get('test')
def setup_model(self):
"""
load/create/download model for any task.
"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model, weights = self.model, None
ckpt = None
if str(model).endswith('.pt'):
weights, ckpt = attempt_load_one_weight(model)
cfg = ckpt['model'].yaml
else:
cfg = model
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
return ckpt
def optimizer_step(self):
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
self.scaler.unscale_(self.optimizer) # unscale gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.ema:
self.ema.update(self.model)
def preprocess_batch(self, batch):
"""
Allows custom preprocessing model inputs and ground truths depending on task type.
"""
return batch
def validate(self):
"""
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
"""
metrics = self.validator(self)
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
return metrics, fitness
def get_model(self, cfg=None, weights=None, verbose=True):
"""Get model and raise NotImplementedError for loading cfg files."""
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self):
"""Returns a NotImplementedError when the get_validator function is called."""
raise NotImplementedError('get_validator function not implemented in trainer')
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""
Returns dataloader derived from torch.data.Dataloader.
"""
raise NotImplementedError('get_dataloader function not implemented in trainer')
def build_dataset(self, img_path, mode='train', batch=None):
"""Build dataset"""
raise NotImplementedError('build_dataset function not implemented in trainer')
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor
"""
# Not needed for classification but necessary for segmentation & detection
return {'loss': loss_items} if loss_items is not None else ['loss']
def set_model_attributes(self):
"""
To set or update model parameters before training.
"""
self.model.names = self.data['names']
def build_targets(self, preds, targets):
"""Builds target tensors for training YOLO model."""
pass
def progress_string(self):
"""Returns a string describing training progress."""
return ''
# TODO: may need to put these following functions into callback
def plot_training_samples(self, batch, ni):
"""Plots training samples during YOLOv5 training."""
pass
def plot_training_labels(self):
"""Plots training labels for YOLO model."""
pass
def save_metrics(self, metrics):
"""Saves training metrics to a CSV file."""
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
with open(self.csv, 'a') as f:
f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n')
def plot_metrics(self):
"""Plot and display metrics visually."""
pass
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
path = Path(name)
self.plots[path] = {'data': data, 'timestamp': time.time()}
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO model."""
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
LOGGER.info(f'\nValidating {f}...')
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None)
self.run_callbacks('on_fit_epoch_end')
def check_resume(self, overrides):
"""Check if resume checkpoint exists and update arguments accordingly."""
resume = self.args.resume
if resume:
try:
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
last = Path(check_file(resume) if exists else get_latest_run())
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
ckpt_args = attempt_load_weights(last).args
if not Path(ckpt_args['data']).exists():
ckpt_args['data'] = self.args.data
resume = True
self.args = get_cfg(ckpt_args)
self.args.model = str(last) # reinstate model
for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
if k in overrides:
setattr(self.args, k, overrides[k])
except Exception as e:
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
"i.e. 'yolo train resume model=path/to/last.pt'") from e
self.resume = resume
def resume_training(self, ckpt):
"""Resume YOLO training from given epoch and best fitness."""
if ckpt is None:
return
best_fitness = 0.0
start_epoch = ckpt['epoch'] + 1
if ckpt['optimizer'] is not None:
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
best_fitness = ckpt['best_fitness']
if self.ema and ckpt.get('ema'):
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
self.ema.updates = ckpt['updates']
if self.resume:
assert start_epoch > 0, \
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
LOGGER.info(
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
if self.epochs < start_epoch:
LOGGER.info(
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
self.epochs += ckpt['epoch'] # finetune additional epochs
self.best_fitness = best_fitness
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
"""
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
momentum, weight decay, and number of iterations.
Args:
model (torch.nn.Module): The model for which to build an optimizer.
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
based on the number of iterations. Default: 'auto'.
lr (float, optional): The learning rate for the optimizer. Default: 0.001.
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
iterations (float, optional): The number of iterations, which determines the optimizer if
name is 'auto'. Default: 1e5.
Returns:
(torch.optim.Optimizer): The constructed optimizer.
"""
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
if name == 'auto':
nc = getattr(model, 'nc', 10) # number of classes
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
fullname = f'{module_name}.{param_name}' if module_name else param_name
if 'bias' in fullname: # bias (no decay)
g[2].append(param)
elif isinstance(module, bn): # weight (no decay)
g[1].append(param)
else: # weight (with decay)
g[0].append(param)
if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
elif name == 'RMSProp':
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
elif name == 'SGD':
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
else:
raise NotImplementedError(
f"Optimizer '{name}' not found in list of available optimizers "
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
return optimizer

205
ultralytics/engine/tuner.py Normal file
View File

@ -0,0 +1,205 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
This module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection,
instance segmentation, image classification, pose estimation, and multi-object tracking.
Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
Example:
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
```python
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10)
```
"""
import random
import time
from copy import deepcopy
import numpy as np
from ultralytics import YOLO
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, yaml_print, yaml_save
class Tuner:
"""
Class responsible for hyperparameter tuning of YOLO models.
The class evolves YOLO model hyperparameters over a given number of iterations
by mutating them according to the search space and retraining the model to evaluate their performance.
Attributes:
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
tune_dir (Path): Directory where evolution logs and results will be saved.
evolve_csv (Path): Path to the CSV file where evolution logs are saved.
Methods:
_mutate(hyp: dict) -> dict:
Mutates the given hyperparameters within the bounds specified in `self.space`.
__call__():
Executes the hyperparameter evolution across multiple iterations.
Example:
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
```python
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10, val=False, cache=True)
```
"""
def __init__(self, args=DEFAULT_CFG, _callbacks=None):
"""
Initialize the Tuner with configurations.
Args:
args (dict, optional): Configuration for hyperparameter evolution.
"""
self.args = get_cfg(overrides=args)
self.space = { # key: (min, max, gain(optionaL))
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
'lr0': (1e-5, 1e-1),
'lrf': (0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.6, 0.98, 0.3), # SGD momentum/Adam beta1
'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4
'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (0.0, 0.95), # warmup initial momentum
'box': (0.02, 0.2), # box loss gain
'cls': (0.2, 4.0), # cls loss gain (scale with pixels)
'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': (0.0, 45.0), # image rotation (+/- deg)
'translate': (0.0, 0.9), # image translation (+/- fraction)
'scale': (0.0, 0.9), # image scale (+/- gain)
'shear': (0.0, 10.0), # image shear (+/- deg)
'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (0.0, 1.0), # image flip up-down (probability)
'fliplr': (0.0, 1.0), # image flip left-right (probability)
'mosaic': (0.0, 1.0), # image mixup (probability)
'mixup': (0.0, 1.0), # image mixup (probability)
'copy_paste': (0.0, 1.0)} # segment copy-paste (probability)
self.tune_dir = get_save_dir(self.args, name='_tune')
self.evolve_csv = self.tune_dir / 'evolve.csv'
self.callbacks = _callbacks or callbacks.get_default_callbacks()
callbacks.add_integration_callbacks(self)
LOGGER.info(f"Initialized Tuner instance with 'tune_dir={self.tune_dir}'.")
def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2):
"""
Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
Args:
parent (str): Parent selection method: 'single' or 'weighted'.
n (int): Number of parents to consider.
mutation (float): Probability of a parameter mutation in any given iteration.
sigma (float): Standard deviation for Gaussian random number generator.
Returns:
(dict): A dictionary containing mutated hyperparameters.
"""
if self.evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
# Select parent(s)
x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
fitness = x[:, 0] # first column
n = min(n, len(x)) # number of previous results to consider
x = x[np.argsort(-fitness)][:n] # top n mutations
w = x[:, 0] - x[:, 0].min() + 1E-6 # weights (sum > 0)
if parent == 'single' or len(x) == 1:
# x = x[random.randint(0, n - 1)] # random selection
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
elif parent == 'weighted':
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
# Mutate
r = np.random # method
r.seed(int(time.time()))
g = np.array([v[2] if len(v) == 3 else 1.0 for k, v in self.space.items()]) # gains 0-1
ng = len(self.space)
v = np.ones(ng)
while all(v == 1): # mutate until a change occurs (prevent duplicates)
v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)
hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())}
else:
hyp = {k: getattr(self.args, k) for k in self.space.keys()}
# Constrain to limits
for k, v in self.space.items():
hyp[k] = max(hyp[k], v[0]) # lower limit
hyp[k] = min(hyp[k], v[1]) # upper limit
hyp[k] = round(hyp[k], 5) # significant digits
return hyp
def __call__(self, model=None, iterations=10, prefix=colorstr('Tuner:')):
"""
Executes the hyperparameter evolution process when the Tuner instance is called.
This method iterates through the number of iterations, performing the following steps in each iteration:
1. Load the existing hyperparameters or initialize new ones.
2. Mutate the hyperparameters using the `mutate` method.
3. Train a YOLO model with the mutated hyperparameters.
4. Log the fitness score and mutated hyperparameters to a CSV file.
Args:
model (Model): A pre-initialized YOLO model to be used for training.
iterations (int): The number of generations to run the evolution for.
Note:
The method utilizes the `self.evolve_csv` Path object to read and log hyperparameters and fitness scores.
Ensure this path is set correctly in the Tuner instance.
"""
t0 = time.time()
best_save_dir, best_metrics = None, None
self.tune_dir.mkdir(parents=True, exist_ok=True)
for i in range(iterations):
# Mutate hyperparameters
mutated_hyp = self._mutate()
LOGGER.info(f'{prefix} Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
try:
# Train YOLO model with mutated hyperparameters
train_args = {**vars(self.args), **mutated_hyp}
results = (deepcopy(model) or YOLO(self.args.model)).train(**train_args)
fitness = results.fitness
except Exception as e:
LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i}\n{e}')
fitness = 0.0
# Save results and mutated_hyp to evolve_csv
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
headers = '' if self.evolve_csv.exists() else (','.join(['fitness_score'] + list(self.space.keys())) + '\n')
with open(self.evolve_csv, 'a') as f:
f.write(headers + ','.join(map(str, log_row)) + '\n')
# Print tuning results
x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
fitness = x[:, 0] # first column
best_idx = fitness.argmax()
best_is_current = best_idx == i
if best_is_current:
best_save_dir = results.save_dir
best_metrics = {k: round(v, 5) for k, v in results.results_dict.items()}
header = (f'{prefix} {i + 1} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
f'{prefix} Results saved to {colorstr("bold", self.tune_dir)}\n'
f'{prefix} Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
f'{prefix} Best fitness metrics are {best_metrics}\n'
f'{prefix} Best fitness model is {best_save_dir}\n'
f'{prefix} Best fitness hyperparameters are printed below.\n')
LOGGER.info('\n' + header)
# Save turning results
data = {k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())}
header = header.replace(prefix, '#').replace('/', '').replace('', '') + '\n'
yaml_save(self.tune_dir / 'best.yaml', data=data, header=header)
yaml_print(self.tune_dir / 'best.yaml')

View File

@ -0,0 +1,325 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Check a model's accuracy on a test or val split of a dataset.
Usage:
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
Usage - formats:
$ yolo mode=val model=yolov8n.pt # PyTorch
yolov8n.torchscript # TorchScript
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
yolov8n_openvino_model # OpenVINO
yolov8n.engine # TensorRT
yolov8n.mlpackage # CoreML (macOS-only)
yolov8n_saved_model # TensorFlow SavedModel
yolov8n.pb # TensorFlow GraphDef
yolov8n.tflite # TensorFlow Lite
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
yolov8n_paddle_model # PaddlePaddle
"""
import json
import time
from pathlib import Path
import numpy as np
import torch
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
class BaseValidator:
"""
BaseValidator
A base class for creating validators.
Attributes:
args (SimpleNamespace): Configuration for the validator.
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
device (torch.device): Device to use for validation.
batch_i (int): Current batch index.
training (bool): Whether the model is in training mode.
names (dict): Class names.
seen: Records the number of images seen so far during validation.
stats: Placeholder for statistics during validation.
confusion_matrix: Placeholder for a confusion matrix.
nc: Number of classes.
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
jdict (dict): Dictionary to store JSON validation results.
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
batch processing times in milliseconds.
save_dir (Path): Directory to save results.
plots (dict): Dictionary to store plots for visualization.
callbacks (dict): Dictionary to store various callback functions.
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""
Initializes a BaseValidator instance.
Args:
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
save_dir (Path, optional): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions.
"""
self.args = get_cfg(overrides=args)
self.dataloader = dataloader
self.pbar = pbar
self.model = None
self.data = None
self.device = None
self.batch_i = None
self.training = True
self.names = None
self.seen = None
self.stats = None
self.confusion_matrix = None
self.nc = None
self.iouv = None
self.jdict = None
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.save_dir = save_dir or get_save_dir(self.args)
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001
self.plots = {}
self.callbacks = _callbacks or callbacks.get_default_callbacks()
@smart_inference_mode()
def __call__(self, trainer=None, model=None):
"""
Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
gets priority).
"""
self.training = trainer is not None
augment = self.args.augment and (not self.training)
if self.training:
self.device = trainer.device
self.data = trainer.data
self.args.half = self.device.type != 'cpu' # force FP16 val during training
model = trainer.ema.ema or trainer.model
model = model.half() if self.args.half else model.float()
# self.model = model
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
model.eval()
else:
callbacks.add_integration_callbacks(self)
self.run_callbacks('on_val_start')
model = AutoBackend(model or self.args.model,
device=select_device(self.args.device, self.args.batch),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half)
# self.model = model
self.device = model.device # update device
self.args.half = model.fp16 # update half
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_imgsz(self.args.imgsz, stride=stride)
if engine:
self.args.batch = model.batch_size
elif not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.split('.')[-1] in ('yaml', 'yml'):
self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data, split=self.args.split)
else:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
if self.device.type in ('cpu', 'mps'):
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
if not pt:
self.args.rect = False
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
model.eval()
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
dt = Profile(), Profile(), Profile(), Profile()
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
self.init_metrics(de_parallel(model))
self.jdict = [] # empty before each val
for batch_i, batch in enumerate(bar):
self.run_callbacks('on_val_batch_start')
self.batch_i = batch_i
# Preprocess
with dt[0]:
batch = self.preprocess(batch)
# Inference
with dt[1]:
preds = model(batch['img'], augment=augment)
# Loss
with dt[2]:
if self.training:
self.loss += model.loss(batch, preds)[1]
# Postprocess
with dt[3]:
preds = self.postprocess(preds)
self.update_metrics(preds, batch)
if self.args.plots and batch_i < 3:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
self.run_callbacks('on_val_batch_end')
stats = self.get_stats()
self.check_stats(stats)
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
self.finalize_metrics()
self.print_results()
self.run_callbacks('on_val_end')
if self.training:
model.float()
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else:
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
tuple(self.speed.values()))
if self.args.save_json and self.jdict:
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
LOGGER.info(f'Saving {f.name}...')
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
if self.args.plots or self.args.save_json:
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
return stats
def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
"""
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
Args:
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
true_classes (torch.Tensor): Target class indices of shape(M,).
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
use_scipy (bool): Whether to use scipy for matching (more precise).
Returns:
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
"""
# Dx10 matrix, where D - detections, 10 - IoU thresholds
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
# LxD matrix where L - labels (rows), D - detections (columns)
correct_class = true_classes[:, None] == pred_classes
iou = iou * correct_class # zero out the wrong classes
iou = iou.cpu().numpy()
for i, threshold in enumerate(self.iouv.cpu().tolist()):
if use_scipy:
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
import scipy # scope import to avoid importing for all commands
cost_matrix = iou * (iou >= threshold)
if cost_matrix.any():
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
valid = cost_matrix[labels_idx, detections_idx] > 0
if valid.any():
correct[detections_idx[valid], i] = True
else:
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
matches = np.array(matches).T
if matches.shape[0]:
if matches.shape[0] > 1:
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
def add_callback(self, event: str, callback):
"""Appends the given callback."""
self.callbacks[event].append(callback)
def run_callbacks(self, event: str):
"""Runs all callbacks associated with a specified event."""
for callback in self.callbacks.get(event, []):
callback(self)
def get_dataloader(self, dataset_path, batch_size):
"""Get data loader from dataset path and batch size."""
raise NotImplementedError('get_dataloader function not implemented for this validator')
def build_dataset(self, img_path):
"""Build dataset"""
raise NotImplementedError('build_dataset function not implemented in validator')
def preprocess(self, batch):
"""Preprocesses an input batch."""
return batch
def postprocess(self, preds):
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
return preds
def init_metrics(self, model):
"""Initialize performance metrics for the YOLO model."""
pass
def update_metrics(self, preds, batch):
"""Updates metrics based on predictions and batch."""
pass
def finalize_metrics(self, *args, **kwargs):
"""Finalizes and returns all metrics."""
pass
def get_stats(self):
"""Returns statistics about the model's performance."""
return {}
def check_stats(self, stats):
"""Checks statistics."""
pass
def print_results(self):
"""Prints the results of the model's predictions."""
pass
def get_desc(self):
"""Get description of the YOLO model."""
pass
@property
def metric_keys(self):
"""Returns the metric keys used in YOLO training/validation."""
return []
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
self.plots[Path(name)] = {'data': data, 'timestamp': time.time()}
# TODO: may need to put these following functions into callback
def plot_val_samples(self, batch, ni):
"""Plots validation samples during training."""
pass
def plot_predictions(self, batch, preds, ni):
"""Plots YOLO model predictions on batch images."""
pass
def pred_to_json(self, preds, batch):
"""Convert predictions to JSON format."""
pass
def eval_json(self, stats):
"""Evaluate and return JSON format of prediction statistics."""
pass