Skip to content

Reference for ultralytics/models/sam/sam3/decoder.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/decoder.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer

def __init__(
    self,
    d_model: int,
    dim_feedforward: int,
    dropout: float,
    cross_attention: nn.Module,
    n_heads: int,
    use_text_cross_attention: bool = False,
)

Bases: nn.Module

TransformerDecoderLayer is made up of self-attn, cross-attn, and feedforward network (FFN).

Args

NameTypeDescriptionDefault
d_modelintrequired
dim_feedforwardintrequired
dropoutfloatrequired
cross_attentionnn.Modulerequired
n_headsintrequired
use_text_cross_attentionboolFalse

Methods

NameDescription
_apply_self_attentionApply self-attention with optional DAC splitting.
forwardInput: - tgt/tgt_query_pos: nq, bs, d_model. -.
forward_ffnFeedforward network forward pass.
with_pos_embedAdd positional embedding to the tensor.
Source code in ultralytics/models/sam/sam3/decoder.pyView on GitHub
class TransformerDecoderLayer(nn.Module):
    """TransformerDecoderLayer is made up of self-attn, cross-attn, and feedforward network (FFN)."""

    def __init__(
        self,
        d_model: int,
        dim_feedforward: int,
        dropout: float,
        cross_attention: nn.Module,
        n_heads: int,
        use_text_cross_attention: bool = False,
    ):
        """Initialize the TransformerDecoderLayer."""
        super().__init__()
        # cross attention
        self.cross_attn = cross_attention
        self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.norm1 = nn.LayerNorm(d_model)

        # cross attention text
        self.use_text_cross_attention = use_text_cross_attention
        if use_text_cross_attention:
            self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
            self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
            self.catext_norm = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.activation = nn.ReLU()
        self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.norm3 = nn.LayerNorm(d_model)


method ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer._apply_self_attention

def _apply_self_attention(self, tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask)

Apply self-attention with optional DAC splitting.

Args

NameTypeDescriptionDefault
tgtrequired
tgt_query_posrequired
dacrequired
dac_use_selfatt_lnrequired
presence_tokenrequired
self_attn_maskrequired
Source code in ultralytics/models/sam/sam3/decoder.pyView on GitHub
def _apply_self_attention(self, tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask):
    """Apply self-attention with optional DAC splitting."""
    if self.self_attn is None:
        return tgt

    if dac:
        # Split queries for DAC (detect-and-classify)
        assert tgt.shape[0] % 2 == 0, "DAC requires even number of queries"
        num_o2o_queries = tgt.shape[0] // 2
        tgt_o2o = tgt[:num_o2o_queries]
        tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
        tgt_o2m = tgt[num_o2o_queries:]
    else:
        tgt_o2o = tgt
        tgt_query_pos_o2o = tgt_query_pos

    # Handle presence token
    if presence_token is not None:
        tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
        tgt_query_pos_o2o = torch.cat([torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0).to(
            tgt_o2o.dtype
        )
        tgt_query_pos = torch.cat([torch.zeros_like(presence_token), tgt_query_pos], dim=0)

    # Self-attention
    q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
    tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0].to(tgt.dtype)
    tgt_o2o = tgt_o2o + self.dropout2(tgt2)

    # Recombine and normalize
    if dac:
        if not dac_use_selfatt_ln:
            tgt_o2o = self.norm2(tgt_o2o)
        tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0)
        if dac_use_selfatt_ln:
            tgt = self.norm2(tgt)
    else:
        tgt = tgt_o2o
        tgt = self.norm2(tgt)

    return tgt, tgt_query_pos


method ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer.forward

def forward(
    self,
    # for tgt
    tgt: torch.Tensor,  # nq, bs, d_model
    tgt_query_pos: torch.Tensor = None,  # pos for query. MLP(Sine(pos))
    memory_text: torch.Tensor = None,  # num_token, bs, d_model
    text_attention_mask: torch.Tensor = None,  # bs, num_token
    # for memory
    memory: torch.Tensor = None,  # hw, bs, d_model
    memory_key_padding_mask: torch.Tensor = None,
    memory_pos: torch.Tensor = None,  # pos for memory
    # sa
    self_attn_mask: torch.Tensor = None,  # mask used for self-attention
    cross_attn_mask: torch.Tensor = None,  # mask used for cross-attention
    # dac
    dac=False,
    dac_use_selfatt_ln=True,
    presence_token=None,
    # skip inside deformable attn
    **kwargs,  # additional kwargs for compatibility
)

Input: - tgt/tgt_query_pos: nq, bs, d_model. -.

Args

NameTypeDescriptionDefault
tgttorch.Tensorrequired
tgt_query_postorch.TensorNone
memory_texttorch.TensorNone
text_attention_masktorch.TensorNone
memorytorch.TensorNone
memory_key_padding_masktorch.TensorNone
memory_postorch.TensorNone
self_attn_masktorch.TensorNone
cross_attn_masktorch.TensorNone
dacFalse
dac_use_selfatt_lnTrue
presence_tokenNone
**kwargsrequired
Source code in ultralytics/models/sam/sam3/decoder.pyView on GitHub
def forward(
    self,
    # for tgt
    tgt: torch.Tensor,  # nq, bs, d_model
    tgt_query_pos: torch.Tensor = None,  # pos for query. MLP(Sine(pos))
    memory_text: torch.Tensor = None,  # num_token, bs, d_model
    text_attention_mask: torch.Tensor = None,  # bs, num_token
    # for memory
    memory: torch.Tensor = None,  # hw, bs, d_model
    memory_key_padding_mask: torch.Tensor = None,
    memory_pos: torch.Tensor = None,  # pos for memory
    # sa
    self_attn_mask: torch.Tensor = None,  # mask used for self-attention
    cross_attn_mask: torch.Tensor = None,  # mask used for cross-attention
    # dac
    dac=False,
    dac_use_selfatt_ln=True,
    presence_token=None,
    # skip inside deformable attn
    **kwargs,  # additional kwargs for compatibility
):
    """Input: - tgt/tgt_query_pos: nq, bs, d_model. -."""
    # self attention
    tgt, tgt_query_pos = self._apply_self_attention(
        tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask
    )

    if self.use_text_cross_attention:
        tgt2 = self.ca_text(
            self.with_pos_embed(tgt, tgt_query_pos),
            memory_text.to(tgt.dtype),
            memory_text.to(tgt.dtype),
            key_padding_mask=text_attention_mask,
        )[0]
        tgt = tgt + self.catext_dropout(tgt2)
        tgt = self.catext_norm(tgt)

    if presence_token is not None:
        presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
        cross_attn_mask = torch.cat([presence_token_mask, cross_attn_mask], dim=1)  # (bs*nheads, 1+nq, hw)

    # Cross attention to image
    tgt2 = self.cross_attn(
        query=self.with_pos_embed(tgt, tgt_query_pos),
        key=self.with_pos_embed(memory, memory_pos),
        value=memory,
        attn_mask=cross_attn_mask,
        key_padding_mask=(memory_key_padding_mask.transpose(0, 1) if memory_key_padding_mask is not None else None),
        need_weights=False,
    )[0]

    tgt = tgt + self.dropout1(tgt2)
    tgt = self.norm1(tgt)

    # ffn
    tgt = self.forward_ffn(tgt.to(memory.dtype))

    presence_token_out = None
    if presence_token is not None:
        presence_token_out = tgt[:1]
        tgt = tgt[1:]

    return tgt, presence_token_out


method ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer.forward_ffn

def forward_ffn(self, tgt)

Feedforward network forward pass.

Args

NameTypeDescriptionDefault
tgtrequired
Source code in ultralytics/models/sam/sam3/decoder.pyView on GitHub
def forward_ffn(self, tgt):
    """Feedforward network forward pass."""
    tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
    tgt = tgt + self.dropout4(tgt2)
    tgt = self.norm3(tgt)
    return tgt


method ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer.with_pos_embed

def with_pos_embed(tensor, pos)

Add positional embedding to the tensor.

Args

NameTypeDescriptionDefault
tensorrequired
posrequired
Source code in ultralytics/models/sam/sam3/decoder.pyView on GitHub
@staticmethod
def with_pos_embed(tensor, pos):
    """Add positional embedding to the tensor."""
    return tensor if pos is None else tensor + pos





class ultralytics.models.sam.sam3.decoder.TransformerDecoder

def __init__(
    self,
    d_model: int,
    frozen: bool,
    interaction_layer,
    layer,
    num_layers: int,
    num_queries: int,
    return_intermediate: bool,
    box_refine: bool = False,
    num_o2m_queries: int = 0,
    dac: bool = False,
    boxRPB: str = "none",
    # Experimental: An object query for SAM 2 tasks
    instance_query: bool = False,
    # Defines the number of additional instance queries,
    # 1 or 4 are the most likely for single vs multi mask support
    num_instances: int = 1,  # Irrelevant if instance_query is False
    dac_use_selfatt_ln: bool = True,
    use_act_checkpoint: bool = False,
    compile_mode=None,
    presence_token: bool = False,
    clamp_presence_logits: bool = True,
    clamp_presence_logit_max_val: float = 10.0,
    use_normed_output_consistently: bool = True,
    separate_box_head_instance: bool = False,
    separate_norm_instance: bool = False,
)

Bases: nn.Module

Transformer Decoder consisting of multiple layers.

Args

NameTypeDescriptionDefault
d_modelintrequired
frozenboolrequired
interaction_layerrequired
layerrequired
num_layersintrequired
num_queriesintrequired
return_intermediateboolrequired
box_refineboolFalse
num_o2m_queriesint0
dacboolFalse
boxRPBstr"none"
instance_queryboolFalse
num_instancesint1
dac_use_selfatt_lnboolTrue
use_act_checkpointboolFalse
compile_modeNone
presence_tokenboolFalse
clamp_presence_logitsboolTrue
clamp_presence_logit_max_valfloat10.0
use_normed_output_consistentlyboolTrue
separate_box_head_instanceboolFalse
separate_norm_instanceboolFalse

Methods

NameDescription
_get_coordsGet normalized coordinates for height and width.
_get_rpb_matrixGet the relative position bias (RPB) matrix for box-relative position bias.
forwardForward pass of the TransformerDecoder.
Source code in ultralytics/models/sam/sam3/decoder.pyView on GitHub
class TransformerDecoder(nn.Module):
    """Transformer Decoder consisting of multiple layers."""

    def __init__(
        self,
        d_model: int,
        frozen: bool,
        interaction_layer,
        layer,
        num_layers: int,
        num_queries: int,
        return_intermediate: bool,
        box_refine: bool = False,
        num_o2m_queries: int = 0,
        dac: bool = False,
        boxRPB: str = "none",
        # Experimental: An object query for SAM 2 tasks
        instance_query: bool = False,
        # Defines the number of additional instance queries,
        # 1 or 4 are the most likely for single vs multi mask support
        num_instances: int = 1,  # Irrelevant if instance_query is False
        dac_use_selfatt_ln: bool = True,
        use_act_checkpoint: bool = False,
        compile_mode=None,
        presence_token: bool = False,
        clamp_presence_logits: bool = True,
        clamp_presence_logit_max_val: float = 10.0,
        use_normed_output_consistently: bool = True,
        separate_box_head_instance: bool = False,
        separate_norm_instance: bool = False,
    ):
        """Initialize the TransformerDecoder."""
        super().__init__()
        self.d_model = d_model
        self.layers = _get_clones(layer, num_layers)
        self.fine_layers = (
            _get_clones(interaction_layer, num_layers) if interaction_layer is not None else [None] * num_layers
        )
        self.num_layers = num_layers
        self.num_queries = num_queries
        self.dac = dac
        if dac:
            self.num_o2m_queries = num_queries
            tot_num_queries = num_queries
        else:
            self.num_o2m_queries = num_o2m_queries
            tot_num_queries = num_queries + num_o2m_queries
        self.norm = nn.LayerNorm(d_model)
        self.return_intermediate = return_intermediate
        self.bbox_embed = MLP(d_model, d_model, 4, 3)
        self.query_embed = nn.Embedding(tot_num_queries, d_model)
        self.instance_query_embed = None
        self.instance_query_reference_points = None
        self.use_instance_query = instance_query
        self.num_instances = num_instances
        self.use_normed_output_consistently = use_normed_output_consistently

        self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
        self.instance_bbox_embed = None
        if separate_box_head_instance:
            self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
        if instance_query:
            self.instance_query_embed = nn.Embedding(num_instances, d_model)
        self.box_refine = box_refine
        if box_refine:
            nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
            nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)

            self.reference_points = nn.Embedding(num_queries, 4)
            if instance_query:
                self.instance_reference_points = nn.Embedding(num_instances, 4)

        assert boxRPB in ["none", "log", "linear", "both"]
        self.boxRPB = boxRPB
        if boxRPB != "none":
            try:
                nheads = self.layers[0].cross_attn_image.num_heads
            except AttributeError:
                nheads = self.layers[0].cross_attn.num_heads

            n_input = 4 if boxRPB == "both" else 2
            self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
            self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
            self.compilable_cord_cache = None
            self.compilable_stored_size = None
            self.coord_cache = {}

        self.roi_pooler = (
            RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
            if interaction_layer is not None
            else None
        )
        if frozen:
            for p in self.parameters():
                p.requires_grad_(False)

        self.presence_token = None
        self.clamp_presence_logits = clamp_presence_logits
        self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
        if presence_token:
            self.presence_token = nn.Embedding(1, d_model)
            self.presence_token_head = MLP(d_model, d_model, 1, 3)
            self.presence_token_out_norm = nn.LayerNorm(d_model)

        self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
        self.dac_use_selfatt_ln = dac_use_selfatt_ln
        self.use_act_checkpoint = use_act_checkpoint

        nn.init.normal_(self.query_embed.weight.data)
        if self.instance_query_embed is not None:
            nn.init.normal_(self.instance_query_embed.weight.data)

        assert self.roi_pooler is None
        assert self.return_intermediate, "support return_intermediate only"
        assert self.box_refine, "support box refine only"

        self.compile_mode = compile_mode
        self.compiled = False
        # We defer compilation till after the first forward, to first warm-up the boxRPB cache

        # assign layer index to each layer so that some layers can decide what to do
        # based on which layer index they are (e.g. cross attention to memory bank only
        # in selected layers)
        for layer_idx, layer in enumerate(self.layers):
            layer.layer_idx = layer_idx


method ultralytics.models.sam.sam3.decoder.TransformerDecoder._get_coords

def _get_coords(H, W, device, dtype)

Get normalized coordinates for height and width.

Args

NameTypeDescriptionDefault
Hrequired
Wrequired
devicerequired
dtyperequired
Source code in ultralytics/models/sam/sam3/decoder.pyView on GitHub
@staticmethod
def _get_coords(H, W, device, dtype):
    """Get normalized coordinates for height and width."""
    coords_h = torch.arange(0, H, dtype=dtype, device=device) / H
    coords_w = torch.arange(0, W, dtype=dtype, device=device) / W
    return coords_h, coords_w


method ultralytics.models.sam.sam3.decoder.TransformerDecoder._get_rpb_matrix

def _get_rpb_matrix(self, reference_boxes, feat_size)

Get the relative position bias (RPB) matrix for box-relative position bias.

Args

NameTypeDescriptionDefault
reference_boxesrequired
feat_sizerequired
Source code in ultralytics/models/sam/sam3/decoder.pyView on GitHub
def _get_rpb_matrix(self, reference_boxes, feat_size):
    """Get the relative position bias (RPB) matrix for box-relative position bias."""
    H, W = feat_size
    boxes_xyxy = xywh2xyxy(reference_boxes).transpose(0, 1)
    bs, num_queries, _ = boxes_xyxy.shape
    if self.compilable_cord_cache is None:
        self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device, reference_boxes.dtype)
        self.compilable_stored_size = (H, W)

    if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
        H,
        W,
    ):
        # good, hitting the cache, will be compilable
        coords_h, coords_w = self.compilable_cord_cache
    else:
        # cache miss, will create compilation issue
        # In case we're not compiling, we'll still rely on the dict-based cache
        if feat_size not in self.coord_cache:
            self.coord_cache[feat_size] = self._get_coords(H, W, reference_boxes.device)
        coords_h, coords_w = self.coord_cache[feat_size]

        assert coords_h.shape == (H,)
        assert coords_w.shape == (W,)

    deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
    deltas_y = deltas_y.view(bs, num_queries, -1, 2)
    deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
    deltas_x = deltas_x.view(bs, num_queries, -1, 2)

    if self.boxRPB in ["log", "both"]:
        deltas_x_log = deltas_x * 8  # normalize to -8, 8
        deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8)

        deltas_y_log = deltas_y * 8  # normalize to -8, 8
        deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8)
        if self.boxRPB == "log":
            deltas_x = deltas_x_log
            deltas_y = deltas_y_log
        else:
            deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
            deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)

    if self.training:
        assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
    deltas_x = self.boxRPB_embed_x(x=deltas_x)  # bs, num_queries, W, n_heads
    deltas_y = self.boxRPB_embed_y(x=deltas_y)  # bs, num_queries, H, n_heads

    if not torch.compiler.is_dynamo_compiling():
        assert deltas_x.shape[:3] == (bs, num_queries, W)
        assert deltas_y.shape[:3] == (bs, num_queries, H)

    B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(2)  # bs, num_queries, H, W, n_heads
    if not torch.compiler.is_dynamo_compiling():
        assert B.shape[:4] == (bs, num_queries, H, W)
    B = B.flatten(2, 3)  # bs, num_queries, H*W, n_heads
    B = B.permute(0, 3, 1, 2)  # bs, n_heads, num_queries, H*W
    B = B.contiguous()  # memeff attn likes ordered strides
    if not torch.compiler.is_dynamo_compiling():
        assert B.shape[2:] == (num_queries, H * W)
    return B


method ultralytics.models.sam.sam3.decoder.TransformerDecoder.forward

def forward(
    self,
    tgt,
    memory,
    tgt_mask: torch.Tensor = None,
    memory_mask: torch.Tensor = None,
    memory_key_padding_mask: torch.Tensor = None,
    pos: torch.Tensor = None,
    reference_boxes: torch.Tensor = None,  # num_queries, bs, 4
    # for memory
    spatial_shapes: torch.Tensor = None,  # bs, num_levels, 2
    valid_ratios: torch.Tensor = None,
    # for text
    memory_text: torch.Tensor = None,
    text_attention_mask: torch.Tensor = None,
    # if `apply_dac` is None, it will default to `self.dac`
    apply_dac: bool | None = None,
    is_instance_prompt=False,
    decoder_extra_kwargs: dict | None = None,
    # ROI memory bank
    obj_roi_memory_feat=None,
    obj_roi_memory_mask=None,
    box_head_trk=None,
)

Forward pass of the TransformerDecoder.

Args

NameTypeDescriptionDefault
tgtrequired
memoryrequired
tgt_masktorch.TensorNone
memory_masktorch.TensorNone
memory_key_padding_masktorch.TensorNone
postorch.TensorNone
reference_boxestorch.TensorNone
spatial_shapestorch.TensorNone
valid_ratiostorch.TensorNone
memory_texttorch.TensorNone
text_attention_masktorch.TensorNone
apply_dacbool | NoneNone
is_instance_promptFalse
decoder_extra_kwargsdict | NoneNone
obj_roi_memory_featNone
obj_roi_memory_maskNone
box_head_trkNone
Source code in ultralytics/models/sam/sam3/decoder.pyView on GitHub
def forward(
    self,
    tgt,
    memory,
    tgt_mask: torch.Tensor = None,
    memory_mask: torch.Tensor = None,
    memory_key_padding_mask: torch.Tensor = None,
    pos: torch.Tensor = None,
    reference_boxes: torch.Tensor = None,  # num_queries, bs, 4
    # for memory
    spatial_shapes: torch.Tensor = None,  # bs, num_levels, 2
    valid_ratios: torch.Tensor = None,
    # for text
    memory_text: torch.Tensor = None,
    text_attention_mask: torch.Tensor = None,
    # if `apply_dac` is None, it will default to `self.dac`
    apply_dac: bool | None = None,
    is_instance_prompt=False,
    decoder_extra_kwargs: dict | None = None,
    # ROI memory bank
    obj_roi_memory_feat=None,
    obj_roi_memory_mask=None,
    box_head_trk=None,
):
    """Forward pass of the TransformerDecoder."""
    if memory_mask is not None:
        assert self.boxRPB == "none", (
            "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
        )

    apply_dac = apply_dac if apply_dac is not None else self.dac
    if apply_dac:
        assert (tgt.shape[0] == self.num_queries) or (
            self.use_instance_query and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
        )

        tgt = tgt.repeat(2, 1, 1)
        # note that we don't tile tgt_mask, since DAC doesn't
        # use self-attention in o2m queries
        if reference_boxes is not None:
            assert (reference_boxes.shape[0] == self.num_queries) or (
                self.use_instance_query and (reference_boxes.shape[0] == self.instance_query_embed.num_embeddings)
            )
            reference_boxes = reference_boxes.repeat(2, 1, 1)

    bs = tgt.shape[1]
    intermediate = []
    intermediate_presence_logits = []
    presence_feats = None

    if self.box_refine:
        if reference_boxes is None:
            # In this case, we're in a one-stage model, so we generate the reference boxes
            reference_boxes = self.reference_points.weight.unsqueeze(1)
            reference_boxes = reference_boxes.repeat(2, bs, 1) if apply_dac else reference_boxes.repeat(1, bs, 1)
            reference_boxes = reference_boxes.sigmoid()
        intermediate_ref_boxes = [reference_boxes]
    else:
        reference_boxes = None
        intermediate_ref_boxes = None

    output = tgt
    presence_out = None
    if self.presence_token is not None and is_instance_prompt is False:
        # expand to batch dim
        presence_out = self.presence_token.weight[None].expand(1, bs, -1)

    box_head = self.bbox_embed
    if is_instance_prompt and self.instance_bbox_embed is not None:
        box_head = self.instance_bbox_embed

    out_norm = self.norm
    if is_instance_prompt and self.instance_norm is not None:
        out_norm = self.instance_norm

    for layer_idx, layer in enumerate(self.layers):
        reference_points_input = (
            reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
        )  # nq, bs, nlevel, 4

        query_sine_embed = gen_sineembed_for_position(
            reference_points_input[:, :, 0, :], self.d_model
        )  # nq, bs, d_model*2

        # conditional query
        query_pos = self.ref_point_head(query_sine_embed)  # nq, bs, d_model

        if self.boxRPB != "none" and reference_boxes is not None:
            assert spatial_shapes.shape[0] == 1, "only single scale support implemented"
            memory_mask = self._get_rpb_matrix(
                reference_boxes,
                (spatial_shapes[0, 0], spatial_shapes[0, 1]),
            )
            memory_mask = memory_mask.flatten(0, 1)  # (bs*n_heads, nq, H*W)
        if self.training:
            assert self.use_act_checkpoint, "Activation checkpointing not enabled in the decoder"
        output, presence_out = layer(
            tgt=output,
            tgt_query_pos=query_pos,
            memory_text=memory_text,
            text_attention_mask=text_attention_mask,
            memory=memory,
            memory_key_padding_mask=memory_key_padding_mask,
            memory_pos=pos,
            self_attn_mask=tgt_mask,
            cross_attn_mask=memory_mask,
            dac=apply_dac,
            dac_use_selfatt_ln=self.dac_use_selfatt_ln,
            presence_token=presence_out,
            **(decoder_extra_kwargs or {}),
            # ROI memory bank
            obj_roi_memory_feat=obj_roi_memory_feat,
            obj_roi_memory_mask=obj_roi_memory_mask,
        )

        # iter update
        if self.box_refine:
            reference_before_sigmoid = inverse_sigmoid(reference_boxes)
            if box_head_trk is None:
                # delta_unsig = self.bbox_embed(output)
                if not self.use_normed_output_consistently:
                    delta_unsig = box_head(output)
                else:
                    delta_unsig = box_head(out_norm(output))
            else:
                # box_head_trk use a separate box head for tracking queries
                Q_det = decoder_extra_kwargs["Q_det"]
                assert output.size(0) >= Q_det
                delta_unsig_det = self.bbox_embed(output[:Q_det])
                delta_unsig_trk = box_head_trk(output[Q_det:])
                delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
            outputs_unsig = delta_unsig + reference_before_sigmoid
            new_reference_points = outputs_unsig.sigmoid()

            reference_boxes = new_reference_points.detach()
            if layer_idx != self.num_layers - 1:
                intermediate_ref_boxes.append(new_reference_points)
        else:
            raise NotImplementedError("not implemented yet")

        intermediate.append(out_norm(output))
        if self.presence_token is not None and is_instance_prompt is False:
            # norm, mlp head
            intermediate_layer_presence_logits = self.presence_token_head(
                self.presence_token_out_norm(presence_out)
            ).squeeze(-1)

            # clamp to mitigate numerical issues
            if self.clamp_presence_logits:
                intermediate_layer_presence_logits.clamp(
                    min=-self.clamp_presence_logit_max_val,
                    max=self.clamp_presence_logit_max_val,
                )

            intermediate_presence_logits.append(intermediate_layer_presence_logits)
            presence_feats = presence_out.clone()

    if not self.compiled and self.compile_mode is not None:
        self.forward = torch.compile(self.forward, mode=self.compile_mode, fullgraph=True)
        self.compiled = True

    return (
        torch.stack(intermediate),
        torch.stack(intermediate_ref_boxes),
        (
            torch.stack(intermediate_presence_logits)
            if self.presence_token is not None and is_instance_prompt is False
            else None
        ),
        presence_feats,
    )





📅 Created 0 days ago ✏️ Updated 0 days ago
Laughing-q