add yolo v10 and modify pipeline
This commit is contained in:
@ -0,0 +1 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
Binary file not shown.
Binary file not shown.
BIN
ultralytics/engine/__pycache__/exporter.cpython-39.pyc
Normal file
BIN
ultralytics/engine/__pycache__/exporter.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -11,8 +11,8 @@ Usage - sources:
|
||||
list.txt # list of images
|
||||
list.streams # list of streams
|
||||
'path/*.jpg' # glob
|
||||
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
||||
'https://youtu.be/LNwODJXcvt4' # YouTube
|
||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream
|
||||
|
||||
Usage - formats:
|
||||
$ yolo mode=predict model=yolov8n.pt # PyTorch
|
||||
@ -26,8 +26,12 @@ Usage - formats:
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
yolov8n_ncnn_model # NCNN
|
||||
"""
|
||||
|
||||
import platform
|
||||
import re
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
@ -58,7 +62,7 @@ Example:
|
||||
|
||||
class BasePredictor:
|
||||
"""
|
||||
BasePredictor
|
||||
BasePredictor.
|
||||
|
||||
A base class for creating predictors.
|
||||
|
||||
@ -70,9 +74,7 @@ class BasePredictor:
|
||||
data (dict): Data configuration.
|
||||
device (torch.device): Device used for prediction.
|
||||
dataset (Dataset): Dataset used for prediction.
|
||||
vid_path (str): Path to video file.
|
||||
vid_writer (cv2.VideoWriter): Video writer for saving video output.
|
||||
data_path (str): Path to data.
|
||||
vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
@ -97,19 +99,22 @@ class BasePredictor:
|
||||
self.imgsz = None
|
||||
self.device = None
|
||||
self.dataset = None
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
self.vid_writer = {} # dict of {save_path: video_writer, ...}
|
||||
self.plotted_img = None
|
||||
self.data_path = None
|
||||
self.source_type = None
|
||||
self.seen = 0
|
||||
self.windows = []
|
||||
self.batch = None
|
||||
self.results = None
|
||||
self.transforms = None
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
self.txt_path = None
|
||||
self._lock = threading.Lock() # for automatic thread-safe inference
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def preprocess(self, im):
|
||||
"""Prepares input image before inference.
|
||||
"""
|
||||
Prepares input image before inference.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
||||
@ -128,9 +133,13 @@ class BasePredictor:
|
||||
return im
|
||||
|
||||
def inference(self, im, *args, **kwargs):
|
||||
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
|
||||
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
"""Runs inference on a given image using the specified model and arguments."""
|
||||
visualize = (
|
||||
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
|
||||
if self.args.visualize and (not self.source_type.tensor)
|
||||
else False
|
||||
)
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
@ -142,45 +151,10 @@ class BasePredictor:
|
||||
Returns:
|
||||
(list): A list of transformed images.
|
||||
"""
|
||||
same_shapes = all(x.shape == im[0].shape for x in im)
|
||||
same_shapes = len({x.shape for x in im}) == 1
|
||||
letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
|
||||
return [letterbox(image=x) for x in im]
|
||||
|
||||
def write_results(self, idx, results, batch):
|
||||
"""Write inference results to a file or directory."""
|
||||
p, im, _ = batch
|
||||
log_string = ''
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
frame = getattr(self.dataset, 'frame', 0)
|
||||
self.data_path = p
|
||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||
result = results[idx]
|
||||
log_string += result.verbose()
|
||||
|
||||
if self.args.save or self.args.show: # Add bbox to image
|
||||
plot_args = {
|
||||
'line_width': self.args.line_width,
|
||||
'boxes': self.args.boxes,
|
||||
'conf': self.args.show_conf,
|
||||
'labels': self.args.show_labels}
|
||||
if not self.args.retina_masks:
|
||||
plot_args['im_gpu'] = im[idx]
|
||||
self.plotted_img = result.plot(**plot_args)
|
||||
# Write
|
||||
if self.args.save_txt:
|
||||
result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
|
||||
if self.args.save_crop:
|
||||
result.save_crop(save_dir=self.save_dir / 'crops',
|
||||
file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}'))
|
||||
|
||||
return log_string
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Post-processes predictions for an image and returns them."""
|
||||
return preds
|
||||
@ -194,157 +168,224 @@ class BasePredictor:
|
||||
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
|
||||
|
||||
def predict_cli(self, source=None, model=None):
|
||||
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
|
||||
"""
|
||||
Method used for CLI prediction.
|
||||
|
||||
It uses always generator as outputs as not required by CLI mode.
|
||||
"""
|
||||
gen = self.stream_inference(source, model)
|
||||
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
||||
for _ in gen: # noqa, running CLI inference without accumulating any outputs (do not modify)
|
||||
pass
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Sets up source and inference mode."""
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
self.transforms = getattr(self.model.model, 'transforms', classify_transforms(
|
||||
self.imgsz[0])) if self.args.task == 'classify' else None
|
||||
self.dataset = load_inference_source(source=source,
|
||||
imgsz=self.imgsz,
|
||||
vid_stride=self.args.vid_stride,
|
||||
stream_buffer=self.args.stream_buffer)
|
||||
self.transforms = (
|
||||
getattr(
|
||||
self.model.model,
|
||||
"transforms",
|
||||
classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
|
||||
)
|
||||
if self.args.task == "classify"
|
||||
else None
|
||||
)
|
||||
self.dataset = load_inference_source(
|
||||
source=source,
|
||||
batch=self.args.batch,
|
||||
vid_stride=self.args.vid_stride,
|
||||
buffer=self.args.stream_buffer,
|
||||
)
|
||||
self.source_type = self.dataset.source_type
|
||||
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
|
||||
len(self.dataset) > 1000 or # images
|
||||
any(getattr(self.dataset, 'video_flag', [False]))): # videos
|
||||
if not getattr(self, "stream", True) and (
|
||||
self.source_type.stream
|
||||
or self.source_type.screenshot
|
||||
or len(self.dataset) > 1000 # many images
|
||||
or any(getattr(self.dataset, "video_flag", [False]))
|
||||
): # videos
|
||||
LOGGER.warning(STREAM_WARNING)
|
||||
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||
self.vid_writer = {}
|
||||
|
||||
@smart_inference_mode()
|
||||
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
||||
"""Streams real-time inference on camera feed and saves results to file."""
|
||||
if self.args.verbose:
|
||||
LOGGER.info('')
|
||||
LOGGER.info("")
|
||||
|
||||
# Setup model
|
||||
if not self.model:
|
||||
self.setup_model(model)
|
||||
|
||||
# Setup source every time predict is called
|
||||
self.setup_source(source if source is not None else self.args.source)
|
||||
with self._lock: # for thread-safe inference
|
||||
# Setup source every time predict is called
|
||||
self.setup_source(source if source is not None else self.args.source)
|
||||
|
||||
# Check if save_dir/ label file exists
|
||||
if self.args.save or self.args.save_txt:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
# Check if save_dir/ label file exists
|
||||
if self.args.save or self.args.save_txt:
|
||||
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
||||
self.done_warmup = True
|
||||
# Warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
||||
self.done_warmup = True
|
||||
|
||||
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
self.run_callbacks('on_predict_start')
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks('on_predict_batch_start')
|
||||
self.batch = batch
|
||||
path, im0s, vid_cap, s = batch
|
||||
self.seen, self.windows, self.batch = 0, [], None
|
||||
profilers = (
|
||||
ops.Profile(device=self.device),
|
||||
ops.Profile(device=self.device),
|
||||
ops.Profile(device=self.device),
|
||||
)
|
||||
self.run_callbacks("on_predict_start")
|
||||
for self.batch in self.dataset:
|
||||
self.run_callbacks("on_predict_batch_start")
|
||||
paths, im0s, s = self.batch
|
||||
|
||||
# Preprocess
|
||||
with profilers[0]:
|
||||
im = self.preprocess(im0s)
|
||||
# Preprocess
|
||||
with profilers[0]:
|
||||
im = self.preprocess(im0s)
|
||||
|
||||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
if self.args.embed:
|
||||
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
|
||||
continue
|
||||
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
self.results = self.postprocess(preds, im, im0s)
|
||||
self.run_callbacks('on_predict_postprocess_end')
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
self.results = self.postprocess(preds, im, im0s)
|
||||
self.run_callbacks("on_predict_postprocess_end")
|
||||
|
||||
# Visualize, save, write results
|
||||
n = len(im0s)
|
||||
for i in range(n):
|
||||
self.seen += 1
|
||||
self.results[i].speed = {
|
||||
'preprocess': profilers[0].dt * 1E3 / n,
|
||||
'inference': profilers[1].dt * 1E3 / n,
|
||||
'postprocess': profilers[2].dt * 1E3 / n}
|
||||
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
||||
p = Path(p)
|
||||
# Visualize, save, write results
|
||||
n = len(im0s)
|
||||
for i in range(n):
|
||||
self.seen += 1
|
||||
self.results[i].speed = {
|
||||
"preprocess": profilers[0].dt * 1e3 / n,
|
||||
"inference": profilers[1].dt * 1e3 / n,
|
||||
"postprocess": profilers[2].dt * 1e3 / n,
|
||||
}
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
s[i] += self.write_results(i, Path(paths[i]), im, s)
|
||||
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
s += self.write_results(i, self.results, (p, im, im0))
|
||||
if self.args.save or self.args.save_txt:
|
||||
self.results[i].save_dir = self.save_dir.__str__()
|
||||
if self.args.show and self.plotted_img is not None:
|
||||
self.show(p)
|
||||
if self.args.save and self.plotted_img is not None:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
# Print batch results
|
||||
if self.args.verbose:
|
||||
LOGGER.info("\n".join(s))
|
||||
|
||||
self.run_callbacks('on_predict_batch_end')
|
||||
yield from self.results
|
||||
|
||||
# Print time (inference-only)
|
||||
if self.args.verbose:
|
||||
LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
|
||||
self.run_callbacks("on_predict_batch_end")
|
||||
yield from self.results
|
||||
|
||||
# Release assets
|
||||
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
|
||||
self.vid_writer[-1].release() # release final video writer
|
||||
for v in self.vid_writer.values():
|
||||
if isinstance(v, cv2.VideoWriter):
|
||||
v.release()
|
||||
|
||||
# Print results
|
||||
# Print final results
|
||||
if self.args.verbose and self.seen:
|
||||
t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
|
||||
LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
|
||||
f'{(1, 3, *im.shape[2:])}' % t)
|
||||
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
|
||||
LOGGER.info(
|
||||
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
|
||||
f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t
|
||||
)
|
||||
if self.args.save or self.args.save_txt or self.args.save_crop:
|
||||
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
|
||||
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
|
||||
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
||||
self.run_callbacks('on_predict_end')
|
||||
self.run_callbacks("on_predict_end")
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
||||
self.model = AutoBackend(model or self.args.model,
|
||||
device=select_device(self.args.device, verbose=verbose),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half,
|
||||
fuse=True,
|
||||
verbose=verbose)
|
||||
self.model = AutoBackend(
|
||||
weights=model or self.args.model,
|
||||
device=select_device(self.args.device, verbose=verbose),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half,
|
||||
batch=self.args.batch,
|
||||
fuse=True,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
self.device = self.model.device # update device
|
||||
self.args.half = self.model.fp16 # update half
|
||||
self.model.eval()
|
||||
|
||||
def show(self, p):
|
||||
"""Display an image in a window using OpenCV imshow()."""
|
||||
im0 = self.plotted_img
|
||||
if platform.system() == 'Linux' and p not in self.windows:
|
||||
self.windows.append(p)
|
||||
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
||||
cv2.imshow(str(p), im0)
|
||||
cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond
|
||||
def write_results(self, i, p, im, s):
|
||||
"""Write inference results to a file or directory."""
|
||||
string = "" # print string
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
||||
string += f"{i}: "
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
match = re.search(r"frame (\d+)/", s[i])
|
||||
frame = int(match.group(1)) if match else None # 0 if frame undetermined
|
||||
|
||||
def save_preds(self, vid_cap, idx, save_path):
|
||||
self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
|
||||
string += "%gx%g " % im.shape[2:]
|
||||
result = self.results[i]
|
||||
result.save_dir = self.save_dir.__str__() # used in other locations
|
||||
string += result.verbose() + f"{result.speed['inference']:.1f}ms"
|
||||
|
||||
# Add predictions to image
|
||||
if self.args.save or self.args.show:
|
||||
self.plotted_img = result.plot(
|
||||
line_width=self.args.line_width,
|
||||
boxes=self.args.show_boxes,
|
||||
conf=self.args.show_conf,
|
||||
labels=self.args.show_labels,
|
||||
im_gpu=None if self.args.retina_masks else im[i],
|
||||
)
|
||||
|
||||
# Save results
|
||||
if self.args.save_txt:
|
||||
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
|
||||
if self.args.save_crop:
|
||||
result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
|
||||
if self.args.show:
|
||||
self.show(str(p))
|
||||
if self.args.save:
|
||||
self.save_predicted_images(str(self.save_dir / (p.name or "tmp.jpg")), frame)
|
||||
|
||||
return string
|
||||
|
||||
def save_predicted_images(self, save_path="", frame=0):
|
||||
"""Save video predictions as mp4 at specified path."""
|
||||
im0 = self.plotted_img
|
||||
# Save imgs
|
||||
if self.dataset.mode == 'image':
|
||||
cv2.imwrite(save_path, im0)
|
||||
else: # 'video' or 'stream'
|
||||
if self.vid_path[idx] != save_path: # new video
|
||||
self.vid_path[idx] = save_path
|
||||
if isinstance(self.vid_writer[idx], cv2.VideoWriter):
|
||||
self.vid_writer[idx].release() # release previous video writer
|
||||
if vid_cap: # video
|
||||
fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec
|
||||
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
else: # stream
|
||||
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
||||
suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG')
|
||||
save_path = str(Path(save_path).with_suffix(suffix))
|
||||
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
|
||||
self.vid_writer[idx].write(im0)
|
||||
im = self.plotted_img
|
||||
|
||||
# Save videos and streams
|
||||
if self.dataset.mode in {"stream", "video"}:
|
||||
fps = self.dataset.fps if self.dataset.mode == "video" else 30
|
||||
frames_path = f'{save_path.split(".", 1)[0]}_frames/'
|
||||
if save_path not in self.vid_writer: # new video
|
||||
if self.args.save_frames:
|
||||
Path(frames_path).mkdir(parents=True, exist_ok=True)
|
||||
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
|
||||
self.vid_writer[save_path] = cv2.VideoWriter(
|
||||
filename=str(Path(save_path).with_suffix(suffix)),
|
||||
fourcc=cv2.VideoWriter_fourcc(*fourcc),
|
||||
fps=fps, # integer required, floats produce error in MP4 codec
|
||||
frameSize=(im.shape[1], im.shape[0]), # (width, height)
|
||||
)
|
||||
|
||||
# Save video
|
||||
self.vid_writer[save_path].write(im)
|
||||
if self.args.save_frames:
|
||||
cv2.imwrite(f"{frames_path}{frame}.jpg", im)
|
||||
|
||||
# Save images
|
||||
else:
|
||||
cv2.imwrite(save_path, im)
|
||||
|
||||
def show(self, p=""):
|
||||
"""Display an image in a window using OpenCV imshow()."""
|
||||
im = self.plotted_img
|
||||
if platform.system() == "Linux" and p not in self.windows:
|
||||
self.windows.append(p)
|
||||
cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||
cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
|
||||
cv2.imshow(p, im)
|
||||
cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Runs all registered callbacks for a specific event."""
|
||||
@ -352,7 +393,5 @@ class BasePredictor:
|
||||
callback(self)
|
||||
|
||||
def add_callback(self, event: str, func):
|
||||
"""
|
||||
Add callback
|
||||
"""
|
||||
"""Add callback."""
|
||||
self.callbacks[event].append(func)
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Ultralytics Results, Boxes and Masks classes for handling inference results
|
||||
Ultralytics Results, Boxes and Masks classes for handling inference results.
|
||||
|
||||
Usage: See https://docs.ultralytics.com/modes/predict/
|
||||
"""
|
||||
@ -13,17 +13,17 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.utils import LOGGER, SimpleClass, deprecation_warn, ops
|
||||
from ultralytics.utils import LOGGER, SimpleClass, ops
|
||||
from ultralytics.utils.plotting import Annotator, colors, save_one_box
|
||||
from ultralytics.utils.torch_utils import smart_inference_mode
|
||||
|
||||
|
||||
class BaseTensor(SimpleClass):
|
||||
"""
|
||||
Base tensor class with additional methods for easy manipulation and device handling.
|
||||
"""
|
||||
"""Base tensor class with additional methods for easy manipulation and device handling."""
|
||||
|
||||
def __init__(self, data, orig_shape) -> None:
|
||||
"""Initialize BaseTensor with data and original shape.
|
||||
"""
|
||||
Initialize BaseTensor with data and original shape.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
|
||||
@ -67,45 +67,63 @@ class Results(SimpleClass):
|
||||
"""
|
||||
A class for storing and manipulating inference results.
|
||||
|
||||
Args:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
path (str): The path to the image file.
|
||||
names (dict): A dictionary of class names.
|
||||
boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection.
|
||||
masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image.
|
||||
probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task.
|
||||
keypoints (List[List[float]], optional): A list of detected keypoints for each object.
|
||||
|
||||
Attributes:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
orig_shape (tuple): The original image shape in (height, width) format.
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (Probs, optional): A Probs object containing probabilities of each class for classification task.
|
||||
keypoints (Keypoints, optional): A Keypoints object containing detected keypoints for each object.
|
||||
speed (dict): A dictionary of preprocess, inference, and postprocess speeds in milliseconds per image.
|
||||
names (dict): A dictionary of class names.
|
||||
path (str): The path to the image file.
|
||||
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
||||
orig_img (numpy.ndarray): Original image as a numpy array.
|
||||
orig_shape (tuple): Original image shape in (height, width) format.
|
||||
boxes (Boxes, optional): Object containing detection bounding boxes.
|
||||
masks (Masks, optional): Object containing detection masks.
|
||||
probs (Probs, optional): Object containing class probabilities for classification tasks.
|
||||
keypoints (Keypoints, optional): Object containing detected keypoints for each object.
|
||||
speed (dict): Dictionary of preprocess, inference, and postprocess speeds (ms/image).
|
||||
names (dict): Dictionary of class names.
|
||||
path (str): Path to the image file.
|
||||
|
||||
Methods:
|
||||
update(boxes=None, masks=None, probs=None, obb=None): Updates object attributes with new detection results.
|
||||
cpu(): Returns a copy of the Results object with all tensors on CPU memory.
|
||||
numpy(): Returns a copy of the Results object with all tensors as numpy arrays.
|
||||
cuda(): Returns a copy of the Results object with all tensors on GPU memory.
|
||||
to(*args, **kwargs): Returns a copy of the Results object with tensors on a specified device and dtype.
|
||||
new(): Returns a new Results object with the same image, path, and names.
|
||||
plot(...): Plots detection results on an input image, returning an annotated image.
|
||||
show(): Show annotated results to screen.
|
||||
save(filename): Save annotated results to file.
|
||||
verbose(): Returns a log string for each task, detailing detections and classifications.
|
||||
save_txt(txt_file, save_conf=False): Saves detection results to a text file.
|
||||
save_crop(save_dir, file_name=Path("im.jpg")): Saves cropped detection images.
|
||||
tojson(normalize=False): Converts detection results to JSON format.
|
||||
"""
|
||||
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
|
||||
"""Initialize the Results class."""
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None, obb=None) -> None:
|
||||
"""
|
||||
Initialize the Results class.
|
||||
|
||||
Args:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
path (str): The path to the image file.
|
||||
names (dict): A dictionary of class names.
|
||||
boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection.
|
||||
masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image.
|
||||
probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task.
|
||||
keypoints (torch.tensor, optional): A 2D tensor of keypoint coordinates for each detection.
|
||||
obb (torch.tensor, optional): A 2D tensor of oriented bounding box coordinates for each detection.
|
||||
"""
|
||||
self.orig_img = orig_img
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = Probs(probs) if probs is not None else None
|
||||
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
||||
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
|
||||
self.obb = OBB(obb, self.orig_shape) if obb is not None else None
|
||||
self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image
|
||||
self.names = names
|
||||
self.path = path
|
||||
self.save_dir = None
|
||||
self._keys = 'boxes', 'masks', 'probs', 'keypoints'
|
||||
self._keys = "boxes", "masks", "probs", "keypoints", "obb"
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return a Results object for the specified index."""
|
||||
return self._apply('__getitem__', idx)
|
||||
return self._apply("__getitem__", idx)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of detections in the Results object."""
|
||||
@ -114,17 +132,30 @@ class Results(SimpleClass):
|
||||
if v is not None:
|
||||
return len(v)
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None):
|
||||
def update(self, boxes=None, masks=None, probs=None, obb=None):
|
||||
"""Update the boxes, masks, and probs attributes of the Results object."""
|
||||
if boxes is not None:
|
||||
ops.clip_boxes(boxes, self.orig_shape) # clip boxes
|
||||
self.boxes = Boxes(boxes, self.orig_shape)
|
||||
self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape)
|
||||
if masks is not None:
|
||||
self.masks = Masks(masks, self.orig_shape)
|
||||
if probs is not None:
|
||||
self.probs = probs
|
||||
if obb is not None:
|
||||
self.obb = OBB(obb, self.orig_shape)
|
||||
|
||||
def _apply(self, fn, *args, **kwargs):
|
||||
"""
|
||||
Applies a function to all non-empty attributes and returns a new Results object with modified attributes. This
|
||||
function is internally called by methods like .to(), .cuda(), .cpu(), etc.
|
||||
|
||||
Args:
|
||||
fn (str): The name of the function to apply.
|
||||
*args: Variable length argument list to pass to the function.
|
||||
**kwargs: Arbitrary keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
Results: A new Results object with attributes modified by the applied function.
|
||||
"""
|
||||
r = self.new()
|
||||
for k in self._keys:
|
||||
v = getattr(self, k)
|
||||
@ -134,40 +165,42 @@ class Results(SimpleClass):
|
||||
|
||||
def cpu(self):
|
||||
"""Return a copy of the Results object with all tensors on CPU memory."""
|
||||
return self._apply('cpu')
|
||||
return self._apply("cpu")
|
||||
|
||||
def numpy(self):
|
||||
"""Return a copy of the Results object with all tensors as numpy arrays."""
|
||||
return self._apply('numpy')
|
||||
return self._apply("numpy")
|
||||
|
||||
def cuda(self):
|
||||
"""Return a copy of the Results object with all tensors on GPU memory."""
|
||||
return self._apply('cuda')
|
||||
return self._apply("cuda")
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Return a copy of the Results object with tensors on the specified device and dtype."""
|
||||
return self._apply('to', *args, **kwargs)
|
||||
return self._apply("to", *args, **kwargs)
|
||||
|
||||
def new(self):
|
||||
"""Return a new Results object with the same image, path, and names."""
|
||||
return Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
conf=True,
|
||||
line_width=None,
|
||||
font_size=None,
|
||||
font='Arial.ttf',
|
||||
pil=False,
|
||||
img=None,
|
||||
im_gpu=None,
|
||||
kpt_radius=5,
|
||||
kpt_line=True,
|
||||
labels=True,
|
||||
boxes=True,
|
||||
masks=True,
|
||||
probs=True,
|
||||
**kwargs # deprecated args TODO: remove support in 8.2
|
||||
self,
|
||||
conf=True,
|
||||
line_width=None,
|
||||
font_size=None,
|
||||
font="Arial.ttf",
|
||||
pil=False,
|
||||
img=None,
|
||||
im_gpu=None,
|
||||
kpt_radius=5,
|
||||
kpt_line=True,
|
||||
labels=True,
|
||||
boxes=True,
|
||||
masks=True,
|
||||
probs=True,
|
||||
show=False,
|
||||
save=False,
|
||||
filename=None,
|
||||
):
|
||||
"""
|
||||
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
||||
@ -186,6 +219,9 @@ class Results(SimpleClass):
|
||||
boxes (bool): Whether to plot the bounding boxes.
|
||||
masks (bool): Whether to plot the masks.
|
||||
probs (bool): Whether to plot classification probability
|
||||
show (bool): Whether to display the annotated image directly.
|
||||
save (bool): Whether to save the annotated image to `filename`.
|
||||
filename (str): Filename to save image to if save is True.
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray): A numpy array of the annotated image.
|
||||
@ -207,19 +243,9 @@ class Results(SimpleClass):
|
||||
if img is None and isinstance(self.orig_img, torch.Tensor):
|
||||
img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy()
|
||||
|
||||
# Deprecation warn TODO: remove in 8.2
|
||||
if 'show_conf' in kwargs:
|
||||
deprecation_warn('show_conf', 'conf')
|
||||
conf = kwargs['show_conf']
|
||||
assert isinstance(conf, bool), '`show_conf` should be of boolean type, i.e, show_conf=True/False'
|
||||
|
||||
if 'line_thickness' in kwargs:
|
||||
deprecation_warn('line_thickness', 'line_width')
|
||||
line_width = kwargs['line_thickness']
|
||||
assert isinstance(line_width, int), '`line_width` should be of int type, i.e, line_width=3'
|
||||
|
||||
names = self.names
|
||||
pred_boxes, show_boxes = self.boxes, boxes
|
||||
is_obb = self.obb is not None
|
||||
pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes
|
||||
pred_masks, show_masks = self.masks, masks
|
||||
pred_probs, show_probs = self.probs, probs
|
||||
annotator = Annotator(
|
||||
@ -228,28 +254,35 @@ class Results(SimpleClass):
|
||||
font_size,
|
||||
font,
|
||||
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
|
||||
example=names)
|
||||
example=names,
|
||||
)
|
||||
|
||||
# Plot Segment results
|
||||
if pred_masks and show_masks:
|
||||
if im_gpu is None:
|
||||
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
||||
im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
|
||||
2, 0, 1).flip(0).contiguous() / 255
|
||||
im_gpu = (
|
||||
torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device)
|
||||
.permute(2, 0, 1)
|
||||
.flip(0)
|
||||
.contiguous()
|
||||
/ 255
|
||||
)
|
||||
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
|
||||
|
||||
# Plot Detect results
|
||||
if pred_boxes and show_boxes:
|
||||
if pred_boxes is not None and show_boxes:
|
||||
for d in reversed(pred_boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
||||
name = ('' if id is None else f'id:{id} ') + names[c]
|
||||
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
name = ("" if id is None else f"id:{id} ") + names[c]
|
||||
label = (f"{name} {conf:.2f}" if conf else name) if labels else None
|
||||
box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze()
|
||||
annotator.box_label(box, label, color=colors(c, True), rotated=is_obb)
|
||||
|
||||
# Plot Classify results
|
||||
if pred_probs is not None and show_probs:
|
||||
text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
|
||||
text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5)
|
||||
x = round(self.orig_shape[0] * 0.03)
|
||||
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
|
||||
@ -258,17 +291,34 @@ class Results(SimpleClass):
|
||||
for k in reversed(self.keypoints.data):
|
||||
annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
|
||||
|
||||
# Show results
|
||||
if show:
|
||||
annotator.show(self.path)
|
||||
|
||||
# Save results
|
||||
if save:
|
||||
annotator.save(filename)
|
||||
|
||||
return annotator.result()
|
||||
|
||||
def show(self, *args, **kwargs):
|
||||
"""Show annotated results image."""
|
||||
self.plot(show=True, *args, **kwargs)
|
||||
|
||||
def save(self, filename=None, *args, **kwargs):
|
||||
"""Save annotated results image."""
|
||||
if not filename:
|
||||
filename = f"results_{Path(self.path).name}"
|
||||
self.plot(save=True, filename=filename, *args, **kwargs)
|
||||
return filename
|
||||
|
||||
def verbose(self):
|
||||
"""
|
||||
Return log string for each task.
|
||||
"""
|
||||
log_string = ''
|
||||
"""Return log string for each task."""
|
||||
log_string = ""
|
||||
probs = self.probs
|
||||
boxes = self.boxes
|
||||
if len(self) == 0:
|
||||
return log_string if probs is not None else f'{log_string}(no detections), '
|
||||
return log_string if probs is not None else f"{log_string}(no detections), "
|
||||
if probs is not None:
|
||||
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
|
||||
if boxes:
|
||||
@ -285,34 +335,35 @@ class Results(SimpleClass):
|
||||
txt_file (str): txt file path.
|
||||
save_conf (bool): save confidence score or not.
|
||||
"""
|
||||
boxes = self.boxes
|
||||
is_obb = self.obb is not None
|
||||
boxes = self.obb if is_obb else self.boxes
|
||||
masks = self.masks
|
||||
probs = self.probs
|
||||
kpts = self.keypoints
|
||||
texts = []
|
||||
if probs is not None:
|
||||
# Classify
|
||||
[texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
|
||||
[texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5]
|
||||
elif boxes:
|
||||
# Detect/segment/pose
|
||||
for j, d in enumerate(boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
|
||||
line = (c, *d.xywhn.view(-1))
|
||||
line = (c, *(d.xyxyxyxyn.view(-1) if is_obb else d.xywhn.view(-1)))
|
||||
if masks:
|
||||
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
|
||||
line = (c, *seg)
|
||||
if kpts is not None:
|
||||
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
|
||||
line += (*kpt.reshape(-1).tolist(), )
|
||||
line += (conf, ) * save_conf + (() if id is None else (id, ))
|
||||
texts.append(('%g ' * len(line)).rstrip() % line)
|
||||
line += (*kpt.reshape(-1).tolist(),)
|
||||
line += (conf,) * save_conf + (() if id is None else (id,))
|
||||
texts.append(("%g " * len(line)).rstrip() % line)
|
||||
|
||||
if texts:
|
||||
Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory
|
||||
with open(txt_file, 'a') as f:
|
||||
f.writelines(text + '\n' for text in texts)
|
||||
with open(txt_file, "a") as f:
|
||||
f.writelines(text + "\n" for text in texts)
|
||||
|
||||
def save_crop(self, save_dir, file_name=Path('im.jpg')):
|
||||
def save_crop(self, save_dir, file_name=Path("im.jpg")):
|
||||
"""
|
||||
Save cropped predictions to `save_dir/cls/file_name.jpg`.
|
||||
|
||||
@ -321,79 +372,105 @@ class Results(SimpleClass):
|
||||
file_name (str | pathlib.Path): File name.
|
||||
"""
|
||||
if self.probs is not None:
|
||||
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
|
||||
LOGGER.warning("WARNING ⚠️ Classify task do not support `save_crop`.")
|
||||
return
|
||||
if self.obb is not None:
|
||||
LOGGER.warning("WARNING ⚠️ OBB task do not support `save_crop`.")
|
||||
return
|
||||
for d in self.boxes:
|
||||
save_one_box(d.xyxy,
|
||||
self.orig_img.copy(),
|
||||
file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name).stem}.jpg',
|
||||
BGR=True)
|
||||
save_one_box(
|
||||
d.xyxy,
|
||||
self.orig_img.copy(),
|
||||
file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg",
|
||||
BGR=True,
|
||||
)
|
||||
|
||||
def tojson(self, normalize=False):
|
||||
"""Convert the object to JSON format."""
|
||||
def summary(self, normalize=False, decimals=5):
|
||||
"""Convert the results to a summarized format."""
|
||||
if self.probs is not None:
|
||||
LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
|
||||
LOGGER.warning("Warning: Classify results do not support the `summary()` method yet.")
|
||||
return
|
||||
|
||||
import json
|
||||
|
||||
# Create list of detection dictionaries
|
||||
results = []
|
||||
data = self.boxes.data.cpu().tolist()
|
||||
h, w = self.orig_shape if normalize else (1, 1)
|
||||
for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
|
||||
box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h}
|
||||
conf = row[-2]
|
||||
box = {
|
||||
"x1": round(row[0] / w, decimals),
|
||||
"y1": round(row[1] / h, decimals),
|
||||
"x2": round(row[2] / w, decimals),
|
||||
"y2": round(row[3] / h, decimals),
|
||||
}
|
||||
conf = round(row[-2], decimals)
|
||||
class_id = int(row[-1])
|
||||
name = self.names[class_id]
|
||||
result = {'name': name, 'class': class_id, 'confidence': conf, 'box': box}
|
||||
result = {"name": self.names[class_id], "class": class_id, "confidence": conf, "box": box}
|
||||
if self.boxes.is_track:
|
||||
result['track_id'] = int(row[-3]) # track ID
|
||||
result["track_id"] = int(row[-3]) # track ID
|
||||
if self.masks:
|
||||
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
|
||||
result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
|
||||
result["segments"] = {
|
||||
"x": (self.masks.xy[i][:, 0] / w).round(decimals).tolist(),
|
||||
"y": (self.masks.xy[i][:, 1] / h).round(decimals).tolist(),
|
||||
}
|
||||
if self.keypoints is not None:
|
||||
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
|
||||
result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
|
||||
result["keypoints"] = {
|
||||
"x": (x / w).numpy().round(decimals).tolist(), # decimals named argument required
|
||||
"y": (y / h).numpy().round(decimals).tolist(),
|
||||
"visible": visible.numpy().round(decimals).tolist(),
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
# Convert detections to JSON
|
||||
return json.dumps(results, indent=2)
|
||||
return results
|
||||
|
||||
def tojson(self, normalize=False, decimals=5):
|
||||
"""Convert the results to JSON format."""
|
||||
import json
|
||||
|
||||
return json.dumps(self.summary(normalize=normalize, decimals=decimals), indent=2)
|
||||
|
||||
|
||||
class Boxes(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection boxes.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 6) or (num_boxes, 7). The last two columns contain confidence and class values.
|
||||
If present, the third last column contains track IDs.
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
Manages detection boxes, providing easy access and manipulation of box coordinates, confidence scores, class
|
||||
identifiers, and optional tracking IDs. Supports multiple formats for box coordinates, including both absolute and
|
||||
normalized forms.
|
||||
|
||||
Attributes:
|
||||
xyxy (torch.Tensor | numpy.ndarray): The boxes in xyxy format.
|
||||
conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
|
||||
cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
|
||||
id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
|
||||
xywh (torch.Tensor | numpy.ndarray): The boxes in xywh format.
|
||||
xyxyn (torch.Tensor | numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
||||
xywhn (torch.Tensor | numpy.ndarray): The boxes in xywh format normalized by original image size.
|
||||
data (torch.Tensor): The raw bboxes tensor (alias for `boxes`).
|
||||
data (torch.Tensor): The raw tensor containing detection boxes and their associated data.
|
||||
orig_shape (tuple): The original image size as a tuple (height, width), used for normalization.
|
||||
is_track (bool): Indicates whether tracking IDs are included in the box data.
|
||||
|
||||
Properties:
|
||||
xyxy (torch.Tensor | numpy.ndarray): Boxes in [x1, y1, x2, y2] format.
|
||||
conf (torch.Tensor | numpy.ndarray): Confidence scores for each box.
|
||||
cls (torch.Tensor | numpy.ndarray): Class labels for each box.
|
||||
id (torch.Tensor | numpy.ndarray, optional): Tracking IDs for each box, if available.
|
||||
xywh (torch.Tensor | numpy.ndarray): Boxes in [x, y, width, height] format, calculated on demand.
|
||||
xyxyn (torch.Tensor | numpy.ndarray): Normalized [x1, y1, x2, y2] boxes, relative to `orig_shape`.
|
||||
xywhn (torch.Tensor | numpy.ndarray): Normalized [x, y, width, height] boxes, relative to `orig_shape`.
|
||||
|
||||
Methods:
|
||||
cpu(): Move the object to CPU memory.
|
||||
numpy(): Convert the object to a numpy array.
|
||||
cuda(): Move the object to CUDA memory.
|
||||
to(*args, **kwargs): Move the object to the specified device.
|
||||
cpu(): Moves the boxes to CPU memory.
|
||||
numpy(): Converts the boxes to a numpy array format.
|
||||
cuda(): Moves the boxes to CUDA (GPU) memory.
|
||||
to(device, dtype=None): Moves the boxes to the specified device.
|
||||
"""
|
||||
|
||||
def __init__(self, boxes, orig_shape) -> None:
|
||||
"""Initialize the Boxes class."""
|
||||
"""
|
||||
Initialize the Boxes class.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes, with
|
||||
shape (num_boxes, 6) or (num_boxes, 7). The last two columns contain confidence and class values.
|
||||
If present, the third last column contains track IDs.
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
"""
|
||||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
n = boxes.shape[-1]
|
||||
assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, track_id, conf, cls
|
||||
assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
|
||||
super().__init__(boxes, orig_shape)
|
||||
self.is_track = n == 7
|
||||
self.orig_shape = orig_shape
|
||||
@ -442,19 +519,12 @@ class Boxes(BaseTensor):
|
||||
xywh[..., [1, 3]] /= self.orig_shape[0]
|
||||
return xywh
|
||||
|
||||
@property
|
||||
def boxes(self):
|
||||
"""Return the raw bboxes tensor (deprecated)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.")
|
||||
return self.data
|
||||
|
||||
|
||||
class Masks(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection masks.
|
||||
|
||||
Attributes:
|
||||
segments (list): Deprecated property for segments (normalized).
|
||||
xy (list): A list of segments in pixel coordinates.
|
||||
xyn (list): A list of normalized segments.
|
||||
|
||||
@ -471,22 +541,14 @@ class Masks(BaseTensor):
|
||||
masks = masks[None, :]
|
||||
super().__init__(masks, orig_shape)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def segments(self):
|
||||
"""Return segments (normalized). Deprecated; use xyn property instead."""
|
||||
LOGGER.warning(
|
||||
"WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and 'Masks.xy' for segments (pixels) instead."
|
||||
)
|
||||
return self.xyn
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xyn(self):
|
||||
"""Return normalized segments."""
|
||||
return [
|
||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
|
||||
for x in ops.masks2segments(self.data)]
|
||||
for x in ops.masks2segments(self.data)
|
||||
]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
@ -494,13 +556,8 @@ class Masks(BaseTensor):
|
||||
"""Return segments in pixel coordinates."""
|
||||
return [
|
||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
|
||||
for x in ops.masks2segments(self.data)]
|
||||
|
||||
@property
|
||||
def masks(self):
|
||||
"""Return the raw masks tensor. Deprecated; use data attribute instead."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.")
|
||||
return self.data
|
||||
for x in ops.masks2segments(self.data)
|
||||
]
|
||||
|
||||
|
||||
class Keypoints(BaseTensor):
|
||||
@ -519,10 +576,14 @@ class Keypoints(BaseTensor):
|
||||
to(device, dtype): Returns a copy of the keypoints tensor with the specified device and dtype.
|
||||
"""
|
||||
|
||||
@smart_inference_mode() # avoid keypoints < conf in-place error
|
||||
def __init__(self, keypoints, orig_shape) -> None:
|
||||
"""Initializes the Keypoints object with detection keypoints and original image size."""
|
||||
if keypoints.ndim == 2:
|
||||
keypoints = keypoints[None, :]
|
||||
if keypoints.shape[2] == 3: # x, y, conf
|
||||
mask = keypoints[..., 2] < 0.5 # points with conf < 0.5 (not visible)
|
||||
keypoints[..., :2][mask] = 0
|
||||
super().__init__(keypoints, orig_shape)
|
||||
self.has_visible = self.data.shape[-1] == 3
|
||||
|
||||
@ -566,6 +627,7 @@ class Probs(BaseTensor):
|
||||
"""
|
||||
|
||||
def __init__(self, probs, orig_shape=None) -> None:
|
||||
"""Initialize the Probs class with classification probabilities and optional original shape of the image."""
|
||||
super().__init__(probs, orig_shape)
|
||||
|
||||
@property
|
||||
@ -591,3 +653,91 @@ class Probs(BaseTensor):
|
||||
def top5conf(self):
|
||||
"""Return the confidences of top 5."""
|
||||
return self.data[self.top5]
|
||||
|
||||
|
||||
class OBB(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating Oriented Bounding Boxes (OBB).
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values.
|
||||
If present, the third last column contains track IDs, and the fifth column from the left contains rotation.
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
xywhr (torch.Tensor | numpy.ndarray): The boxes in [x_center, y_center, width, height, rotation] format.
|
||||
conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
|
||||
cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
|
||||
id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
|
||||
xyxyxyxyn (torch.Tensor | numpy.ndarray): The rotated boxes in xyxyxyxy format normalized by orig image size.
|
||||
xyxyxyxy (torch.Tensor | numpy.ndarray): The rotated boxes in xyxyxyxy format.
|
||||
xyxy (torch.Tensor | numpy.ndarray): The horizontal boxes in xyxyxyxy format.
|
||||
data (torch.Tensor): The raw OBB tensor (alias for `boxes`).
|
||||
|
||||
Methods:
|
||||
cpu(): Move the object to CPU memory.
|
||||
numpy(): Convert the object to a numpy array.
|
||||
cuda(): Move the object to CUDA memory.
|
||||
to(*args, **kwargs): Move the object to the specified device.
|
||||
"""
|
||||
|
||||
def __init__(self, boxes, orig_shape) -> None:
|
||||
"""Initialize the Boxes class."""
|
||||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
n = boxes.shape[-1]
|
||||
assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
|
||||
super().__init__(boxes, orig_shape)
|
||||
self.is_track = n == 8
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@property
|
||||
def xywhr(self):
|
||||
"""Return the rotated boxes in xywhr format."""
|
||||
return self.data[:, :5]
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
"""Return the confidence values of the boxes."""
|
||||
return self.data[:, -2]
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
"""Return the class values of the boxes."""
|
||||
return self.data[:, -1]
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
"""Return the track IDs of the boxes (if available)."""
|
||||
return self.data[:, -3] if self.is_track else None
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxyxyxy(self):
|
||||
"""Return the boxes in xyxyxyxy format, (N, 4, 2)."""
|
||||
return ops.xywhr2xyxyxyxy(self.xywhr)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxyxyxyn(self):
|
||||
"""Return the boxes in xyxyxyxy format, (N, 4, 2)."""
|
||||
xyxyxyxyn = self.xyxyxyxy.clone() if isinstance(self.xyxyxyxy, torch.Tensor) else np.copy(self.xyxyxyxy)
|
||||
xyxyxyxyn[..., 0] /= self.orig_shape[1]
|
||||
xyxyxyxyn[..., 1] /= self.orig_shape[0]
|
||||
return xyxyxyxyn
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxy(self):
|
||||
"""
|
||||
Return the horizontal boxes in xyxy format, (N, 4).
|
||||
|
||||
Accepts both torch and numpy boxes.
|
||||
"""
|
||||
x1 = self.xyxyxyxy[..., 0].min(1).values
|
||||
x2 = self.xyxyxyxy[..., 0].max(1).values
|
||||
y1 = self.xyxyxyxy[..., 1].min(1).values
|
||||
y2 = self.xyxyxyxy[..., 1].max(1).values
|
||||
xyxy = [x1, y1, x2, y2]
|
||||
return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1)
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Train a model on a dataset
|
||||
Train a model on a dataset.
|
||||
|
||||
Usage:
|
||||
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
||||
@ -19,31 +19,45 @@ import numpy as np
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch import nn, optim
|
||||
from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis,
|
||||
yaml_save)
|
||||
from ultralytics.utils import (
|
||||
DEFAULT_CFG,
|
||||
LOGGER,
|
||||
RANK,
|
||||
TQDM,
|
||||
__version__,
|
||||
callbacks,
|
||||
clean_url,
|
||||
colorstr,
|
||||
emojis,
|
||||
yaml_save,
|
||||
)
|
||||
from ultralytics.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
|
||||
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
||||
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.utils.files import get_latest_run
|
||||
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
|
||||
strip_optimizer)
|
||||
from ultralytics.utils.torch_utils import (
|
||||
EarlyStopping,
|
||||
ModelEMA,
|
||||
de_parallel,
|
||||
init_seeds,
|
||||
one_cycle,
|
||||
select_device,
|
||||
strip_optimizer,
|
||||
)
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
"""
|
||||
BaseTrainer
|
||||
BaseTrainer.
|
||||
|
||||
A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the trainer.
|
||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||
validator (BaseValidator): Validator instance.
|
||||
model (nn.Module): Model instance.
|
||||
callbacks (defaultdict): Dictionary of callbacks.
|
||||
@ -62,6 +76,7 @@ class BaseTrainer:
|
||||
trainset (torch.utils.data.Dataset): Training dataset.
|
||||
testset (torch.utils.data.Dataset): Testing dataset.
|
||||
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
||||
resume (bool): Resume training from a checkpoint.
|
||||
lf (nn.Module): Loss function.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
||||
best_fitness (float): The best fitness value achieved.
|
||||
@ -84,19 +99,19 @@ class BaseTrainer:
|
||||
self.check_resume(overrides)
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.validator = None
|
||||
self.model = None
|
||||
self.metrics = None
|
||||
self.plots = {}
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
||||
# Dirs
|
||||
self.save_dir = get_save_dir(self.args)
|
||||
self.wdir = self.save_dir / 'weights' # weights dir
|
||||
self.args.name = self.save_dir.name # update name for loggers
|
||||
self.wdir = self.save_dir / "weights" # weights dir
|
||||
if RANK in (-1, 0):
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.args.save_dir = str(self.save_dir)
|
||||
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
||||
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
|
||||
self.save_period = self.args.save_period
|
||||
|
||||
self.batch_size = self.args.batch
|
||||
@ -106,18 +121,23 @@ class BaseTrainer:
|
||||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
if self.device.type in ('cpu', 'mps'):
|
||||
if self.device.type in ("cpu", "mps"):
|
||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
# Model and Dataset
|
||||
self.model = self.args.model
|
||||
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
try:
|
||||
if self.args.task == 'classify':
|
||||
if self.args.task == "classify":
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'):
|
||||
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||
"detect",
|
||||
"segment",
|
||||
"pose",
|
||||
"obb",
|
||||
):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
if 'yaml_file' in self.data:
|
||||
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
||||
if "yaml_file" in self.data:
|
||||
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||
|
||||
@ -133,8 +153,8 @@ class BaseTrainer:
|
||||
self.fitness = None
|
||||
self.loss = None
|
||||
self.tloss = None
|
||||
self.loss_names = ['Loss']
|
||||
self.csv = self.save_dir / 'results.csv'
|
||||
self.loss_names = ["Loss"]
|
||||
self.csv = self.save_dir / "results.csv"
|
||||
self.plot_idx = [0, 1, 2]
|
||||
|
||||
# Callbacks
|
||||
@ -143,15 +163,11 @@ class BaseTrainer:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback.
|
||||
"""
|
||||
"""Appends the given callback."""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
Overrides the existing callbacks with the given callback.
|
||||
"""
|
||||
"""Overrides the existing callbacks with the given callback."""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
@ -162,7 +178,7 @@ class BaseTrainer:
|
||||
def train(self):
|
||||
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
||||
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
||||
world_size = len(self.args.device.split(','))
|
||||
world_size = len(self.args.device.split(","))
|
||||
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
||||
world_size = len(self.args.device)
|
||||
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
||||
@ -171,14 +187,16 @@ class BaseTrainer:
|
||||
world_size = 0
|
||||
|
||||
# Run subprocess if DDP training, else train normally
|
||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
||||
# Argument checks
|
||||
if self.args.rect:
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
||||
self.args.rect = False
|
||||
if self.args.batch == -1:
|
||||
LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
||||
"default 'batch=16'")
|
||||
LOGGER.warning(
|
||||
"WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
||||
"default 'batch=16'"
|
||||
)
|
||||
self.args.batch = 16
|
||||
|
||||
# Command
|
||||
@ -194,42 +212,56 @@ class BaseTrainer:
|
||||
else:
|
||||
self._do_train(world_size)
|
||||
|
||||
def _setup_scheduler(self):
|
||||
"""Initialize training learning rate scheduler."""
|
||||
if self.args.cos_lr:
|
||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||
else:
|
||||
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
|
||||
def _setup_ddp(self, world_size):
|
||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
self.device = torch.device('cuda', RANK)
|
||||
self.device = torch.device("cuda", RANK)
|
||||
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
|
||||
os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
||||
dist.init_process_group(
|
||||
'nccl' if dist.is_nccl_available() else 'gloo',
|
||||
backend="nccl" if dist.is_nccl_available() else "gloo",
|
||||
timeout=timedelta(seconds=10800), # 3 hours
|
||||
rank=RANK,
|
||||
world_size=world_size)
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
def _setup_train(self, world_size):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
"""Builds dataloaders and optimizer on correct rank process."""
|
||||
|
||||
# Model
|
||||
self.run_callbacks('on_pretrain_routine_start')
|
||||
self.run_callbacks("on_pretrain_routine_start")
|
||||
ckpt = self.setup_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
|
||||
# Freeze layers
|
||||
freeze_list = self.args.freeze if isinstance(
|
||||
self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
|
||||
always_freeze_names = ['.dfl'] # always freeze these layers
|
||||
freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
|
||||
freeze_list = (
|
||||
self.args.freeze
|
||||
if isinstance(self.args.freeze, list)
|
||||
else range(self.args.freeze)
|
||||
if isinstance(self.args.freeze, int)
|
||||
else []
|
||||
)
|
||||
always_freeze_names = [".dfl"] # always freeze these layers
|
||||
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
|
||||
for k, v in self.model.named_parameters():
|
||||
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
||||
if any(x in k for x in freeze_layer_names):
|
||||
LOGGER.info(f"Freezing layer '{k}'")
|
||||
v.requires_grad = False
|
||||
elif not v.requires_grad:
|
||||
LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
|
||||
'See ultralytics.engine.trainer for customization of frozen layers.')
|
||||
elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients
|
||||
LOGGER.info(
|
||||
f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
|
||||
"See ultralytics.engine.trainer for customization of frozen layers."
|
||||
)
|
||||
v.requires_grad = True
|
||||
|
||||
# Check AMP
|
||||
@ -241,13 +273,14 @@ class BaseTrainer:
|
||||
if RANK > -1 and world_size > 1: # DDP
|
||||
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
||||
self.amp = bool(self.amp) # as boolean
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[RANK])
|
||||
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
|
||||
|
||||
# Check imgsz
|
||||
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
||||
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
|
||||
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
||||
self.stride = gs # for multiscale training
|
||||
|
||||
# Batch size
|
||||
if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
|
||||
@ -255,11 +288,14 @@ class BaseTrainer:
|
||||
|
||||
# Dataloaders
|
||||
batch_size = self.batch_size // max(world_size, 1)
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
||||
if RANK in (-1, 0):
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
|
||||
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
||||
self.test_loader = self.get_dataloader(
|
||||
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
||||
)
|
||||
self.validator = self.get_validator()
|
||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
|
||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
||||
self.ema = ModelEMA(self.model)
|
||||
if self.args.plots:
|
||||
@ -269,22 +305,20 @@ class BaseTrainer:
|
||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
||||
self.optimizer = self.build_optimizer(model=self.model,
|
||||
name=self.args.optimizer,
|
||||
lr=self.args.lr0,
|
||||
momentum=self.args.momentum,
|
||||
decay=weight_decay,
|
||||
iterations=iterations)
|
||||
self.optimizer = self.build_optimizer(
|
||||
model=self.model,
|
||||
name=self.args.optimizer,
|
||||
lr=self.args.lr0,
|
||||
momentum=self.args.momentum,
|
||||
decay=weight_decay,
|
||||
iterations=iterations,
|
||||
)
|
||||
# Scheduler
|
||||
if self.args.cos_lr:
|
||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||
else:
|
||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
self._setup_scheduler()
|
||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||
self.resume_training(ckpt)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.run_callbacks('on_pretrain_routine_end')
|
||||
self.run_callbacks("on_pretrain_routine_end")
|
||||
|
||||
def _do_train(self, world_size=1):
|
||||
"""Train completed, evaluate and plot if specified by arguments."""
|
||||
@ -292,35 +326,33 @@ class BaseTrainer:
|
||||
self._setup_ddp(world_size)
|
||||
self._setup_train(world_size)
|
||||
|
||||
self.epoch_time = None
|
||||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
|
||||
last_opt_step = -1
|
||||
self.run_callbacks('on_train_start')
|
||||
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f'Starting training for {self.epochs} epochs...')
|
||||
self.epoch_time = None
|
||||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
self.run_callbacks("on_train_start")
|
||||
LOGGER.info(
|
||||
f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
||||
)
|
||||
if self.args.close_mosaic:
|
||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||
epoch = self.epochs # predefine for resume fully trained model edge cases
|
||||
for epoch in range(self.start_epoch, self.epochs):
|
||||
epoch = self.start_epoch
|
||||
while True:
|
||||
self.epoch = epoch
|
||||
self.run_callbacks('on_train_epoch_start')
|
||||
self.run_callbacks("on_train_epoch_start")
|
||||
self.model.train()
|
||||
if RANK != -1:
|
||||
self.train_loader.sampler.set_epoch(epoch)
|
||||
pbar = enumerate(self.train_loader)
|
||||
# Update dataloader attributes (optional)
|
||||
if epoch == (self.epochs - self.args.close_mosaic):
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
self._close_dataloader_mosaic()
|
||||
self.train_loader.reset()
|
||||
|
||||
if RANK in (-1, 0):
|
||||
@ -329,18 +361,19 @@ class BaseTrainer:
|
||||
self.tloss = None
|
||||
self.optimizer.zero_grad()
|
||||
for i, batch in pbar:
|
||||
self.run_callbacks('on_train_batch_start')
|
||||
self.run_callbacks("on_train_batch_start")
|
||||
# Warmup
|
||||
ni = i + nb * epoch
|
||||
if ni <= nw:
|
||||
xi = [0, nw] # x interp
|
||||
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
|
||||
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
|
||||
for j, x in enumerate(self.optimizer.param_groups):
|
||||
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||
x['lr'] = np.interp(
|
||||
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
|
||||
if 'momentum' in x:
|
||||
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||
x["lr"] = np.interp(
|
||||
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
|
||||
)
|
||||
if "momentum" in x:
|
||||
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||
|
||||
# Forward
|
||||
with torch.cuda.amp.autocast(self.amp):
|
||||
@ -348,8 +381,9 @@ class BaseTrainer:
|
||||
self.loss, self.loss_items = self.model(batch)
|
||||
if RANK != -1:
|
||||
self.loss *= world_size
|
||||
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
||||
else self.loss_items
|
||||
self.tloss = (
|
||||
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
||||
)
|
||||
|
||||
# Backward
|
||||
self.scaler.scale(self.loss).backward()
|
||||
@ -359,115 +393,137 @@ class BaseTrainer:
|
||||
self.optimizer_step()
|
||||
last_opt_step = ni
|
||||
|
||||
# Timed stopping
|
||||
if self.args.time:
|
||||
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
|
||||
if RANK != -1: # if DDP training
|
||||
broadcast_list = [self.stop if RANK == 0 else None]
|
||||
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||
self.stop = broadcast_list[0]
|
||||
if self.stop: # training time exceeded
|
||||
break
|
||||
|
||||
# Log
|
||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
||||
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||
if RANK in (-1, 0):
|
||||
pbar.set_description(
|
||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
|
||||
self.run_callbacks('on_batch_end')
|
||||
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
||||
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
||||
)
|
||||
self.run_callbacks("on_batch_end")
|
||||
if self.args.plots and ni in self.plot_idx:
|
||||
self.plot_training_samples(batch, ni)
|
||||
|
||||
self.run_callbacks('on_train_batch_end')
|
||||
|
||||
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||
self.scheduler.step()
|
||||
self.run_callbacks('on_train_epoch_end')
|
||||
self.run_callbacks("on_train_batch_end")
|
||||
|
||||
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
self.run_callbacks("on_train_epoch_end")
|
||||
if RANK in (-1, 0):
|
||||
final_epoch = epoch + 1 == self.epochs
|
||||
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
||||
|
||||
# Validation
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
|
||||
|
||||
if self.args.val or final_epoch:
|
||||
if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \
|
||||
or final_epoch or self.stopper.possible_stop or self.stop:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
||||
self.stop = self.stopper(epoch + 1, self.fitness)
|
||||
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
|
||||
if self.args.time:
|
||||
self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
|
||||
|
||||
# Save model
|
||||
if self.args.save or (epoch + 1 == self.epochs):
|
||||
if self.args.save or final_epoch:
|
||||
self.save_model()
|
||||
self.run_callbacks('on_model_save')
|
||||
self.run_callbacks("on_model_save")
|
||||
|
||||
tnow = time.time()
|
||||
self.epoch_time = tnow - self.epoch_time_start
|
||||
self.epoch_time_start = tnow
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
|
||||
# Scheduler
|
||||
t = time.time()
|
||||
self.epoch_time = t - self.epoch_time_start
|
||||
self.epoch_time_start = t
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||
if self.args.time:
|
||||
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
||||
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
||||
self._setup_scheduler()
|
||||
self.scheduler.last_epoch = self.epoch # do not move
|
||||
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
||||
self.scheduler.step()
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
||||
|
||||
# Early Stopping
|
||||
if RANK != -1: # if DDP training
|
||||
broadcast_list = [self.stop if RANK == 0 else None]
|
||||
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||
if RANK != 0:
|
||||
self.stop = broadcast_list[0]
|
||||
self.stop = broadcast_list[0]
|
||||
if self.stop:
|
||||
break # must break all DDP ranks
|
||||
epoch += 1
|
||||
|
||||
if RANK in (-1, 0):
|
||||
# Do final val with best.pt
|
||||
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
LOGGER.info(
|
||||
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
||||
f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
|
||||
)
|
||||
self.final_eval()
|
||||
if self.args.plots:
|
||||
self.plot_metrics()
|
||||
self.run_callbacks('on_train_end')
|
||||
self.run_callbacks("on_train_end")
|
||||
torch.cuda.empty_cache()
|
||||
self.run_callbacks('teardown')
|
||||
self.run_callbacks("teardown")
|
||||
|
||||
def save_model(self):
|
||||
"""Save model checkpoints based on various conditions."""
|
||||
"""Save model training checkpoints with additional metadata."""
|
||||
import pandas as pd # scope for faster startup
|
||||
|
||||
metrics = {**self.metrics, **{"fitness": self.fitness}}
|
||||
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
|
||||
ckpt = {
|
||||
'epoch': self.epoch,
|
||||
'best_fitness': self.best_fitness,
|
||||
'model': deepcopy(de_parallel(self.model)).half(),
|
||||
'ema': deepcopy(self.ema.ema).half(),
|
||||
'updates': self.ema.updates,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'train_args': vars(self.args), # save as dict
|
||||
'date': datetime.now().isoformat(),
|
||||
'version': __version__}
|
||||
"epoch": self.epoch,
|
||||
"best_fitness": self.best_fitness,
|
||||
"model": deepcopy(de_parallel(self.model)).half(),
|
||||
"ema": deepcopy(self.ema.ema).half(),
|
||||
"updates": self.ema.updates,
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
"train_args": vars(self.args), # save as dict
|
||||
"train_metrics": metrics,
|
||||
"train_results": results,
|
||||
"date": datetime.now().isoformat(),
|
||||
"version": __version__,
|
||||
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
||||
"docs": "https://docs.ultralytics.com",
|
||||
}
|
||||
|
||||
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
||||
try:
|
||||
import dill as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, self.last, pickle_module=pickle)
|
||||
# Save last and best
|
||||
torch.save(ckpt, self.last)
|
||||
if self.best_fitness == self.fitness:
|
||||
torch.save(ckpt, self.best, pickle_module=pickle)
|
||||
if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
|
||||
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
|
||||
del ckpt
|
||||
torch.save(ckpt, self.best)
|
||||
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
||||
torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(data):
|
||||
"""
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
Get train, val path from data dict if it exists.
|
||||
|
||||
Returns None if data format is not recognized.
|
||||
"""
|
||||
return data['train'], data.get('val') or data.get('test')
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
load/create/download model for any task.
|
||||
"""
|
||||
"""Load/create/download model for any task."""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
model, weights = self.model, None
|
||||
ckpt = None
|
||||
if str(model).endswith('.pt'):
|
||||
if str(model).endswith(".pt"):
|
||||
weights, ckpt = attempt_load_one_weight(model)
|
||||
cfg = ckpt['model'].yaml
|
||||
cfg = ckpt["model"].yaml
|
||||
else:
|
||||
cfg = model
|
||||
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
||||
@ -484,17 +540,17 @@ class BaseTrainer:
|
||||
self.ema.update(self.model)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
"""
|
||||
"""Allows custom preprocessing model inputs and ground truths depending on task type."""
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
Runs validation on test set using self.validator.
|
||||
|
||||
The returned dict is expected to contain "fitness" key.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
if not self.best_fitness or self.best_fitness < fitness:
|
||||
self.best_fitness = fitness
|
||||
return metrics, fitness
|
||||
@ -505,30 +561,28 @@ class BaseTrainer:
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns a NotImplementedError when the get_validator function is called."""
|
||||
raise NotImplementedError('get_validator function not implemented in trainer')
|
||||
raise NotImplementedError("get_validator function not implemented in trainer")
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""
|
||||
Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
raise NotImplementedError('get_dataloader function not implemented in trainer')
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||
"""Returns dataloader derived from torch.data.Dataloader."""
|
||||
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
"""Build dataset"""
|
||||
raise NotImplementedError('build_dataset function not implemented in trainer')
|
||||
def build_dataset(self, img_path, mode="train", batch=None):
|
||||
"""Build dataset."""
|
||||
raise NotImplementedError("build_dataset function not implemented in trainer")
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
Returns a loss dict with labelled training loss items tensor.
|
||||
|
||||
Note:
|
||||
This is not needed for classification but necessary for segmentation & detection
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
return {'loss': loss_items} if loss_items is not None else ['loss']
|
||||
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""
|
||||
To set or update model parameters before training.
|
||||
"""
|
||||
self.model.names = self.data['names']
|
||||
"""To set or update model parameters before training."""
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
"""Builds target tensors for training YOLO model."""
|
||||
@ -536,11 +590,11 @@ class BaseTrainer:
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a string describing training progress."""
|
||||
return ''
|
||||
return ""
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples during YOLOv5 training."""
|
||||
"""Plots training samples during YOLO training."""
|
||||
pass
|
||||
|
||||
def plot_training_labels(self):
|
||||
@ -551,9 +605,9 @@ class BaseTrainer:
|
||||
"""Saves training metrics to a CSV file."""
|
||||
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||
n = len(metrics) + 1 # number of cols
|
||||
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
||||
with open(self.csv, 'a') as f:
|
||||
f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n')
|
||||
s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
|
||||
with open(self.csv, "a") as f:
|
||||
f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plot and display metrics visually."""
|
||||
@ -562,7 +616,7 @@ class BaseTrainer:
|
||||
def on_plot(self, name, data=None):
|
||||
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||
path = Path(name)
|
||||
self.plots[path] = {'data': data, 'timestamp': time.time()}
|
||||
self.plots[path] = {"data": data, "timestamp": time.time()}
|
||||
|
||||
def final_eval(self):
|
||||
"""Performs final evaluation and validation for object detection YOLO model."""
|
||||
@ -570,11 +624,11 @@ class BaseTrainer:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
LOGGER.info(f'\nValidating {f}...')
|
||||
LOGGER.info(f"\nValidating {f}...")
|
||||
self.validator.args.plots = self.args.plots
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop('fitness', None)
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
self.metrics.pop("fitness", None)
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
|
||||
def check_resume(self, overrides):
|
||||
"""Check if resume checkpoint exists and update arguments accordingly."""
|
||||
@ -586,56 +640,62 @@ class BaseTrainer:
|
||||
|
||||
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
||||
ckpt_args = attempt_load_weights(last).args
|
||||
if not Path(ckpt_args['data']).exists():
|
||||
ckpt_args['data'] = self.args.data
|
||||
if not Path(ckpt_args["data"]).exists():
|
||||
ckpt_args["data"] = self.args.data
|
||||
|
||||
resume = True
|
||||
self.args = get_cfg(ckpt_args)
|
||||
self.args.model = str(last) # reinstate model
|
||||
for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
|
||||
self.args.model = self.args.resume = str(last) # reinstate model
|
||||
for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume
|
||||
if k in overrides:
|
||||
setattr(self.args, k, overrides[k])
|
||||
|
||||
except Exception as e:
|
||||
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
|
||||
"i.e. 'yolo train resume model=path/to/last.pt'") from e
|
||||
raise FileNotFoundError(
|
||||
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
||||
"i.e. 'yolo train resume model=path/to/last.pt'"
|
||||
) from e
|
||||
self.resume = resume
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
"""Resume YOLO training from given epoch and best fitness."""
|
||||
if ckpt is None:
|
||||
if ckpt is None or not self.resume:
|
||||
return
|
||||
best_fitness = 0.0
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
if ckpt['optimizer'] is not None:
|
||||
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
||||
best_fitness = ckpt['best_fitness']
|
||||
if self.ema and ckpt.get('ema'):
|
||||
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
||||
self.ema.updates = ckpt['updates']
|
||||
if self.resume:
|
||||
assert start_epoch > 0, \
|
||||
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
||||
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
||||
LOGGER.info(
|
||||
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
if ckpt["optimizer"] is not None:
|
||||
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
||||
best_fitness = ckpt["best_fitness"]
|
||||
if self.ema and ckpt.get("ema"):
|
||||
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
||||
self.ema.updates = ckpt["updates"]
|
||||
assert start_epoch > 0, (
|
||||
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
||||
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
||||
)
|
||||
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
|
||||
if self.epochs < start_epoch:
|
||||
LOGGER.info(
|
||||
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
|
||||
self.epochs += ckpt['epoch'] # finetune additional epochs
|
||||
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
||||
)
|
||||
self.epochs += ckpt["epoch"] # finetune additional epochs
|
||||
self.best_fitness = best_fitness
|
||||
self.start_epoch = start_epoch
|
||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
self._close_dataloader_mosaic()
|
||||
|
||||
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||
def _close_dataloader_mosaic(self):
|
||||
"""Update dataloaders to stop using mosaic augmentation."""
|
||||
if hasattr(self.train_loader.dataset, "mosaic"):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, "close_mosaic"):
|
||||
LOGGER.info("Closing dataloader mosaic")
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
|
||||
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||
"""
|
||||
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
|
||||
momentum, weight decay, and number of iterations.
|
||||
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
|
||||
weight decay, and number of iterations.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model for which to build an optimizer.
|
||||
@ -652,38 +712,45 @@ class BaseTrainer:
|
||||
"""
|
||||
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
if name == 'auto':
|
||||
nc = getattr(model, 'nc', 10) # number of classes
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||
if name == "auto":
|
||||
LOGGER.info(
|
||||
f"{colorstr('optimizer:')} 'optimizer=auto' found, "
|
||||
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
|
||||
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
|
||||
)
|
||||
nc = getattr(model, "nc", 10) # number of classes
|
||||
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
||||
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
|
||||
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
|
||||
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
fullname = f'{module_name}.{param_name}' if module_name else param_name
|
||||
if 'bias' in fullname: # bias (no decay)
|
||||
fullname = f"{module_name}.{param_name}" if module_name else param_name
|
||||
if "bias" in fullname: # bias (no decay)
|
||||
g[2].append(param)
|
||||
elif isinstance(module, bn): # weight (no decay)
|
||||
g[1].append(param)
|
||||
else: # weight (with decay)
|
||||
g[0].append(param)
|
||||
|
||||
if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
|
||||
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
|
||||
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||
elif name == 'RMSProp':
|
||||
elif name == "RMSProp":
|
||||
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||
elif name == 'SGD':
|
||||
elif name == "SGD":
|
||||
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Optimizer '{name}' not found in list of available optimizers "
|
||||
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
|
||||
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
|
||||
f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
|
||||
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
|
||||
)
|
||||
|
||||
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
||||
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
||||
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
|
||||
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
|
||||
LOGGER.info(
|
||||
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
|
||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
||||
)
|
||||
return optimizer
|
||||
|
@ -13,48 +13,59 @@ Example:
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO('yolov8n.pt')
|
||||
model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10)
|
||||
model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
|
||||
```
|
||||
"""
|
||||
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, yaml_print, yaml_save
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save
|
||||
from ultralytics.utils.plotting import plot_tune_results
|
||||
|
||||
|
||||
class Tuner:
|
||||
"""
|
||||
Class responsible for hyperparameter tuning of YOLO models.
|
||||
Class responsible for hyperparameter tuning of YOLO models.
|
||||
|
||||
The class evolves YOLO model hyperparameters over a given number of iterations
|
||||
by mutating them according to the search space and retraining the model to evaluate their performance.
|
||||
The class evolves YOLO model hyperparameters over a given number of iterations
|
||||
by mutating them according to the search space and retraining the model to evaluate their performance.
|
||||
|
||||
Attributes:
|
||||
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
|
||||
tune_dir (Path): Directory where evolution logs and results will be saved.
|
||||
evolve_csv (Path): Path to the CSV file where evolution logs are saved.
|
||||
Attributes:
|
||||
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
|
||||
tune_dir (Path): Directory where evolution logs and results will be saved.
|
||||
tune_csv (Path): Path to the CSV file where evolution logs are saved.
|
||||
|
||||
Methods:
|
||||
_mutate(hyp: dict) -> dict:
|
||||
Mutates the given hyperparameters within the bounds specified in `self.space`.
|
||||
Methods:
|
||||
_mutate(hyp: dict) -> dict:
|
||||
Mutates the given hyperparameters within the bounds specified in `self.space`.
|
||||
|
||||
__call__():
|
||||
Executes the hyperparameter evolution across multiple iterations.
|
||||
__call__():
|
||||
Executes the hyperparameter evolution across multiple iterations.
|
||||
|
||||
Example:
|
||||
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
Example:
|
||||
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO('yolov8n.pt')
|
||||
model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10, val=False, cache=True)
|
||||
```
|
||||
"""
|
||||
model = YOLO('yolov8n.pt')
|
||||
model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
|
||||
```
|
||||
|
||||
Tune with custom search space.
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO('yolov8n.pt')
|
||||
model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, args=DEFAULT_CFG, _callbacks=None):
|
||||
"""
|
||||
@ -63,37 +74,44 @@ class Tuner:
|
||||
Args:
|
||||
args (dict, optional): Configuration for hyperparameter evolution.
|
||||
"""
|
||||
self.args = get_cfg(overrides=args)
|
||||
self.space = { # key: (min, max, gain(optionaL))
|
||||
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
|
||||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
||||
'lr0': (1e-5, 1e-1),
|
||||
'lrf': (0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
'momentum': (0.6, 0.98, 0.3), # SGD momentum/Adam beta1
|
||||
'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4
|
||||
'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok)
|
||||
'warmup_momentum': (0.0, 0.95), # warmup initial momentum
|
||||
'box': (0.02, 0.2), # box loss gain
|
||||
'cls': (0.2, 4.0), # cls loss gain (scale with pixels)
|
||||
'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||
'degrees': (0.0, 45.0), # image rotation (+/- deg)
|
||||
'translate': (0.0, 0.9), # image translation (+/- fraction)
|
||||
'scale': (0.0, 0.9), # image scale (+/- gain)
|
||||
'shear': (0.0, 10.0), # image shear (+/- deg)
|
||||
'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
'flipud': (0.0, 1.0), # image flip up-down (probability)
|
||||
'fliplr': (0.0, 1.0), # image flip left-right (probability)
|
||||
'mosaic': (0.0, 1.0), # image mixup (probability)
|
||||
'mixup': (0.0, 1.0), # image mixup (probability)
|
||||
'copy_paste': (0.0, 1.0)} # segment copy-paste (probability)
|
||||
self.tune_dir = get_save_dir(self.args, name='_tune')
|
||||
self.evolve_csv = self.tune_dir / 'evolve.csv'
|
||||
"lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
||||
"lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
"momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
|
||||
"weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
|
||||
"warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
|
||||
"warmup_momentum": (0.0, 0.95), # warmup initial momentum
|
||||
"box": (1.0, 20.0), # box loss gain
|
||||
"cls": (0.2, 4.0), # cls loss gain (scale with pixels)
|
||||
"dfl": (0.4, 6.0), # dfl loss gain
|
||||
"hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
"hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
"hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||
"degrees": (0.0, 45.0), # image rotation (+/- deg)
|
||||
"translate": (0.0, 0.9), # image translation (+/- fraction)
|
||||
"scale": (0.0, 0.95), # image scale (+/- gain)
|
||||
"shear": (0.0, 10.0), # image shear (+/- deg)
|
||||
"perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
"flipud": (0.0, 1.0), # image flip up-down (probability)
|
||||
"fliplr": (0.0, 1.0), # image flip left-right (probability)
|
||||
"bgr": (0.0, 1.0), # image channel bgr (probability)
|
||||
"mosaic": (0.0, 1.0), # image mixup (probability)
|
||||
"mixup": (0.0, 1.0), # image mixup (probability)
|
||||
"copy_paste": (0.0, 1.0), # segment copy-paste (probability)
|
||||
}
|
||||
self.args = get_cfg(overrides=args)
|
||||
self.tune_dir = get_save_dir(self.args, name="tune")
|
||||
self.tune_csv = self.tune_dir / "tune_results.csv"
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
self.prefix = colorstr("Tuner: ")
|
||||
callbacks.add_integration_callbacks(self)
|
||||
LOGGER.info(f"Initialized Tuner instance with 'tune_dir={self.tune_dir}'.")
|
||||
LOGGER.info(
|
||||
f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
|
||||
f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
|
||||
)
|
||||
|
||||
def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2):
|
||||
def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
|
||||
"""
|
||||
Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
|
||||
|
||||
@ -106,17 +124,17 @@ class Tuner:
|
||||
Returns:
|
||||
(dict): A dictionary containing mutated hyperparameters.
|
||||
"""
|
||||
if self.evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
|
||||
if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
|
||||
# Select parent(s)
|
||||
x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
|
||||
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
||||
fitness = x[:, 0] # first column
|
||||
n = min(n, len(x)) # number of previous results to consider
|
||||
x = x[np.argsort(-fitness)][:n] # top n mutations
|
||||
w = x[:, 0] - x[:, 0].min() + 1E-6 # weights (sum > 0)
|
||||
if parent == 'single' or len(x) == 1:
|
||||
w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
|
||||
if parent == "single" or len(x) == 1:
|
||||
# x = x[random.randint(0, n - 1)] # random selection
|
||||
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
|
||||
elif parent == 'weighted':
|
||||
elif parent == "weighted":
|
||||
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
|
||||
|
||||
# Mutate
|
||||
@ -139,7 +157,7 @@ class Tuner:
|
||||
|
||||
return hyp
|
||||
|
||||
def __call__(self, model=None, iterations=10, prefix=colorstr('Tuner:')):
|
||||
def __call__(self, model=None, iterations=10, cleanup=True):
|
||||
"""
|
||||
Executes the hyperparameter evolution process when the Tuner instance is called.
|
||||
|
||||
@ -152,54 +170,73 @@ class Tuner:
|
||||
Args:
|
||||
model (Model): A pre-initialized YOLO model to be used for training.
|
||||
iterations (int): The number of generations to run the evolution for.
|
||||
cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
|
||||
|
||||
Note:
|
||||
The method utilizes the `self.evolve_csv` Path object to read and log hyperparameters and fitness scores.
|
||||
The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
|
||||
Ensure this path is set correctly in the Tuner instance.
|
||||
"""
|
||||
|
||||
t0 = time.time()
|
||||
best_save_dir, best_metrics = None, None
|
||||
self.tune_dir.mkdir(parents=True, exist_ok=True)
|
||||
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
|
||||
for i in range(iterations):
|
||||
# Mutate hyperparameters
|
||||
mutated_hyp = self._mutate()
|
||||
LOGGER.info(f'{prefix} Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
|
||||
LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
|
||||
|
||||
metrics = {}
|
||||
train_args = {**vars(self.args), **mutated_hyp}
|
||||
save_dir = get_save_dir(get_cfg(train_args))
|
||||
weights_dir = save_dir / "weights"
|
||||
try:
|
||||
# Train YOLO model with mutated hyperparameters
|
||||
train_args = {**vars(self.args), **mutated_hyp}
|
||||
results = (deepcopy(model) or YOLO(self.args.model)).train(**train_args)
|
||||
fitness = results.fitness
|
||||
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
|
||||
cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())]
|
||||
return_code = subprocess.run(cmd, check=True).returncode
|
||||
ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
|
||||
metrics = torch.load(ckpt_file)["train_metrics"]
|
||||
assert return_code == 0, "training failed"
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i}\n{e}')
|
||||
fitness = 0.0
|
||||
LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}")
|
||||
|
||||
# Save results and mutated_hyp to evolve_csv
|
||||
# Save results and mutated_hyp to CSV
|
||||
fitness = metrics.get("fitness", 0.0)
|
||||
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
|
||||
headers = '' if self.evolve_csv.exists() else (','.join(['fitness_score'] + list(self.space.keys())) + '\n')
|
||||
with open(self.evolve_csv, 'a') as f:
|
||||
f.write(headers + ','.join(map(str, log_row)) + '\n')
|
||||
headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
|
||||
with open(self.tune_csv, "a") as f:
|
||||
f.write(headers + ",".join(map(str, log_row)) + "\n")
|
||||
|
||||
# Print tuning results
|
||||
x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
|
||||
# Get best results
|
||||
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
||||
fitness = x[:, 0] # first column
|
||||
best_idx = fitness.argmax()
|
||||
best_is_current = best_idx == i
|
||||
if best_is_current:
|
||||
best_save_dir = results.save_dir
|
||||
best_metrics = {k: round(v, 5) for k, v in results.results_dict.items()}
|
||||
header = (f'{prefix} {i + 1} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
|
||||
f'{prefix} Results saved to {colorstr("bold", self.tune_dir)}\n'
|
||||
f'{prefix} Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
|
||||
f'{prefix} Best fitness metrics are {best_metrics}\n'
|
||||
f'{prefix} Best fitness model is {best_save_dir}\n'
|
||||
f'{prefix} Best fitness hyperparameters are printed below.\n')
|
||||
best_save_dir = save_dir
|
||||
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
|
||||
for ckpt in weights_dir.glob("*.pt"):
|
||||
shutil.copy2(ckpt, self.tune_dir / "weights")
|
||||
elif cleanup:
|
||||
shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space
|
||||
|
||||
LOGGER.info('\n' + header)
|
||||
# Plot tune results
|
||||
plot_tune_results(self.tune_csv)
|
||||
|
||||
# Save turning results
|
||||
data = {k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())}
|
||||
header = header.replace(prefix, '#').replace('[1m/', '').replace('[0m', '') + '\n'
|
||||
yaml_save(self.tune_dir / 'best.yaml', data=data, header=header)
|
||||
yaml_print(self.tune_dir / 'best.yaml')
|
||||
# Save and print tune results
|
||||
header = (
|
||||
f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
|
||||
f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
|
||||
f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
|
||||
f'{self.prefix}Best fitness metrics are {best_metrics}\n'
|
||||
f'{self.prefix}Best fitness model is {best_save_dir}\n'
|
||||
f'{self.prefix}Best fitness hyperparameters are printed below.\n'
|
||||
)
|
||||
LOGGER.info("\n" + header)
|
||||
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
|
||||
yaml_save(
|
||||
self.tune_dir / "best_hyperparameters.yaml",
|
||||
data=data,
|
||||
header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
|
||||
)
|
||||
yaml_print(self.tune_dir / "best_hyperparameters.yaml")
|
||||
|
@ -17,7 +17,9 @@ Usage - formats:
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
yolov8n_ncnn_model # NCNN
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
@ -36,7 +38,7 @@ from ultralytics.utils.torch_utils import de_parallel, select_device, smart_infe
|
||||
|
||||
class BaseValidator:
|
||||
"""
|
||||
BaseValidator
|
||||
BaseValidator.
|
||||
|
||||
A base class for creating validators.
|
||||
|
||||
@ -77,7 +79,7 @@ class BaseValidator:
|
||||
self.args = get_cfg(overrides=args)
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.model = None
|
||||
self.stride = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
self.batch_i = None
|
||||
@ -89,20 +91,20 @@ class BaseValidator:
|
||||
self.nc = None
|
||||
self.iouv = None
|
||||
self.jdict = None
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||
|
||||
self.save_dir = save_dir or get_save_dir(self.args)
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.001 # default conf=0.001
|
||||
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
|
||||
|
||||
self.plots = {}
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""
|
||||
Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
|
||||
"""Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
|
||||
gets priority).
|
||||
"""
|
||||
self.training = trainer is not None
|
||||
@ -110,7 +112,7 @@ class BaseValidator:
|
||||
if self.training:
|
||||
self.device = trainer.device
|
||||
self.data = trainer.data
|
||||
self.args.half = self.device.type != 'cpu' # force FP16 val during training
|
||||
# self.args.half = self.device.type != "cpu" # force FP16 val during training
|
||||
model = trainer.ema.ema or trainer.model
|
||||
model = model.half() if self.args.half else model.float()
|
||||
# self.model = model
|
||||
@ -119,12 +121,13 @@ class BaseValidator:
|
||||
model.eval()
|
||||
else:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
self.run_callbacks('on_val_start')
|
||||
model = AutoBackend(model or self.args.model,
|
||||
device=select_device(self.args.device, self.args.batch),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half)
|
||||
model = AutoBackend(
|
||||
weights=model or self.args.model,
|
||||
device=select_device(self.args.device, self.args.batch),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half,
|
||||
)
|
||||
# self.model = model
|
||||
self.device = model.device # update device
|
||||
self.args.half = model.fp16 # update half
|
||||
@ -134,30 +137,37 @@ class BaseValidator:
|
||||
self.args.batch = model.batch_size
|
||||
elif not pt and not jit:
|
||||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
||||
|
||||
if isinstance(self.args.data, str) and self.args.data.split('.')[-1] in ('yaml', 'yml'):
|
||||
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
elif self.args.task == 'classify':
|
||||
elif self.args.task == "classify":
|
||||
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||
else:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||
|
||||
if self.device.type in ('cpu', 'mps'):
|
||||
if self.device.type in ("cpu", "mps"):
|
||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||
if not pt:
|
||||
self.args.rect = False
|
||||
self.stride = model.stride # used in get_dataloader() for padding
|
||||
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
||||
|
||||
model.eval()
|
||||
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
|
||||
|
||||
dt = Profile(), Profile(), Profile(), Profile()
|
||||
self.run_callbacks("on_val_start")
|
||||
dt = (
|
||||
Profile(device=self.device),
|
||||
Profile(device=self.device),
|
||||
Profile(device=self.device),
|
||||
Profile(device=self.device),
|
||||
)
|
||||
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
|
||||
self.init_metrics(de_parallel(model))
|
||||
self.jdict = [] # empty before each val
|
||||
for batch_i, batch in enumerate(bar):
|
||||
self.run_callbacks('on_val_batch_start')
|
||||
self.run_callbacks("on_val_batch_start")
|
||||
self.batch_i = batch_i
|
||||
# Preprocess
|
||||
with dt[0]:
|
||||
@ -165,7 +175,7 @@ class BaseValidator:
|
||||
|
||||
# Inference
|
||||
with dt[1]:
|
||||
preds = model(batch['img'], augment=augment)
|
||||
preds = model(batch["img"], augment=augment)
|
||||
|
||||
# Loss
|
||||
with dt[2]:
|
||||
@ -181,23 +191,32 @@ class BaseValidator:
|
||||
self.plot_val_samples(batch, batch_i)
|
||||
self.plot_predictions(batch, preds, batch_i)
|
||||
|
||||
self.run_callbacks('on_val_batch_end')
|
||||
self.run_callbacks("on_val_batch_end")
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
|
||||
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
||||
self.finalize_metrics()
|
||||
self.print_results()
|
||||
self.run_callbacks('on_val_end')
|
||||
if not (self.args.save_json and self.is_coco and len(self.jdict)):
|
||||
self.print_results()
|
||||
self.run_callbacks("on_val_end")
|
||||
if self.training:
|
||||
model.float()
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
||||
if self.args.save_json and self.jdict:
|
||||
with open(str(self.save_dir / "predictions.json"), "w") as f:
|
||||
LOGGER.info(f"Saving {f.name}...")
|
||||
json.dump(self.jdict, f) # flatten and save
|
||||
stats = self.eval_json(stats) # update stats
|
||||
stats['fitness'] = stats['metrics/mAP50-95(B)']
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||
else:
|
||||
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
||||
tuple(self.speed.values()))
|
||||
LOGGER.info(
|
||||
"Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
|
||||
% tuple(self.speed.values())
|
||||
)
|
||||
if self.args.save_json and self.jdict:
|
||||
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
||||
LOGGER.info(f'Saving {f.name}...')
|
||||
with open(str(self.save_dir / "predictions.json"), "w") as f:
|
||||
LOGGER.info(f"Saving {f.name}...")
|
||||
json.dump(self.jdict, f) # flatten and save
|
||||
stats = self.eval_json(stats) # update stats
|
||||
if self.args.plots or self.args.save_json:
|
||||
@ -227,6 +246,7 @@ class BaseValidator:
|
||||
if use_scipy:
|
||||
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
|
||||
import scipy # scope import to avoid importing for all commands
|
||||
|
||||
cost_matrix = iou * (iou >= threshold)
|
||||
if cost_matrix.any():
|
||||
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
|
||||
@ -256,11 +276,11 @@ class BaseValidator:
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
"""Get data loader from dataset path and batch size."""
|
||||
raise NotImplementedError('get_dataloader function not implemented for this validator')
|
||||
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
||||
|
||||
def build_dataset(self, img_path):
|
||||
"""Build dataset"""
|
||||
raise NotImplementedError('build_dataset function not implemented in validator')
|
||||
"""Build dataset."""
|
||||
raise NotImplementedError("build_dataset function not implemented in validator")
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses an input batch."""
|
||||
@ -305,7 +325,7 @@ class BaseValidator:
|
||||
|
||||
def on_plot(self, name, data=None):
|
||||
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||
self.plots[Path(name)] = {'data': data, 'timestamp': time.time()}
|
||||
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_val_samples(self, batch, ni):
|
||||
|
Reference in New Issue
Block a user