Skip to content

Reference for ultralytics/nn/modules/transformer.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/modules/transformer.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.nn.modules.transformer.TransformerEncoderLayer

def __init__(
    self,
    c1: int,
    cm: int = 2048,
    num_heads: int = 8,
    dropout: float = 0.0,
    act: nn.Module = nn.GELU(),
    normalize_before: bool = False,
)

Bases: nn.Module

A single layer of the transformer encoder.

This class implements a standard transformer encoder layer with multi-head attention and feedforward network, supporting both pre-normalization and post-normalization configurations.

Args

NameTypeDescriptionDefault
c1intInput dimension.required
cmintHidden dimension in the feedforward network.2048
num_headsintNumber of attention heads.8
dropoutfloatDropout probability.0.0
actnn.ModuleActivation function.nn.GELU()
normalize_beforeboolWhether to apply normalization before attention and feedforward.False

Attributes

NameTypeDescription
mann.MultiheadAttentionMulti-head attention module.
fc1nn.LinearFirst linear layer in the feedforward network.
fc2nn.LinearSecond linear layer in the feedforward network.
norm1nn.LayerNormLayer normalization after attention.
norm2nn.LayerNormLayer normalization after feedforward network.
dropoutnn.DropoutDropout layer for the feedforward network.
dropout1nn.DropoutDropout layer after attention.
dropout2nn.DropoutDropout layer after feedforward network.
actnn.ModuleActivation function.
normalize_beforeboolWhether to apply normalization before attention and feedforward.

Methods

NameDescription
forwardForward propagate the input through the encoder module.
forward_postPerform forward pass with post-normalization.
forward_prePerform forward pass with pre-normalization.
with_pos_embedAdd position embeddings to the tensor if provided.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
class TransformerEncoderLayer(nn.Module):
    """A single layer of the transformer encoder.

    This class implements a standard transformer encoder layer with multi-head attention and feedforward network,
    supporting both pre-normalization and post-normalization configurations.

    Attributes:
        ma (nn.MultiheadAttention): Multi-head attention module.
        fc1 (nn.Linear): First linear layer in the feedforward network.
        fc2 (nn.Linear): Second linear layer in the feedforward network.
        norm1 (nn.LayerNorm): Layer normalization after attention.
        norm2 (nn.LayerNorm): Layer normalization after feedforward network.
        dropout (nn.Dropout): Dropout layer for the feedforward network.
        dropout1 (nn.Dropout): Dropout layer after attention.
        dropout2 (nn.Dropout): Dropout layer after feedforward network.
        act (nn.Module): Activation function.
        normalize_before (bool): Whether to apply normalization before attention and feedforward.
    """

    def __init__(
        self,
        c1: int,
        cm: int = 2048,
        num_heads: int = 8,
        dropout: float = 0.0,
        act: nn.Module = nn.GELU(),
        normalize_before: bool = False,
    ):
        """Initialize the TransformerEncoderLayer with specified parameters.

        Args:
            c1 (int): Input dimension.
            cm (int): Hidden dimension in the feedforward network.
            num_heads (int): Number of attention heads.
            dropout (float): Dropout probability.
            act (nn.Module): Activation function.
            normalize_before (bool): Whether to apply normalization before attention and feedforward.
        """
        super().__init__()
        from ...utils.torch_utils import TORCH_1_9

        if not TORCH_1_9:
            raise ModuleNotFoundError(
                "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)."
            )
        self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
        # Implementation of Feedforward model
        self.fc1 = nn.Linear(c1, cm)
        self.fc2 = nn.Linear(cm, c1)

        self.norm1 = nn.LayerNorm(c1)
        self.norm2 = nn.LayerNorm(c1)
        self.dropout = nn.Dropout(dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.act = act
        self.normalize_before = normalize_before


method ultralytics.nn.modules.transformer.TransformerEncoderLayer.forward

def forward(
    self,
    src: torch.Tensor,
    src_mask: torch.Tensor | None = None,
    src_key_padding_mask: torch.Tensor | None = None,
    pos: torch.Tensor | None = None,
) -> torch.Tensor

Forward propagate the input through the encoder module.

Args

NameTypeDescriptionDefault
srctorch.TensorInput tensor.required
src_masktorch.Tensor, optionalMask for the src sequence.None
src_key_padding_masktorch.Tensor, optionalMask for the src keys per batch.None
postorch.Tensor, optionalPositional encoding.None

Returns

TypeDescription
torch.TensorOutput tensor after transformer encoder layer.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward(
    self,
    src: torch.Tensor,
    src_mask: torch.Tensor | None = None,
    src_key_padding_mask: torch.Tensor | None = None,
    pos: torch.Tensor | None = None,
) -> torch.Tensor:
    """Forward propagate the input through the encoder module.

    Args:
        src (torch.Tensor): Input tensor.
        src_mask (torch.Tensor, optional): Mask for the src sequence.
        src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
        pos (torch.Tensor, optional): Positional encoding.

    Returns:
        (torch.Tensor): Output tensor after transformer encoder layer.
    """
    if self.normalize_before:
        return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
    return self.forward_post(src, src_mask, src_key_padding_mask, pos)


method ultralytics.nn.modules.transformer.TransformerEncoderLayer.forward_post

def forward_post(
    self,
    src: torch.Tensor,
    src_mask: torch.Tensor | None = None,
    src_key_padding_mask: torch.Tensor | None = None,
    pos: torch.Tensor | None = None,
) -> torch.Tensor

Perform forward pass with post-normalization.

Args

NameTypeDescriptionDefault
srctorch.TensorInput tensor.required
src_masktorch.Tensor, optionalMask for the src sequence.None
src_key_padding_masktorch.Tensor, optionalMask for the src keys per batch.None
postorch.Tensor, optionalPositional encoding.None

Returns

TypeDescription
torch.TensorOutput tensor after attention and feedforward.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward_post(
    self,
    src: torch.Tensor,
    src_mask: torch.Tensor | None = None,
    src_key_padding_mask: torch.Tensor | None = None,
    pos: torch.Tensor | None = None,
) -> torch.Tensor:
    """Perform forward pass with post-normalization.

    Args:
        src (torch.Tensor): Input tensor.
        src_mask (torch.Tensor, optional): Mask for the src sequence.
        src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
        pos (torch.Tensor, optional): Positional encoding.

    Returns:
        (torch.Tensor): Output tensor after attention and feedforward.
    """
    q = k = self.with_pos_embed(src, pos)
    src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
    src = src + self.dropout1(src2)
    src = self.norm1(src)
    src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
    src = src + self.dropout2(src2)
    return self.norm2(src)


method ultralytics.nn.modules.transformer.TransformerEncoderLayer.forward_pre

def forward_pre(
    self,
    src: torch.Tensor,
    src_mask: torch.Tensor | None = None,
    src_key_padding_mask: torch.Tensor | None = None,
    pos: torch.Tensor | None = None,
) -> torch.Tensor

Perform forward pass with pre-normalization.

Args

NameTypeDescriptionDefault
srctorch.TensorInput tensor.required
src_masktorch.Tensor, optionalMask for the src sequence.None
src_key_padding_masktorch.Tensor, optionalMask for the src keys per batch.None
postorch.Tensor, optionalPositional encoding.None

Returns

TypeDescription
torch.TensorOutput tensor after attention and feedforward.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward_pre(
    self,
    src: torch.Tensor,
    src_mask: torch.Tensor | None = None,
    src_key_padding_mask: torch.Tensor | None = None,
    pos: torch.Tensor | None = None,
) -> torch.Tensor:
    """Perform forward pass with pre-normalization.

    Args:
        src (torch.Tensor): Input tensor.
        src_mask (torch.Tensor, optional): Mask for the src sequence.
        src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
        pos (torch.Tensor, optional): Positional encoding.

    Returns:
        (torch.Tensor): Output tensor after attention and feedforward.
    """
    src2 = self.norm1(src)
    q = k = self.with_pos_embed(src2, pos)
    src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
    src = src + self.dropout1(src2)
    src2 = self.norm2(src)
    src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
    return src + self.dropout2(src2)


method ultralytics.nn.modules.transformer.TransformerEncoderLayer.with_pos_embed

def with_pos_embed(tensor: torch.Tensor, pos: torch.Tensor | None = None) -> torch.Tensor

Add position embeddings to the tensor if provided.

Args

NameTypeDescriptionDefault
tensortorch.Tensorrequired
postorch.Tensor | NoneNone
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
@staticmethod
def with_pos_embed(tensor: torch.Tensor, pos: torch.Tensor | None = None) -> torch.Tensor:
    """Add position embeddings to the tensor if provided."""
    return tensor if pos is None else tensor + pos





class ultralytics.nn.modules.transformer.AIFI

def __init__(
    self,
    c1: int,
    cm: int = 2048,
    num_heads: int = 8,
    dropout: float = 0,
    act: nn.Module = nn.GELU(),
    normalize_before: bool = False,
)

Bases: TransformerEncoderLayer

AIFI transformer layer for 2D data with positional embeddings.

This class extends TransformerEncoderLayer to work with 2D feature maps by adding 2D sine-cosine positional embeddings and handling the spatial dimensions appropriately.

Args

NameTypeDescriptionDefault
c1intInput dimension.required
cmintHidden dimension in the feedforward network.2048
num_headsintNumber of attention heads.8
dropoutfloatDropout probability.0
actnn.ModuleActivation function.nn.GELU()
normalize_beforeboolWhether to apply normalization before attention and feedforward.False

Methods

NameDescription
build_2d_sincos_position_embeddingBuild 2D sine-cosine position embedding.
forwardForward pass for the AIFI transformer layer.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
class AIFI(TransformerEncoderLayer):
    """AIFI transformer layer for 2D data with positional embeddings.

    This class extends TransformerEncoderLayer to work with 2D feature maps by adding 2D sine-cosine positional
    embeddings and handling the spatial dimensions appropriately.
    """

    def __init__(
        self,
        c1: int,
        cm: int = 2048,
        num_heads: int = 8,
        dropout: float = 0,
        act: nn.Module = nn.GELU(),
        normalize_before: bool = False,
    ):
        """Initialize the AIFI instance with specified parameters.

        Args:
            c1 (int): Input dimension.
            cm (int): Hidden dimension in the feedforward network.
            num_heads (int): Number of attention heads.
            dropout (float): Dropout probability.
            act (nn.Module): Activation function.
            normalize_before (bool): Whether to apply normalization before attention and feedforward.
        """
        super().__init__(c1, cm, num_heads, dropout, act, normalize_before)


method ultralytics.nn.modules.transformer.AIFI.build_2d_sincos_position_embedding

def build_2d_sincos_position_embedding(
    w: int, h: int, embed_dim: int = 256, temperature: float = 10000.0
) -> torch.Tensor

Build 2D sine-cosine position embedding.

Args

NameTypeDescriptionDefault
wintWidth of the feature map.required
hintHeight of the feature map.required
embed_dimintEmbedding dimension.256
temperaturefloatTemperature for the sine/cosine functions.10000.0

Returns

TypeDescription
torch.TensorPosition embedding with shape [1, embed_dim, h*w].
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
@staticmethod
def build_2d_sincos_position_embedding(
    w: int, h: int, embed_dim: int = 256, temperature: float = 10000.0
) -> torch.Tensor:
    """Build 2D sine-cosine position embedding.

    Args:
        w (int): Width of the feature map.
        h (int): Height of the feature map.
        embed_dim (int): Embedding dimension.
        temperature (float): Temperature for the sine/cosine functions.

    Returns:
        (torch.Tensor): Position embedding with shape [1, embed_dim, h*w].
    """
    assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
    grid_w = torch.arange(w, dtype=torch.float32)
    grid_h = torch.arange(h, dtype=torch.float32)
    grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") if TORCH_1_11 else torch.meshgrid(grid_w, grid_h)
    pos_dim = embed_dim // 4
    omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
    omega = 1.0 / (temperature**omega)

    out_w = grid_w.flatten()[..., None] @ omega[None]
    out_h = grid_h.flatten()[..., None] @ omega[None]

    return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]


method ultralytics.nn.modules.transformer.AIFI.forward

def forward(self, x: torch.Tensor) -> torch.Tensor

Forward pass for the AIFI transformer layer.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor with shape [B, C, H, W].required

Returns

TypeDescription
torch.TensorOutput tensor with shape [B, C, H, W].
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass for the AIFI transformer layer.

    Args:
        x (torch.Tensor): Input tensor with shape [B, C, H, W].

    Returns:
        (torch.Tensor): Output tensor with shape [B, C, H, W].
    """
    c, h, w = x.shape[1:]
    pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
    # Flatten [B, C, H, W] to [B, HxW, C]
    x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
    return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()





class ultralytics.nn.modules.transformer.TransformerLayer

TransformerLayer(self, c: int, num_heads: int)

Bases: nn.Module

Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance).

Args

NameTypeDescriptionDefault
cintInput and output channel dimension.required
num_headsintNumber of attention heads.required

Methods

NameDescription
forwardApply a transformer block to the input x and return the output.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
class TransformerLayer(nn.Module):
    """Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""

    def __init__(self, c: int, num_heads: int):
        """Initialize a self-attention mechanism using linear transformations and multi-head attention.

        Args:
            c (int): Input and output channel dimension.
            num_heads (int): Number of attention heads.
        """
        super().__init__()
        self.q = nn.Linear(c, c, bias=False)
        self.k = nn.Linear(c, c, bias=False)
        self.v = nn.Linear(c, c, bias=False)
        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
        self.fc1 = nn.Linear(c, c, bias=False)
        self.fc2 = nn.Linear(c, c, bias=False)


method ultralytics.nn.modules.transformer.TransformerLayer.forward

def forward(self, x: torch.Tensor) -> torch.Tensor

Apply a transformer block to the input x and return the output.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor after transformer layer.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply a transformer block to the input x and return the output.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor after transformer layer.
    """
    x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
    return self.fc2(self.fc1(x)) + x





class ultralytics.nn.modules.transformer.TransformerBlock

TransformerBlock(self, c1: int, c2: int, num_heads: int, num_layers: int)

Bases: nn.Module

Vision Transformer block based on https://arxiv.org/abs/2010.11929.

This class implements a complete transformer block with optional convolution layer for channel adjustment, learnable position embedding, and multiple transformer layers.

Args

NameTypeDescriptionDefault
c1intInput channel dimension.required
c2intOutput channel dimension.required
num_headsintNumber of attention heads.required
num_layersintNumber of transformer layers.required

Attributes

NameTypeDescription
convConv, optionalConvolution layer if input and output channels differ.
linearnn.LinearLearnable position embedding.
trnn.SequentialSequential container of transformer layers.
c2intOutput channel dimension.

Methods

NameDescription
forwardForward propagate the input through the transformer block.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
class TransformerBlock(nn.Module):
    """Vision Transformer block based on https://arxiv.org/abs/2010.11929.

    This class implements a complete transformer block with optional convolution layer for channel adjustment, learnable
    position embedding, and multiple transformer layers.

    Attributes:
        conv (Conv, optional): Convolution layer if input and output channels differ.
        linear (nn.Linear): Learnable position embedding.
        tr (nn.Sequential): Sequential container of transformer layers.
        c2 (int): Output channel dimension.
    """

    def __init__(self, c1: int, c2: int, num_heads: int, num_layers: int):
        """Initialize a Transformer module with position embedding and specified number of heads and layers.

        Args:
            c1 (int): Input channel dimension.
            c2 (int): Output channel dimension.
            num_heads (int): Number of attention heads.
            num_layers (int): Number of transformer layers.
        """
        super().__init__()
        self.conv = None
        if c1 != c2:
            self.conv = Conv(c1, c2)
        self.linear = nn.Linear(c2, c2)  # learnable position embedding
        self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
        self.c2 = c2


method ultralytics.nn.modules.transformer.TransformerBlock.forward

def forward(self, x: torch.Tensor) -> torch.Tensor

Forward propagate the input through the transformer block.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor with shape [b, c1, w, h].required

Returns

TypeDescription
torch.TensorOutput tensor with shape [b, c2, w, h].
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward propagate the input through the transformer block.

    Args:
        x (torch.Tensor): Input tensor with shape [b, c1, w, h].

    Returns:
        (torch.Tensor): Output tensor with shape [b, c2, w, h].
    """
    if self.conv is not None:
        x = self.conv(x)
    b, _, w, h = x.shape
    p = x.flatten(2).permute(2, 0, 1)
    return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)





class ultralytics.nn.modules.transformer.MLPBlock

MLPBlock(self, embedding_dim: int, mlp_dim: int, act = nn.GELU)

Bases: nn.Module

A single block of a multi-layer perceptron.

Args

NameTypeDescriptionDefault
embedding_dimintInput and output dimension.required
mlp_dimintHidden dimension.required
actnn.ModuleActivation function.nn.GELU

Methods

NameDescription
forwardForward pass for the MLPBlock.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
class MLPBlock(nn.Module):
    """A single block of a multi-layer perceptron."""

    def __init__(self, embedding_dim: int, mlp_dim: int, act=nn.GELU):
        """Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.

        Args:
            embedding_dim (int): Input and output dimension.
            mlp_dim (int): Hidden dimension.
            act (nn.Module): Activation function.
        """
        super().__init__()
        self.lin1 = nn.Linear(embedding_dim, mlp_dim)
        self.lin2 = nn.Linear(mlp_dim, embedding_dim)
        self.act = act()


method ultralytics.nn.modules.transformer.MLPBlock.forward

def forward(self, x: torch.Tensor) -> torch.Tensor

Forward pass for the MLPBlock.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor after MLP block.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass for the MLPBlock.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor after MLP block.
    """
    return self.lin2(self.act(self.lin1(x)))





class ultralytics.nn.modules.transformer.MLP

MLP(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act = nn.ReLU, sigmoid: bool = False)

Bases: nn.Module

A simple multi-layer perceptron (also called FFN).

This class implements a configurable MLP with multiple linear layers, activation functions, and optional sigmoid output activation.

Args

NameTypeDescriptionDefault
input_dimintInput dimension.required
hidden_dimintHidden dimension.required
output_dimintOutput dimension.required
num_layersintNumber of layers.required
actnn.ModuleActivation function.nn.ReLU
sigmoidboolWhether to apply sigmoid to the output.False

Attributes

NameTypeDescription
num_layersintNumber of layers in the MLP.
layersnn.ModuleListList of linear layers.
sigmoidboolWhether to apply sigmoid to the output.
actnn.ModuleActivation function.

Methods

NameDescription
forwardForward pass for the entire MLP.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
class MLP(nn.Module):
    """A simple multi-layer perceptron (also called FFN).

    This class implements a configurable MLP with multiple linear layers, activation functions, and optional sigmoid
    output activation.

    Attributes:
        num_layers (int): Number of layers in the MLP.
        layers (nn.ModuleList): List of linear layers.
        sigmoid (bool): Whether to apply sigmoid to the output.
        act (nn.Module): Activation function.
    """

    def __init__(
        self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act=nn.ReLU, sigmoid: bool = False
    ):
        """Initialize the MLP with specified input, hidden, output dimensions and number of layers.

        Args:
            input_dim (int): Input dimension.
            hidden_dim (int): Hidden dimension.
            output_dim (int): Output dimension.
            num_layers (int): Number of layers.
            act (nn.Module): Activation function.
            sigmoid (bool): Whether to apply sigmoid to the output.
        """
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim, *h], [*h, output_dim]))
        self.sigmoid = sigmoid
        self.act = act()


method ultralytics.nn.modules.transformer.MLP.forward

def forward(self, x: torch.Tensor) -> torch.Tensor

Forward pass for the entire MLP.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor after MLP.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass for the entire MLP.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor after MLP.
    """
    for i, layer in enumerate(self.layers):
        x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x)
    return x.sigmoid() if getattr(self, "sigmoid", False) else x





class ultralytics.nn.modules.transformer.LayerNorm2d

LayerNorm2d(self, num_channels: int, eps: float = 1e-6)

Bases: nn.Module

2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.

This class implements layer normalization for 2D feature maps, normalizing across the channel dimension while preserving spatial dimensions.

Args

NameTypeDescriptionDefault
num_channelsintNumber of channels in the input.required
epsfloatSmall constant for numerical stability.1e-6

Attributes

NameTypeDescription
weightnn.ParameterLearnable scale parameter.
biasnn.ParameterLearnable bias parameter.
epsfloatSmall constant for numerical stability.
References
https//github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py
https//github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py

Methods

NameDescription
forwardPerform forward pass for 2D layer normalization.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
class LayerNorm2d(nn.Module):
    """2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.

    This class implements layer normalization for 2D feature maps, normalizing across the channel dimension while
    preserving spatial dimensions.

    Attributes:
        weight (nn.Parameter): Learnable scale parameter.
        bias (nn.Parameter): Learnable bias parameter.
        eps (float): Small constant for numerical stability.

    References:
        https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py
        https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
    """

    def __init__(self, num_channels: int, eps: float = 1e-6):
        """Initialize LayerNorm2d with the given parameters.

        Args:
            num_channels (int): Number of channels in the input.
            eps (float): Small constant for numerical stability.
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps


method ultralytics.nn.modules.transformer.LayerNorm2d.forward

def forward(self, x: torch.Tensor) -> torch.Tensor

Perform forward pass for 2D layer normalization.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorNormalized output tensor.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Perform forward pass for 2D layer normalization.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Normalized output tensor.
    """
    u = x.mean(1, keepdim=True)
    s = (x - u).pow(2).mean(1, keepdim=True)
    x = (x - u) / torch.sqrt(s + self.eps)
    return self.weight[:, None, None] * x + self.bias[:, None, None]





class ultralytics.nn.modules.transformer.MSDeformAttn

MSDeformAttn(self, d_model: int = 256, n_levels: int = 4, n_heads: int = 8, n_points: int = 4)

Bases: nn.Module

Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.

This module implements multiscale deformable attention that can attend to features at multiple scales with learnable sampling locations and attention weights.

Args

NameTypeDescriptionDefault
d_modelintModel dimension.256
n_levelsintNumber of feature levels.4
n_headsintNumber of attention heads.8
n_pointsintNumber of sampling points per attention head per feature level.4

Attributes

NameTypeDescription
im2col_stepintStep size for im2col operations.
d_modelintModel dimension.
n_levelsintNumber of feature levels.
n_headsintNumber of attention heads.
n_pointsintNumber of sampling points per attention head per feature level.
sampling_offsetsnn.LinearLinear layer for generating sampling offsets.
attention_weightsnn.LinearLinear layer for generating attention weights.
value_projnn.LinearLinear layer for projecting values.
output_projnn.LinearLinear layer for projecting output.
References
https//github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py

Methods

NameDescription
_reset_parametersReset module parameters.
forwardPerform forward pass for multiscale deformable attention.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
class MSDeformAttn(nn.Module):
    """Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.

    This module implements multiscale deformable attention that can attend to features at multiple scales with learnable
    sampling locations and attention weights.

    Attributes:
        im2col_step (int): Step size for im2col operations.
        d_model (int): Model dimension.
        n_levels (int): Number of feature levels.
        n_heads (int): Number of attention heads.
        n_points (int): Number of sampling points per attention head per feature level.
        sampling_offsets (nn.Linear): Linear layer for generating sampling offsets.
        attention_weights (nn.Linear): Linear layer for generating attention weights.
        value_proj (nn.Linear): Linear layer for projecting values.
        output_proj (nn.Linear): Linear layer for projecting output.

    References:
        https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
    """

    def __init__(self, d_model: int = 256, n_levels: int = 4, n_heads: int = 8, n_points: int = 4):
        """Initialize MSDeformAttn with the given parameters.

        Args:
            d_model (int): Model dimension.
            n_levels (int): Number of feature levels.
            n_heads (int): Number of attention heads.
            n_points (int): Number of sampling points per attention head per feature level.
        """
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
        _d_per_head = d_model // n_heads
        # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation
        assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`"

        self.im2col_step = 64

        self.d_model = d_model
        self.n_levels = n_levels
        self.n_heads = n_heads
        self.n_points = n_points

        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
        self.value_proj = nn.Linear(d_model, d_model)
        self.output_proj = nn.Linear(d_model, d_model)

        self._reset_parameters()


method ultralytics.nn.modules.transformer.MSDeformAttn._reset_parameters

def _reset_parameters(self)

Reset module parameters.

Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def _reset_parameters(self):
    """Reset module parameters."""
    constant_(self.sampling_offsets.weight.data, 0.0)
    thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
    grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
    grid_init = (
        (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
        .view(self.n_heads, 1, 1, 2)
        .repeat(1, self.n_levels, self.n_points, 1)
    )
    for i in range(self.n_points):
        grid_init[:, :, i, :] *= i + 1
    with torch.no_grad():
        self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
    constant_(self.attention_weights.weight.data, 0.0)
    constant_(self.attention_weights.bias.data, 0.0)
    xavier_uniform_(self.value_proj.weight.data)
    constant_(self.value_proj.bias.data, 0.0)
    xavier_uniform_(self.output_proj.weight.data)
    constant_(self.output_proj.bias.data, 0.0)


method ultralytics.nn.modules.transformer.MSDeformAttn.forward

def forward(
    self,
    query: torch.Tensor,
    refer_bbox: torch.Tensor,
    value: torch.Tensor,
    value_shapes: list,
    value_mask: torch.Tensor | None = None,
) -> torch.Tensor

Perform forward pass for multiscale deformable attention.

Args

NameTypeDescriptionDefault
querytorch.TensorQuery tensor with shape [bs, query_length, C].required
refer_bboxtorch.TensorReference bounding boxes with shape [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area.required
valuetorch.TensorValue tensor with shape [bs, value_length, C].required
value_shapeslistList with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].required
value_masktorch.Tensor, optionalMask tensor with shape [bs, value_length], True for non-padding elements, False for padding elements.None

Returns

TypeDescription
torch.TensorOutput tensor with shape [bs, Length_{query}, C].
References
https//github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward(
    self,
    query: torch.Tensor,
    refer_bbox: torch.Tensor,
    value: torch.Tensor,
    value_shapes: list,
    value_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Perform forward pass for multiscale deformable attention.

    Args:
        query (torch.Tensor): Query tensor with shape [bs, query_length, C].
        refer_bbox (torch.Tensor): Reference bounding boxes with shape [bs, query_length, n_levels, 2], range in [0,
            1], top-left (0,0), bottom-right (1, 1), including padding area.
        value (torch.Tensor): Value tensor with shape [bs, value_length, C].
        value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].
        value_mask (torch.Tensor, optional): Mask tensor with shape [bs, value_length], True for non-padding
            elements, False for padding elements.

    Returns:
        (torch.Tensor): Output tensor with shape [bs, Length_{query}, C].

    References:
        https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
    """
    bs, len_q = query.shape[:2]
    len_v = value.shape[1]
    assert sum(s[0] * s[1] for s in value_shapes) == len_v

    value = self.value_proj(value)
    if value_mask is not None:
        value = value.masked_fill(value_mask[..., None], float(0))
    value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)
    sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2)
    attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
    attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
    # N, Len_q, n_heads, n_levels, n_points, 2
    num_points = refer_bbox.shape[-1]
    if num_points == 2:
        offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
        add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
        sampling_locations = refer_bbox[:, :, None, :, None, :] + add
    elif num_points == 4:
        add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
        sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
    else:
        raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.")
    output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
    return self.output_proj(output)





class ultralytics.nn.modules.transformer.DeformableTransformerDecoderLayer

def __init__(
    self,
    d_model: int = 256,
    n_heads: int = 8,
    d_ffn: int = 1024,
    dropout: float = 0.0,
    act: nn.Module = nn.ReLU(),
    n_levels: int = 4,
    n_points: int = 4,
)

Bases: nn.Module

Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.

This class implements a single decoder layer with self-attention, cross-attention using multiscale deformable attention, and a feedforward network.

Args

NameTypeDescriptionDefault
d_modelintModel dimension.256
n_headsintNumber of attention heads.8
d_ffnintDimension of the feedforward network.1024
dropoutfloatDropout probability.0.0
actnn.ModuleActivation function.nn.ReLU()
n_levelsintNumber of feature levels.4
n_pointsintNumber of sampling points.4

Attributes

NameTypeDescription
self_attnnn.MultiheadAttentionSelf-attention module.
dropout1nn.DropoutDropout after self-attention.
norm1nn.LayerNormLayer normalization after self-attention.
cross_attnMSDeformAttnCross-attention module.
dropout2nn.DropoutDropout after cross-attention.
norm2nn.LayerNormLayer normalization after cross-attention.
linear1nn.LinearFirst linear layer in the feedforward network.
actnn.ModuleActivation function.
dropout3nn.DropoutDropout in the feedforward network.
linear2nn.LinearSecond linear layer in the feedforward network.
dropout4nn.DropoutDropout after the feedforward network.
norm3nn.LayerNormLayer normalization after the feedforward network.
References
https//github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
https//github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py

Methods

NameDescription
forwardPerform the forward pass through the entire decoder layer.
forward_ffnPerform forward pass through the Feed-Forward Network part of the layer.
with_pos_embedAdd positional embeddings to the input tensor, if provided.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
class DeformableTransformerDecoderLayer(nn.Module):
    """Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.

    This class implements a single decoder layer with self-attention, cross-attention using multiscale deformable
    attention, and a feedforward network.

    Attributes:
        self_attn (nn.MultiheadAttention): Self-attention module.
        dropout1 (nn.Dropout): Dropout after self-attention.
        norm1 (nn.LayerNorm): Layer normalization after self-attention.
        cross_attn (MSDeformAttn): Cross-attention module.
        dropout2 (nn.Dropout): Dropout after cross-attention.
        norm2 (nn.LayerNorm): Layer normalization after cross-attention.
        linear1 (nn.Linear): First linear layer in the feedforward network.
        act (nn.Module): Activation function.
        dropout3 (nn.Dropout): Dropout in the feedforward network.
        linear2 (nn.Linear): Second linear layer in the feedforward network.
        dropout4 (nn.Dropout): Dropout after the feedforward network.
        norm3 (nn.LayerNorm): Layer normalization after the feedforward network.

    References:
        https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
        https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
    """

    def __init__(
        self,
        d_model: int = 256,
        n_heads: int = 8,
        d_ffn: int = 1024,
        dropout: float = 0.0,
        act: nn.Module = nn.ReLU(),
        n_levels: int = 4,
        n_points: int = 4,
    ):
        """Initialize the DeformableTransformerDecoderLayer with the given parameters.

        Args:
            d_model (int): Model dimension.
            n_heads (int): Number of attention heads.
            d_ffn (int): Dimension of the feedforward network.
            dropout (float): Dropout probability.
            act (nn.Module): Activation function.
            n_levels (int): Number of feature levels.
            n_points (int): Number of sampling points.
        """
        super().__init__()

        # Self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # Cross attention
        self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # FFN
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.act = act
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)


method ultralytics.nn.modules.transformer.DeformableTransformerDecoderLayer.forward

def forward(
    self,
    embed: torch.Tensor,
    refer_bbox: torch.Tensor,
    feats: torch.Tensor,
    shapes: list,
    padding_mask: torch.Tensor | None = None,
    attn_mask: torch.Tensor | None = None,
    query_pos: torch.Tensor | None = None,
) -> torch.Tensor

Perform the forward pass through the entire decoder layer.

Args

NameTypeDescriptionDefault
embedtorch.TensorInput embeddings.required
refer_bboxtorch.TensorReference bounding boxes.required
featstorch.TensorFeature maps.required
shapeslistFeature shapes.required
padding_masktorch.Tensor, optionalPadding mask.None
attn_masktorch.Tensor, optionalAttention mask.None
query_postorch.Tensor, optionalQuery position embeddings.None

Returns

TypeDescription
torch.TensorOutput tensor after decoder layer.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward(
    self,
    embed: torch.Tensor,
    refer_bbox: torch.Tensor,
    feats: torch.Tensor,
    shapes: list,
    padding_mask: torch.Tensor | None = None,
    attn_mask: torch.Tensor | None = None,
    query_pos: torch.Tensor | None = None,
) -> torch.Tensor:
    """Perform the forward pass through the entire decoder layer.

    Args:
        embed (torch.Tensor): Input embeddings.
        refer_bbox (torch.Tensor): Reference bounding boxes.
        feats (torch.Tensor): Feature maps.
        shapes (list): Feature shapes.
        padding_mask (torch.Tensor, optional): Padding mask.
        attn_mask (torch.Tensor, optional): Attention mask.
        query_pos (torch.Tensor, optional): Query position embeddings.

    Returns:
        (torch.Tensor): Output tensor after decoder layer.
    """
    # Self attention
    q = k = self.with_pos_embed(embed, query_pos)
    tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
        0
    ].transpose(0, 1)
    embed = embed + self.dropout1(tgt)
    embed = self.norm1(embed)

    # Cross attention
    tgt = self.cross_attn(
        self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask
    )
    embed = embed + self.dropout2(tgt)
    embed = self.norm2(embed)

    # FFN
    return self.forward_ffn(embed)


method ultralytics.nn.modules.transformer.DeformableTransformerDecoderLayer.forward_ffn

def forward_ffn(self, tgt: torch.Tensor) -> torch.Tensor

Perform forward pass through the Feed-Forward Network part of the layer.

Args

NameTypeDescriptionDefault
tgttorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor after FFN.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward_ffn(self, tgt: torch.Tensor) -> torch.Tensor:
    """Perform forward pass through the Feed-Forward Network part of the layer.

    Args:
        tgt (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor after FFN.
    """
    tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
    tgt = tgt + self.dropout4(tgt2)
    return self.norm3(tgt)


method ultralytics.nn.modules.transformer.DeformableTransformerDecoderLayer.with_pos_embed

def with_pos_embed(tensor: torch.Tensor, pos: torch.Tensor | None) -> torch.Tensor

Add positional embeddings to the input tensor, if provided.

Args

NameTypeDescriptionDefault
tensortorch.Tensorrequired
postorch.Tensor | Nonerequired
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
@staticmethod
def with_pos_embed(tensor: torch.Tensor, pos: torch.Tensor | None) -> torch.Tensor:
    """Add positional embeddings to the input tensor, if provided."""
    return tensor if pos is None else tensor + pos





class ultralytics.nn.modules.transformer.DeformableTransformerDecoder

DeformableTransformerDecoder(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1)

Bases: nn.Module

Deformable Transformer Decoder based on PaddleDetection implementation.

This class implements a complete deformable transformer decoder with multiple decoder layers and prediction heads for bounding box regression and classification.

Args

NameTypeDescriptionDefault
hidden_dimintHidden dimension.required
decoder_layernn.ModuleDecoder layer module.required
num_layersintNumber of decoder layers.required
eval_idxintIndex of the layer to use during evaluation.-1

Attributes

NameTypeDescription
layersnn.ModuleListList of decoder layers.
num_layersintNumber of decoder layers.
hidden_dimintHidden dimension.
eval_idxintIndex of the layer to use during evaluation.
References
https//github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py

Methods

NameDescription
forwardPerform the forward pass through the entire decoder.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
class DeformableTransformerDecoder(nn.Module):
    """Deformable Transformer Decoder based on PaddleDetection implementation.

    This class implements a complete deformable transformer decoder with multiple decoder layers and prediction heads
    for bounding box regression and classification.

    Attributes:
        layers (nn.ModuleList): List of decoder layers.
        num_layers (int): Number of decoder layers.
        hidden_dim (int): Hidden dimension.
        eval_idx (int): Index of the layer to use during evaluation.

    References:
        https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
    """

    def __init__(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1):
        """Initialize the DeformableTransformerDecoder with the given parameters.

        Args:
            hidden_dim (int): Hidden dimension.
            decoder_layer (nn.Module): Decoder layer module.
            num_layers (int): Number of decoder layers.
            eval_idx (int): Index of the layer to use during evaluation.
        """
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx


method ultralytics.nn.modules.transformer.DeformableTransformerDecoder.forward

def forward(
    self,
    embed: torch.Tensor,  # decoder embeddings
    refer_bbox: torch.Tensor,  # anchor
    feats: torch.Tensor,  # image features
    shapes: list,  # feature shapes
    bbox_head: nn.Module,
    score_head: nn.Module,
    pos_mlp: nn.Module,
    attn_mask: torch.Tensor | None = None,
    padding_mask: torch.Tensor | None = None,
)

Perform the forward pass through the entire decoder.

Args

NameTypeDescriptionDefault
embedtorch.TensorDecoder embeddings.required
refer_bboxtorch.TensorReference bounding boxes.required
featstorch.TensorImage features.required
shapeslistFeature shapes.required
bbox_headnn.ModuleBounding box prediction head.required
score_headnn.ModuleScore prediction head.required
pos_mlpnn.ModulePosition MLP.required
attn_masktorch.Tensor, optionalAttention mask.None
padding_masktorch.Tensor, optionalPadding mask.None

Returns

TypeDescription
dec_bboxes (torch.Tensor)Decoded bounding boxes.
dec_cls (torch.Tensor)Decoded classification scores.
Source code in ultralytics/nn/modules/transformer.pyView on GitHub
def forward(
    self,
    embed: torch.Tensor,  # decoder embeddings
    refer_bbox: torch.Tensor,  # anchor
    feats: torch.Tensor,  # image features
    shapes: list,  # feature shapes
    bbox_head: nn.Module,
    score_head: nn.Module,
    pos_mlp: nn.Module,
    attn_mask: torch.Tensor | None = None,
    padding_mask: torch.Tensor | None = None,
):
    """Perform the forward pass through the entire decoder.

    Args:
        embed (torch.Tensor): Decoder embeddings.
        refer_bbox (torch.Tensor): Reference bounding boxes.
        feats (torch.Tensor): Image features.
        shapes (list): Feature shapes.
        bbox_head (nn.Module): Bounding box prediction head.
        score_head (nn.Module): Score prediction head.
        pos_mlp (nn.Module): Position MLP.
        attn_mask (torch.Tensor, optional): Attention mask.
        padding_mask (torch.Tensor, optional): Padding mask.

    Returns:
        dec_bboxes (torch.Tensor): Decoded bounding boxes.
        dec_cls (torch.Tensor): Decoded classification scores.
    """
    output = embed
    dec_bboxes = []
    dec_cls = []
    last_refined_bbox = None
    refer_bbox = refer_bbox.sigmoid()
    for i, layer in enumerate(self.layers):
        output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))

        bbox = bbox_head[i](output)
        refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox))

        if self.training:
            dec_cls.append(score_head[i](output))
            if i == 0:
                dec_bboxes.append(refined_bbox)
            else:
                dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox)))
        elif i == self.eval_idx:
            dec_cls.append(score_head[i](output))
            dec_bboxes.append(refined_bbox)
            break

        last_refined_bbox = refined_bbox
        refer_bbox = refined_bbox.detach() if self.training else refined_bbox

    return torch.stack(dec_bboxes), torch.stack(dec_cls)





📅 Created 2 years ago ✏️ Updated 18 days ago
glenn-jocherjk4eBurhan-Q