Skip to content

Reference for ultralytics/models/sam/sam3/vitdet.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vitdet.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.sam.sam3.vitdet.Attention

def __init__(
    self,
    dim: int,
    num_heads: int = 8,
    qkv_bias: bool = True,
    use_rel_pos: bool = False,
    rel_pos_zero_init: bool = True,
    input_size: tuple[int, int] | None = None,
    cls_token: bool = False,
    use_rope: bool = False,
    rope_theta: float = 10000.0,
    rope_pt_size: tuple[int, int] | None = None,
    rope_interp: bool = False,
)

Bases: nn.Module

Multi-head Attention block with relative position embeddings and 2d-rope.

dim (int): Number of input channels. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. input_size (int or None): Input resolution for calculating the relative positional parameter size or rope size. attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer". cls_token: whether a cls_token is present. use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together) use_rel_pos: whether to use relative positional embeddings rope_theta: control frequencies of rope rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling rope_interp: whether to interpolate (or extrapolate) rope to match input size.

Args

NameTypeDescriptionDefault
dimintrequired
num_headsint8
qkv_biasboolTrue
use_rel_posboolFalse
rel_pos_zero_initboolTrue
input_sizetuple[int, int] | NoneNone
cls_tokenboolFalse
use_ropeboolFalse
rope_thetafloat10000.0
rope_pt_sizetuple[int, int] | NoneNone
rope_interpboolFalse

Methods

NameDescription
_apply_ropeApply 2d-rope to q and k.
_setup_rel_posSetup relative positional embeddings.
_setup_rope_freqsSetup 2d-rope frequencies.
forwardForward pass of attention block.
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
class Attention(nn.Module):
    """Multi-head Attention block with relative position embeddings and 2d-rope."""

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = True,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = True,
        input_size: tuple[int, int] | None = None,
        cls_token: bool = False,
        use_rope: bool = False,
        rope_theta: float = 10000.0,
        rope_pt_size: tuple[int, int] | None = None,
        rope_interp: bool = False,
    ):
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            input_size (int or None): Input resolution for calculating the relative positional parameter size or rope
                size.
            attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
            cls_token: whether a cls_token is present.
            use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
            use_rel_pos: whether to use relative positional embeddings
            rope_theta: control frequencies of rope
            rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
            rope_interp: whether to interpolate (or extrapolate) rope to match input size.
        """
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim**-0.5
        self.cls_token = cls_token

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

        # rel_pos embeddings and rope
        self.use_rel_pos = use_rel_pos
        self.input_size = input_size

        self.use_rope = use_rope
        self.rope_theta = rope_theta
        self.rope_pt_size = rope_pt_size
        self.rope_interp = rope_interp

        # init rel_pos embeddings and rope
        self._setup_rel_pos(rel_pos_zero_init, input_size)
        self._setup_rope_freqs(input_size)


method ultralytics.models.sam.sam3.vitdet.Attention._apply_rope

def _apply_rope(self, q, k) -> tuple[Tensor, Tensor]

Apply 2d-rope to q and k.

Args

NameTypeDescriptionDefault
qrequired
krequired
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
def _apply_rope(self, q, k) -> tuple[Tensor, Tensor]:
    """Apply 2d-rope to q and k."""
    if not self.use_rope:
        return q, k

    assert self.freqs_cis is not None
    return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis.to(q.device))


method ultralytics.models.sam.sam3.vitdet.Attention._setup_rel_pos

def _setup_rel_pos(self, rel_pos_zero_init: bool = True, input_size: tuple[int, int] | None = None) -> None

Setup relative positional embeddings.

Args

NameTypeDescriptionDefault
rel_pos_zero_initboolTrue
input_sizetuple[int, int] | NoneNone
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
def _setup_rel_pos(self, rel_pos_zero_init: bool = True, input_size: tuple[int, int] | None = None) -> None:
    """Setup relative positional embeddings."""
    if not self.use_rel_pos:
        self.rel_pos_h = None
        self.rel_pos_w = None
        return

    assert input_size is not None
    assert self.cls_token is False, "not supported"
    # initialize relative positional embeddings
    self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
    self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))

    if not rel_pos_zero_init:
        nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
        nn.init.trunc_normal_(self.rel_pos_w, std=0.02)

    # Precompute the relative coords
    H, W = input_size
    q_coords = torch.arange(H)[:, None]
    k_coords = torch.arange(W)[None, :]
    relative_coords = (q_coords - k_coords) + (H - 1)
    self.relative_coords = relative_coords.long()


method ultralytics.models.sam.sam3.vitdet.Attention._setup_rope_freqs

def _setup_rope_freqs(self, input_size: tuple[int, int] | None = None) -> None

Setup 2d-rope frequencies.

Args

NameTypeDescriptionDefault
input_sizetuple[int, int] | NoneNone
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
def _setup_rope_freqs(self, input_size: tuple[int, int] | None = None) -> None:
    """Setup 2d-rope frequencies."""
    if not self.use_rope:
        self.freqs_cis = None
        return

    assert input_size is not None
    # determine rope input size
    if self.rope_pt_size is None:
        self.rope_pt_size = input_size

    # initialize 2d rope freqs
    self.compute_cis = partial(
        compute_axial_cis,
        dim=self.head_dim,
        theta=self.rope_theta,
    )

    # interpolate rope
    scale_pos = 1.0
    if self.rope_interp:
        scale_pos = self.rope_pt_size[0] / input_size[0]
    # get scaled freqs_cis
    freqs_cis = self.compute_cis(
        end_x=input_size[0],
        end_y=input_size[1],
        scale_pos=scale_pos,
    )
    if self.cls_token:
        t = torch.zeros(
            self.head_dim // 2,
            dtype=torch.float32,
            device=freqs_cis.device,
        )
        cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :]
        freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0)

    self.freqs_cis = freqs_cis


method ultralytics.models.sam.sam3.vitdet.Attention.forward

def forward(self, x: Tensor) -> Tensor

Forward pass of attention block.

Args

NameTypeDescriptionDefault
xTensorrequired
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
def forward(self, x: Tensor) -> Tensor:
    """Forward pass of attention block."""
    s = 1 if self.cls_token else 0  # used to exclude cls_token
    if x.ndim == 4:
        B, H, W, _ = x.shape
        assert s == 0  # no cls_token
        L = H * W
        ndim = 4
    else:
        assert x.ndim == 3
        B, L, _ = x.shape
        ndim = 3
        H = W = math.sqrt(L - s)

    # qkv with shape (3, B, nHead, L, C)
    qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1)
    # q, k, v with shape (B, nHead, L, C)
    q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)

    # handle rope and rel pos embeddings
    q, k = self._apply_rope(q, k)
    if self.use_rel_pos:
        q, k = concat_rel_pos(
            q.flatten(0, 1),
            k.flatten(0, 1),
            (H, W),
            x.shape[1:3],
            self.rel_pos_h,
            self.rel_pos_w,
            rescale=True,
            relative_coords=self.relative_coords,
        )

        # sdpa expects [B, nheads, H*W, C] so we transpose back
        q = q.reshape(B, self.num_heads, H * W, -1)
        k = k.reshape(B, self.num_heads, H * W, -1)

    x = F.scaled_dot_product_attention(q, k, v)

    if ndim == 4:
        x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
    else:
        x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)

    x = self.proj(x)

    return x





class ultralytics.models.sam.sam3.vitdet.Block

def __init__(
    self,
    dim: int,
    num_heads: int,
    mlp_ratio: float = 4.0,
    qkv_bias: bool = True,
    drop_path: float = 0.0,
    norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
    act_layer: Callable[..., nn.Module] = nn.GELU,
    use_rel_pos: bool = False,
    rel_pos_zero_init: bool = True,
    window_size: int = 0,
    input_size: tuple[int, int] | None = None,
    use_rope: bool = False,
    rope_pt_size: tuple[int, int] | None = None,
    rope_interp: bool = False,
    cls_token: bool = False,
    dropout: float = 0.0,
    init_values: float | None = None,
)

Bases: nn.Module

Transformer blocks with support of window attention.

dim (int): Number of input channels. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. drop_path (float): Stochastic depth rate. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then not use window attention. input_size (int or None): Input resolution for calculating the relative positional parameter size. dropout (float): Dropout rate. cls_token: whether a cls_token is present. use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together) rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling rope_interp: whether to interpolate (or extrapolate) rope to match target input size, expected to specify source size as rope_pt_size. init_values: layer scale init, None for no layer scale.

Args

NameTypeDescriptionDefault
dimintrequired
num_headsintrequired
mlp_ratiofloat4.0
qkv_biasboolTrue
drop_pathfloat0.0
norm_layerCallable[..., nn.Module]nn.LayerNorm
act_layerCallable[..., nn.Module]nn.GELU
use_rel_posboolFalse
rel_pos_zero_initboolTrue
window_sizeint0
input_sizetuple[int, int] | NoneNone
use_ropeboolFalse
rope_pt_sizetuple[int, int] | NoneNone
rope_interpboolFalse
cls_tokenboolFalse
dropoutfloat0.0
init_valuesfloat | NoneNone

Methods

NameDescription
forwardForward pass of the transformer block.
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
class Block(nn.Module):
    """Transformer blocks with support of window attention."""

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_path: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        act_layer: Callable[..., nn.Module] = nn.GELU,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = True,
        window_size: int = 0,
        input_size: tuple[int, int] | None = None,
        use_rope: bool = False,
        rope_pt_size: tuple[int, int] | None = None,
        rope_interp: bool = False,
        cls_token: bool = False,
        dropout: float = 0.0,
        init_values: float | None = None,
    ):
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            drop_path (float): Stochastic depth rate.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks. If it equals 0, then not use window attention.
            input_size (int or None): Input resolution for calculating the relative positional parameter size.
            dropout (float): Dropout rate.
            cls_token: whether a cls_token is present.
            use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
            rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
            rope_interp: whether to interpolate (or extrapolate) rope to match target input size, expected to specify
                source size as rope_pt_size.
            init_values: layer scale init, None for no layer scale.
        """
        super().__init__()

        check_requirements("timm")
        from timm.layers import DropPath, Mlp

        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            use_rel_pos=use_rel_pos,
            rel_pos_zero_init=rel_pos_zero_init,
            input_size=input_size if window_size == 0 else (window_size, window_size),
            use_rope=use_rope,
            rope_pt_size=rope_pt_size,
            rope_interp=rope_interp,
            cls_token=cls_token,
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=(dropout, 0.0),
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.dropout = nn.Dropout(dropout)
        self.window_size = window_size


method ultralytics.models.sam.sam3.vitdet.Block.forward

def forward(self, x: Tensor) -> Tensor

Forward pass of the transformer block.

Args

NameTypeDescriptionDefault
xTensorrequired
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
def forward(self, x: Tensor) -> Tensor:
    """Forward pass of the transformer block."""
    shortcut = x
    x = self.norm1(x)
    # Window partition
    if self.window_size > 0:
        H, W = x.shape[1], x.shape[2]
        x, pad_hw = window_partition(x, self.window_size)

    x = self.ls1(self.attn(x))
    # Reverse window partition
    if self.window_size > 0:
        x = window_unpartition(x, self.window_size, pad_hw, (H, W))

    x = shortcut + self.dropout(self.drop_path(x))
    x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))

    return x





class ultralytics.models.sam.sam3.vitdet.ViT

def __init__(
    self,
    img_size: int = 1024,
    patch_size: int = 16,
    in_chans: int = 3,
    embed_dim: int = 768,
    depth: int = 12,
    num_heads: int = 12,
    mlp_ratio: float = 4.0,
    qkv_bias: bool = True,
    drop_path_rate: float = 0.0,
    norm_layer: Callable[..., nn.Module] | str = "LayerNorm",
    act_layer: Callable[..., nn.Module] = nn.GELU,
    use_abs_pos: bool = True,
    tile_abs_pos: bool = True,
    rel_pos_blocks: tuple[int, ...] | bool = (2, 5, 8, 11),
    rel_pos_zero_init: bool = True,
    window_size: int = 14,
    global_att_blocks: tuple[int, ...] = (2, 5, 8, 11),
    use_rope: bool = False,
    rope_pt_size: int | None = None,
    use_interp_rope: bool = False,
    pretrain_img_size: int = 224,
    pretrain_use_cls_token: bool = True,
    retain_cls_token: bool = True,
    dropout: float = 0.0,
    return_interm_layers: bool = False,
    init_values: float | None = None,  # for layerscale
    ln_pre: bool = False,
    ln_post: bool = False,
    bias_patch_embed: bool = True,
    compile_mode: str | None = None,
    use_act_checkpoint: bool = True,
)

Bases: nn.Module

This module implements Vision Transformer (ViT) backbone in :paper:vitdet. "Exploring Plain Vision Transformer

Backbones for Object Detection", https://arxiv.org/abs/2203.16527.

img_size (int): Input image size. Only relevant for rel pos or rope. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. drop_path_rate (float): Stochastic depth rate. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation. rel_pos_blocks (list): Blocks which have rel pos embeddings. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention). use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together). rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling. use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size, expected to specify source size as rope_pt_size. use_act_checkpoint (bool): If True, use activation checkpointing. pretrain_img_size (int): input image size for pretraining models. pretrain_use_cls_token (bool): If True, pretraining models use class token. retain_cls_token: whether cls_token should be retained. dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp. return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks). init_values: layer scale init, None for no layer scale. ln_pre (bool): If True, apply layer norm before transformer blocks. ln_post (bool): If True, apply layer norm after transformer blocks. bias_patch_embed (bool): bias in conv for patch embed? compile_mode (str): mode to compile the forward.

Args

NameTypeDescriptionDefault
img_sizeint1024
patch_sizeint16
in_chansint3
embed_dimint768
depthint12
num_headsint12
mlp_ratiofloat4.0
qkv_biasboolTrue
drop_path_ratefloat0.0
norm_layerCallable[..., nn.Module] | str"LayerNorm"
act_layerCallable[..., nn.Module]nn.GELU
use_abs_posboolTrue
tile_abs_posboolTrue
rel_pos_blockstuple[int, ...] | bool(2, 5, 8, 11)
rel_pos_zero_initboolTrue
window_sizeint14
global_att_blockstuple[int, ...](2, 5, 8, 11)
use_ropeboolFalse
rope_pt_sizeint | NoneNone
use_interp_ropeboolFalse
pretrain_img_sizeint224
pretrain_use_cls_tokenboolTrue
retain_cls_tokenboolTrue
dropoutfloat0.0
return_interm_layersboolFalse
init_valuesfloat | NoneNone
ln_preboolFalse
ln_postboolFalse
bias_patch_embedboolTrue
compile_modestr | NoneNone
use_act_checkpointboolTrue

Methods

NameDescription
_init_weightsInitialize the weights.
forwardVit forward path and get feature maps.
set_imgszSetup rel pos embeddings and rope freqs for a new input image size.
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
class ViT(nn.Module):
    """This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. "Exploring Plain Vision Transformer
    Backbones for Object Detection", https://arxiv.org/abs/2203.16527.
    """

    def __init__(
        self,
        img_size: int = 1024,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_path_rate: float = 0.0,
        norm_layer: Callable[..., nn.Module] | str = "LayerNorm",
        act_layer: Callable[..., nn.Module] = nn.GELU,
        use_abs_pos: bool = True,
        tile_abs_pos: bool = True,
        rel_pos_blocks: tuple[int, ...] | bool = (2, 5, 8, 11),
        rel_pos_zero_init: bool = True,
        window_size: int = 14,
        global_att_blocks: tuple[int, ...] = (2, 5, 8, 11),
        use_rope: bool = False,
        rope_pt_size: int | None = None,
        use_interp_rope: bool = False,
        pretrain_img_size: int = 224,
        pretrain_use_cls_token: bool = True,
        retain_cls_token: bool = True,
        dropout: float = 0.0,
        return_interm_layers: bool = False,
        init_values: float | None = None,  # for layerscale
        ln_pre: bool = False,
        ln_post: bool = False,
        bias_patch_embed: bool = True,
        compile_mode: str | None = None,
        use_act_checkpoint: bool = True,
    ):
        """
        Args:
            img_size (int): Input image size. Only relevant for rel pos or rope.
            patch_size (int): Patch size.
            in_chans (int): Number of input image channels.
            embed_dim (int): Patch embedding dimension.
            depth (int): Depth of ViT.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            drop_path_rate (float): Stochastic depth rate.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            use_abs_pos (bool): If True, use absolute positional embeddings.
            tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation.
            rel_pos_blocks (list): Blocks which have rel pos embeddings.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks.
            global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention).
            use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together).
            rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling.
            use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size, expected to
                specify source size as rope_pt_size.
            use_act_checkpoint (bool): If True, use activation checkpointing.
            pretrain_img_size (int): input image size for pretraining models.
            pretrain_use_cls_token (bool): If True, pretraining models use class token.
            retain_cls_token: whether cls_token should be retained.
            dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
            return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
            init_values: layer scale init, None for no layer scale.
            ln_pre (bool): If True, apply layer norm before transformer blocks.
            ln_post (bool): If True, apply layer norm after transformer blocks.
            bias_patch_embed (bool): bias in conv for patch embed?
            compile_mode (str): mode to compile the forward.
        """
        super().__init__()
        self.pretrain_use_cls_token = pretrain_use_cls_token

        window_block_indexes = [i for i in range(depth) if i not in global_att_blocks]
        self.full_attn_ids = list(global_att_blocks)
        self.rel_pos_blocks = [False] * depth
        if isinstance(rel_pos_blocks, bool) and rel_pos_blocks:
            self.rel_pos_blocks = [True] * depth
        else:
            for i in rel_pos_blocks:
                self.rel_pos_blocks[i] = True

        self.retain_cls_token = retain_cls_token
        if self.retain_cls_token:
            assert pretrain_use_cls_token
            assert len(window_block_indexes) == 0, "windowing not supported with cls token"

            assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"

            scale = embed_dim**-0.5
            self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim))

        if isinstance(norm_layer, str):
            norm_layer = partial(getattr(nn, norm_layer), eps=1e-5)

        self.patch_embed = PatchEmbed(
            kernel_size=(patch_size, patch_size),
            stride=(patch_size, patch_size),
            in_chans=in_chans,
            embed_dim=embed_dim,
            bias=bias_patch_embed,
        )

        # Handle absolute positional embedding
        self.tile_abs_pos = tile_abs_pos
        self.use_abs_pos = use_abs_pos
        if self.tile_abs_pos:
            assert self.use_abs_pos

        if self.use_abs_pos:
            # Initialize absolute positional embedding with pretrain image size.
            num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
            num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
            self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
        else:
            self.pos_embed = None

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        self.patch_size = patch_size
        self.window_size = window_size
        self.blocks = nn.ModuleList()
        cur_stage = 1
        for i in range(depth):
            block = Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                act_layer=act_layer,
                use_rel_pos=self.rel_pos_blocks[i],
                rel_pos_zero_init=rel_pos_zero_init,
                window_size=window_size if i in window_block_indexes else 0,
                input_size=(img_size // patch_size, img_size // patch_size),
                use_rope=use_rope,
                rope_pt_size=((window_size, window_size) if rope_pt_size is None else (rope_pt_size, rope_pt_size)),
                rope_interp=use_interp_rope,
                cls_token=self.retain_cls_token,
                dropout=dropout,
                init_values=init_values,
            )

            if i not in window_block_indexes:
                cur_stage += 1

            self.use_act_checkpoint = use_act_checkpoint

            self.blocks.append(block)

        self.return_interm_layers = return_interm_layers
        self.channel_list = [embed_dim] * len(self.full_attn_ids) if return_interm_layers else [embed_dim]

        if self.pos_embed is not None:
            nn.init.trunc_normal_(self.pos_embed, std=0.02)

        self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity()
        self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity()

        self.apply(self._init_weights)

        if compile_mode is not None:
            self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
            if self.use_act_checkpoint and self.training:
                torch._dynamo.config.optimize_ddp = False


method ultralytics.models.sam.sam3.vitdet.ViT._init_weights

def _init_weights(self, m: nn.Module) -> None

Initialize the weights.

Args

NameTypeDescriptionDefault
mnn.Modulerequired
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
def _init_weights(self, m: nn.Module) -> None:
    """Initialize the weights."""
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)


method ultralytics.models.sam.sam3.vitdet.ViT.forward

def forward(self, x: torch.Tensor) -> list[torch.Tensor]

Vit forward path and get feature maps.

Args

NameTypeDescriptionDefault
xtorch.Tensorrequired
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
    """Vit forward path and get feature maps."""
    x = self.patch_embed(x)
    h, w = x.shape[1], x.shape[2]

    s = 0
    if self.retain_cls_token:
        # If cls_token is retained, we don't
        # maintain spatial shape
        x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
        s = 1

    if self.pos_embed is not None:
        x = x + get_abs_pos(
            self.pos_embed,
            self.pretrain_use_cls_token,
            (h, w),
            self.retain_cls_token,
            tiling=self.tile_abs_pos,
        )

    x = self.ln_pre(x)

    outputs = []
    for i, blk in enumerate(self.blocks):
        if self.use_act_checkpoint and self.training:
            x = checkpoint.checkpoint(blk, x, use_reentrant=False)
        else:
            x = blk(x)
        if (i == self.full_attn_ids[-1]) or (self.return_interm_layers and i in self.full_attn_ids):
            if i == self.full_attn_ids[-1]:
                x = self.ln_post(x)

            feats = x[:, s:]
            if feats.ndim == 4:
                feats = feats.permute(0, 3, 1, 2)
            else:
                assert feats.ndim == 3
                h = w = math.sqrt(feats.shape[1])
                feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)

            outputs.append(feats)

    return outputs


method ultralytics.models.sam.sam3.vitdet.ViT.set_imgsz

def set_imgsz(self, imgsz: list[int] = [1008, 1008])

Setup rel pos embeddings and rope freqs for a new input image size.

Args

NameTypeDescriptionDefault
imgszlist[int][1008, 1008]
Source code in ultralytics/models/sam/sam3/vitdet.pyView on GitHub
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
    """Setup rel pos embeddings and rope freqs for a new input image size."""
    for block in self.blocks:
        if block.window_size != 0:
            continue
        block.attn._setup_rel_pos(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))
        block.attn._setup_rope_freqs(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))





📅 Created 0 days ago ✏️ Updated 0 days ago
Laughing-q