Reference for ultralytics/models/sam/sam3/decoder.py
Improvements
This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/decoder.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏
Summary
class ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer
def __init__(
self,
d_model: int,
dim_feedforward: int,
dropout: float,
cross_attention: nn.Module,
n_heads: int,
use_text_cross_attention: bool = False,
)
Bases: nn.Module
TransformerDecoderLayer is made up of self-attn, cross-attn, and feedforward network (FFN).
Args
| Name | Type | Description | Default |
|---|---|---|---|
d_model | int | required | |
dim_feedforward | int | required | |
dropout | float | required | |
cross_attention | nn.Module | required | |
n_heads | int | required | |
use_text_cross_attention | bool | False |
Methods
| Name | Description |
|---|---|
_apply_self_attention | Apply self-attention with optional DAC splitting. |
forward | Input: - tgt/tgt_query_pos: nq, bs, d_model. -. |
forward_ffn | Feedforward network forward pass. |
with_pos_embed | Add positional embedding to the tensor. |
Source code in ultralytics/models/sam/sam3/decoder.py
View on GitHubclass TransformerDecoderLayer(nn.Module):
"""TransformerDecoderLayer is made up of self-attn, cross-attn, and feedforward network (FFN)."""
def __init__(
self,
d_model: int,
dim_feedforward: int,
dropout: float,
cross_attention: nn.Module,
n_heads: int,
use_text_cross_attention: bool = False,
):
"""Initialize the TransformerDecoderLayer."""
super().__init__()
# cross attention
self.cross_attn = cross_attention
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm1 = nn.LayerNorm(d_model)
# cross attention text
self.use_text_cross_attention = use_text_cross_attention
if use_text_cross_attention:
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.catext_norm = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = nn.ReLU()
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm3 = nn.LayerNorm(d_model)
method ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer._apply_self_attention
def _apply_self_attention(self, tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask)
Apply self-attention with optional DAC splitting.
Args
| Name | Type | Description | Default |
|---|---|---|---|
tgt | required | ||
tgt_query_pos | required | ||
dac | required | ||
dac_use_selfatt_ln | required | ||
presence_token | required | ||
self_attn_mask | required |
Source code in ultralytics/models/sam/sam3/decoder.py
View on GitHubdef _apply_self_attention(self, tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask):
"""Apply self-attention with optional DAC splitting."""
if self.self_attn is None:
return tgt
if dac:
# Split queries for DAC (detect-and-classify)
assert tgt.shape[0] % 2 == 0, "DAC requires even number of queries"
num_o2o_queries = tgt.shape[0] // 2
tgt_o2o = tgt[:num_o2o_queries]
tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
tgt_o2m = tgt[num_o2o_queries:]
else:
tgt_o2o = tgt
tgt_query_pos_o2o = tgt_query_pos
# Handle presence token
if presence_token is not None:
tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
tgt_query_pos_o2o = torch.cat([torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0).to(
tgt_o2o.dtype
)
tgt_query_pos = torch.cat([torch.zeros_like(presence_token), tgt_query_pos], dim=0)
# Self-attention
q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0].to(tgt.dtype)
tgt_o2o = tgt_o2o + self.dropout2(tgt2)
# Recombine and normalize
if dac:
if not dac_use_selfatt_ln:
tgt_o2o = self.norm2(tgt_o2o)
tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0)
if dac_use_selfatt_ln:
tgt = self.norm2(tgt)
else:
tgt = tgt_o2o
tgt = self.norm2(tgt)
return tgt, tgt_query_pos
method ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer.forward
def forward(
self,
# for tgt
tgt: torch.Tensor, # nq, bs, d_model
tgt_query_pos: torch.Tensor = None, # pos for query. MLP(Sine(pos))
memory_text: torch.Tensor = None, # num_token, bs, d_model
text_attention_mask: torch.Tensor = None, # bs, num_token
# for memory
memory: torch.Tensor = None, # hw, bs, d_model
memory_key_padding_mask: torch.Tensor = None,
memory_pos: torch.Tensor = None, # pos for memory
# sa
self_attn_mask: torch.Tensor = None, # mask used for self-attention
cross_attn_mask: torch.Tensor = None, # mask used for cross-attention
# dac
dac=False,
dac_use_selfatt_ln=True,
presence_token=None,
# skip inside deformable attn
**kwargs, # additional kwargs for compatibility
)
Input: - tgt/tgt_query_pos: nq, bs, d_model. -.
Args
| Name | Type | Description | Default |
|---|---|---|---|
tgt | torch.Tensor | required | |
tgt_query_pos | torch.Tensor | None | |
memory_text | torch.Tensor | None | |
text_attention_mask | torch.Tensor | None | |
memory | torch.Tensor | None | |
memory_key_padding_mask | torch.Tensor | None | |
memory_pos | torch.Tensor | None | |
self_attn_mask | torch.Tensor | None | |
cross_attn_mask | torch.Tensor | None | |
dac | False | ||
dac_use_selfatt_ln | True | ||
presence_token | None | ||
**kwargs | required |
Source code in ultralytics/models/sam/sam3/decoder.py
View on GitHubdef forward(
self,
# for tgt
tgt: torch.Tensor, # nq, bs, d_model
tgt_query_pos: torch.Tensor = None, # pos for query. MLP(Sine(pos))
memory_text: torch.Tensor = None, # num_token, bs, d_model
text_attention_mask: torch.Tensor = None, # bs, num_token
# for memory
memory: torch.Tensor = None, # hw, bs, d_model
memory_key_padding_mask: torch.Tensor = None,
memory_pos: torch.Tensor = None, # pos for memory
# sa
self_attn_mask: torch.Tensor = None, # mask used for self-attention
cross_attn_mask: torch.Tensor = None, # mask used for cross-attention
# dac
dac=False,
dac_use_selfatt_ln=True,
presence_token=None,
# skip inside deformable attn
**kwargs, # additional kwargs for compatibility
):
"""Input: - tgt/tgt_query_pos: nq, bs, d_model. -."""
# self attention
tgt, tgt_query_pos = self._apply_self_attention(
tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask
)
if self.use_text_cross_attention:
tgt2 = self.ca_text(
self.with_pos_embed(tgt, tgt_query_pos),
memory_text.to(tgt.dtype),
memory_text.to(tgt.dtype),
key_padding_mask=text_attention_mask,
)[0]
tgt = tgt + self.catext_dropout(tgt2)
tgt = self.catext_norm(tgt)
if presence_token is not None:
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
cross_attn_mask = torch.cat([presence_token_mask, cross_attn_mask], dim=1) # (bs*nheads, 1+nq, hw)
# Cross attention to image
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=cross_attn_mask,
key_padding_mask=(memory_key_padding_mask.transpose(0, 1) if memory_key_padding_mask is not None else None),
need_weights=False,
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt.to(memory.dtype))
presence_token_out = None
if presence_token is not None:
presence_token_out = tgt[:1]
tgt = tgt[1:]
return tgt, presence_token_out
method ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer.forward_ffn
def forward_ffn(self, tgt)
Feedforward network forward pass.
Args
| Name | Type | Description | Default |
|---|---|---|---|
tgt | required |
Source code in ultralytics/models/sam/sam3/decoder.py
View on GitHubdef forward_ffn(self, tgt):
"""Feedforward network forward pass."""
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
method ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer.with_pos_embed
def with_pos_embed(tensor, pos)
Add positional embedding to the tensor.
Args
| Name | Type | Description | Default |
|---|---|---|---|
tensor | required | ||
pos | required |
Source code in ultralytics/models/sam/sam3/decoder.py
View on GitHub@staticmethod
def with_pos_embed(tensor, pos):
"""Add positional embedding to the tensor."""
return tensor if pos is None else tensor + pos
class ultralytics.models.sam.sam3.decoder.TransformerDecoder
def __init__(
self,
d_model: int,
frozen: bool,
interaction_layer,
layer,
num_layers: int,
num_queries: int,
return_intermediate: bool,
box_refine: bool = False,
num_o2m_queries: int = 0,
dac: bool = False,
boxRPB: str = "none",
# Experimental: An object query for SAM 2 tasks
instance_query: bool = False,
# Defines the number of additional instance queries,
# 1 or 4 are the most likely for single vs multi mask support
num_instances: int = 1, # Irrelevant if instance_query is False
dac_use_selfatt_ln: bool = True,
use_act_checkpoint: bool = False,
compile_mode=None,
presence_token: bool = False,
clamp_presence_logits: bool = True,
clamp_presence_logit_max_val: float = 10.0,
use_normed_output_consistently: bool = True,
separate_box_head_instance: bool = False,
separate_norm_instance: bool = False,
)
Bases: nn.Module
Transformer Decoder consisting of multiple layers.
Args
| Name | Type | Description | Default |
|---|---|---|---|
d_model | int | required | |
frozen | bool | required | |
interaction_layer | required | ||
layer | required | ||
num_layers | int | required | |
num_queries | int | required | |
return_intermediate | bool | required | |
box_refine | bool | False | |
num_o2m_queries | int | 0 | |
dac | bool | False | |
boxRPB | str | "none" | |
instance_query | bool | False | |
num_instances | int | 1 | |
dac_use_selfatt_ln | bool | True | |
use_act_checkpoint | bool | False | |
compile_mode | None | ||
presence_token | bool | False | |
clamp_presence_logits | bool | True | |
clamp_presence_logit_max_val | float | 10.0 | |
use_normed_output_consistently | bool | True | |
separate_box_head_instance | bool | False | |
separate_norm_instance | bool | False |
Methods
| Name | Description |
|---|---|
_get_coords | Get normalized coordinates for height and width. |
_get_rpb_matrix | Get the relative position bias (RPB) matrix for box-relative position bias. |
forward | Forward pass of the TransformerDecoder. |
Source code in ultralytics/models/sam/sam3/decoder.py
View on GitHubclass TransformerDecoder(nn.Module):
"""Transformer Decoder consisting of multiple layers."""
def __init__(
self,
d_model: int,
frozen: bool,
interaction_layer,
layer,
num_layers: int,
num_queries: int,
return_intermediate: bool,
box_refine: bool = False,
num_o2m_queries: int = 0,
dac: bool = False,
boxRPB: str = "none",
# Experimental: An object query for SAM 2 tasks
instance_query: bool = False,
# Defines the number of additional instance queries,
# 1 or 4 are the most likely for single vs multi mask support
num_instances: int = 1, # Irrelevant if instance_query is False
dac_use_selfatt_ln: bool = True,
use_act_checkpoint: bool = False,
compile_mode=None,
presence_token: bool = False,
clamp_presence_logits: bool = True,
clamp_presence_logit_max_val: float = 10.0,
use_normed_output_consistently: bool = True,
separate_box_head_instance: bool = False,
separate_norm_instance: bool = False,
):
"""Initialize the TransformerDecoder."""
super().__init__()
self.d_model = d_model
self.layers = _get_clones(layer, num_layers)
self.fine_layers = (
_get_clones(interaction_layer, num_layers) if interaction_layer is not None else [None] * num_layers
)
self.num_layers = num_layers
self.num_queries = num_queries
self.dac = dac
if dac:
self.num_o2m_queries = num_queries
tot_num_queries = num_queries
else:
self.num_o2m_queries = num_o2m_queries
tot_num_queries = num_queries + num_o2m_queries
self.norm = nn.LayerNorm(d_model)
self.return_intermediate = return_intermediate
self.bbox_embed = MLP(d_model, d_model, 4, 3)
self.query_embed = nn.Embedding(tot_num_queries, d_model)
self.instance_query_embed = None
self.instance_query_reference_points = None
self.use_instance_query = instance_query
self.num_instances = num_instances
self.use_normed_output_consistently = use_normed_output_consistently
self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
self.instance_bbox_embed = None
if separate_box_head_instance:
self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
if instance_query:
self.instance_query_embed = nn.Embedding(num_instances, d_model)
self.box_refine = box_refine
if box_refine:
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
self.reference_points = nn.Embedding(num_queries, 4)
if instance_query:
self.instance_reference_points = nn.Embedding(num_instances, 4)
assert boxRPB in ["none", "log", "linear", "both"]
self.boxRPB = boxRPB
if boxRPB != "none":
try:
nheads = self.layers[0].cross_attn_image.num_heads
except AttributeError:
nheads = self.layers[0].cross_attn.num_heads
n_input = 4 if boxRPB == "both" else 2
self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
self.compilable_cord_cache = None
self.compilable_stored_size = None
self.coord_cache = {}
self.roi_pooler = (
RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
if interaction_layer is not None
else None
)
if frozen:
for p in self.parameters():
p.requires_grad_(False)
self.presence_token = None
self.clamp_presence_logits = clamp_presence_logits
self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
if presence_token:
self.presence_token = nn.Embedding(1, d_model)
self.presence_token_head = MLP(d_model, d_model, 1, 3)
self.presence_token_out_norm = nn.LayerNorm(d_model)
self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
self.dac_use_selfatt_ln = dac_use_selfatt_ln
self.use_act_checkpoint = use_act_checkpoint
nn.init.normal_(self.query_embed.weight.data)
if self.instance_query_embed is not None:
nn.init.normal_(self.instance_query_embed.weight.data)
assert self.roi_pooler is None
assert self.return_intermediate, "support return_intermediate only"
assert self.box_refine, "support box refine only"
self.compile_mode = compile_mode
self.compiled = False
# We defer compilation till after the first forward, to first warm-up the boxRPB cache
# assign layer index to each layer so that some layers can decide what to do
# based on which layer index they are (e.g. cross attention to memory bank only
# in selected layers)
for layer_idx, layer in enumerate(self.layers):
layer.layer_idx = layer_idx
method ultralytics.models.sam.sam3.decoder.TransformerDecoder._get_coords
def _get_coords(H, W, device, dtype)
Get normalized coordinates for height and width.
Args
| Name | Type | Description | Default |
|---|---|---|---|
H | required | ||
W | required | ||
device | required | ||
dtype | required |
Source code in ultralytics/models/sam/sam3/decoder.py
View on GitHub@staticmethod
def _get_coords(H, W, device, dtype):
"""Get normalized coordinates for height and width."""
coords_h = torch.arange(0, H, dtype=dtype, device=device) / H
coords_w = torch.arange(0, W, dtype=dtype, device=device) / W
return coords_h, coords_w
method ultralytics.models.sam.sam3.decoder.TransformerDecoder._get_rpb_matrix
def _get_rpb_matrix(self, reference_boxes, feat_size)
Get the relative position bias (RPB) matrix for box-relative position bias.
Args
| Name | Type | Description | Default |
|---|---|---|---|
reference_boxes | required | ||
feat_size | required |
Source code in ultralytics/models/sam/sam3/decoder.py
View on GitHubdef _get_rpb_matrix(self, reference_boxes, feat_size):
"""Get the relative position bias (RPB) matrix for box-relative position bias."""
H, W = feat_size
boxes_xyxy = xywh2xyxy(reference_boxes).transpose(0, 1)
bs, num_queries, _ = boxes_xyxy.shape
if self.compilable_cord_cache is None:
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device, reference_boxes.dtype)
self.compilable_stored_size = (H, W)
if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
H,
W,
):
# good, hitting the cache, will be compilable
coords_h, coords_w = self.compilable_cord_cache
else:
# cache miss, will create compilation issue
# In case we're not compiling, we'll still rely on the dict-based cache
if feat_size not in self.coord_cache:
self.coord_cache[feat_size] = self._get_coords(H, W, reference_boxes.device)
coords_h, coords_w = self.coord_cache[feat_size]
assert coords_h.shape == (H,)
assert coords_w.shape == (W,)
deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
deltas_y = deltas_y.view(bs, num_queries, -1, 2)
deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
deltas_x = deltas_x.view(bs, num_queries, -1, 2)
if self.boxRPB in ["log", "both"]:
deltas_x_log = deltas_x * 8 # normalize to -8, 8
deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8)
deltas_y_log = deltas_y * 8 # normalize to -8, 8
deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8)
if self.boxRPB == "log":
deltas_x = deltas_x_log
deltas_y = deltas_y_log
else:
deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
if self.training:
assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
deltas_x = self.boxRPB_embed_x(x=deltas_x) # bs, num_queries, W, n_heads
deltas_y = self.boxRPB_embed_y(x=deltas_y) # bs, num_queries, H, n_heads
if not torch.compiler.is_dynamo_compiling():
assert deltas_x.shape[:3] == (bs, num_queries, W)
assert deltas_y.shape[:3] == (bs, num_queries, H)
B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(2) # bs, num_queries, H, W, n_heads
if not torch.compiler.is_dynamo_compiling():
assert B.shape[:4] == (bs, num_queries, H, W)
B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
B = B.contiguous() # memeff attn likes ordered strides
if not torch.compiler.is_dynamo_compiling():
assert B.shape[2:] == (num_queries, H * W)
return B
method ultralytics.models.sam.sam3.decoder.TransformerDecoder.forward
def forward(
self,
tgt,
memory,
tgt_mask: torch.Tensor = None,
memory_mask: torch.Tensor = None,
memory_key_padding_mask: torch.Tensor = None,
pos: torch.Tensor = None,
reference_boxes: torch.Tensor = None, # num_queries, bs, 4
# for memory
spatial_shapes: torch.Tensor = None, # bs, num_levels, 2
valid_ratios: torch.Tensor = None,
# for text
memory_text: torch.Tensor = None,
text_attention_mask: torch.Tensor = None,
# if `apply_dac` is None, it will default to `self.dac`
apply_dac: bool | None = None,
is_instance_prompt=False,
decoder_extra_kwargs: dict | None = None,
# ROI memory bank
obj_roi_memory_feat=None,
obj_roi_memory_mask=None,
box_head_trk=None,
)
Forward pass of the TransformerDecoder.
Args
| Name | Type | Description | Default |
|---|---|---|---|
tgt | required | ||
memory | required | ||
tgt_mask | torch.Tensor | None | |
memory_mask | torch.Tensor | None | |
memory_key_padding_mask | torch.Tensor | None | |
pos | torch.Tensor | None | |
reference_boxes | torch.Tensor | None | |
spatial_shapes | torch.Tensor | None | |
valid_ratios | torch.Tensor | None | |
memory_text | torch.Tensor | None | |
text_attention_mask | torch.Tensor | None | |
apply_dac | bool | None | None | |
is_instance_prompt | False | ||
decoder_extra_kwargs | dict | None | None | |
obj_roi_memory_feat | None | ||
obj_roi_memory_mask | None | ||
box_head_trk | None |
Source code in ultralytics/models/sam/sam3/decoder.py
View on GitHubdef forward(
self,
tgt,
memory,
tgt_mask: torch.Tensor = None,
memory_mask: torch.Tensor = None,
memory_key_padding_mask: torch.Tensor = None,
pos: torch.Tensor = None,
reference_boxes: torch.Tensor = None, # num_queries, bs, 4
# for memory
spatial_shapes: torch.Tensor = None, # bs, num_levels, 2
valid_ratios: torch.Tensor = None,
# for text
memory_text: torch.Tensor = None,
text_attention_mask: torch.Tensor = None,
# if `apply_dac` is None, it will default to `self.dac`
apply_dac: bool | None = None,
is_instance_prompt=False,
decoder_extra_kwargs: dict | None = None,
# ROI memory bank
obj_roi_memory_feat=None,
obj_roi_memory_mask=None,
box_head_trk=None,
):
"""Forward pass of the TransformerDecoder."""
if memory_mask is not None:
assert self.boxRPB == "none", (
"inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
)
apply_dac = apply_dac if apply_dac is not None else self.dac
if apply_dac:
assert (tgt.shape[0] == self.num_queries) or (
self.use_instance_query and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
)
tgt = tgt.repeat(2, 1, 1)
# note that we don't tile tgt_mask, since DAC doesn't
# use self-attention in o2m queries
if reference_boxes is not None:
assert (reference_boxes.shape[0] == self.num_queries) or (
self.use_instance_query and (reference_boxes.shape[0] == self.instance_query_embed.num_embeddings)
)
reference_boxes = reference_boxes.repeat(2, 1, 1)
bs = tgt.shape[1]
intermediate = []
intermediate_presence_logits = []
presence_feats = None
if self.box_refine:
if reference_boxes is None:
# In this case, we're in a one-stage model, so we generate the reference boxes
reference_boxes = self.reference_points.weight.unsqueeze(1)
reference_boxes = reference_boxes.repeat(2, bs, 1) if apply_dac else reference_boxes.repeat(1, bs, 1)
reference_boxes = reference_boxes.sigmoid()
intermediate_ref_boxes = [reference_boxes]
else:
reference_boxes = None
intermediate_ref_boxes = None
output = tgt
presence_out = None
if self.presence_token is not None and is_instance_prompt is False:
# expand to batch dim
presence_out = self.presence_token.weight[None].expand(1, bs, -1)
box_head = self.bbox_embed
if is_instance_prompt and self.instance_bbox_embed is not None:
box_head = self.instance_bbox_embed
out_norm = self.norm
if is_instance_prompt and self.instance_norm is not None:
out_norm = self.instance_norm
for layer_idx, layer in enumerate(self.layers):
reference_points_input = (
reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
) # nq, bs, nlevel, 4
query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :], self.d_model
) # nq, bs, d_model*2
# conditional query
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
if self.boxRPB != "none" and reference_boxes is not None:
assert spatial_shapes.shape[0] == 1, "only single scale support implemented"
memory_mask = self._get_rpb_matrix(
reference_boxes,
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
)
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
if self.training:
assert self.use_act_checkpoint, "Activation checkpointing not enabled in the decoder"
output, presence_out = layer(
tgt=output,
tgt_query_pos=query_pos,
memory_text=memory_text,
text_attention_mask=text_attention_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
memory_pos=pos,
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask,
dac=apply_dac,
dac_use_selfatt_ln=self.dac_use_selfatt_ln,
presence_token=presence_out,
**(decoder_extra_kwargs or {}),
# ROI memory bank
obj_roi_memory_feat=obj_roi_memory_feat,
obj_roi_memory_mask=obj_roi_memory_mask,
)
# iter update
if self.box_refine:
reference_before_sigmoid = inverse_sigmoid(reference_boxes)
if box_head_trk is None:
# delta_unsig = self.bbox_embed(output)
if not self.use_normed_output_consistently:
delta_unsig = box_head(output)
else:
delta_unsig = box_head(out_norm(output))
else:
# box_head_trk use a separate box head for tracking queries
Q_det = decoder_extra_kwargs["Q_det"]
assert output.size(0) >= Q_det
delta_unsig_det = self.bbox_embed(output[:Q_det])
delta_unsig_trk = box_head_trk(output[Q_det:])
delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid()
reference_boxes = new_reference_points.detach()
if layer_idx != self.num_layers - 1:
intermediate_ref_boxes.append(new_reference_points)
else:
raise NotImplementedError("not implemented yet")
intermediate.append(out_norm(output))
if self.presence_token is not None and is_instance_prompt is False:
# norm, mlp head
intermediate_layer_presence_logits = self.presence_token_head(
self.presence_token_out_norm(presence_out)
).squeeze(-1)
# clamp to mitigate numerical issues
if self.clamp_presence_logits:
intermediate_layer_presence_logits.clamp(
min=-self.clamp_presence_logit_max_val,
max=self.clamp_presence_logit_max_val,
)
intermediate_presence_logits.append(intermediate_layer_presence_logits)
presence_feats = presence_out.clone()
if not self.compiled and self.compile_mode is not None:
self.forward = torch.compile(self.forward, mode=self.compile_mode, fullgraph=True)
self.compiled = True
return (
torch.stack(intermediate),
torch.stack(intermediate_ref_boxes),
(
torch.stack(intermediate_presence_logits)
if self.presence_token is not None and is_instance_prompt is False
else None
),
presence_feats,
)