Skip to content

Reference for ultralytics/models/sam/modules/utils.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/utils.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.sam.modules.utils.select_closest_cond_frames

select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)

Select the closest conditioning frames to a given frame index.

Parameters:

Name Type Description Default
frame_idx int

Current frame index.

required
cond_frame_outputs Dict[int, Any]

Dictionary of conditioning frame outputs keyed by frame indices.

required
max_cond_frame_num int

Maximum number of conditioning frames to select.

required

Returns:

Name Type Description
selected_outputs Dict[int, Any]

Selected items from cond_frame_outputs.

unselected_outputs Dict[int, Any]

Items not selected from cond_frame_outputs.

Examples:

>>> frame_idx = 5
>>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
>>> max_cond_frame_num = 2
>>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
>>> print(selected)
{3: 'b', 7: 'c'}
>>> print(unselected)
{1: 'a', 9: 'd'}
Source code in ultralytics/models/sam/modules/utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
    """
    Select the closest conditioning frames to a given frame index.

    Args:
        frame_idx (int): Current frame index.
        cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
        max_cond_frame_num (int): Maximum number of conditioning frames to select.

    Returns:
        selected_outputs (Dict[int, Any]): Selected items from cond_frame_outputs.
        unselected_outputs (Dict[int, Any]): Items not selected from cond_frame_outputs.

    Examples:
        >>> frame_idx = 5
        >>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
        >>> max_cond_frame_num = 2
        >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
        >>> print(selected)
        {3: 'b', 7: 'c'}
        >>> print(unselected)
        {1: 'a', 9: 'd'}
    """
    if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
        selected_outputs = cond_frame_outputs
        unselected_outputs = {}
    else:
        assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
        selected_outputs = {}

        # The closest conditioning frame before `frame_idx` (if any)
        idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
        if idx_before is not None:
            selected_outputs[idx_before] = cond_frame_outputs[idx_before]

        # The closest conditioning frame after `frame_idx` (if any)
        idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
        if idx_after is not None:
            selected_outputs[idx_after] = cond_frame_outputs[idx_after]

        # Add other temporally closest conditioning frames until reaching a total
        # of `max_cond_frame_num` conditioning frames.
        num_remain = max_cond_frame_num - len(selected_outputs)
        inds_remain = sorted(
            (t for t in cond_frame_outputs if t not in selected_outputs),
            key=lambda x: abs(x - frame_idx),
        )[:num_remain]
        selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
        unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}

    return selected_outputs, unselected_outputs





ultralytics.models.sam.modules.utils.get_1d_sine_pe

get_1d_sine_pe(pos_inds, dim, temperature=10000)

Generate 1D sinusoidal positional embeddings for given positions and dimensions.

Parameters:

Name Type Description Default
pos_inds Tensor

Position indices for which to generate embeddings.

required
dim int

Dimension of the positional embeddings. Should be an even number.

required
temperature float

Scaling factor for the frequency of the sinusoidal functions.

10000

Returns:

Type Description
Tensor

Sinusoidal positional embeddings with shape (pos_inds.shape, dim).

Examples:

>>> pos = torch.tensor([0, 1, 2, 3])
>>> embeddings = get_1d_sine_pe(pos, 128)
>>> embeddings.shape
torch.Size([4, 128])
Source code in ultralytics/models/sam/modules/utils.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
    """
    Generate 1D sinusoidal positional embeddings for given positions and dimensions.

    Args:
        pos_inds (torch.Tensor): Position indices for which to generate embeddings.
        dim (int): Dimension of the positional embeddings. Should be an even number.
        temperature (float): Scaling factor for the frequency of the sinusoidal functions.

    Returns:
        (torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).

    Examples:
        >>> pos = torch.tensor([0, 1, 2, 3])
        >>> embeddings = get_1d_sine_pe(pos, 128)
        >>> embeddings.shape
        torch.Size([4, 128])
    """
    pe_dim = dim // 2
    dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
    dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)

    pos_embed = pos_inds.unsqueeze(-1) / dim_t
    pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
    return pos_embed





ultralytics.models.sam.modules.utils.init_t_xy

init_t_xy(end_x: int, end_y: int)

Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.

This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor and corresponding x and y coordinate tensors.

Parameters:

Name Type Description Default
end_x int

Width of the grid (number of columns).

required
end_y int

Height of the grid (number of rows).

required

Returns:

Name Type Description
t Tensor

Linear indices for each position in the grid, with shape (end_x * end_y).

t_x Tensor

X-coordinates for each position, with shape (end_x * end_y).

t_y Tensor

Y-coordinates for each position, with shape (end_x * end_y).

Examples:

>>> t, t_x, t_y = init_t_xy(3, 2)
>>> print(t)
tensor([0., 1., 2., 3., 4., 5.])
>>> print(t_x)
tensor([0., 1., 2., 0., 1., 2.])
>>> print(t_y)
tensor([0., 0., 0., 1., 1., 1.])
Source code in ultralytics/models/sam/modules/utils.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def init_t_xy(end_x: int, end_y: int):
    """
    Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.

    This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor
    and corresponding x and y coordinate tensors.

    Args:
        end_x (int): Width of the grid (number of columns).
        end_y (int): Height of the grid (number of rows).

    Returns:
        t (torch.Tensor): Linear indices for each position in the grid, with shape (end_x * end_y).
        t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
        t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).

    Examples:
        >>> t, t_x, t_y = init_t_xy(3, 2)
        >>> print(t)
        tensor([0., 1., 2., 3., 4., 5.])
        >>> print(t_x)
        tensor([0., 1., 2., 0., 1., 2.])
        >>> print(t_y)
        tensor([0., 0., 0., 1., 1., 1.])
    """
    t = torch.arange(end_x * end_y, dtype=torch.float32)
    t_x = (t % end_x).float()
    t_y = torch.div(t, end_x, rounding_mode="floor").float()
    return t_x, t_y





ultralytics.models.sam.modules.utils.compute_axial_cis

compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0)

Compute axial complex exponential positional encodings for 2D spatial positions in a grid.

This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate frequency components for the x and y dimensions.

Parameters:

Name Type Description Default
dim int

Dimension of the positional encoding.

required
end_x int

Width of the 2D grid.

required
end_y int

Height of the 2D grid.

required
theta float

Scaling factor for frequency computation.

10000.0

Returns:

Name Type Description
freqs_cis_x Tensor

Complex exponential positional encodings for x-dimension with shape (end_x*end_y, dim//4).

freqs_cis_y Tensor

Complex exponential positional encodings for y-dimension with shape (end_x*end_y, dim//4).

Examples:

>>> dim, end_x, end_y = 128, 8, 8
>>> freqs_cis_x, freqs_cis_y = compute_axial_cis(dim, end_x, end_y)
>>> freqs_cis_x.shape
torch.Size([64, 32])
>>> freqs_cis_y.shape
torch.Size([64, 32])
Source code in ultralytics/models/sam/modules/utils.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
    """
    Compute axial complex exponential positional encodings for 2D spatial positions in a grid.

    This function generates complex exponential positional encodings for a 2D grid of spatial positions,
    using separate frequency components for the x and y dimensions.

    Args:
        dim (int): Dimension of the positional encoding.
        end_x (int): Width of the 2D grid.
        end_y (int): Height of the 2D grid.
        theta (float, optional): Scaling factor for frequency computation.

    Returns:
        freqs_cis_x (torch.Tensor): Complex exponential positional encodings for x-dimension with shape
            (end_x*end_y, dim//4).
        freqs_cis_y (torch.Tensor): Complex exponential positional encodings for y-dimension with shape
            (end_x*end_y, dim//4).

    Examples:
        >>> dim, end_x, end_y = 128, 8, 8
        >>> freqs_cis_x, freqs_cis_y = compute_axial_cis(dim, end_x, end_y)
        >>> freqs_cis_x.shape
        torch.Size([64, 32])
        >>> freqs_cis_y.shape
        torch.Size([64, 32])
    """
    freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))

    t_x, t_y = init_t_xy(end_x, end_y)
    freqs_x = torch.outer(t_x, freqs_x)
    freqs_y = torch.outer(t_y, freqs_y)
    freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
    freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
    return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)





ultralytics.models.sam.modules.utils.reshape_for_broadcast

reshape_for_broadcast(freqs_cis: Tensor, x: Tensor)

Reshape frequency tensor for broadcasting with input tensor.

Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor. This function is typically used in positional encoding operations.

Parameters:

Name Type Description Default
freqs_cis Tensor

Frequency tensor with shape matching the last two dimensions of x.

required
x Tensor

Input tensor to broadcast with.

required

Returns:

Type Description
Tensor

Reshaped frequency tensor ready for broadcasting with the input tensor.

Raises:

Type Description
AssertionError

If the shape of freqs_cis doesn't match the last two dimensions of x.

Source code in ultralytics/models/sam/modules/utils.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    Reshape frequency tensor for broadcasting with input tensor.

    Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.
    This function is typically used in positional encoding operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
        x (torch.Tensor): Input tensor to broadcast with.

    Returns:
        (torch.Tensor): Reshaped frequency tensor ready for broadcasting with the input tensor.

    Raises:
        AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
    shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)





ultralytics.models.sam.modules.utils.apply_rotary_enc

apply_rotary_enc(
    xq: Tensor, xk: Tensor, freqs_cis: Tensor, repeat_freqs_k: bool = False
)

Apply rotary positional encoding to query and key tensors.

This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency components. RoPE is a technique that injects relative position information into self-attention mechanisms.

Parameters:

Name Type Description Default
xq Tensor

Query tensor to encode with positional information.

required
xk Tensor

Key tensor to encode with positional information.

required
freqs_cis Tensor

Complex-valued frequency components for rotary encoding with shape matching the last two dimensions of xq.

required
repeat_freqs_k bool

Whether to repeat frequency components along sequence length dimension to match key sequence length.

False

Returns:

Name Type Description
xq_out Tensor

Query tensor with rotary positional encoding applied.

xk_out Tensor

Key tensor with rotary positional encoding applied, or original xk if xk is empty.

Examples:

>>> import torch
>>> xq = torch.randn(2, 8, 16, 64)  # [batch, heads, seq_len, dim]
>>> xk = torch.randn(2, 8, 16, 64)
>>> freqs_cis = compute_axial_cis(64, 4, 4)  # For a 4x4 spatial grid with dim=64
>>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)
Source code in ultralytics/models/sam/modules/utils.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def apply_rotary_enc(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
    repeat_freqs_k: bool = False,
):
    """
    Apply rotary positional encoding to query and key tensors.

    This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
    components. RoPE is a technique that injects relative position information into self-attention mechanisms.

    Args:
        xq (torch.Tensor): Query tensor to encode with positional information.
        xk (torch.Tensor): Key tensor to encode with positional information.
        freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
            last two dimensions of xq.
        repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
            to match key sequence length.

    Returns:
        xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
        xk_out (torch.Tensor): Key tensor with rotary positional encoding applied, or original xk if xk is empty.

    Examples:
        >>> import torch
        >>> xq = torch.randn(2, 8, 16, 64)  # [batch, heads, seq_len, dim]
        >>> xk = torch.randn(2, 8, 16, 64)
        >>> freqs_cis = compute_axial_cis(64, 4, 4)  # For a 4x4 spatial grid with dim=64
        >>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)
    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    if xk_ is None:
        # No keys to rotate, due to dropout
        return xq_out.type_as(xq).to(xq.device), xk
    # Repeat freqs along seq_len dim to match k seq_len
    if repeat_freqs_k:
        r = xk_.shape[-2] // xq_.shape[-2]
        freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)





ultralytics.models.sam.modules.utils.window_partition

window_partition(x, window_size)

Partition input tensor into non-overlapping windows with padding if needed.

Parameters:

Name Type Description Default
x Tensor

Input tensor with shape (B, H, W, C).

required
window_size int

Size of each window.

required

Returns:

Name Type Description
windows Tensor

Partitioned windows with shape (B * num_windows, window_size, window_size, C).

padded_h_w Tuple[int, int]

Padded height and width before partition.

Examples:

>>> x = torch.randn(1, 16, 16, 3)
>>> windows, (Hp, Wp) = window_partition(x, window_size=4)
>>> print(windows.shape, Hp, Wp)
torch.Size([16, 4, 4, 3]) 16 16
Source code in ultralytics/models/sam/modules/utils.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def window_partition(x, window_size):
    """
    Partition input tensor into non-overlapping windows with padding if needed.

    Args:
        x (torch.Tensor): Input tensor with shape (B, H, W, C).
        window_size (int): Size of each window.

    Returns:
        windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).
        padded_h_w (Tuple[int, int]): Padded height and width before partition.

    Examples:
        >>> x = torch.randn(1, 16, 16, 3)
        >>> windows, (Hp, Wp) = window_partition(x, window_size=4)
        >>> print(windows.shape, Hp, Wp)
        torch.Size([16, 4, 4, 3]) 16 16
    """
    B, H, W, C = x.shape

    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
    Hp, Wp = H + pad_h, W + pad_w

    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows, (Hp, Wp)





ultralytics.models.sam.modules.utils.window_unpartition

window_unpartition(windows, window_size, pad_hw, hw)

Unpartition windowed sequences into original sequences and remove padding.

This function reverses the windowing process, reconstructing the original input from windowed segments and removing any padding that was added during the windowing process.

Parameters:

Name Type Description Default
windows Tensor

Input tensor of windowed sequences with shape (B * num_windows, window_size, window_size, C), where B is the batch size, num_windows is the number of windows, window_size is the size of each window, and C is the number of channels.

required
window_size int

Size of each window.

required
pad_hw Tuple[int, int]

Padded height and width (Hp, Wp) of the input before windowing.

required
hw Tuple[int, int]

Original height and width (H, W) of the input before padding and windowing.

required

Returns:

Type Description
Tensor

Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W are the original height and width, and C is the number of channels.

Examples:

>>> windows = torch.rand(32, 8, 8, 64)  # 32 windows of size 8x8 with 64 channels
>>> pad_hw = (16, 16)  # Padded height and width
>>> hw = (15, 14)  # Original height and width
>>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)
>>> print(x.shape)
torch.Size([1, 15, 14, 64])
Source code in ultralytics/models/sam/modules/utils.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def window_unpartition(windows, window_size, pad_hw, hw):
    """
    Unpartition windowed sequences into original sequences and remove padding.

    This function reverses the windowing process, reconstructing the original input from windowed segments
    and removing any padding that was added during the windowing process.

    Args:
        windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
            window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
            the size of each window, and C is the number of channels.
        window_size (int): Size of each window.
        pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
        hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.

    Returns:
        (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
            are the original height and width, and C is the number of channels.

    Examples:
        >>> windows = torch.rand(32, 8, 8, 64)  # 32 windows of size 8x8 with 64 channels
        >>> pad_hw = (16, 16)  # Padded height and width
        >>> hw = (15, 14)  # Original height and width
        >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)
        >>> print(x.shape)
        torch.Size([1, 15, 14, 64])
    """
    Hp, Wp = pad_hw
    H, W = hw
    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)

    if Hp > H or Wp > W:
        x = x[:, :H, :W, :].contiguous()
    return x





ultralytics.models.sam.modules.utils.get_rel_pos

get_rel_pos(q_size: int, k_size: int, rel_pos: Tensor) -> torch.Tensor

Extract relative positional embeddings based on query and key sizes.

Parameters:

Name Type Description Default
q_size int

Size of the query.

required
k_size int

Size of the key.

required
rel_pos Tensor

Relative position embeddings with shape (L, C), where L is the maximum relative distance and C is the embedding dimension.

required

Returns:

Type Description
Tensor

Extracted positional embeddings according to relative positions, with shape (q_size, k_size, C).

Examples:

>>> q_size, k_size = 8, 16
>>> rel_pos = torch.randn(31, 64)  # 31 = 2 * max(8, 16) - 1
>>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
>>> print(extracted_pos.shape)
torch.Size([8, 16, 64])
Source code in ultralytics/models/sam/modules/utils.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
    """
    Extract relative positional embeddings based on query and key sizes.

    Args:
        q_size (int): Size of the query.
        k_size (int): Size of the key.
        rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
            distance and C is the embedding dimension.

    Returns:
        (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
            k_size, C).

    Examples:
        >>> q_size, k_size = 8, 16
        >>> rel_pos = torch.randn(31, 64)  # 31 = 2 * max(8, 16) - 1
        >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
        >>> print(extracted_pos.shape)
        torch.Size([8, 16, 64])
    """
    max_rel_dist = int(2 * max(q_size, k_size) - 1)
    # Interpolate rel pos if needed.
    if rel_pos.shape[0] != max_rel_dist:
        # Interpolate rel pos.
        rel_pos_resized = F.interpolate(
            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
            size=max_rel_dist,
            mode="linear",
        )
        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
    else:
        rel_pos_resized = rel_pos

    # Scale the coords with short length if shapes for q and k are different.
    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

    return rel_pos_resized[relative_coords.long()]





ultralytics.models.sam.modules.utils.add_decomposed_rel_pos

add_decomposed_rel_pos(
    attn: Tensor,
    q: Tensor,
    rel_pos_h: Tensor,
    rel_pos_w: Tensor,
    q_size: Tuple[int, int],
    k_size: Tuple[int, int],
) -> torch.Tensor

Add decomposed Relative Positional Embeddings to the attention map.

This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2 paper. It enhances the attention mechanism by incorporating spatial relationships between query and key positions.

Parameters:

Name Type Description Default
attn Tensor

Attention map with shape (B, q_h * q_w, k_h * k_w).

required
q Tensor

Query tensor in the attention layer with shape (B, q_h * q_w, C).

required
rel_pos_h Tensor

Relative position embeddings for height axis with shape (Lh, C).

required
rel_pos_w Tensor

Relative position embeddings for width axis with shape (Lw, C).

required
q_size Tuple[int, int]

Spatial sequence size of query q as (q_h, q_w).

required
k_size Tuple[int, int]

Spatial sequence size of key k as (k_h, k_w).

required

Returns:

Type Description
Tensor

Updated attention map with added relative positional embeddings, shape (B, q_h * q_w, k_h * k_w).

Examples:

>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
>>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
>>> q = torch.rand(B, q_h * q_w, C)
>>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
>>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
>>> q_size, k_size = (q_h, q_w), (k_h, k_w)
>>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
>>> print(updated_attn.shape)
torch.Size([1, 64, 64])
References

https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py

Source code in ultralytics/models/sam/modules/utils.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def add_decomposed_rel_pos(
    attn: torch.Tensor,
    q: torch.Tensor,
    rel_pos_h: torch.Tensor,
    rel_pos_w: torch.Tensor,
    q_size: Tuple[int, int],
    k_size: Tuple[int, int],
) -> torch.Tensor:
    """
    Add decomposed Relative Positional Embeddings to the attention map.

    This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
    paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
    positions.

    Args:
        attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
        q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
        rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
        rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
        q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
        k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).

    Returns:
        (torch.Tensor): Updated attention map with added relative positional embeddings, shape
            (B, q_h * q_w, k_h * k_w).

    Examples:
        >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
        >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
        >>> q = torch.rand(B, q_h * q_w, C)
        >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
        >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
        >>> q_size, k_size = (q_h, q_w), (k_h, k_w)
        >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
        >>> print(updated_attn.shape)
        torch.Size([1, 64, 64])

    References:
        https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
    """
    q_h, q_w = q_size
    k_h, k_w = k_size
    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
    Rw = get_rel_pos(q_w, k_w, rel_pos_w)

    B, _, dim = q.shape
    r_q = q.reshape(B, q_h, q_w, dim)
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)

    attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
        B, q_h * q_w, k_h * k_w
    )

    return attn





📅 Created 8 months ago ✏️ Updated 7 months ago