Saltar al contenido

Referencia para ultralytics/utils/tal.py

Nota

Este archivo está disponible en https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/utils/tal .py. Si detectas algún problema, por favor, ayuda a solucionarlo contribuyendo con una Pull Request 🛠️. ¡Gracias 🙏!



ultralytics.utils.tal.TaskAlignedAssigner

Bases: Module

Un asignador de tareas para la detección de objetos.

Esta clase asigna objetos verdaderos (gt) a los anclajes basándose en la métrica alineada con la tarea, que combina la información de clasificación y la de localización. información de clasificación y localización.

Atributos:

Nombre Tipo Descripción
topk int

El número de candidatos principales a tener en cuenta.

num_classes int

El número de clases de objetos.

alpha float

El parámetro alfa para el componente de clasificación de la métrica alineada con la tarea.

beta float

El parámetro beta para el componente de localización de la métrica alineada con la tarea.

eps float

Un valor pequeño para evitar la división por cero.

Código fuente en ultralytics/utils/tal.py
class TaskAlignedAssigner(nn.Module):
    """
    A task-aligned assigner for object detection.

    This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
    classification and localization information.

    Attributes:
        topk (int): The number of top candidates to consider.
        num_classes (int): The number of object classes.
        alpha (float): The alpha parameter for the classification component of the task-aligned metric.
        beta (float): The beta parameter for the localization component of the task-aligned metric.
        eps (float): A small value to prevent division by zero.
    """

    def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
        """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
        super().__init__()
        self.topk = topk
        self.num_classes = num_classes
        self.bg_idx = num_classes
        self.alpha = alpha
        self.beta = beta
        self.eps = eps

    @torch.no_grad()
    def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
        """
        Compute the task-aligned assignment. Reference code is available at
        https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.

        Args:
            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
            anc_points (Tensor): shape(num_total_anchors, 2)
            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
            mask_gt (Tensor): shape(bs, n_max_boxes, 1)

        Returns:
            target_labels (Tensor): shape(bs, num_total_anchors)
            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
            fg_mask (Tensor): shape(bs, num_total_anchors)
            target_gt_idx (Tensor): shape(bs, num_total_anchors)
        """
        self.bs = pd_scores.shape[0]
        self.n_max_boxes = gt_bboxes.shape[1]

        if self.n_max_boxes == 0:
            device = gt_bboxes.device
            return (
                torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
                torch.zeros_like(pd_bboxes).to(device),
                torch.zeros_like(pd_scores).to(device),
                torch.zeros_like(pd_scores[..., 0]).to(device),
                torch.zeros_like(pd_scores[..., 0]).to(device),
            )

        mask_pos, align_metric, overlaps = self.get_pos_mask(
            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
        )

        target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)

        # Assigned target
        target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)

        # Normalize
        align_metric *= mask_pos
        pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)  # b, max_num_obj
        pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)  # b, max_num_obj
        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
        target_scores = target_scores * norm_align_metric

        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx

    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
        """Get in_gts mask, (b, max_num_obj, h*w)."""
        mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
        # Get anchor_align metric, (b, max_num_obj, h*w)
        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
        # Get topk_metric mask, (b, max_num_obj, h*w)
        mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
        # Merge all mask to a final mask, (b, max_num_obj, h*w)
        mask_pos = mask_topk * mask_in_gts * mask_gt

        return mask_pos, align_metric, overlaps

    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
        """Compute alignment metric given predicted and ground truth bounding boxes."""
        na = pd_bboxes.shape[-2]
        mask_gt = mask_gt.bool()  # b, max_num_obj, h*w
        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)

        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
        # Get the scores of each grid for each gt cls
        bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]  # b, max_num_obj, h*w

        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
        overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)

        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
        return align_metric, overlaps

    def iou_calculation(self, gt_bboxes, pd_bboxes):
        """IoU calculation for horizontal bounding boxes."""
        return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)

    def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
        """
        Select the top-k candidates based on the given metrics.

        Args:
            metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
                              max_num_obj is the maximum number of objects, and h*w represents the
                              total number of anchor points.
            largest (bool): If True, select the largest values; otherwise, select the smallest values.
            topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
                                topk is the number of top candidates to consider. If not provided,
                                the top-k values are automatically computed based on the given metrics.

        Returns:
            (Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
        """

        # (b, max_num_obj, topk)
        topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
        if topk_mask is None:
            topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
        # (b, max_num_obj, topk)
        topk_idxs.masked_fill_(~topk_mask, 0)

        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
        count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
        ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
        for k in range(self.topk):
            # Expand topk_idxs for each value of k and add 1 at the specified positions
            count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
        # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
        # Filter invalid bboxes
        count_tensor.masked_fill_(count_tensor > 1, 0)

        return count_tensor.to(metrics.dtype)

    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
        """
        Compute target labels, target bounding boxes, and target scores for the positive anchor points.

        Args:
            gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
                                batch size and max_num_obj is the maximum number of objects.
            gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
            target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
                                    anchor points, with shape (b, h*w), where h*w is the total
                                    number of anchor points.
            fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
                              (foreground) anchor points.

        Returns:
            (Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
                - target_labels (Tensor): Shape (b, h*w), containing the target labels for
                                          positive anchor points.
                - target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
                                          for positive anchor points.
                - target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
                                          for positive anchor points, where num_classes is the number
                                          of object classes.
        """

        # Assigned target labels, (b, 1)
        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)

        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
        target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]

        # Assigned target scores
        target_labels.clamp_(0)

        # 10x faster than F.one_hot()
        target_scores = torch.zeros(
            (target_labels.shape[0], target_labels.shape[1], self.num_classes),
            dtype=torch.int64,
            device=target_labels.device,
        )  # (b, h*w, 80)
        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)

        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)

        return target_labels, target_bboxes, target_scores

    @staticmethod
    def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
        """
        Select the positive anchor center in gt.

        Args:
            xy_centers (Tensor): shape(h*w, 2)
            gt_bboxes (Tensor): shape(b, n_boxes, 4)

        Returns:
            (Tensor): shape(b, n_boxes, h*w)
        """
        n_anchors = xy_centers.shape[0]
        bs, n_boxes, _ = gt_bboxes.shape
        lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)  # left-top, right-bottom
        bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
        # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
        return bbox_deltas.amin(3).gt_(eps)

    @staticmethod
    def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
        """
        If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.

        Args:
            mask_pos (Tensor): shape(b, n_max_boxes, h*w)
            overlaps (Tensor): shape(b, n_max_boxes, h*w)

        Returns:
            target_gt_idx (Tensor): shape(b, h*w)
            fg_mask (Tensor): shape(b, h*w)
            mask_pos (Tensor): shape(b, n_max_boxes, h*w)
        """
        # (b, n_max_boxes, h*w) -> (b, h*w)
        fg_mask = mask_pos.sum(-2)
        if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
            mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
            max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)

            is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
            is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)

            mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
            fg_mask = mask_pos.sum(-2)
        # Find each grid serve which gt(index)
        target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
        return target_gt_idx, fg_mask, mask_pos

__init__(topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-09)

Inicializa un objeto TaskAlignedAssigner con hiperparámetros personalizables.

Código fuente en ultralytics/utils/tal.py
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
    """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
    super().__init__()
    self.topk = topk
    self.num_classes = num_classes
    self.bg_idx = num_classes
    self.alpha = alpha
    self.beta = beta
    self.eps = eps

forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)

Calcula la asignación de tareas alineadas. El código de referencia está disponible en https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.

Parámetros:

Nombre Tipo Descripción Por defecto
pd_scores Tensor

forma(bs, num_total_anclas, num_clases)

necesario
pd_bboxes Tensor

forma(bs, num_total_anclas, 4)

necesario
anc_points Tensor

forma(num_total_anclas, 2)

necesario
gt_labels Tensor

forma(bs, n_cajas_máx, 1)

necesario
gt_bboxes Tensor

forma(bs, n_cajas_máx, 4)

necesario
mask_gt Tensor

forma(bs, n_cajas_máx, 1)

necesario

Devuelve:

Nombre Tipo Descripción
target_labels Tensor

forma(bs, num_total_anclas)

target_bboxes Tensor

forma(bs, num_total_anclas, 4)

target_scores Tensor

forma(bs, num_total_anclas, num_clases)

fg_mask Tensor

forma(bs, num_total_anclas)

target_gt_idx Tensor

forma(bs, num_total_anclas)

Código fuente en ultralytics/utils/tal.py
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
    """
    Compute the task-aligned assignment. Reference code is available at
    https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.

    Args:
        pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
        pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
        anc_points (Tensor): shape(num_total_anchors, 2)
        gt_labels (Tensor): shape(bs, n_max_boxes, 1)
        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
        mask_gt (Tensor): shape(bs, n_max_boxes, 1)

    Returns:
        target_labels (Tensor): shape(bs, num_total_anchors)
        target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
        target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
        fg_mask (Tensor): shape(bs, num_total_anchors)
        target_gt_idx (Tensor): shape(bs, num_total_anchors)
    """
    self.bs = pd_scores.shape[0]
    self.n_max_boxes = gt_bboxes.shape[1]

    if self.n_max_boxes == 0:
        device = gt_bboxes.device
        return (
            torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
            torch.zeros_like(pd_bboxes).to(device),
            torch.zeros_like(pd_scores).to(device),
            torch.zeros_like(pd_scores[..., 0]).to(device),
            torch.zeros_like(pd_scores[..., 0]).to(device),
        )

    mask_pos, align_metric, overlaps = self.get_pos_mask(
        pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
    )

    target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)

    # Assigned target
    target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)

    # Normalize
    align_metric *= mask_pos
    pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)  # b, max_num_obj
    pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)  # b, max_num_obj
    norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
    target_scores = target_scores * norm_align_metric

    return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx

get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt)

Calcula la métrica de alineación dadas las cajas delimitadoras predichas y las de la verdad básica.

Código fuente en ultralytics/utils/tal.py
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
    """Compute alignment metric given predicted and ground truth bounding boxes."""
    na = pd_bboxes.shape[-2]
    mask_gt = mask_gt.bool()  # b, max_num_obj, h*w
    overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
    bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)

    ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
    ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
    ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
    # Get the scores of each grid for each gt cls
    bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]  # b, max_num_obj, h*w

    # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
    pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
    gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
    overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)

    align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
    return align_metric, overlaps

get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt)

Obtener máscara in_gts, (b, max_num_obj, h*w).

Código fuente en ultralytics/utils/tal.py
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
    """Get in_gts mask, (b, max_num_obj, h*w)."""
    mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
    # Get anchor_align metric, (b, max_num_obj, h*w)
    align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
    # Get topk_metric mask, (b, max_num_obj, h*w)
    mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
    # Merge all mask to a final mask, (b, max_num_obj, h*w)
    mask_pos = mask_topk * mask_in_gts * mask_gt

    return mask_pos, align_metric, overlaps

get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)

Calcula las etiquetas de los objetivos, los cuadros delimitadores de los objetivos y las puntuaciones de los objetivos para los puntos de anclaje positivos.

Parámetros:

Nombre Tipo Descripción Por defecto
gt_labels Tensor

Etiquetas de la verdad sobre el terreno de forma (b, max_num_obj, 1), donde b es el tamaño del lote y max_num_obj es el número máximo de objetos.

necesario
gt_bboxes Tensor

Cajas delimitadoras de la verdad sobre el terreno de forma (b, max_num_obj, 4).

necesario
target_gt_idx Tensor

Índices de los objetos reales asignados para los puntos de anclaje positivos puntos de anclaje, con forma (b, hw), donde hwes el número total total de puntos de anclaje.

necesario
fg_mask Tensor

Un booleano tensor de forma (b, h*w) que indica los puntos de anclaje positivos (primer plano) de los puntos de anclaje.

necesario

Devuelve:

Tipo Descripción
Tuple[Tensor, Tensor, Tensor]

Una tupla que contiene los siguientes tensores: - etiquetas_objetivo (Tensor): Forma (b, hw), que contiene las etiquetas objetivo para puntos de anclaje positivos. - cajas_objetivo (Tensor):Forma(b, hw, 4), que contiene las cajas delimitadoras objetivo de los puntos de anclaje positivos. - puntuaciones_objetivo (Tensor): Forma (b, h*w, num_clases), que contiene las puntuaciones objetivo para los puntos de anclaje positivos, donde num_clases es el número de clases de objetos.

Código fuente en ultralytics/utils/tal.py
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
    """
    Compute target labels, target bounding boxes, and target scores for the positive anchor points.

    Args:
        gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
                            batch size and max_num_obj is the maximum number of objects.
        gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
        target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
                                anchor points, with shape (b, h*w), where h*w is the total
                                number of anchor points.
        fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
                          (foreground) anchor points.

    Returns:
        (Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
            - target_labels (Tensor): Shape (b, h*w), containing the target labels for
                                      positive anchor points.
            - target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
                                      for positive anchor points.
            - target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
                                      for positive anchor points, where num_classes is the number
                                      of object classes.
    """

    # Assigned target labels, (b, 1)
    batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
    target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
    target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)

    # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
    target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]

    # Assigned target scores
    target_labels.clamp_(0)

    # 10x faster than F.one_hot()
    target_scores = torch.zeros(
        (target_labels.shape[0], target_labels.shape[1], self.num_classes),
        dtype=torch.int64,
        device=target_labels.device,
    )  # (b, h*w, 80)
    target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)

    fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
    target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)

    return target_labels, target_bboxes, target_scores

iou_calculation(gt_bboxes, pd_bboxes)

Cálculo del IoU para cajas delimitadoras horizontales.

Código fuente en ultralytics/utils/tal.py
def iou_calculation(self, gt_bboxes, pd_bboxes):
    """IoU calculation for horizontal bounding boxes."""
    return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)

select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-09) staticmethod

Selecciona el centro de anclaje positivo en gt.

Parámetros:

Nombre Tipo Descripción Por defecto
xy_centers Tensor

forma(h*w, 2)

necesario
gt_bboxes Tensor

forma(b, n_cajas, 4)

necesario

Devuelve:

Tipo Descripción
Tensor

forma(b, n_cajas, h*w)

Código fuente en ultralytics/utils/tal.py
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
    """
    Select the positive anchor center in gt.

    Args:
        xy_centers (Tensor): shape(h*w, 2)
        gt_bboxes (Tensor): shape(b, n_boxes, 4)

    Returns:
        (Tensor): shape(b, n_boxes, h*w)
    """
    n_anchors = xy_centers.shape[0]
    bs, n_boxes, _ = gt_bboxes.shape
    lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)  # left-top, right-bottom
    bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
    # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
    return bbox_deltas.amin(3).gt_(eps)

select_highest_overlaps(mask_pos, overlaps, n_max_boxes) staticmethod

Si un cuadro de anclaje está asignado a varios gts, se seleccionará el que tenga el IoU más alto.

Parámetros:

Nombre Tipo Descripción Por defecto
mask_pos Tensor

forma(b, n_cajas_máx, h*w)

necesario
overlaps Tensor

forma(b, n_cajas_máx, h*w)

necesario

Devuelve:

Nombre Tipo Descripción
target_gt_idx Tensor

forma(b, h*w)

fg_mask Tensor

forma(b, h*w)

mask_pos Tensor

forma(b, n_cajas_máx, h*w)

Código fuente en ultralytics/utils/tal.py
@staticmethod
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
    """
    If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.

    Args:
        mask_pos (Tensor): shape(b, n_max_boxes, h*w)
        overlaps (Tensor): shape(b, n_max_boxes, h*w)

    Returns:
        target_gt_idx (Tensor): shape(b, h*w)
        fg_mask (Tensor): shape(b, h*w)
        mask_pos (Tensor): shape(b, n_max_boxes, h*w)
    """
    # (b, n_max_boxes, h*w) -> (b, h*w)
    fg_mask = mask_pos.sum(-2)
    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)

        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)

        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
        fg_mask = mask_pos.sum(-2)
    # Find each grid serve which gt(index)
    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
    return target_gt_idx, fg_mask, mask_pos

select_topk_candidates(metrics, largest=True, topk_mask=None)

Selecciona los k mejores candidatos en función de las métricas dadas.

Parámetros:

Nombre Tipo Descripción Por defecto
metrics Tensor

Un tensor de forma (b, max_num_obj, hw), donde b es el tamaño del lote, max_num_objes el número máximo de objetos, y hw representa el número total de puntos de anclaje.

necesario
largest bool

Si es Verdadero, selecciona los valores más grandes; si no, selecciona los valores más pequeños.

True
topk_mask Tensor

Un booleano opcional tensor de forma (b, max_num_obj, topk), donde topk es el número de candidatos principales a considerar. Si no se proporciona los valores top-k se calculan automáticamente basándose en las métricas dadas.

None

Devuelve:

Tipo Descripción
Tensor

Un tensor de forma (b, max_num_obj, h*w) que contiene los candidatos top-k seleccionados.

Código fuente en ultralytics/utils/tal.py
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
    """
    Select the top-k candidates based on the given metrics.

    Args:
        metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
                          max_num_obj is the maximum number of objects, and h*w represents the
                          total number of anchor points.
        largest (bool): If True, select the largest values; otherwise, select the smallest values.
        topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
                            topk is the number of top candidates to consider. If not provided,
                            the top-k values are automatically computed based on the given metrics.

    Returns:
        (Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
    """

    # (b, max_num_obj, topk)
    topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
    if topk_mask is None:
        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
    # (b, max_num_obj, topk)
    topk_idxs.masked_fill_(~topk_mask, 0)

    # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
    count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
    ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
    for k in range(self.topk):
        # Expand topk_idxs for each value of k and add 1 at the specified positions
        count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
    # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
    # Filter invalid bboxes
    count_tensor.masked_fill_(count_tensor > 1, 0)

    return count_tensor.to(metrics.dtype)



ultralytics.utils.tal.RotatedTaskAlignedAssigner

Bases: TaskAlignedAssigner

Código fuente en ultralytics/utils/tal.py
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
    def iou_calculation(self, gt_bboxes, pd_bboxes):
        """IoU calculation for rotated bounding boxes."""
        return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)

    @staticmethod
    def select_candidates_in_gts(xy_centers, gt_bboxes):
        """
        Select the positive anchor center in gt for rotated bounding boxes.

        Args:
            xy_centers (Tensor): shape(h*w, 2)
            gt_bboxes (Tensor): shape(b, n_boxes, 5)

        Returns:
            (Tensor): shape(b, n_boxes, h*w)
        """
        # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
        corners = xywhr2xyxyxyxy(gt_bboxes)
        # (b, n_boxes, 1, 2)
        a, b, _, d = corners.split(1, dim=-2)
        ab = b - a
        ad = d - a

        # (b, n_boxes, h*w, 2)
        ap = xy_centers - a
        norm_ab = (ab * ab).sum(dim=-1)
        norm_ad = (ad * ad).sum(dim=-1)
        ap_dot_ab = (ap * ab).sum(dim=-1)
        ap_dot_ad = (ap * ad).sum(dim=-1)
        return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad)  # is_in_box

iou_calculation(gt_bboxes, pd_bboxes)

Cálculo del IoU para cajas delimitadoras rotadas.

Código fuente en ultralytics/utils/tal.py
def iou_calculation(self, gt_bboxes, pd_bboxes):
    """IoU calculation for rotated bounding boxes."""
    return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)

select_candidates_in_gts(xy_centers, gt_bboxes) staticmethod

Selecciona el centro de anclaje positivo en gt para las cajas delimitadoras rotadas.

Parámetros:

Nombre Tipo Descripción Por defecto
xy_centers Tensor

forma(h*w, 2)

necesario
gt_bboxes Tensor

forma(b, n_cajas, 5)

necesario

Devuelve:

Tipo Descripción
Tensor

forma(b, n_cajas, h*w)

Código fuente en ultralytics/utils/tal.py
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes):
    """
    Select the positive anchor center in gt for rotated bounding boxes.

    Args:
        xy_centers (Tensor): shape(h*w, 2)
        gt_bboxes (Tensor): shape(b, n_boxes, 5)

    Returns:
        (Tensor): shape(b, n_boxes, h*w)
    """
    # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
    corners = xywhr2xyxyxyxy(gt_bboxes)
    # (b, n_boxes, 1, 2)
    a, b, _, d = corners.split(1, dim=-2)
    ab = b - a
    ad = d - a

    # (b, n_boxes, h*w, 2)
    ap = xy_centers - a
    norm_ab = (ab * ab).sum(dim=-1)
    norm_ad = (ad * ad).sum(dim=-1)
    ap_dot_ab = (ap * ab).sum(dim=-1)
    ap_dot_ad = (ap * ad).sum(dim=-1)
    return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad)  # is_in_box



ultralytics.utils.tal.make_anchors(feats, strides, grid_cell_offset=0.5)

Generar anclajes a partir de características.

Código fuente en ultralytics/utils/tal.py
def make_anchors(feats, strides, grid_cell_offset=0.5):
    """Generate anchors from features."""
    anchor_points, stride_tensor = [], []
    assert feats is not None
    dtype, device = feats[0].dtype, feats[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = feats[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
        sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)



ultralytics.utils.tal.dist2bbox(distance, anchor_points, xywh=True, dim=-1)

Transforma distancia(ltrb) en caja(xywh o xyxy).

Código fuente en ultralytics/utils/tal.py
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
    """Transform distance(ltrb) to box(xywh or xyxy)."""
    lt, rb = distance.chunk(2, dim)
    x1y1 = anchor_points - lt
    x2y2 = anchor_points + rb
    if xywh:
        c_xy = (x1y1 + x2y2) / 2
        wh = x2y2 - x1y1
        return torch.cat((c_xy, wh), dim)  # xywh bbox
    return torch.cat((x1y1, x2y2), dim)  # xyxy bbox



ultralytics.utils.tal.bbox2dist(anchor_points, bbox, reg_max)

Transforma bbox(xyxy) en dist(ltrb).

Código fuente en ultralytics/utils/tal.py
def bbox2dist(anchor_points, bbox, reg_max):
    """Transform bbox(xyxy) to dist(ltrb)."""
    x1y1, x2y2 = bbox.chunk(2, -1)
    return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)  # dist (lt, rb)



ultralytics.utils.tal.dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1)

Decodifica las coordenadas del cuadro delimitador del objeto predicho a partir de los puntos de anclaje y la distribución.

Parámetros:

Nombre Tipo Descripción Por defecto
pred_dist Tensor

Distancia rotada prevista, (bs, h*w, 4).

necesario
pred_angle Tensor

Ángulo previsto, (bs, h*w, 1).

necesario
anchor_points Tensor

Puntos de anclaje, (h*w, 2).

necesario

Devuelve: (torch.Tensor): Cajas delimitadoras rotadas predichas, (bs, h*w, 4).

Código fuente en ultralytics/utils/tal.py
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
    """
    Decode predicted object bounding box coordinates from anchor points and distribution.

    Args:
        pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
        pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
        anchor_points (torch.Tensor): Anchor points, (h*w, 2).
    Returns:
        (torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
    """
    lt, rb = pred_dist.split(2, dim=dim)
    cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
    # (bs, h*w, 1)
    xf, yf = ((rb - lt) / 2).split(1, dim=dim)
    x, y = xf * cos - yf * sin, xf * sin + yf * cos
    xy = torch.cat([x, y], dim=dim) + anchor_points
    return torch.cat([xy, lt + rb], dim=dim)





Creado 2023-11-12, Actualizado 2024-05-08
Autores: Burhan-Q (1), glenn-jocher (4), Laughing-q (1)