add yolo v10 and modify pipeline
This commit is contained in:
@ -6,20 +6,32 @@ import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
|
||||
from ultralytics.utils.metrics import bbox_iou
|
||||
|
||||
from .ops import HungarianMatcher
|
||||
|
||||
|
||||
class DETRLoss(nn.Module):
|
||||
"""
|
||||
DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the
|
||||
DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary
|
||||
losses.
|
||||
|
||||
def __init__(self,
|
||||
nc=80,
|
||||
loss_gain=None,
|
||||
aux_loss=True,
|
||||
use_fl=True,
|
||||
use_vfl=False,
|
||||
use_uni_match=False,
|
||||
uni_match_ind=0):
|
||||
Attributes:
|
||||
nc (int): The number of classes.
|
||||
loss_gain (dict): Coefficients for different loss components.
|
||||
aux_loss (bool): Whether to compute auxiliary losses.
|
||||
use_fl (bool): Use FocalLoss or not.
|
||||
use_vfl (bool): Use VarifocalLoss or not.
|
||||
use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch.
|
||||
uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True.
|
||||
matcher (HungarianMatcher): Object to compute matching cost and indices.
|
||||
fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None.
|
||||
vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None.
|
||||
device (torch.device): Device on which tensors are stored.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
|
||||
):
|
||||
"""
|
||||
DETR loss function.
|
||||
|
||||
@ -34,9 +46,9 @@ class DETRLoss(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if loss_gain is None:
|
||||
loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
|
||||
loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
|
||||
self.nc = nc
|
||||
self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
|
||||
self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
|
||||
self.loss_gain = loss_gain
|
||||
self.aux_loss = aux_loss
|
||||
self.fl = FocalLoss() if use_fl else None
|
||||
@ -46,9 +58,10 @@ class DETRLoss(nn.Module):
|
||||
self.uni_match_ind = uni_match_ind
|
||||
self.device = None
|
||||
|
||||
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
|
||||
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
||||
name_class = f'loss_class{postfix}'
|
||||
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
|
||||
"""Computes the classification loss based on predictions, target values, and ground truth scores."""
|
||||
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
||||
name_class = f"loss_class{postfix}"
|
||||
bs, nq = pred_scores.shape[:2]
|
||||
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
|
||||
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
|
||||
@ -63,25 +76,28 @@ class DETRLoss(nn.Module):
|
||||
loss_cls = self.fl(pred_scores, one_hot.float())
|
||||
loss_cls /= max(num_gts, 1) / nq
|
||||
else:
|
||||
loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
|
||||
loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
|
||||
|
||||
return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
|
||||
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
||||
|
||||
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
|
||||
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
||||
name_bbox = f'loss_bbox{postfix}'
|
||||
name_giou = f'loss_giou{postfix}'
|
||||
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
|
||||
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
|
||||
boxes.
|
||||
"""
|
||||
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
||||
name_bbox = f"loss_bbox{postfix}"
|
||||
name_giou = f"loss_giou{postfix}"
|
||||
|
||||
loss = {}
|
||||
if len(gt_bboxes) == 0:
|
||||
loss[name_bbox] = torch.tensor(0., device=self.device)
|
||||
loss[name_giou] = torch.tensor(0., device=self.device)
|
||||
loss[name_bbox] = torch.tensor(0.0, device=self.device)
|
||||
loss[name_giou] = torch.tensor(0.0, device=self.device)
|
||||
return loss
|
||||
|
||||
loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
|
||||
loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
|
||||
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
|
||||
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
|
||||
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
|
||||
loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
|
||||
return {k: v.squeeze() for k, v in loss.items()}
|
||||
|
||||
# This function is for future RT-DETR Segment models
|
||||
@ -115,50 +131,57 @@ class DETRLoss(nn.Module):
|
||||
# loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
# return loss.sum() / num_gts
|
||||
|
||||
def _get_loss_aux(self,
|
||||
pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
match_indices=None,
|
||||
postfix='',
|
||||
masks=None,
|
||||
gt_mask=None):
|
||||
"""Get auxiliary losses"""
|
||||
def _get_loss_aux(
|
||||
self,
|
||||
pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
match_indices=None,
|
||||
postfix="",
|
||||
masks=None,
|
||||
gt_mask=None,
|
||||
):
|
||||
"""Get auxiliary losses."""
|
||||
# NOTE: loss class, bbox, giou, mask, dice
|
||||
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
||||
if match_indices is None and self.use_uni_match:
|
||||
match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
|
||||
pred_scores[self.uni_match_ind],
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=masks[self.uni_match_ind] if masks is not None else None,
|
||||
gt_mask=gt_mask)
|
||||
match_indices = self.matcher(
|
||||
pred_bboxes[self.uni_match_ind],
|
||||
pred_scores[self.uni_match_ind],
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=masks[self.uni_match_ind] if masks is not None else None,
|
||||
gt_mask=gt_mask,
|
||||
)
|
||||
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
|
||||
aux_masks = masks[i] if masks is not None else None
|
||||
loss_ = self._get_loss(aux_bboxes,
|
||||
aux_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=aux_masks,
|
||||
gt_mask=gt_mask,
|
||||
postfix=postfix,
|
||||
match_indices=match_indices)
|
||||
loss[0] += loss_[f'loss_class{postfix}']
|
||||
loss[1] += loss_[f'loss_bbox{postfix}']
|
||||
loss[2] += loss_[f'loss_giou{postfix}']
|
||||
loss_ = self._get_loss(
|
||||
aux_bboxes,
|
||||
aux_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=aux_masks,
|
||||
gt_mask=gt_mask,
|
||||
postfix=postfix,
|
||||
match_indices=match_indices,
|
||||
)
|
||||
loss[0] += loss_[f"loss_class{postfix}"]
|
||||
loss[1] += loss_[f"loss_bbox{postfix}"]
|
||||
loss[2] += loss_[f"loss_giou{postfix}"]
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
|
||||
# loss[3] += loss_[f'loss_mask{postfix}']
|
||||
# loss[4] += loss_[f'loss_dice{postfix}']
|
||||
|
||||
loss = {
|
||||
f'loss_class_aux{postfix}': loss[0],
|
||||
f'loss_bbox_aux{postfix}': loss[1],
|
||||
f'loss_giou_aux{postfix}': loss[2]}
|
||||
f"loss_class_aux{postfix}": loss[0],
|
||||
f"loss_bbox_aux{postfix}": loss[1],
|
||||
f"loss_giou_aux{postfix}": loss[2],
|
||||
}
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss[f'loss_mask_aux{postfix}'] = loss[3]
|
||||
# loss[f'loss_dice_aux{postfix}'] = loss[4]
|
||||
@ -166,39 +189,45 @@ class DETRLoss(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _get_index(match_indices):
|
||||
"""Returns batch indices, source indices, and destination indices from provided match indices."""
|
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
||||
src_idx = torch.cat([src for (src, _) in match_indices])
|
||||
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
||||
return (batch_idx, src_idx), dst_idx
|
||||
|
||||
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
|
||||
pred_assigned = torch.cat([
|
||||
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (I, _) in zip(pred_bboxes, match_indices)])
|
||||
gt_assigned = torch.cat([
|
||||
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (_, J) in zip(gt_bboxes, match_indices)])
|
||||
"""Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
|
||||
pred_assigned = torch.cat(
|
||||
[
|
||||
t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (i, _) in zip(pred_bboxes, match_indices)
|
||||
]
|
||||
)
|
||||
gt_assigned = torch.cat(
|
||||
[
|
||||
t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (_, j) in zip(gt_bboxes, match_indices)
|
||||
]
|
||||
)
|
||||
return pred_assigned, gt_assigned
|
||||
|
||||
def _get_loss(self,
|
||||
pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=None,
|
||||
gt_mask=None,
|
||||
postfix='',
|
||||
match_indices=None):
|
||||
"""Get losses"""
|
||||
def _get_loss(
|
||||
self,
|
||||
pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=None,
|
||||
gt_mask=None,
|
||||
postfix="",
|
||||
match_indices=None,
|
||||
):
|
||||
"""Get losses."""
|
||||
if match_indices is None:
|
||||
match_indices = self.matcher(pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=masks,
|
||||
gt_mask=gt_mask)
|
||||
match_indices = self.matcher(
|
||||
pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
|
||||
)
|
||||
|
||||
idx, gt_idx = self._get_index(match_indices)
|
||||
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
|
||||
@ -218,7 +247,7 @@ class DETRLoss(nn.Module):
|
||||
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
|
||||
return loss
|
||||
|
||||
def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
|
||||
def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
|
||||
"""
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): [l, b, query, 4]
|
||||
@ -230,43 +259,62 @@ class DETRLoss(nn.Module):
|
||||
postfix (str): postfix of loss name.
|
||||
"""
|
||||
self.device = pred_bboxes.device
|
||||
match_indices = kwargs.get('match_indices', None)
|
||||
gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
|
||||
match_indices = kwargs.get("match_indices", None)
|
||||
gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
|
||||
|
||||
total_loss = self._get_loss(pred_bboxes[-1],
|
||||
pred_scores[-1],
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
postfix=postfix,
|
||||
match_indices=match_indices)
|
||||
total_loss = self._get_loss(
|
||||
pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
|
||||
)
|
||||
|
||||
if self.aux_loss:
|
||||
total_loss.update(
|
||||
self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
|
||||
postfix))
|
||||
self._get_loss_aux(
|
||||
pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
|
||||
)
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class RTDETRDetectionLoss(DETRLoss):
|
||||
"""
|
||||
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
||||
|
||||
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
|
||||
an additional denoising training loss when provided with denoising metadata.
|
||||
"""
|
||||
|
||||
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
|
||||
"""
|
||||
Forward pass to compute the detection loss.
|
||||
|
||||
Args:
|
||||
preds (tuple): Predicted bounding boxes and scores.
|
||||
batch (dict): Batch data containing ground truth information.
|
||||
dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None.
|
||||
dn_scores (torch.Tensor, optional): Denoising scores. Default is None.
|
||||
dn_meta (dict, optional): Metadata for denoising. Default is None.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary containing the total loss and, if applicable, the denoising loss.
|
||||
"""
|
||||
pred_bboxes, pred_scores = preds
|
||||
total_loss = super().forward(pred_bboxes, pred_scores, batch)
|
||||
|
||||
# Check for denoising metadata to compute denoising training loss
|
||||
if dn_meta is not None:
|
||||
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
|
||||
assert len(batch['gt_groups']) == len(dn_pos_idx)
|
||||
dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
|
||||
assert len(batch["gt_groups"]) == len(dn_pos_idx)
|
||||
|
||||
# Denoising match indices
|
||||
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
|
||||
# Get the match indices for denoising
|
||||
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
|
||||
|
||||
# Compute denoising training loss
|
||||
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
|
||||
# Compute the denoising training loss
|
||||
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
|
||||
total_loss.update(dn_loss)
|
||||
else:
|
||||
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
|
||||
# If no denoising metadata is provided, set denoising loss to zero
|
||||
total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
|
||||
|
||||
return total_loss
|
||||
|
||||
@ -276,12 +324,12 @@ class RTDETRDetectionLoss(DETRLoss):
|
||||
Get the match indices for denoising.
|
||||
|
||||
Args:
|
||||
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
|
||||
dn_num_group (int): The number of groups of denoising.
|
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||
dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
|
||||
dn_num_group (int): Number of denoising groups.
|
||||
gt_groups (List[int]): List of integers representing the number of ground truths for each image.
|
||||
|
||||
Returns:
|
||||
dn_match_indices (List(tuple)): Matched indices.
|
||||
(List[tuple]): List of tuples containing matched indices for denoising.
|
||||
"""
|
||||
dn_match_indices = []
|
||||
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||
@ -289,8 +337,8 @@ class RTDETRDetectionLoss(DETRLoss):
|
||||
if num_gt > 0:
|
||||
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
||||
gt_idx = gt_idx.repeat(dn_num_group)
|
||||
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
|
||||
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
|
||||
assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
|
||||
f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
|
||||
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
||||
else:
|
||||
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|
||||
|
Reference in New Issue
Block a user