Skip to content

Reference for ultralytics/models/utils/loss.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/utils/loss.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.utils.loss.DETRLoss

DETRLoss(
    nc=80,
    loss_gain=None,
    aux_loss=True,
    use_fl=True,
    use_vfl=False,
    use_uni_match=False,
    uni_match_ind=0,
)

Bases: Module

DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses.

Attributes:

NameTypeDescription
ncint

The number of classes.

loss_gaindict

Coefficients for different loss components.

aux_lossbool

Whether to compute auxiliary losses.

use_flbool

Use FocalLoss or not.

use_vflbool

Use VarifocalLoss or not.

use_uni_matchbool

Whether to use a fixed layer to assign labels for the auxiliary branch.

uni_match_indint

The fixed indices of a layer to use if use_uni_match is True.

matcherHungarianMatcher

Object to compute matching cost and indices.

flFocalLoss or None

Focal Loss object if use_fl is True, otherwise None.

vflVarifocalLoss or None

Varifocal Loss object if use_vfl is True, otherwise None.

devicedevice

Device on which tensors are stored.

Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary losses and various loss types.

Parameters:

NameTypeDescriptionDefault
ncint

Number of classes.

80
loss_gaindict

Coefficients for different loss components.

None
aux_lossbool

Use auxiliary losses from each decoder layer.

True
use_flbool

Use FocalLoss.

True
use_vflbool

Use VarifocalLoss.

False
use_uni_matchbool

Use fixed layer for auxiliary branch label assignment.

False
uni_match_indint

Index of fixed layer for uni_match.

0
Source code in ultralytics/models/utils/loss.py
def __init__(
    self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
):
    """
    Initialize DETR loss function with customizable components and gains.

    Uses default loss_gain if not provided. Initializes HungarianMatcher with
    preset cost gains. Supports auxiliary losses and various loss types.

    Args:
        nc (int): Number of classes.
        loss_gain (dict): Coefficients for different loss components.
        aux_loss (bool): Use auxiliary losses from each decoder layer.
        use_fl (bool): Use FocalLoss.
        use_vfl (bool): Use VarifocalLoss.
        use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
        uni_match_ind (int): Index of fixed layer for uni_match.
    """
    super().__init__()

    if loss_gain is None:
        loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
    self.nc = nc
    self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
    self.loss_gain = loss_gain
    self.aux_loss = aux_loss
    self.fl = FocalLoss() if use_fl else None
    self.vfl = VarifocalLoss() if use_vfl else None

    self.use_uni_match = use_uni_match
    self.uni_match_ind = uni_match_ind
    self.device = None

forward

forward(pred_bboxes, pred_scores, batch, postfix='', **kwargs)

Calculate loss for predicted bounding boxes and scores.

Parameters:

NameTypeDescriptionDefault
pred_bboxesTensor

Predicted bounding boxes, shape [l, b, query, 4].

required
pred_scoresTensor

Predicted class scores, shape [l, b, query, num_classes].

required
batchdict

Batch information containing: cls (torch.Tensor): Ground truth classes, shape [num_gts]. bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4]. gt_groups (List[int]): Number of ground truths for each image in the batch.

required
postfixstr

Postfix for loss names.

''
**kwargsAny

Additional arguments, may include 'match_indices'.

{}

Returns:

TypeDescription
dict

Computed losses, including main and auxiliary (if enabled).

Note

Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if self.aux_loss is True.

Source code in ultralytics/models/utils/loss.py
def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
    """
    Calculate loss for predicted bounding boxes and scores.

    Args:
        pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
        pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
        batch (dict): Batch information containing:
            cls (torch.Tensor): Ground truth classes, shape [num_gts].
            bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
            gt_groups (List[int]): Number of ground truths for each image in the batch.
        postfix (str): Postfix for loss names.
        **kwargs (Any): Additional arguments, may include 'match_indices'.

    Returns:
        (dict): Computed losses, including main and auxiliary (if enabled).

    Note:
        Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
        self.aux_loss is True.
    """
    self.device = pred_bboxes.device
    match_indices = kwargs.get("match_indices", None)
    gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]

    total_loss = self._get_loss(
        pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
    )

    if self.aux_loss:
        total_loss.update(
            self._get_loss_aux(
                pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
            )
        )

    return total_loss





ultralytics.models.utils.loss.RTDETRDetectionLoss

RTDETRDetectionLoss(
    nc=80,
    loss_gain=None,
    aux_loss=True,
    use_fl=True,
    use_vfl=False,
    use_uni_match=False,
    uni_match_ind=0,
)

Bases: DETRLoss

Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.

This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as an additional denoising training loss when provided with denoising metadata.

Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary losses and various loss types.

Parameters:

NameTypeDescriptionDefault
ncint

Number of classes.

80
loss_gaindict

Coefficients for different loss components.

None
aux_lossbool

Use auxiliary losses from each decoder layer.

True
use_flbool

Use FocalLoss.

True
use_vflbool

Use VarifocalLoss.

False
use_uni_matchbool

Use fixed layer for auxiliary branch label assignment.

False
uni_match_indint

Index of fixed layer for uni_match.

0
Source code in ultralytics/models/utils/loss.py
def __init__(
    self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
):
    """
    Initialize DETR loss function with customizable components and gains.

    Uses default loss_gain if not provided. Initializes HungarianMatcher with
    preset cost gains. Supports auxiliary losses and various loss types.

    Args:
        nc (int): Number of classes.
        loss_gain (dict): Coefficients for different loss components.
        aux_loss (bool): Use auxiliary losses from each decoder layer.
        use_fl (bool): Use FocalLoss.
        use_vfl (bool): Use VarifocalLoss.
        use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
        uni_match_ind (int): Index of fixed layer for uni_match.
    """
    super().__init__()

    if loss_gain is None:
        loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
    self.nc = nc
    self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
    self.loss_gain = loss_gain
    self.aux_loss = aux_loss
    self.fl = FocalLoss() if use_fl else None
    self.vfl = VarifocalLoss() if use_vfl else None

    self.use_uni_match = use_uni_match
    self.uni_match_ind = uni_match_ind
    self.device = None

forward

forward(preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None)

Forward pass to compute the detection loss.

Parameters:

NameTypeDescriptionDefault
predstuple

Predicted bounding boxes and scores.

required
batchdict

Batch data containing ground truth information.

required
dn_bboxesTensor

Denoising bounding boxes. Default is None.

None
dn_scoresTensor

Denoising scores. Default is None.

None
dn_metadict

Metadata for denoising. Default is None.

None

Returns:

TypeDescription
dict

Dictionary containing the total loss and, if applicable, the denoising loss.

Source code in ultralytics/models/utils/loss.py
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
    """
    Forward pass to compute the detection loss.

    Args:
        preds (tuple): Predicted bounding boxes and scores.
        batch (dict): Batch data containing ground truth information.
        dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None.
        dn_scores (torch.Tensor, optional): Denoising scores. Default is None.
        dn_meta (dict, optional): Metadata for denoising. Default is None.

    Returns:
        (dict): Dictionary containing the total loss and, if applicable, the denoising loss.
    """
    pred_bboxes, pred_scores = preds
    total_loss = super().forward(pred_bboxes, pred_scores, batch)

    # Check for denoising metadata to compute denoising training loss
    if dn_meta is not None:
        dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
        assert len(batch["gt_groups"]) == len(dn_pos_idx)

        # Get the match indices for denoising
        match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])

        # Compute the denoising training loss
        dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
        total_loss.update(dn_loss)
    else:
        # If no denoising metadata is provided, set denoising loss to zero
        total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})

    return total_loss

get_dn_match_indices staticmethod

get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups)

Get the match indices for denoising.

Parameters:

NameTypeDescriptionDefault
dn_pos_idxList[Tensor]

List of tensors containing positive indices for denoising.

required
dn_num_groupint

Number of denoising groups.

required
gt_groupsList[int]

List of integers representing the number of ground truths for each image.

required

Returns:

TypeDescription
List[tuple]

List of tuples containing matched indices for denoising.

Source code in ultralytics/models/utils/loss.py
@staticmethod
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
    """
    Get the match indices for denoising.

    Args:
        dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
        dn_num_group (int): Number of denoising groups.
        gt_groups (List[int]): List of integers representing the number of ground truths for each image.

    Returns:
        (List[tuple]): List of tuples containing matched indices for denoising.
    """
    dn_match_indices = []
    idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
    for i, num_gt in enumerate(gt_groups):
        if num_gt > 0:
            gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
            gt_idx = gt_idx.repeat(dn_num_group)
            assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
            f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
            dn_match_indices.append((dn_pos_idx[i], gt_idx))
        else:
            dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
    return dn_match_indices



📅 Created 11 months ago ✏️ Updated 1 month ago