Skip to content

Reference for ultralytics/utils/tal.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/tal.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.utils.tal.TaskAlignedAssigner

TaskAlignedAssigner(topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-09)

Bases: Module

A task-aligned assigner for object detection.

This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both classification and localization information.

Attributes:

Name Type Description
topk int

The number of top candidates to consider.

num_classes int

The number of object classes.

bg_idx int

Background class index.

alpha float

The alpha parameter for the classification component of the task-aligned metric.

beta float

The beta parameter for the localization component of the task-aligned metric.

eps float

A small value to prevent division by zero.

Source code in ultralytics/utils/tal.py
30
31
32
33
34
35
36
37
38
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
    """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
    super().__init__()
    self.topk = topk
    self.num_classes = num_classes
    self.bg_idx = num_classes
    self.alpha = alpha
    self.beta = beta
    self.eps = eps

forward

forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)

Compute the task-aligned assignment.

Parameters:

Name Type Description Default
pd_scores Tensor

Predicted classification scores with shape (bs, num_total_anchors, num_classes).

required
pd_bboxes Tensor

Predicted bounding boxes with shape (bs, num_total_anchors, 4).

required
anc_points Tensor

Anchor points with shape (num_total_anchors, 2).

required
gt_labels Tensor

Ground truth labels with shape (bs, n_max_boxes, 1).

required
gt_bboxes Tensor

Ground truth boxes with shape (bs, n_max_boxes, 4).

required
mask_gt Tensor

Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).

required

Returns:

Name Type Description
target_labels Tensor

Target labels with shape (bs, num_total_anchors).

target_bboxes Tensor

Target bounding boxes with shape (bs, num_total_anchors, 4).

target_scores Tensor

Target scores with shape (bs, num_total_anchors, num_classes).

fg_mask Tensor

Foreground mask with shape (bs, num_total_anchors).

target_gt_idx Tensor

Target ground truth indices with shape (bs, num_total_anchors).

References

https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py

Source code in ultralytics/utils/tal.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
    """
    Compute the task-aligned assignment.

    Args:
        pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
        pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
        anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
        gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
        gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
        mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).

    Returns:
        target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
        target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
        target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
        fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
        target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).

    References:
        https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
    """
    self.bs = pd_scores.shape[0]
    self.n_max_boxes = gt_bboxes.shape[1]
    device = gt_bboxes.device

    if self.n_max_boxes == 0:
        return (
            torch.full_like(pd_scores[..., 0], self.bg_idx),
            torch.zeros_like(pd_bboxes),
            torch.zeros_like(pd_scores),
            torch.zeros_like(pd_scores[..., 0]),
            torch.zeros_like(pd_scores[..., 0]),
        )

    try:
        return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
    except torch.cuda.OutOfMemoryError:
        # Move tensors to CPU, compute, then move back to original device
        LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
        cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
        result = self._forward(*cpu_tensors)
        return tuple(t.to(device) for t in result)

get_box_metrics

get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt)

Compute alignment metric given predicted and ground truth bounding boxes.

Parameters:

Name Type Description Default
pd_scores Tensor

Predicted classification scores with shape (bs, num_total_anchors, num_classes).

required
pd_bboxes Tensor

Predicted bounding boxes with shape (bs, num_total_anchors, 4).

required
gt_labels Tensor

Ground truth labels with shape (bs, n_max_boxes, 1).

required
gt_bboxes Tensor

Ground truth boxes with shape (bs, n_max_boxes, 4).

required
mask_gt Tensor

Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).

required

Returns:

Name Type Description
align_metric Tensor

Alignment metric combining classification and localization.

overlaps Tensor

IoU overlaps between predicted and ground truth boxes.

Source code in ultralytics/utils/tal.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
    """
    Compute alignment metric given predicted and ground truth bounding boxes.

    Args:
        pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
        pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
        gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
        gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
        mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).

    Returns:
        align_metric (torch.Tensor): Alignment metric combining classification and localization.
        overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
    """
    na = pd_bboxes.shape[-2]
    mask_gt = mask_gt.bool()  # b, max_num_obj, h*w
    overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
    bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)

    ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
    ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
    ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
    # Get the scores of each grid for each gt cls
    bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]  # b, max_num_obj, h*w

    # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
    pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
    gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
    overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)

    align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
    return align_metric, overlaps

get_pos_mask

get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt)

Get positive mask for each ground truth box.

Parameters:

Name Type Description Default
pd_scores Tensor

Predicted classification scores with shape (bs, num_total_anchors, num_classes).

required
pd_bboxes Tensor

Predicted bounding boxes with shape (bs, num_total_anchors, 4).

required
gt_labels Tensor

Ground truth labels with shape (bs, n_max_boxes, 1).

required
gt_bboxes Tensor

Ground truth boxes with shape (bs, n_max_boxes, 4).

required
anc_points Tensor

Anchor points with shape (num_total_anchors, 2).

required
mask_gt Tensor

Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).

required

Returns:

Name Type Description
mask_pos Tensor

Positive mask with shape (bs, max_num_obj, h*w).

align_metric Tensor

Alignment metric with shape (bs, max_num_obj, h*w).

overlaps Tensor

Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).

Source code in ultralytics/utils/tal.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
    """
    Get positive mask for each ground truth box.

    Args:
        pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
        pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
        gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
        gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
        anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
        mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).

    Returns:
        mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
        align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
        overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
    """
    mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
    # Get anchor_align metric, (b, max_num_obj, h*w)
    align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
    # Get topk_metric mask, (b, max_num_obj, h*w)
    mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
    # Merge all mask to a final mask, (b, max_num_obj, h*w)
    mask_pos = mask_topk * mask_in_gts * mask_gt

    return mask_pos, align_metric, overlaps

get_targets

get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)

Compute target labels, target bounding boxes, and target scores for the positive anchor points.

Parameters:

Name Type Description Default
gt_labels Tensor

Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and max_num_obj is the maximum number of objects.

required
gt_bboxes Tensor

Ground truth bounding boxes of shape (b, max_num_obj, 4).

required
target_gt_idx Tensor

Indices of the assigned ground truth objects for positive anchor points, with shape (b, hw), where hw is the total number of anchor points.

required
fg_mask Tensor

A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor points.

required

Returns:

Name Type Description
target_labels Tensor

Shape (b, h*w), containing the target labels for positive anchor points.

target_bboxes Tensor

Shape (b, h*w, 4), containing the target bounding boxes for positive anchor points.

target_scores Tensor

Shape (b, h*w, num_classes), containing the target scores for positive anchor points.

Source code in ultralytics/utils/tal.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
    """
    Compute target labels, target bounding boxes, and target scores for the positive anchor points.

    Args:
        gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
                            batch size and max_num_obj is the maximum number of objects.
        gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
        target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
                                anchor points, with shape (b, h*w), where h*w is the total
                                number of anchor points.
        fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
                          (foreground) anchor points.

    Returns:
        target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
        target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
                                      anchor points.
        target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
                                      anchor points.
    """
    # Assigned target labels, (b, 1)
    batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
    target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
    target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)

    # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
    target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]

    # Assigned target scores
    target_labels.clamp_(0)

    # 10x faster than F.one_hot()
    target_scores = torch.zeros(
        (target_labels.shape[0], target_labels.shape[1], self.num_classes),
        dtype=torch.int64,
        device=target_labels.device,
    )  # (b, h*w, 80)
    target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)

    fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
    target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)

    return target_labels, target_bboxes, target_scores

iou_calculation

iou_calculation(gt_bboxes, pd_bboxes)

Calculate IoU for horizontal bounding boxes.

Parameters:

Name Type Description Default
gt_bboxes Tensor

Ground truth boxes.

required
pd_bboxes Tensor

Predicted boxes.

required

Returns:

Type Description
Tensor

IoU values between each pair of boxes.

Source code in ultralytics/utils/tal.py
183
184
185
186
187
188
189
190
191
192
193
194
def iou_calculation(self, gt_bboxes, pd_bboxes):
    """
    Calculate IoU for horizontal bounding boxes.

    Args:
        gt_bboxes (torch.Tensor): Ground truth boxes.
        pd_bboxes (torch.Tensor): Predicted boxes.

    Returns:
        (torch.Tensor): IoU values between each pair of boxes.
    """
    return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)

select_candidates_in_gts staticmethod

select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-09)

Select positive anchor centers within ground truth bounding boxes.

Parameters:

Name Type Description Default
xy_centers Tensor

Anchor center coordinates, shape (h*w, 2).

required
gt_bboxes Tensor

Ground truth bounding boxes, shape (b, n_boxes, 4).

required
eps float

Small value for numerical stability. Defaults to 1e-9.

1e-09

Returns:

Type Description
Tensor

Boolean mask of positive anchors, shape (b, n_boxes, h*w).

Note

b: batch size, n_boxes: number of ground truth boxes, h: height, w: width. Bounding box format: [x_min, y_min, x_max, y_max].

Source code in ultralytics/utils/tal.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
    """
    Select positive anchor centers within ground truth bounding boxes.

    Args:
        xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
        gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
        eps (float, optional): Small value for numerical stability. Defaults to 1e-9.

    Returns:
        (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).

    Note:
        b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
        Bounding box format: [x_min, y_min, x_max, y_max].
    """
    n_anchors = xy_centers.shape[0]
    bs, n_boxes, _ = gt_bboxes.shape
    lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)  # left-top, right-bottom
    bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
    return bbox_deltas.amin(3).gt_(eps)

select_highest_overlaps staticmethod

select_highest_overlaps(mask_pos, overlaps, n_max_boxes)

Select anchor boxes with highest IoU when assigned to multiple ground truths.

Parameters:

Name Type Description Default
mask_pos Tensor

Positive mask, shape (b, n_max_boxes, h*w).

required
overlaps Tensor

IoU overlaps, shape (b, n_max_boxes, h*w).

required
n_max_boxes int

Maximum number of ground truth boxes.

required

Returns:

Name Type Description
target_gt_idx Tensor

Indices of assigned ground truths, shape (b, h*w).

fg_mask Tensor

Foreground mask, shape (b, h*w).

mask_pos Tensor

Updated positive mask, shape (b, n_max_boxes, h*w).

Source code in ultralytics/utils/tal.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
@staticmethod
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
    """
    Select anchor boxes with highest IoU when assigned to multiple ground truths.

    Args:
        mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
        overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
        n_max_boxes (int): Maximum number of ground truth boxes.

    Returns:
        target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
        fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
        mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
    """
    # Convert (b, n_max_boxes, h*w) -> (b, h*w)
    fg_mask = mask_pos.sum(-2)
    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)

        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)

        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
        fg_mask = mask_pos.sum(-2)
    # Find each grid serve which gt(index)
    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
    return target_gt_idx, fg_mask, mask_pos

select_topk_candidates

select_topk_candidates(metrics, largest=True, topk_mask=None)

Select the top-k candidates based on the given metrics.

Parameters:

Name Type Description Default
metrics Tensor

A tensor of shape (b, max_num_obj, hw), where b is the batch size, max_num_obj is the maximum number of objects, and hw represents the total number of anchor points.

required
largest bool

If True, select the largest values; otherwise, select the smallest values.

True
topk_mask Tensor

An optional boolean tensor of shape (b, max_num_obj, topk), where topk is the number of top candidates to consider. If not provided, the top-k values are automatically computed based on the given metrics.

None

Returns:

Type Description
Tensor

A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.

Source code in ultralytics/utils/tal.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
    """
    Select the top-k candidates based on the given metrics.

    Args:
        metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
                          max_num_obj is the maximum number of objects, and h*w represents the
                          total number of anchor points.
        largest (bool): If True, select the largest values; otherwise, select the smallest values.
        topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
                            topk is the number of top candidates to consider. If not provided,
                            the top-k values are automatically computed based on the given metrics.

    Returns:
        (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
    """
    # (b, max_num_obj, topk)
    topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
    if topk_mask is None:
        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
    # (b, max_num_obj, topk)
    topk_idxs.masked_fill_(~topk_mask, 0)

    # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
    count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
    ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
    for k in range(self.topk):
        # Expand topk_idxs for each value of k and add 1 at the specified positions
        count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
    # Filter invalid bboxes
    count_tensor.masked_fill_(count_tensor > 1, 0)

    return count_tensor.to(metrics.dtype)





ultralytics.utils.tal.RotatedTaskAlignedAssigner

RotatedTaskAlignedAssigner(
    topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-09
)

Bases: TaskAlignedAssigner

Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric.

Source code in ultralytics/utils/tal.py
30
31
32
33
34
35
36
37
38
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
    """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
    super().__init__()
    self.topk = topk
    self.num_classes = num_classes
    self.bg_idx = num_classes
    self.alpha = alpha
    self.beta = beta
    self.eps = eps

iou_calculation

iou_calculation(gt_bboxes, pd_bboxes)

Calculate IoU for rotated bounding boxes.

Source code in ultralytics/utils/tal.py
332
333
334
def iou_calculation(self, gt_bboxes, pd_bboxes):
    """Calculate IoU for rotated bounding boxes."""
    return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)

select_candidates_in_gts staticmethod

select_candidates_in_gts(xy_centers, gt_bboxes)

Select the positive anchor center in gt for rotated bounding boxes.

Parameters:

Name Type Description Default
xy_centers Tensor

Anchor center coordinates with shape (h*w, 2).

required
gt_bboxes Tensor

Ground truth bounding boxes with shape (b, n_boxes, 5).

required

Returns:

Type Description
Tensor

Boolean mask of positive anchors with shape (b, n_boxes, h*w).

Source code in ultralytics/utils/tal.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes):
    """
    Select the positive anchor center in gt for rotated bounding boxes.

    Args:
        xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
        gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).

    Returns:
        (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
    """
    # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
    corners = xywhr2xyxyxyxy(gt_bboxes)
    # (b, n_boxes, 1, 2)
    a, b, _, d = corners.split(1, dim=-2)
    ab = b - a
    ad = d - a

    # (b, n_boxes, h*w, 2)
    ap = xy_centers - a
    norm_ab = (ab * ab).sum(dim=-1)
    norm_ad = (ad * ad).sum(dim=-1)
    ap_dot_ab = (ap * ab).sum(dim=-1)
    ap_dot_ad = (ap * ad).sum(dim=-1)
    return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad)  # is_in_box





ultralytics.utils.tal.make_anchors

make_anchors(feats, strides, grid_cell_offset=0.5)

Generate anchors from features.

Source code in ultralytics/utils/tal.py
364
365
366
367
368
369
370
371
372
373
374
375
376
def make_anchors(feats, strides, grid_cell_offset=0.5):
    """Generate anchors from features."""
    anchor_points, stride_tensor = [], []
    assert feats is not None
    dtype, device = feats[0].dtype, feats[0].device
    for i, stride in enumerate(strides):
        h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
        sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)





ultralytics.utils.tal.dist2bbox

dist2bbox(distance, anchor_points, xywh=True, dim=-1)

Transform distance(ltrb) to box(xywh or xyxy).

Source code in ultralytics/utils/tal.py
379
380
381
382
383
384
385
386
387
388
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
    """Transform distance(ltrb) to box(xywh or xyxy)."""
    lt, rb = distance.chunk(2, dim)
    x1y1 = anchor_points - lt
    x2y2 = anchor_points + rb
    if xywh:
        c_xy = (x1y1 + x2y2) / 2
        wh = x2y2 - x1y1
        return torch.cat((c_xy, wh), dim)  # xywh bbox
    return torch.cat((x1y1, x2y2), dim)  # xyxy bbox





ultralytics.utils.tal.bbox2dist

bbox2dist(anchor_points, bbox, reg_max)

Transform bbox(xyxy) to dist(ltrb).

Source code in ultralytics/utils/tal.py
391
392
393
394
def bbox2dist(anchor_points, bbox, reg_max):
    """Transform bbox(xyxy) to dist(ltrb)."""
    x1y1, x2y2 = bbox.chunk(2, -1)
    return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)  # dist (lt, rb)





ultralytics.utils.tal.dist2rbox

dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1)

Decode predicted rotated bounding box coordinates from anchor points and distribution.

Parameters:

Name Type Description Default
pred_dist Tensor

Predicted rotated distance with shape (bs, h*w, 4).

required
pred_angle Tensor

Predicted angle with shape (bs, h*w, 1).

required
anchor_points Tensor

Anchor points with shape (h*w, 2).

required
dim int

Dimension along which to split. Defaults to -1.

-1

Returns:

Type Description
Tensor

Predicted rotated bounding boxes with shape (bs, h*w, 4).

Source code in ultralytics/utils/tal.py
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
    """
    Decode predicted rotated bounding box coordinates from anchor points and distribution.

    Args:
        pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
        pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
        anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
        dim (int, optional): Dimension along which to split. Defaults to -1.

    Returns:
        (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
    """
    lt, rb = pred_dist.split(2, dim=dim)
    cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
    # (bs, h*w, 1)
    xf, yf = ((rb - lt) / 2).split(1, dim=dim)
    x, y = xf * cos - yf * sin, xf * sin + yf * cos
    xy = torch.cat([x, y], dim=dim) + anchor_points
    return torch.cat([xy, lt + rb], dim=dim)





📅 Created 1 year ago ✏️ Updated 7 months ago