Skip to content

Reference for ultralytics/models/sam/modules/transformer.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/transformer.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.sam.modules.transformer.TwoWayTransformer

TwoWayTransformer(
    depth: int,
    embedding_dim: int,
    num_heads: int,
    mlp_dim: int,
    activation: Type[Module] = nn.ReLU,
    attention_downsample_rate: int = 2,
)

Bases: Module

A Two-Way Transformer module for simultaneous attention to image and query points.

This class implements a specialized transformer decoder that attends to an input image using queries with supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point cloud processing.

Attributes:

Name Type Description
depth int

Number of layers in the transformer.

embedding_dim int

Channel dimension for input embeddings.

num_heads int

Number of heads for multihead attention.

mlp_dim int

Internal channel dimension for the MLP block.

layers ModuleList

List of TwoWayAttentionBlock layers composing the transformer.

final_attn_token_to_image Attention

Final attention layer from queries to image.

norm_final_attn LayerNorm

Layer normalization applied to final queries.

Methods:

Name Description
forward

Processes image and point embeddings through the transformer.

Examples:

>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> image_embedding = torch.randn(1, 256, 32, 32)
>>> image_pe = torch.randn(1, 256, 32, 32)
>>> point_embedding = torch.randn(1, 100, 256)
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
>>> print(output_queries.shape, output_image.shape)

Parameters:

Name Type Description Default
depth int

Number of layers in the transformer.

required
embedding_dim int

Channel dimension for input embeddings.

required
num_heads int

Number of heads for multihead attention. Must divide embedding_dim.

required
mlp_dim int

Internal channel dimension for the MLP block.

required
activation Type[Module]

Activation function to use in the MLP block.

ReLU
attention_downsample_rate int

Downsampling rate for attention mechanism.

2
Source code in ultralytics/models/sam/modules/transformer.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    depth: int,
    embedding_dim: int,
    num_heads: int,
    mlp_dim: int,
    activation: Type[nn.Module] = nn.ReLU,
    attention_downsample_rate: int = 2,
) -> None:
    """
    Initialize a Two-Way Transformer for simultaneous attention to image and query points.

    Args:
        depth (int): Number of layers in the transformer.
        embedding_dim (int): Channel dimension for input embeddings.
        num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
        mlp_dim (int): Internal channel dimension for the MLP block.
        activation (Type[nn.Module]): Activation function to use in the MLP block.
        attention_downsample_rate (int): Downsampling rate for attention mechanism.
    """
    super().__init__()
    self.depth = depth
    self.embedding_dim = embedding_dim
    self.num_heads = num_heads
    self.mlp_dim = mlp_dim
    self.layers = nn.ModuleList()

    for i in range(depth):
        self.layers.append(
            TwoWayAttentionBlock(
                embedding_dim=embedding_dim,
                num_heads=num_heads,
                mlp_dim=mlp_dim,
                activation=activation,
                attention_downsample_rate=attention_downsample_rate,
                skip_first_layer_pe=(i == 0),
            )
        )

    self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
    self.norm_final_attn = nn.LayerNorm(embedding_dim)

forward

forward(
    image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor
) -> Tuple[Tensor, Tensor]

Process image and point embeddings through the Two-Way Transformer.

Parameters:

Name Type Description Default
image_embedding Tensor

Image to attend to, with shape (B, embedding_dim, H, W).

required
image_pe Tensor

Positional encoding to add to the image, with same shape as image_embedding.

required
point_embedding Tensor

Embedding to add to query points, with shape (B, N_points, embedding_dim).

required

Returns:

Name Type Description
queries Tensor

Processed point embeddings with shape (B, N_points, embedding_dim).

keys Tensor

Processed image embeddings with shape (B, H*W, embedding_dim).

Source code in ultralytics/models/sam/modules/transformer.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def forward(
    self,
    image_embedding: Tensor,
    image_pe: Tensor,
    point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
    """
    Process image and point embeddings through the Two-Way Transformer.

    Args:
        image_embedding (Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
        image_pe (Tensor): Positional encoding to add to the image, with same shape as image_embedding.
        point_embedding (Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).

    Returns:
        queries (Tensor): Processed point embeddings with shape (B, N_points, embedding_dim).
        keys (Tensor): Processed image embeddings with shape (B, H*W, embedding_dim).
    """
    # BxCxHxW -> BxHWxC == B x N_image_tokens x C
    image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
    image_pe = image_pe.flatten(2).permute(0, 2, 1)

    # Prepare queries
    queries = point_embedding
    keys = image_embedding

    # Apply transformer blocks and final layernorm
    for layer in self.layers:
        queries, keys = layer(
            queries=queries,
            keys=keys,
            query_pe=point_embedding,
            key_pe=image_pe,
        )

    # Apply the final attention layer from the points to the image
    q = queries + point_embedding
    k = keys + image_pe
    attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
    queries = queries + attn_out
    queries = self.norm_final_attn(queries)

    return queries, keys





ultralytics.models.sam.modules.transformer.TwoWayAttentionBlock

TwoWayAttentionBlock(
    embedding_dim: int,
    num_heads: int,
    mlp_dim: int = 2048,
    activation: Type[Module] = nn.ReLU,
    attention_downsample_rate: int = 2,
    skip_first_layer_pe: bool = False,
)

Bases: Module

A two-way attention block for simultaneous attention to image and query points.

This class implements a specialized transformer block with four main layers: self-attention on sparse inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs.

Attributes:

Name Type Description
self_attn Attention

Self-attention layer for queries.

norm1 LayerNorm

Layer normalization after self-attention.

cross_attn_token_to_image Attention

Cross-attention layer from queries to keys.

norm2 LayerNorm

Layer normalization after token-to-image attention.

mlp MLPBlock

MLP block for transforming query embeddings.

norm3 LayerNorm

Layer normalization after MLP block.

norm4 LayerNorm

Layer normalization after image-to-token attention.

cross_attn_image_to_token Attention

Cross-attention layer from keys to queries.

skip_first_layer_pe bool

Whether to skip positional encoding in the first layer.

Methods:

Name Description
forward

Applies self-attention and cross-attention to queries and keys.

Examples:

>>> embedding_dim, num_heads = 256, 8
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
>>> queries = torch.randn(1, 100, embedding_dim)
>>> keys = torch.randn(1, 1000, embedding_dim)
>>> query_pe = torch.randn(1, 100, embedding_dim)
>>> key_pe = torch.randn(1, 1000, embedding_dim)
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)

This block implements a specialized transformer layer with four main components: self-attention on sparse inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs.

Parameters:

Name Type Description Default
embedding_dim int

Channel dimension of the embeddings.

required
num_heads int

Number of attention heads in the attention layers.

required
mlp_dim int

Hidden dimension of the MLP block.

2048
activation Type[Module]

Activation function for the MLP block.

ReLU
attention_downsample_rate int

Downsampling rate for the attention mechanism.

2
skip_first_layer_pe bool

Whether to skip positional encoding in the first layer.

False
Source code in ultralytics/models/sam/modules/transformer.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def __init__(
    self,
    embedding_dim: int,
    num_heads: int,
    mlp_dim: int = 2048,
    activation: Type[nn.Module] = nn.ReLU,
    attention_downsample_rate: int = 2,
    skip_first_layer_pe: bool = False,
) -> None:
    """
    Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.

    This block implements a specialized transformer layer with four main components: self-attention on sparse
    inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
    of dense inputs to sparse inputs.

    Args:
        embedding_dim (int): Channel dimension of the embeddings.
        num_heads (int): Number of attention heads in the attention layers.
        mlp_dim (int): Hidden dimension of the MLP block.
        activation (Type[nn.Module]): Activation function for the MLP block.
        attention_downsample_rate (int): Downsampling rate for the attention mechanism.
        skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
    """
    super().__init__()
    self.self_attn = Attention(embedding_dim, num_heads)
    self.norm1 = nn.LayerNorm(embedding_dim)

    self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
    self.norm2 = nn.LayerNorm(embedding_dim)

    self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
    self.norm3 = nn.LayerNorm(embedding_dim)

    self.norm4 = nn.LayerNorm(embedding_dim)
    self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)

    self.skip_first_layer_pe = skip_first_layer_pe

forward

forward(
    queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]

Apply two-way attention to process query and key embeddings in a transformer block.

Parameters:

Name Type Description Default
queries Tensor

Query embeddings with shape (B, N_queries, embedding_dim).

required
keys Tensor

Key embeddings with shape (B, N_keys, embedding_dim).

required
query_pe Tensor

Positional encodings for queries with same shape as queries.

required
key_pe Tensor

Positional encodings for keys with same shape as keys.

required

Returns:

Name Type Description
queries Tensor

Processed query embeddings with shape (B, N_queries, embedding_dim).

keys Tensor

Processed key embeddings with shape (B, N_keys, embedding_dim).

Source code in ultralytics/models/sam/modules/transformer.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
    """
    Apply two-way attention to process query and key embeddings in a transformer block.

    Args:
        queries (Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
        keys (Tensor): Key embeddings with shape (B, N_keys, embedding_dim).
        query_pe (Tensor): Positional encodings for queries with same shape as queries.
        key_pe (Tensor): Positional encodings for keys with same shape as keys.

    Returns:
        queries (Tensor): Processed query embeddings with shape (B, N_queries, embedding_dim).
        keys (Tensor): Processed key embeddings with shape (B, N_keys, embedding_dim).
    """
    # Self attention block
    if self.skip_first_layer_pe:
        queries = self.self_attn(q=queries, k=queries, v=queries)
    else:
        q = queries + query_pe
        attn_out = self.self_attn(q=q, k=q, v=queries)
        queries = queries + attn_out
    queries = self.norm1(queries)

    # Cross attention block, tokens attending to image embedding
    q = queries + query_pe
    k = keys + key_pe
    attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
    queries = queries + attn_out
    queries = self.norm2(queries)

    # MLP block
    mlp_out = self.mlp(queries)
    queries = queries + mlp_out
    queries = self.norm3(queries)

    # Cross attention block, image embedding attending to tokens
    q = queries + query_pe
    k = keys + key_pe
    attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
    keys = keys + attn_out
    keys = self.norm4(keys)

    return queries, keys





ultralytics.models.sam.modules.transformer.Attention

Attention(
    embedding_dim: int,
    num_heads: int,
    downsample_rate: int = 1,
    kv_in_dim: int = None,
)

Bases: Module

An attention layer with downscaling capability for embedding size after projection.

This class implements a multi-head attention mechanism with the option to downsample the internal dimension of queries, keys, and values.

Attributes:

Name Type Description
embedding_dim int

Dimensionality of input embeddings.

kv_in_dim int

Dimensionality of key and value inputs.

internal_dim int

Internal dimension after downsampling.

num_heads int

Number of attention heads.

q_proj Linear

Linear projection for queries.

k_proj Linear

Linear projection for keys.

v_proj Linear

Linear projection for values.

out_proj Linear

Linear projection for output.

Methods:

Name Description
_separate_heads

Separates input tensor into attention heads.

_recombine_heads

Recombines separated attention heads.

forward

Computes attention output for given query, key, and value tensors.

Examples:

>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
>>> q = torch.randn(1, 100, 256)
>>> k = v = torch.randn(1, 50, 256)
>>> output = attn(q, k, v)
>>> print(output.shape)
torch.Size([1, 100, 256])

Parameters:

Name Type Description Default
embedding_dim int

Dimensionality of input embeddings.

required
num_heads int

Number of attention heads.

required
downsample_rate int

Factor by which internal dimensions are downsampled.

1
kv_in_dim int | None

Dimensionality of key and value inputs. If None, uses embedding_dim.

None

Raises:

Type Description
AssertionError

If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).

Source code in ultralytics/models/sam/modules/transformer.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def __init__(
    self,
    embedding_dim: int,
    num_heads: int,
    downsample_rate: int = 1,
    kv_in_dim: int = None,
) -> None:
    """
    Initialize the Attention module with specified dimensions and settings.

    Args:
        embedding_dim (int): Dimensionality of input embeddings.
        num_heads (int): Number of attention heads.
        downsample_rate (int): Factor by which internal dimensions are downsampled.
        kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.

    Raises:
        AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
    """
    super().__init__()
    self.embedding_dim = embedding_dim
    self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
    self.internal_dim = embedding_dim // downsample_rate
    self.num_heads = num_heads
    assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."

    self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
    self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
    self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
    self.out_proj = nn.Linear(self.internal_dim, embedding_dim)

forward

forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor

Apply multi-head attention to query, key, and value tensors with optional downsampling.

Parameters:

Name Type Description Default
q Tensor

Query tensor with shape (B, N_q, embedding_dim).

required
k Tensor

Key tensor with shape (B, N_k, embedding_dim).

required
v Tensor

Value tensor with shape (B, N_k, embedding_dim).

required

Returns:

Type Description
Tensor

Output tensor after attention with shape (B, N_q, embedding_dim).

Source code in ultralytics/models/sam/modules/transformer.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    """
    Apply multi-head attention to query, key, and value tensors with optional downsampling.

    Args:
        q (Tensor): Query tensor with shape (B, N_q, embedding_dim).
        k (Tensor): Key tensor with shape (B, N_k, embedding_dim).
        v (Tensor): Value tensor with shape (B, N_k, embedding_dim).

    Returns:
        (Tensor): Output tensor after attention with shape (B, N_q, embedding_dim).
    """
    # Input projections
    q = self.q_proj(q)
    k = self.k_proj(k)
    v = self.v_proj(v)

    # Separate into heads
    q = self._separate_heads(q, self.num_heads)
    k = self._separate_heads(k, self.num_heads)
    v = self._separate_heads(v, self.num_heads)

    # Attention
    _, _, _, c_per_head = q.shape
    attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
    attn = attn / math.sqrt(c_per_head)
    attn = torch.softmax(attn, dim=-1)

    # Get output
    out = attn @ v
    out = self._recombine_heads(out)
    return self.out_proj(out)



📅 Created 1 year ago ✏️ Updated 7 months ago