Skip to content

Reference for ultralytics/utils/loss.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/loss.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.utils.loss.VarifocalLoss

VarifocalLoss(gamma=2.0, alpha=0.75)

Bases: Module

Varifocal loss by Zhang et al.

https://arxiv.org/abs/2008.13367.

Parameters:

Name Type Description Default
gamma float

The focusing parameter that controls how much the loss focuses on hard-to-classify examples.

2.0
alpha float

The balancing factor used to address class imbalance.

0.75
Source code in ultralytics/utils/loss.py
27
28
29
30
31
def __init__(self, gamma=2.0, alpha=0.75):
    """Initialize the VarifocalLoss class."""
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha

forward

forward(pred_score, gt_score, label)

Compute varifocal loss between predictions and ground truth.

Source code in ultralytics/utils/loss.py
33
34
35
36
37
38
39
40
41
42
def forward(self, pred_score, gt_score, label):
    """Compute varifocal loss between predictions and ground truth."""
    weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
    with autocast(enabled=False):
        loss = (
            (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
            .mean(1)
            .sum()
        )
    return loss





ultralytics.utils.loss.FocalLoss

FocalLoss(gamma=1.5, alpha=0.25)

Bases: Module

Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).

Parameters:

Name Type Description Default
gamma float

The focusing parameter that controls how much the loss focuses on hard-to-classify examples.

1.5
alpha float

The balancing factor used to address class imbalance.

0.25
Source code in ultralytics/utils/loss.py
54
55
56
57
58
def __init__(self, gamma=1.5, alpha=0.25):
    """Initialize FocalLoss class with no parameters."""
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha

forward

forward(pred, label)

Calculate focal loss with modulating factors for class imbalance.

Source code in ultralytics/utils/loss.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def forward(self, pred, label):
    """Calculate focal loss with modulating factors for class imbalance."""
    loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
    # p_t = torch.exp(-loss)
    # loss *= self.alpha * (1.000001 - p_t) ** self.gamma  # non-zero power for gradient stability

    # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
    pred_prob = pred.sigmoid()  # prob from logits
    p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
    modulating_factor = (1.0 - p_t) ** self.gamma
    loss *= modulating_factor
    if self.alpha > 0:
        alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
        loss *= alpha_factor
    return loss.mean(1).sum()





ultralytics.utils.loss.DFLoss

DFLoss(reg_max=16)

Bases: Module

Criterion class for computing Distribution Focal Loss (DFL).

Source code in ultralytics/utils/loss.py
80
81
82
83
def __init__(self, reg_max=16) -> None:
    """Initialize the DFL module with regularization maximum."""
    super().__init__()
    self.reg_max = reg_max

__call__

__call__(pred_dist, target)

Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391.

Source code in ultralytics/utils/loss.py
85
86
87
88
89
90
91
92
93
94
95
def __call__(self, pred_dist, target):
    """Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
    target = target.clamp_(0, self.reg_max - 1 - 0.01)
    tl = target.long()  # target left
    tr = tl + 1  # target right
    wl = tr - target  # weight left
    wr = 1 - wl  # weight right
    return (
        F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
        + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
    ).mean(-1, keepdim=True)





ultralytics.utils.loss.BboxLoss

BboxLoss(reg_max=16)

Bases: Module

Criterion class for computing training losses for bounding boxes.

Source code in ultralytics/utils/loss.py
101
102
103
104
def __init__(self, reg_max=16):
    """Initialize the BboxLoss module with regularization maximum and DFL settings."""
    super().__init__()
    self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None

forward

forward(
    pred_dist,
    pred_bboxes,
    anchor_points,
    target_bboxes,
    target_scores,
    target_scores_sum,
    fg_mask,
)

Compute IoU and DFL losses for bounding boxes.

Source code in ultralytics/utils/loss.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
    """Compute IoU and DFL losses for bounding boxes."""
    weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
    iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
    loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

    # DFL loss
    if self.dfl_loss:
        target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
        loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
        loss_dfl = loss_dfl.sum() / target_scores_sum
    else:
        loss_dfl = torch.tensor(0.0).to(pred_dist.device)

    return loss_iou, loss_dfl





ultralytics.utils.loss.RotatedBboxLoss

RotatedBboxLoss(reg_max)

Bases: BboxLoss

Criterion class for computing training losses for rotated bounding boxes.

Source code in ultralytics/utils/loss.py
126
127
128
def __init__(self, reg_max):
    """Initialize the BboxLoss module with regularization maximum and DFL settings."""
    super().__init__(reg_max)

forward

forward(
    pred_dist,
    pred_bboxes,
    anchor_points,
    target_bboxes,
    target_scores,
    target_scores_sum,
    fg_mask,
)

Compute IoU and DFL losses for rotated bounding boxes.

Source code in ultralytics/utils/loss.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
    """Compute IoU and DFL losses for rotated bounding boxes."""
    weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
    iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
    loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

    # DFL loss
    if self.dfl_loss:
        target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
        loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
        loss_dfl = loss_dfl.sum() / target_scores_sum
    else:
        loss_dfl = torch.tensor(0.0).to(pred_dist.device)

    return loss_iou, loss_dfl





ultralytics.utils.loss.KeypointLoss

KeypointLoss(sigmas)

Bases: Module

Criterion class for computing keypoint losses.

Source code in ultralytics/utils/loss.py
150
151
152
153
def __init__(self, sigmas) -> None:
    """Initialize the KeypointLoss class with keypoint sigmas."""
    super().__init__()
    self.sigmas = sigmas

forward

forward(pred_kpts, gt_kpts, kpt_mask, area)

Calculate keypoint loss factor and Euclidean distance loss for keypoints.

Source code in ultralytics/utils/loss.py
155
156
157
158
159
160
161
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
    """Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
    d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
    kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
    # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9)  # from formula
    e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2)  # from cocoeval
    return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()





ultralytics.utils.loss.v8DetectionLoss

v8DetectionLoss(model, tal_topk=10)

Criterion class for computing training losses for YOLOv8 object detection.

Source code in ultralytics/utils/loss.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def __init__(self, model, tal_topk=10):  # model must be de-paralleled
    """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
    device = next(model.parameters()).device  # get model device
    h = model.args  # hyperparameters

    m = model.model[-1]  # Detect() module
    self.bce = nn.BCEWithLogitsLoss(reduction="none")
    self.hyp = h
    self.stride = m.stride  # model strides
    self.nc = m.nc  # number of classes
    self.no = m.nc + m.reg_max * 4
    self.reg_max = m.reg_max
    self.device = device

    self.use_dfl = m.reg_max > 1

    self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
    self.bbox_loss = BboxLoss(m.reg_max).to(device)
    self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)

__call__

__call__(preds, batch)

Calculate the sum of the loss for box, cls and dfl multiplied by batch size.

Source code in ultralytics/utils/loss.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def __call__(self, preds, batch):
    """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
    loss = torch.zeros(3, device=self.device)  # box, cls, dfl
    feats = preds[1] if isinstance(preds, tuple) else preds
    pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
        (self.reg_max * 4, self.nc), 1
    )

    pred_scores = pred_scores.permute(0, 2, 1).contiguous()
    pred_distri = pred_distri.permute(0, 2, 1).contiguous()

    dtype = pred_scores.dtype
    batch_size = pred_scores.shape[0]
    imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
    anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)

    # Targets
    targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
    targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
    gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
    mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)

    # Pboxes
    pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
    # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
    # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2

    _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
        # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
        pred_scores.detach().sigmoid(),
        (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
        anchor_points * stride_tensor,
        gt_labels,
        gt_bboxes,
        mask_gt,
    )

    target_scores_sum = max(target_scores.sum(), 1)

    # Cls loss
    # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
    loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE

    # Bbox loss
    if fg_mask.sum():
        target_bboxes /= stride_tensor
        loss[0], loss[2] = self.bbox_loss(
            pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
        )

    loss[0] *= self.hyp.box  # box gain
    loss[1] *= self.hyp.cls  # cls gain
    loss[2] *= self.hyp.dfl  # dfl gain

    return loss * batch_size, loss.detach()  # loss(box, cls, dfl)

bbox_decode

bbox_decode(anchor_points, pred_dist)

Decode predicted object bounding box coordinates from anchor points and distribution.

Source code in ultralytics/utils/loss.py
204
205
206
207
208
209
210
211
def bbox_decode(self, anchor_points, pred_dist):
    """Decode predicted object bounding box coordinates from anchor points and distribution."""
    if self.use_dfl:
        b, a, c = pred_dist.shape  # batch, anchors, channels
        pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
        # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
        # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
    return dist2bbox(pred_dist, anchor_points, xywh=False)

preprocess

preprocess(targets, batch_size, scale_tensor)

Preprocess targets by converting to tensor format and scaling coordinates.

Source code in ultralytics/utils/loss.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def preprocess(self, targets, batch_size, scale_tensor):
    """Preprocess targets by converting to tensor format and scaling coordinates."""
    nl, ne = targets.shape
    if nl == 0:
        out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
    else:
        i = targets[:, 0]  # image index
        _, counts = i.unique(return_counts=True)
        counts = counts.to(dtype=torch.int32)
        out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
        for j in range(batch_size):
            matches = i == j
            if n := matches.sum():
                out[j, :n] = targets[matches, 1:]
        out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
    return out





ultralytics.utils.loss.v8SegmentationLoss

v8SegmentationLoss(model)

Bases: v8DetectionLoss

Criterion class for computing training losses for YOLOv8 segmentation.

Source code in ultralytics/utils/loss.py
273
274
275
276
def __init__(self, model):  # model must be de-paralleled
    """Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
    super().__init__(model)
    self.overlap = model.args.overlap_mask

__call__

__call__(preds, batch)

Calculate and return the combined loss for detection and segmentation.

Source code in ultralytics/utils/loss.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def __call__(self, preds, batch):
    """Calculate and return the combined loss for detection and segmentation."""
    loss = torch.zeros(4, device=self.device)  # box, seg, cls, dfl
    feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
    batch_size, _, mask_h, mask_w = proto.shape  # batch size, number of masks, mask height, mask width
    pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
        (self.reg_max * 4, self.nc), 1
    )

    # B, grids, ..
    pred_scores = pred_scores.permute(0, 2, 1).contiguous()
    pred_distri = pred_distri.permute(0, 2, 1).contiguous()
    pred_masks = pred_masks.permute(0, 2, 1).contiguous()

    dtype = pred_scores.dtype
    imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
    anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)

    # Targets
    try:
        batch_idx = batch["batch_idx"].view(-1, 1)
        targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
    except RuntimeError as e:
        raise TypeError(
            "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
            "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
            "i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
            "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
            "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
        ) from e

    # Pboxes
    pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)

    _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
        pred_scores.detach().sigmoid(),
        (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
        anchor_points * stride_tensor,
        gt_labels,
        gt_bboxes,
        mask_gt,
    )

    target_scores_sum = max(target_scores.sum(), 1)

    # Cls loss
    # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
    loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE

    if fg_mask.sum():
        # Bbox loss
        loss[0], loss[3] = self.bbox_loss(
            pred_distri,
            pred_bboxes,
            anchor_points,
            target_bboxes / stride_tensor,
            target_scores,
            target_scores_sum,
            fg_mask,
        )
        # Masks loss
        masks = batch["masks"].to(self.device).float()
        if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample
            masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]

        loss[1] = self.calculate_segmentation_loss(
            fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
        )

    # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
    else:
        loss[1] += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss

    loss[0] *= self.hyp.box  # box gain
    loss[1] *= self.hyp.box  # seg gain
    loss[2] *= self.hyp.cls  # cls gain
    loss[3] *= self.hyp.dfl  # dfl gain

    return loss * batch_size, loss.detach()  # loss(box, cls, dfl)

calculate_segmentation_loss

calculate_segmentation_loss(
    fg_mask: Tensor,
    masks: Tensor,
    target_gt_idx: Tensor,
    target_bboxes: Tensor,
    batch_idx: Tensor,
    proto: Tensor,
    pred_masks: Tensor,
    imgsz: Tensor,
    overlap: bool,
) -> torch.Tensor

Calculate the loss for instance segmentation.

Parameters:

Name Type Description Default
fg_mask Tensor

A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.

required
masks Tensor

Ground truth masks of shape (BS, H, W) if overlap is False, otherwise (BS, ?, H, W).

required
target_gt_idx Tensor

Indexes of ground truth objects for each anchor of shape (BS, N_anchors).

required
target_bboxes Tensor

Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).

required
batch_idx Tensor

Batch indices of shape (N_labels_in_batch, 1).

required
proto Tensor

Prototype masks of shape (BS, 32, H, W).

required
pred_masks Tensor

Predicted masks for each anchor of shape (BS, N_anchors, 32).

required
imgsz Tensor

Size of the input image as a tensor of shape (2), i.e., (H, W).

required
overlap bool

Whether the masks in masks tensor overlap.

required

Returns:

Type Description
Tensor

The calculated loss for instance segmentation.

Notes

The batch loss can be computed for improved speed at higher memory usage. For example, pred_mask can be computed as follows: pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)

Source code in ultralytics/utils/loss.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
def calculate_segmentation_loss(
    self,
    fg_mask: torch.Tensor,
    masks: torch.Tensor,
    target_gt_idx: torch.Tensor,
    target_bboxes: torch.Tensor,
    batch_idx: torch.Tensor,
    proto: torch.Tensor,
    pred_masks: torch.Tensor,
    imgsz: torch.Tensor,
    overlap: bool,
) -> torch.Tensor:
    """
    Calculate the loss for instance segmentation.

    Args:
        fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
        masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
        target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
        target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
        batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
        proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
        pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
        imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
        overlap (bool): Whether the masks in `masks` tensor overlap.

    Returns:
        (torch.Tensor): The calculated loss for instance segmentation.

    Notes:
        The batch loss can be computed for improved speed at higher memory usage.
        For example, pred_mask can be computed as follows:
            pred_mask = torch.einsum('in,nhw->ihw', pred, proto)  # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
    """
    _, _, mask_h, mask_w = proto.shape
    loss = 0

    # Normalize to 0-1
    target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]

    # Areas of target bboxes
    marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)

    # Normalize to mask size
    mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)

    for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
        fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
        if fg_mask_i.any():
            mask_idx = target_gt_idx_i[fg_mask_i]
            if overlap:
                gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
                gt_mask = gt_mask.float()
            else:
                gt_mask = masks[batch_idx.view(-1) == i][mask_idx]

            loss += self.single_mask_loss(
                gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
            )

        # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
        else:
            loss += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss

    return loss / fg_mask.sum()

single_mask_loss staticmethod

single_mask_loss(
    gt_mask: Tensor, pred: Tensor, proto: Tensor, xyxy: Tensor, area: Tensor
) -> torch.Tensor

Compute the instance segmentation loss for a single image.

Parameters:

Name Type Description Default
gt_mask Tensor

Ground truth mask of shape (n, H, W), where n is the number of objects.

required
pred Tensor

Predicted mask coefficients of shape (n, 32).

required
proto Tensor

Prototype masks of shape (32, H, W).

required
xyxy Tensor

Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).

required
area Tensor

Area of each ground truth bounding box of shape (n,).

required

Returns:

Type Description
Tensor

The calculated mask loss for a single image.

Notes

The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the predicted masks from the prototype masks and predicted mask coefficients.

Source code in ultralytics/utils/loss.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
@staticmethod
def single_mask_loss(
    gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
) -> torch.Tensor:
    """
    Compute the instance segmentation loss for a single image.

    Args:
        gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
        pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
        proto (torch.Tensor): Prototype masks of shape (32, H, W).
        xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
        area (torch.Tensor): Area of each ground truth bounding box of shape (n,).

    Returns:
        (torch.Tensor): The calculated mask loss for a single image.

    Notes:
        The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
        predicted masks from the prototype masks and predicted mask coefficients.
    """
    pred_mask = torch.einsum("in,nhw->ihw", pred, proto)  # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
    loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
    return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()





ultralytics.utils.loss.v8PoseLoss

v8PoseLoss(model)

Bases: v8DetectionLoss

Criterion class for computing training losses for YOLOv8 pose estimation.

Source code in ultralytics/utils/loss.py
456
457
458
459
460
461
462
463
464
def __init__(self, model):  # model must be de-paralleled
    """Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
    super().__init__(model)
    self.kpt_shape = model.model[-1].kpt_shape
    self.bce_pose = nn.BCEWithLogitsLoss()
    is_pose = self.kpt_shape == [17, 3]
    nkpt = self.kpt_shape[0]  # number of keypoints
    sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
    self.keypoint_loss = KeypointLoss(sigmas=sigmas)

__call__

__call__(preds, batch)

Calculate the total loss and detach it for pose estimation.

Source code in ultralytics/utils/loss.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
def __call__(self, preds, batch):
    """Calculate the total loss and detach it for pose estimation."""
    loss = torch.zeros(5, device=self.device)  # box, cls, dfl, kpt_location, kpt_visibility
    feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
    pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
        (self.reg_max * 4, self.nc), 1
    )

    # B, grids, ..
    pred_scores = pred_scores.permute(0, 2, 1).contiguous()
    pred_distri = pred_distri.permute(0, 2, 1).contiguous()
    pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()

    dtype = pred_scores.dtype
    imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
    anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)

    # Targets
    batch_size = pred_scores.shape[0]
    batch_idx = batch["batch_idx"].view(-1, 1)
    targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
    targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
    gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
    mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)

    # Pboxes
    pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
    pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape))  # (b, h*w, 17, 3)

    _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
        pred_scores.detach().sigmoid(),
        (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
        anchor_points * stride_tensor,
        gt_labels,
        gt_bboxes,
        mask_gt,
    )

    target_scores_sum = max(target_scores.sum(), 1)

    # Cls loss
    # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
    loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE

    # Bbox loss
    if fg_mask.sum():
        target_bboxes /= stride_tensor
        loss[0], loss[4] = self.bbox_loss(
            pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
        )
        keypoints = batch["keypoints"].to(self.device).float().clone()
        keypoints[..., 0] *= imgsz[1]
        keypoints[..., 1] *= imgsz[0]

        loss[1], loss[2] = self.calculate_keypoints_loss(
            fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
        )

    loss[0] *= self.hyp.box  # box gain
    loss[1] *= self.hyp.pose  # pose gain
    loss[2] *= self.hyp.kobj  # kobj gain
    loss[3] *= self.hyp.cls  # cls gain
    loss[4] *= self.hyp.dfl  # dfl gain

    return loss * batch_size, loss.detach()  # loss(box, cls, dfl)

calculate_keypoints_loss

calculate_keypoints_loss(
    masks,
    target_gt_idx,
    keypoints,
    batch_idx,
    stride_tensor,
    target_bboxes,
    pred_kpts,
)

Calculate the keypoints loss for the model.

This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is a binary classification loss that classifies whether a keypoint is present or not.

Parameters:

Name Type Description Default
masks Tensor

Binary mask tensor indicating object presence, shape (BS, N_anchors).

required
target_gt_idx Tensor

Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).

required
keypoints Tensor

Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).

required
batch_idx Tensor

Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).

required
stride_tensor Tensor

Stride tensor for anchors, shape (N_anchors, 1).

required
target_bboxes Tensor

Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).

required
pred_kpts Tensor

Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).

required

Returns:

Name Type Description
kpts_loss Tensor

The keypoints loss.

kpts_obj_loss Tensor

The keypoints object loss.

Source code in ultralytics/utils/loss.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
def calculate_keypoints_loss(
    self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
):
    """
    Calculate the keypoints loss for the model.

    This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
    based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
    a binary classification loss that classifies whether a keypoint is present or not.

    Args:
        masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
        target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
        keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
        batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
        stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
        target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
        pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).

    Returns:
        kpts_loss (torch.Tensor): The keypoints loss.
        kpts_obj_loss (torch.Tensor): The keypoints object loss.
    """
    batch_idx = batch_idx.flatten()
    batch_size = len(masks)

    # Find the maximum number of keypoints in a single image
    max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()

    # Create a tensor to hold batched keypoints
    batched_keypoints = torch.zeros(
        (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
    )

    # TODO: any idea how to vectorize this?
    # Fill batched_keypoints with keypoints based on batch_idx
    for i in range(batch_size):
        keypoints_i = keypoints[batch_idx == i]
        batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i

    # Expand dimensions of target_gt_idx to match the shape of batched_keypoints
    target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)

    # Use target_gt_idx_expanded to select keypoints from batched_keypoints
    selected_keypoints = batched_keypoints.gather(
        1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
    )

    # Divide coordinates by stride
    selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)

    kpts_loss = 0
    kpts_obj_loss = 0

    if masks.any():
        gt_kpt = selected_keypoints[masks]
        area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
        pred_kpt = pred_kpts[masks]
        kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
        kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area)  # pose loss

        if pred_kpt.shape[-1] == 3:
            kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float())  # keypoint obj loss

    return kpts_loss, kpts_obj_loss

kpts_decode staticmethod

kpts_decode(anchor_points, pred_kpts)

Decode predicted keypoints to image coordinates.

Source code in ultralytics/utils/loss.py
532
533
534
535
536
537
538
539
@staticmethod
def kpts_decode(anchor_points, pred_kpts):
    """Decode predicted keypoints to image coordinates."""
    y = pred_kpts.clone()
    y[..., :2] *= 2.0
    y[..., 0] += anchor_points[:, [0]] - 0.5
    y[..., 1] += anchor_points[:, [1]] - 0.5
    return y





ultralytics.utils.loss.v8ClassificationLoss

Criterion class for computing training losses for classification.

__call__

__call__(preds, batch)

Compute the classification loss between predictions and true labels.

Source code in ultralytics/utils/loss.py
611
612
613
614
615
616
def __call__(self, preds, batch):
    """Compute the classification loss between predictions and true labels."""
    preds = preds[1] if isinstance(preds, (list, tuple)) else preds
    loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
    loss_items = loss.detach()
    return loss, loss_items





ultralytics.utils.loss.v8OBBLoss

v8OBBLoss(model)

Bases: v8DetectionLoss

Calculates losses for object detection, classification, and box distribution in rotated YOLO models.

Source code in ultralytics/utils/loss.py
622
623
624
625
626
def __init__(self, model):
    """Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
    super().__init__(model)
    self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
    self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)

__call__

__call__(preds, batch)

Calculate and return the loss for oriented bounding box detection.

Source code in ultralytics/utils/loss.py
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
def __call__(self, preds, batch):
    """Calculate and return the loss for oriented bounding box detection."""
    loss = torch.zeros(3, device=self.device)  # box, cls, dfl
    feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
    batch_size = pred_angle.shape[0]  # batch size, number of masks, mask height, mask width
    pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
        (self.reg_max * 4, self.nc), 1
    )

    # b, grids, ..
    pred_scores = pred_scores.permute(0, 2, 1).contiguous()
    pred_distri = pred_distri.permute(0, 2, 1).contiguous()
    pred_angle = pred_angle.permute(0, 2, 1).contiguous()

    dtype = pred_scores.dtype
    imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
    anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)

    # targets
    try:
        batch_idx = batch["batch_idx"].view(-1, 1)
        targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
        rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
        targets = targets[(rw >= 2) & (rh >= 2)]  # filter rboxes of tiny size to stabilize training
        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
        gt_labels, gt_bboxes = targets.split((1, 5), 2)  # cls, xywhr
        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
    except RuntimeError as e:
        raise TypeError(
            "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
            "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
            "i.e. 'yolo train model=yolo11n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
            "correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
            "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
        ) from e

    # Pboxes
    pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle)  # xyxy, (b, h*w, 4)

    bboxes_for_assigner = pred_bboxes.clone().detach()
    # Only the first four elements need to be scaled
    bboxes_for_assigner[..., :4] *= stride_tensor
    _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
        pred_scores.detach().sigmoid(),
        bboxes_for_assigner.type(gt_bboxes.dtype),
        anchor_points * stride_tensor,
        gt_labels,
        gt_bboxes,
        mask_gt,
    )

    target_scores_sum = max(target_scores.sum(), 1)

    # Cls loss
    # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
    loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE

    # Bbox loss
    if fg_mask.sum():
        target_bboxes[..., :4] /= stride_tensor
        loss[0], loss[2] = self.bbox_loss(
            pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
        )
    else:
        loss[0] += (pred_angle * 0).sum()

    loss[0] *= self.hyp.box  # box gain
    loss[1] *= self.hyp.cls  # cls gain
    loss[2] *= self.hyp.dfl  # dfl gain

    return loss * batch_size, loss.detach()  # loss(box, cls, dfl)

bbox_decode

bbox_decode(anchor_points, pred_dist, pred_angle)

Decode predicted object bounding box coordinates from anchor points and distribution.

Parameters:

Name Type Description Default
anchor_points Tensor

Anchor points, (h*w, 2).

required
pred_dist Tensor

Predicted rotated distance, (bs, h*w, 4).

required
pred_angle Tensor

Predicted angle, (bs, h*w, 1).

required

Returns:

Type Description
Tensor

Predicted rotated bounding boxes with angles, (bs, h*w, 5).

Source code in ultralytics/utils/loss.py
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
def bbox_decode(self, anchor_points, pred_dist, pred_angle):
    """
    Decode predicted object bounding box coordinates from anchor points and distribution.

    Args:
        anchor_points (torch.Tensor): Anchor points, (h*w, 2).
        pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
        pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).

    Returns:
        (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
    """
    if self.use_dfl:
        b, a, c = pred_dist.shape  # batch, anchors, channels
        pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
    return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)

preprocess

preprocess(targets, batch_size, scale_tensor)

Preprocess targets for oriented bounding box detection.

Source code in ultralytics/utils/loss.py
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
def preprocess(self, targets, batch_size, scale_tensor):
    """Preprocess targets for oriented bounding box detection."""
    if targets.shape[0] == 0:
        out = torch.zeros(batch_size, 0, 6, device=self.device)
    else:
        i = targets[:, 0]  # image index
        _, counts = i.unique(return_counts=True)
        counts = counts.to(dtype=torch.int32)
        out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
        for j in range(batch_size):
            matches = i == j
            if n := matches.sum():
                bboxes = targets[matches, 2:]
                bboxes[..., :4].mul_(scale_tensor)
                out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
    return out





ultralytics.utils.loss.E2EDetectLoss

E2EDetectLoss(model)

Criterion class for computing training losses for end-to-end detection.

Source code in ultralytics/utils/loss.py
738
739
740
741
def __init__(self, model):
    """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
    self.one2many = v8DetectionLoss(model, tal_topk=10)
    self.one2one = v8DetectionLoss(model, tal_topk=1)

__call__

__call__(preds, batch)

Calculate the sum of the loss for box, cls and dfl multiplied by batch size.

Source code in ultralytics/utils/loss.py
743
744
745
746
747
748
749
750
def __call__(self, preds, batch):
    """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
    preds = preds[1] if isinstance(preds, tuple) else preds
    one2many = preds["one2many"]
    loss_one2many = self.one2many(one2many, batch)
    one2one = preds["one2one"]
    loss_one2one = self.one2one(one2one, batch)
    return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]





ultralytics.utils.loss.TVPDetectLoss

TVPDetectLoss(model)

Criterion class for computing training losses for text-visual prompt detection.

Source code in ultralytics/utils/loss.py
756
757
758
759
760
761
762
def __init__(self, model):
    """Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
    self.vp_criterion = v8DetectionLoss(model)
    # NOTE: store following info as it's changeable in __call__
    self.ori_nc = self.vp_criterion.nc
    self.ori_no = self.vp_criterion.no
    self.ori_reg_max = self.vp_criterion.reg_max

__call__

__call__(preds, batch)

Calculate the loss for text-visual prompt detection.

Source code in ultralytics/utils/loss.py
764
765
766
767
768
769
770
771
772
773
774
775
776
def __call__(self, preds, batch):
    """Calculate the loss for text-visual prompt detection."""
    feats = preds[1] if isinstance(preds, tuple) else preds
    assert self.ori_reg_max == self.vp_criterion.reg_max  # TODO: remove it

    if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
        loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
        return loss, loss.detach()

    vp_feats = self._get_vp_features(feats)
    vp_loss = self.vp_criterion(vp_feats, batch)
    box_loss = vp_loss[0][1]
    return box_loss, vp_loss[1]





ultralytics.utils.loss.TVPSegmentLoss

TVPSegmentLoss(model)

Bases: TVPDetectLoss

Criterion class for computing training losses for text-visual prompt segmentation.

Source code in ultralytics/utils/loss.py
795
796
797
def __init__(self, model):
    """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
    self.vp_criterion = v8SegmentationLoss(model)

__call__

__call__(preds, batch)

Calculate the loss for text-visual prompt segmentation.

Source code in ultralytics/utils/loss.py
799
800
801
802
803
804
805
806
807
808
809
810
811
def __call__(self, preds, batch):
    """Calculate the loss for text-visual prompt segmentation."""
    feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
    assert self.tp_criterion.reg_max == self.vp_criterion.reg_max

    if self.tp_criterion.reg_max * 4 + self.tp_criterion.nc == feats[0].shape[1]:
        loss = torch.zeros(4, device=self.tp_criterion.device, requires_grad=True)
        return loss, loss.detach()

    vp_feats = self._get_vp_features(feats)
    vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
    cls_loss = vp_loss[0][2]
    return cls_loss, vp_loss[1]





📅 Created 1 year ago ✏️ Updated 25 days ago