Skip to content

Reference for ultralytics/models/sam/sam3/text_encoder_ve.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/text_encoder_ve.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.sam.sam3.text_encoder_ve.ResidualAttentionBlock

def __init__(
    self,
    d_model: int,
    n_head: int,
    mlp_ratio: float = 4.0,
    ls_init_value: float | None = None,
    act_layer: Callable[[], nn.Module] = nn.GELU,
    norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
)

Bases: nn.Module

Transformer block with multi-head attention, layer normalization, and MLP feed-forward network.

Args

NameTypeDescriptionDefault
d_modelintrequired
n_headintrequired
mlp_ratiofloat4.0
ls_init_valuefloat | NoneNone
act_layerCallable[[], nn.Module]nn.GELU
norm_layerCallable[[int], nn.Module]nn.LayerNorm

Methods

NameDescription
attentionCompute multi-head attention with optional cross-attention support and masking.
forwardApply residual attention with layer normalization and MLP, supporting optional cross-attention.
Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
class ResidualAttentionBlock(nn.Module):
    """Transformer block with multi-head attention, layer normalization, and MLP feed-forward network."""

    def __init__(
        self,
        d_model: int,
        n_head: int,
        mlp_ratio: float = 4.0,
        ls_init_value: float | None = None,
        act_layer: Callable[[], nn.Module] = nn.GELU,
        norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
    ):
        """Initialize residual attention block with configurable dimensions and normalization."""
        super().__init__()
        # Attention
        self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)

        # LayerNorm, LayerScale
        self.ln_1 = norm_layer(d_model)
        self.ln_2 = norm_layer(d_model)

        self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
        self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()

        # MLP
        mlp_width = int(d_model * mlp_ratio)
        self.mlp = nn.Sequential(
            OrderedDict(
                [
                    ("c_fc", nn.Linear(d_model, mlp_width)),
                    ("gelu", act_layer()),
                    ("c_proj", nn.Linear(mlp_width, d_model)),
                ]
            )
        )


method ultralytics.models.sam.sam3.text_encoder_ve.ResidualAttentionBlock.attention

def attention(
    self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None
) -> torch.Tensor

Compute multi-head attention with optional cross-attention support and masking.

Args

NameTypeDescriptionDefault
q_xtorch.Tensorrequired
k_xtorch.TensorNone
v_xtorch.TensorNone
attn_masktorch.TensorNone
Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
def attention(
    self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None
) -> torch.Tensor:
    """Compute multi-head attention with optional cross-attention support and masking."""
    k_x = k_x if k_x is not None else q_x
    v_x = v_x if v_x is not None else q_x
    if attn_mask is not None:
        # Leave boolean masks as is
        if not attn_mask.dtype == torch.bool:
            attn_mask = attn_mask.to(q_x.dtype)

    return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]


method ultralytics.models.sam.sam3.text_encoder_ve.ResidualAttentionBlock.forward

def forward(
    self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None
) -> torch.Tensor

Apply residual attention with layer normalization and MLP, supporting optional cross-attention.

Args

NameTypeDescriptionDefault
q_xtorch.Tensorrequired
k_xtorch.TensorNone
v_xtorch.TensorNone
attn_masktorch.TensorNone
Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
def forward(
    self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None
) -> torch.Tensor:
    """Apply residual attention with layer normalization and MLP, supporting optional cross-attention."""
    k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
    v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
    x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
    x = x + self.ls_2(self.mlp(self.ln_2(x)))
    return x





class ultralytics.models.sam.sam3.text_encoder_ve.Transformer

def __init__(
    self,
    width: int,
    layers: int,
    heads: int,
    mlp_ratio: float = 4.0,
    ls_init_value: float | None = None,
    act_layer: Callable[[], nn.Module] = nn.GELU,
    norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
    compile_mode: str | None = None,
    use_act_checkpoint: bool = False,
)

Bases: nn.Module

Stack of residual attention blocks forming a transformer encoder with optional gradient checkpointing.

Args

NameTypeDescriptionDefault
widthintrequired
layersintrequired
headsintrequired
mlp_ratiofloat4.0
ls_init_valuefloat | NoneNone
act_layerCallable[[], nn.Module]nn.GELU
norm_layerCallable[[int], nn.Module]nn.LayerNorm
compile_modestr | NoneNone
use_act_checkpointboolFalse

Methods

NameDescription
forwardProcess input through all transformer blocks with optional gradient checkpointing during training.
Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
class Transformer(nn.Module):
    """Stack of residual attention blocks forming a transformer encoder with optional gradient checkpointing."""

    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float = 4.0,
        ls_init_value: float | None = None,
        act_layer: Callable[[], nn.Module] = nn.GELU,
        norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
        compile_mode: str | None = None,
        use_act_checkpoint: bool = False,
    ):
        """Initialize transformer with configurable depth, width, and optional compilation/checkpointing."""
        super().__init__()
        self.width = width
        self.layers = layers
        self.grad_checkpointing = use_act_checkpoint
        self.resblocks = nn.ModuleList(
            [
                ResidualAttentionBlock(
                    width,
                    heads,
                    mlp_ratio,
                    ls_init_value=ls_init_value,
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                )
                for _ in range(layers)
            ]
        )

        if compile_mode is not None:
            self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
            if self.grad_checkpointing:
                torch._dynamo.config.optimize_ddp = False


method ultralytics.models.sam.sam3.text_encoder_ve.Transformer.forward

def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor

Process input through all transformer blocks with optional gradient checkpointing during training.

Args

NameTypeDescriptionDefault
xtorch.Tensorrequired
attn_masktorch.TensorNone
Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor:
    """Process input through all transformer blocks with optional gradient checkpointing during training."""
    for _, r in enumerate(self.resblocks):
        if self.grad_checkpointing and not torch.jit.is_scripting() and self.training:
            x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
        else:
            x = r(x, attn_mask=attn_mask)
    return x





class ultralytics.models.sam.sam3.text_encoder_ve.TextTransformer

def __init__(
    self,
    context_length: int = 77,
    vocab_size: int = 49408,
    width: int = 512,
    heads: int = 8,
    layers: int = 12,
    mlp_ratio: float = 4.0,
    ls_init_value: float | None = None,
    output_dim: int = 512,
    no_causal_mask: bool = False,
    pool_type: str = "none",  # no pooling
    proj_bias: bool = False,
    act_layer: Callable = nn.GELU,
    norm_layer: Callable = nn.LayerNorm,
    output_tokens: bool = False,
    use_ln_post: bool = True,
    compile_mode: str | None = None,
    use_act_checkpoint: bool = False,
)

Bases: nn.Module

Text transformer encoder with causal masking and flexible pooling strategies.

Args

NameTypeDescriptionDefault
context_lengthint77
vocab_sizeint49408
widthint512
headsint8
layersint12
mlp_ratiofloat4.0
ls_init_valuefloat | NoneNone
output_dimint512
no_causal_maskboolFalse
pool_typestr"none"
proj_biasboolFalse
act_layerCallablenn.GELU
norm_layerCallablenn.LayerNorm
output_tokensboolFalse
use_ln_postboolTrue
compile_modestr | NoneNone
use_act_checkpointboolFalse

Methods

NameDescription
build_causal_maskCreate a causal attention mask to prevent attention to future tokens.
forwardForward pass through the text transformer, returning pooled output and optionally token embeddings.
Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
class TextTransformer(nn.Module):
    """Text transformer encoder with causal masking and flexible pooling strategies."""

    def __init__(
        self,
        context_length: int = 77,
        vocab_size: int = 49408,
        width: int = 512,
        heads: int = 8,
        layers: int = 12,
        mlp_ratio: float = 4.0,
        ls_init_value: float | None = None,
        output_dim: int = 512,
        no_causal_mask: bool = False,
        pool_type: str = "none",  # no pooling
        proj_bias: bool = False,
        act_layer: Callable = nn.GELU,
        norm_layer: Callable = nn.LayerNorm,
        output_tokens: bool = False,
        use_ln_post: bool = True,
        compile_mode: str | None = None,
        use_act_checkpoint: bool = False,
    ):
        """Initialize text transformer with embedding layers, transformer blocks, and pooling options."""
        super().__init__()
        assert pool_type in ("first", "last", "argmax", "none")
        self.output_tokens = output_tokens
        self.num_pos = self.context_length = context_length
        self.vocab_size = vocab_size
        self.width = width
        self.output_dim = output_dim
        self.heads = heads
        self.pool_type = pool_type

        self.token_embedding = nn.Embedding(self.vocab_size, width)
        self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
        self.transformer = Transformer(
            width=width,
            layers=layers,
            heads=heads,
            mlp_ratio=mlp_ratio,
            ls_init_value=ls_init_value,
            act_layer=act_layer,
            norm_layer=norm_layer,
            compile_mode=compile_mode,
            use_act_checkpoint=use_act_checkpoint,
        )
        self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
        if no_causal_mask:
            self.attn_mask = None
        else:
            self.register_buffer("attn_mask", self.build_causal_mask(), persistent=False)
        if proj_bias:
            self.text_projection = nn.Linear(width, output_dim)
        else:
            self.text_projection = nn.Parameter(torch.empty(width, output_dim))


method ultralytics.models.sam.sam3.text_encoder_ve.TextTransformer.build_causal_mask

def build_causal_mask(self) -> torch.Tensor

Create a causal attention mask to prevent attention to future tokens.

Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
def build_causal_mask(self) -> torch.Tensor:
    """Create a causal attention mask to prevent attention to future tokens."""
    # lazily create causal attention mask, with full attention between the tokens
    # pytorch uses additive attention mask; fill with -inf
    mask = torch.empty(self.num_pos, self.num_pos)
    mask.fill_(float("-inf"))
    mask.triu_(1)  # zero out the lower diagonal
    return mask


method ultralytics.models.sam.sam3.text_encoder_ve.TextTransformer.forward

def forward(self, text: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]

Forward pass through the text transformer, returning pooled output and optionally token embeddings.

Args

NameTypeDescriptionDefault
texttorch.Tensorrequired
Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
def forward(self, text: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Forward pass through the text transformer, returning pooled output and optionally token embeddings."""
    seq_len = text.shape[1]
    x = self.token_embedding(text)  # [batch_size, n_ctx, d_model]

    attn_mask = self.attn_mask
    if attn_mask is not None:
        attn_mask = attn_mask[:seq_len, :seq_len]

    x = x + self.positional_embedding[:seq_len]
    x = self.transformer(x, attn_mask=attn_mask)

    x = self.ln_final(x)
    pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
    if self.text_projection is not None:
        if isinstance(self.text_projection, nn.Linear):
            pooled = self.text_projection(pooled)
        else:
            pooled = pooled @ self.text_projection
    if self.output_tokens:
        return pooled, tokens
    return pooled





class ultralytics.models.sam.sam3.text_encoder_ve.VETextEncoder

def __init__(
    self,
    d_model: int,
    tokenizer: Callable,
    width: int = 1024,
    heads: int = 16,
    layers: int = 24,
    context_length: int = 32,
    vocab_size: int = 49408,
    use_ln_post: bool = True,
    compile_mode: str | None = None,
    use_act_checkpoint: bool = True,
)

Bases: nn.Module

Text encoder for Vision Encoder (VE) models, combining a text transformer and a linear resizer.

Args

NameTypeDescriptionDefault
d_modelintrequired
tokenizerCallablerequired
widthint1024
headsint16
layersint24
context_lengthint32
vocab_sizeint49408
use_ln_postboolTrue
compile_modestr | NoneNone
use_act_checkpointboolTrue

Methods

NameDescription
forwardEncode text input, either raw strings or pre-encoded tensors, and resize to match decoder dimensions.
Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
class VETextEncoder(nn.Module):
    """Text encoder for Vision Encoder (VE) models, combining a text transformer and a linear resizer."""

    def __init__(
        self,
        d_model: int,
        tokenizer: Callable,
        width: int = 1024,
        heads: int = 16,
        layers: int = 24,
        context_length: int = 32,
        vocab_size: int = 49408,
        use_ln_post: bool = True,
        compile_mode: str | None = None,
        use_act_checkpoint: bool = True,
    ):
        """Initialize VE text encoder with a text transformer and a linear resizer to match decoder dimensions."""
        super().__init__()
        self.context_length = context_length
        self.use_ln_post = use_ln_post
        self.tokenizer = tokenizer

        self.encoder = TextTransformer(
            context_length=self.context_length,
            vocab_size=vocab_size,
            width=width,
            heads=heads,
            layers=layers,
            # we want the tokens, not just the pooled output
            output_tokens=True,
            use_ln_post=use_ln_post,
            compile_mode=compile_mode,
            use_act_checkpoint=use_act_checkpoint,
        )
        self.resizer = nn.Linear(self.encoder.width, d_model)


method ultralytics.models.sam.sam3.text_encoder_ve.VETextEncoder.forward

def forward(
    self, text: list[str] | tuple[torch.Tensor, torch.Tensor, dict], input_boxes: list | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Encode text input, either raw strings or pre-encoded tensors, and resize to match decoder dimensions.

Args

NameTypeDescriptionDefault
textlist[str] | tuple[torch.Tensor, torch.Tensor, dict]required
input_boxeslist | NoneNone
Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
def forward(
    self, text: list[str] | tuple[torch.Tensor, torch.Tensor, dict], input_boxes: list | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Encode text input, either raw strings or pre-encoded tensors, and resize to match decoder dimensions."""
    if isinstance(text[0], str):
        # no use case for this
        assert input_boxes is None or len(input_boxes) == 0, "not supported"

        # Encode the text
        tokenized = self.tokenizer(text, context_length=self.context_length).to(
            self.resizer.weight.device
        )  # [b, seq_len]
        text_attention_mask = (tokenized != 0).bool()

        # manually embed the tokens
        inputs_embeds = self.encoder.token_embedding(tokenized)  # [b, seq_len, d=1024]
        _, text_memory = self.encoder(tokenized)  # [b, seq_len, d=1024]

        assert text_memory.shape[1] == inputs_embeds.shape[1]
        # Invert attention mask because its the opposite in pytorch transformer
        text_attention_mask = text_attention_mask.ne(1)
        # Transpose memory because pytorch's attention expects sequence first
        text_memory = text_memory.transpose(0, 1)
        # Resize the encoder hidden states to be of the same d_model as the decoder
        text_memory_resized = self.resizer(text_memory)
    else:
        # The text is already encoded, use as is.
        text_attention_mask, text_memory_resized, tokenized = text
        inputs_embeds = tokenized["inputs_embeds"]
        assert input_boxes is None or len(input_boxes) == 0, "Can't replace boxes in text if it's already encoded"

    # Note that the input_embeds are returned in pytorch's convention (sequence first)
    return (
        text_attention_mask,
        text_memory_resized,
        inputs_embeds.transpose(0, 1),
    )





function ultralytics.models.sam.sam3.text_encoder_ve.text_global_pool

def text_global_pool(
    x: torch.Tensor, text: torch.Tensor = None, pool_type: str = "argmax"
) -> tuple[torch.Tensor, torch.Tensor]

Extract pooled representation and tokens from text embeddings using specified pooling strategy

(first/last/argmax/none).

Args

NameTypeDescriptionDefault
xtorch.Tensorrequired
texttorch.TensorNone
pool_typestr"argmax"
Source code in ultralytics/models/sam/sam3/text_encoder_ve.pyView on GitHub
def text_global_pool(
    x: torch.Tensor, text: torch.Tensor = None, pool_type: str = "argmax"
) -> tuple[torch.Tensor, torch.Tensor]:
    """Extract pooled representation and tokens from text embeddings using specified pooling strategy
    (first/last/argmax/none).
    """
    if pool_type == "first":
        pooled, tokens = x[:, 0], x[:, 1:]
    elif pool_type == "last":
        pooled, tokens = x[:, -1], x[:, :-1]
    elif pool_type == "argmax":
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        assert text is not None
        pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
    else:
        pooled = tokens = x
    return pooled, tokens





📅 Created 0 days ago ✏️ Updated 0 days ago
Laughing-q