Skip to content

Reference for ultralytics/models/sam/modules/memory_attention.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/memory_attention.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.sam.modules.memory_attention.MemoryAttentionLayer

MemoryAttentionLayer(
    d_model: int = 256,
    dim_feedforward: int = 2048,
    dropout: float = 0.1,
    pos_enc_at_attn: bool = False,
    pos_enc_at_cross_attn_keys: bool = True,
    pos_enc_at_cross_attn_queries: bool = False,
)

Bases: Module

Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.

This class combines self-attention, cross-attention, and feedforward components to process input tensors and generate memory-based attention outputs.

Attributes:

Name Type Description
d_model int

Dimensionality of the model.

dim_feedforward int

Dimensionality of the feedforward network.

dropout_value float

Dropout rate for regularization.

self_attn RoPEAttention

Self-attention mechanism using RoPE (Rotary Position Embedding).

cross_attn_image RoPEAttention

Cross-attention mechanism for image processing.

linear1 Linear

First linear layer of the feedforward network.

linear2 Linear

Second linear layer of the feedforward network.

norm1 LayerNorm

Layer normalization for self-attention output.

norm2 LayerNorm

Layer normalization for cross-attention output.

norm3 LayerNorm

Layer normalization for feedforward network output.

dropout1 Dropout

Dropout layer after self-attention.

dropout2 Dropout

Dropout layer after cross-attention.

dropout3 Dropout

Dropout layer after feedforward network.

activation ReLU

Activation function for the feedforward network.

pos_enc_at_attn bool

Flag to add positional encoding at attention.

pos_enc_at_cross_attn_queries bool

Flag to add positional encoding to cross-attention queries.

pos_enc_at_cross_attn_keys bool

Flag to add positional encoding to cross-attention keys.

Methods:

Name Description
forward

Performs the full memory attention operation on input tensors.

_forward_sa

Performs self-attention on input tensor.

_forward_ca

Performs cross-attention between target and memory tensors.

Examples:

>>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
>>> tgt = torch.randn(1, 100, 256)
>>> memory = torch.randn(1, 100, 64)
>>> pos = torch.randn(1, 100, 256)
>>> query_pos = torch.randn(1, 100, 256)
>>> output = layer(tgt, memory, pos, query_pos)
>>> print(output.shape)
torch.Size([1, 100, 256])

Parameters:

Name Type Description Default
d_model int

Dimensionality of the model.

256
dim_feedforward int

Dimensionality of the feedforward network.

2048
dropout float

Dropout rate for regularization.

0.1
pos_enc_at_attn bool

Whether to add positional encoding at attention.

False
pos_enc_at_cross_attn_keys bool

Whether to add positional encoding to cross-attention keys.

True
pos_enc_at_cross_attn_queries bool

Whether to add positional encoding to cross-attention queries.

False
Source code in ultralytics/models/sam/modules/memory_attention.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def __init__(
    self,
    d_model: int = 256,
    dim_feedforward: int = 2048,
    dropout: float = 0.1,
    pos_enc_at_attn: bool = False,
    pos_enc_at_cross_attn_keys: bool = True,
    pos_enc_at_cross_attn_queries: bool = False,
):
    """
    Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.

    Args:
        d_model (int): Dimensionality of the model.
        dim_feedforward (int): Dimensionality of the feedforward network.
        dropout (float): Dropout rate for regularization.
        pos_enc_at_attn (bool): Whether to add positional encoding at attention.
        pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
        pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
    """
    super().__init__()
    self.d_model = d_model
    self.dim_feedforward = dim_feedforward
    self.dropout_value = dropout
    self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
    self.cross_attn_image = RoPEAttention(
        rope_k_repeat=True,
        embedding_dim=256,
        num_heads=1,
        downsample_rate=1,
        kv_in_dim=64,
    )

    # Implementation of Feedforward model
    self.linear1 = nn.Linear(d_model, dim_feedforward)
    self.dropout = nn.Dropout(dropout)
    self.linear2 = nn.Linear(dim_feedforward, d_model)

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)
    self.dropout3 = nn.Dropout(dropout)

    self.activation = nn.ReLU()

    # Where to add pos enc
    self.pos_enc_at_attn = pos_enc_at_attn
    self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
    self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys

forward

forward(
    tgt: Tensor,
    memory: Tensor,
    pos: Optional[Tensor] = None,
    query_pos: Optional[Tensor] = None,
    num_k_exclude_rope: int = 0,
) -> torch.Tensor

Process input tensors through self-attention, cross-attention, and feedforward network layers.

Source code in ultralytics/models/sam/modules/memory_attention.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def forward(
    self,
    tgt: Tensor,
    memory: Tensor,
    pos: Optional[Tensor] = None,
    query_pos: Optional[Tensor] = None,
    num_k_exclude_rope: int = 0,
) -> torch.Tensor:
    """Process input tensors through self-attention, cross-attention, and feedforward network layers."""
    tgt = self._forward_sa(tgt, query_pos)
    tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
    # MLP
    tgt2 = self.norm3(tgt)
    tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
    tgt = tgt + self.dropout3(tgt2)
    return tgt





ultralytics.models.sam.modules.memory_attention.MemoryAttention

MemoryAttention(
    d_model: int,
    pos_enc_at_input: bool,
    layer: Module,
    num_layers: int,
    batch_first: bool = True,
)

Bases: Module

Memory attention module for processing sequential data with self and cross-attention mechanisms.

This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for processing sequential data, particularly useful in transformer-like architectures.

Attributes:

Name Type Description
d_model int

The dimension of the model's hidden state.

layers ModuleList

A list of MemoryAttentionLayer modules.

num_layers int

The number of attention layers.

norm LayerNorm

Layer normalization applied to the output.

pos_enc_at_input bool

Whether to apply positional encoding at the input.

batch_first bool

Whether the input tensors are in batch-first format.

Methods:

Name Description
forward

Processes input tensors through the attention layers.

Examples:

>>> d_model = 256
>>> layer = MemoryAttentionLayer(d_model)
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
>>> curr = torch.randn(10, 32, d_model)  # (seq_len, batch_size, d_model)
>>> memory = torch.randn(20, 32, d_model)  # (mem_len, batch_size, d_model)
>>> curr_pos = torch.randn(10, 32, d_model)
>>> memory_pos = torch.randn(20, 32, d_model)
>>> output = attention(curr, memory, curr_pos, memory_pos)
>>> print(output.shape)
torch.Size([10, 32, 256])

This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for processing sequential data, particularly useful in transformer-like architectures.

Parameters:

Name Type Description Default
d_model int

The dimension of the model's hidden state.

required
pos_enc_at_input bool

Whether to apply positional encoding at the input.

required
layer Module

The attention layer to be used in the module.

required
num_layers int

The number of attention layers.

required
batch_first bool

Whether the input tensors are in batch-first format.

True

Examples:

>>> d_model = 256
>>> layer = MemoryAttentionLayer(d_model)
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
>>> curr = torch.randn(10, 32, d_model)  # (seq_len, batch_size, d_model)
>>> memory = torch.randn(20, 32, d_model)  # (mem_len, batch_size, d_model)
>>> curr_pos = torch.randn(10, 32, d_model)
>>> memory_pos = torch.randn(20, 32, d_model)
>>> output = attention(curr, memory, curr_pos, memory_pos)
>>> print(output.shape)
torch.Size([10, 32, 256])
Source code in ultralytics/models/sam/modules/memory_attention.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def __init__(
    self,
    d_model: int,
    pos_enc_at_input: bool,
    layer: nn.Module,
    num_layers: int,
    batch_first: bool = True,  # Do layers expect batch first input?
):
    """
    Initialize MemoryAttention with specified layers and normalization for sequential data processing.

    This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
    for processing sequential data, particularly useful in transformer-like architectures.

    Args:
        d_model (int): The dimension of the model's hidden state.
        pos_enc_at_input (bool): Whether to apply positional encoding at the input.
        layer (nn.Module): The attention layer to be used in the module.
        num_layers (int): The number of attention layers.
        batch_first (bool): Whether the input tensors are in batch-first format.

    Examples:
        >>> d_model = 256
        >>> layer = MemoryAttentionLayer(d_model)
        >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
        >>> curr = torch.randn(10, 32, d_model)  # (seq_len, batch_size, d_model)
        >>> memory = torch.randn(20, 32, d_model)  # (mem_len, batch_size, d_model)
        >>> curr_pos = torch.randn(10, 32, d_model)
        >>> memory_pos = torch.randn(20, 32, d_model)
        >>> output = attention(curr, memory, curr_pos, memory_pos)
        >>> print(output.shape)
        torch.Size([10, 32, 256])
    """
    super().__init__()
    self.d_model = d_model
    self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
    self.num_layers = num_layers
    self.norm = nn.LayerNorm(d_model)
    self.pos_enc_at_input = pos_enc_at_input
    self.batch_first = batch_first

forward

forward(
    curr: Tensor,
    memory: Tensor,
    curr_pos: Optional[Tensor] = None,
    memory_pos: Optional[Tensor] = None,
    num_obj_ptr_tokens: int = 0,
) -> torch.Tensor

Process inputs through attention layers, applying self and cross-attention with positional encoding.

Parameters:

Name Type Description Default
curr Tensor

Self-attention input tensor, representing the current state.

required
memory Tensor

Cross-attention input tensor, representing memory information.

required
curr_pos Optional[Tensor]

Positional encoding for self-attention inputs.

None
memory_pos Optional[Tensor]

Positional encoding for cross-attention inputs.

None
num_obj_ptr_tokens int

Number of object pointer tokens to exclude from rotary position embedding.

0

Returns:

Type Description
Tensor

Processed output tensor after applying attention layers and normalization.

Examples:

>>> d_model = 256
>>> layer = MemoryAttentionLayer(d_model)
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
>>> curr = torch.randn(10, 32, d_model)  # (seq_len, batch_size, d_model)
>>> memory = torch.randn(20, 32, d_model)  # (mem_len, batch_size, d_model)
>>> curr_pos = torch.randn(10, 32, d_model)
>>> memory_pos = torch.randn(20, 32, d_model)
>>> output = attention(curr, memory, curr_pos, memory_pos)
>>> print(output.shape)
torch.Size([10, 32, 256])
Source code in ultralytics/models/sam/modules/memory_attention.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def forward(
    self,
    curr: torch.Tensor,  # self-attention inputs
    memory: torch.Tensor,  # cross-attention inputs
    curr_pos: Optional[Tensor] = None,  # pos_enc for self-attention inputs
    memory_pos: Optional[Tensor] = None,  # pos_enc for cross-attention inputs
    num_obj_ptr_tokens: int = 0,  # number of object pointer *tokens*
) -> torch.Tensor:
    """
    Process inputs through attention layers, applying self and cross-attention with positional encoding.

    Args:
        curr (torch.Tensor): Self-attention input tensor, representing the current state.
        memory (torch.Tensor): Cross-attention input tensor, representing memory information.
        curr_pos (Optional[Tensor]): Positional encoding for self-attention inputs.
        memory_pos (Optional[Tensor]): Positional encoding for cross-attention inputs.
        num_obj_ptr_tokens (int): Number of object pointer tokens to exclude from rotary position embedding.

    Returns:
        (torch.Tensor): Processed output tensor after applying attention layers and normalization.

    Examples:
        >>> d_model = 256
        >>> layer = MemoryAttentionLayer(d_model)
        >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
        >>> curr = torch.randn(10, 32, d_model)  # (seq_len, batch_size, d_model)
        >>> memory = torch.randn(20, 32, d_model)  # (mem_len, batch_size, d_model)
        >>> curr_pos = torch.randn(10, 32, d_model)
        >>> memory_pos = torch.randn(20, 32, d_model)
        >>> output = attention(curr, memory, curr_pos, memory_pos)
        >>> print(output.shape)
        torch.Size([10, 32, 256])
    """
    if isinstance(curr, list):
        assert isinstance(curr_pos, list)
        assert len(curr) == len(curr_pos) == 1
        curr, curr_pos = curr[0], curr_pos[0]

    assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"

    output = curr
    if self.pos_enc_at_input and curr_pos is not None:
        output = output + 0.1 * curr_pos

    if self.batch_first:
        # Convert to batch first
        output = output.transpose(0, 1)
        curr_pos = curr_pos.transpose(0, 1)
        memory = memory.transpose(0, 1)
        memory_pos = memory_pos.transpose(0, 1)

    for layer in self.layers:
        kwds = {}
        if isinstance(layer.cross_attn_image, RoPEAttention):
            kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}

        output = layer(
            tgt=output,
            memory=memory,
            pos=memory_pos,
            query_pos=curr_pos,
            **kwds,
        )
    normed_output = self.norm(output)

    if self.batch_first:
        # Convert back to seq first
        normed_output = normed_output.transpose(0, 1)
        curr_pos = curr_pos.transpose(0, 1)

    return normed_output





📅 Created 8 months ago ✏️ Updated 7 months ago