add yolo v10 and modify pipeline
This commit is contained in:
@ -4,4 +4,4 @@ from ultralytics.models.yolo.classify.predict import ClassificationPredictor
|
||||
from ultralytics.models.yolo.classify.train import ClassificationTrainer
|
||||
from ultralytics.models.yolo.classify.val import ClassificationValidator
|
||||
|
||||
__all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator'
|
||||
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,6 +1,8 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
@ -26,13 +28,23 @@ class ClassificationPredictor(BasePredictor):
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initializes ClassificationPredictor setting the task to 'classify'."""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'classify'
|
||||
self.args.task = "classify"
|
||||
self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Converts input image to model-compatible data type."""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
||||
is_legacy_transform = any(
|
||||
self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
|
||||
)
|
||||
if is_legacy_transform: # to handle legacy transforms
|
||||
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
||||
else:
|
||||
img = torch.stack(
|
||||
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
||||
)
|
||||
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
|
||||
|
@ -33,23 +33,23 @@ class ClassificationTrainer(BaseTrainer):
|
||||
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'classify'
|
||||
if overrides.get('imgsz') is None:
|
||||
overrides['imgsz'] = 224
|
||||
overrides["task"] = "classify"
|
||||
if overrides.get("imgsz") is None:
|
||||
overrides["imgsz"] = 224
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set the YOLO model's class names from the loaded dataset."""
|
||||
self.model.names = self.data['names']
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Returns a modified PyTorch model configured for training YOLO."""
|
||||
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
for m in model.modules():
|
||||
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
|
||||
if not self.args.pretrained and hasattr(m, "reset_parameters"):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||
m.p = self.args.dropout # set dropout
|
||||
@ -64,31 +64,32 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
model, ckpt = str(self.model), None
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith('.pt'):
|
||||
self.model, ckpt = attempt_load_one_weight(model, device='cpu')
|
||||
if model.endswith(".pt"):
|
||||
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
elif model.split('.')[-1] in ('yaml', 'yml'):
|
||||
elif model.split(".")[-1] in ("yaml", "yml"):
|
||||
self.model = self.get_model(cfg=model)
|
||||
elif model in torchvision.models.__dict__:
|
||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
|
||||
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
|
||||
else:
|
||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
||||
raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
|
||||
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
||||
|
||||
return ckpt
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
|
||||
def build_dataset(self, img_path, mode="train", batch=None):
|
||||
"""Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = self.build_dataset(dataset_path, mode)
|
||||
|
||||
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
|
||||
# Attach inference transforms
|
||||
if mode != 'train':
|
||||
if mode != "train":
|
||||
if is_parallel(self.model):
|
||||
self.model.module.transforms = loader.dataset.torch_transforms
|
||||
else:
|
||||
@ -97,26 +98,32 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images and classes."""
|
||||
batch['img'] = batch['img'].to(self.device)
|
||||
batch['cls'] = batch['cls'].to(self.device)
|
||||
batch["img"] = batch["img"].to(self.device)
|
||||
batch["cls"] = batch["cls"].to(self.device)
|
||||
return batch
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a formatted string showing training progress."""
|
||||
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
|
||||
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||
"Epoch",
|
||||
"GPU_mem",
|
||||
*self.loss_names,
|
||||
"Instances",
|
||||
"Size",
|
||||
)
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ['loss']
|
||||
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||
self.loss_names = ["loss"]
|
||||
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
|
||||
segmentation & detection
|
||||
Returns a loss dict with labelled training loss items tensor.
|
||||
|
||||
Not needed for classification but necessary for segmentation & detection
|
||||
"""
|
||||
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||
if loss_items is None:
|
||||
return keys
|
||||
loss_items = [round(float(loss_items), 5)]
|
||||
@ -132,19 +139,20 @@ class ClassificationTrainer(BaseTrainer):
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
LOGGER.info(f'\nValidating {f}...')
|
||||
LOGGER.info(f"\nValidating {f}...")
|
||||
self.validator.args.data = self.args.data
|
||||
self.validator.args.plots = self.args.plots
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop('fitness', None)
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
self.metrics.pop("fitness", None)
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples with their annotations."""
|
||||
plot_images(
|
||||
images=batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
||||
on_plot=self.on_plot)
|
||||
images=batch["img"],
|
||||
batch_idx=torch.arange(len(batch["img"])),
|
||||
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
|
@ -31,43 +31,42 @@ class ClassificationValidator(BaseValidator):
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.targets = None
|
||||
self.pred = None
|
||||
self.args.task = 'classify'
|
||||
self.args.task = "classify"
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
def get_desc(self):
|
||||
"""Returns a formatted string summarizing classification metrics."""
|
||||
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
|
||||
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task='classify')
|
||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
|
||||
self.pred = []
|
||||
self.targets = []
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses input batch and returns it."""
|
||||
batch['img'] = batch['img'].to(self.device, non_blocking=True)
|
||||
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
|
||||
batch['cls'] = batch['cls'].to(self.device)
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
||||
batch["cls"] = batch["cls"].to(self.device)
|
||||
return batch
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Updates running metrics with model predictions and batch targets."""
|
||||
n5 = min(len(self.names), 5)
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
|
||||
self.targets.append(batch['cls'])
|
||||
self.targets.append(batch["cls"])
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir,
|
||||
names=self.names.values(),
|
||||
normalize=normalize,
|
||||
on_plot=self.on_plot)
|
||||
self.confusion_matrix.plot(
|
||||
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
||||
)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
self.metrics.save_dir = self.save_dir
|
||||
@ -78,6 +77,7 @@ class ClassificationValidator(BaseValidator):
|
||||
return self.metrics.results_dict
|
||||
|
||||
def build_dataset(self, img_path):
|
||||
"""Creates and returns a ClassificationDataset instance using given image path and preprocessing parameters."""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
@ -87,24 +87,27 @@ class ClassificationValidator(BaseValidator):
|
||||
|
||||
def print_results(self):
|
||||
"""Prints evaluation metrics for YOLO object detection model."""
|
||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
||||
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plot validation image samples."""
|
||||
plot_images(
|
||||
images=batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
images=batch["img"],
|
||||
batch_idx=torch.arange(len(batch["img"])),
|
||||
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||
plot_images(batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=torch.argmax(preds, dim=1),
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot) # pred
|
||||
plot_images(
|
||||
batch["img"],
|
||||
batch_idx=torch.arange(len(batch["img"])),
|
||||
cls=torch.argmax(preds, dim=1),
|
||||
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
) # pred
|
||||
|
Reference in New Issue
Block a user