Skip to content

Reference for ultralytics/models/sam/modules/decoders.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/decoders.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.sam.modules.decoders.MaskDecoder

MaskDecoder(
    transformer_dim: int,
    transformer: Module,
    num_multimask_outputs: int = 3,
    activation: Type[Module] = nn.GELU,
    iou_head_depth: int = 3,
    iou_head_hidden_dim: int = 256,
)

Bases: Module

Decoder module for generating masks and their associated quality scores using a transformer architecture.

This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and generate mask predictions along with their quality scores.

Attributes:

Name Type Description
transformer_dim int

Channel dimension for the transformer module.

transformer Module

Transformer module used for mask prediction.

num_multimask_outputs int

Number of masks to predict for disambiguating masks.

iou_token Embedding

Embedding for the IoU token.

num_mask_tokens int

Number of mask tokens.

mask_tokens Embedding

Embedding for the mask tokens.

output_upscaling Sequential

Neural network sequence for upscaling the output.

output_hypernetworks_mlps ModuleList

Hypernetwork MLPs for generating masks.

iou_prediction_head Module

MLP for predicting mask quality.

Methods:

Name Description
forward

Predicts masks given image and prompt embeddings.

predict_masks

Internal method for mask prediction.

Examples:

>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
>>> masks, iou_pred = decoder(
...     image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
... )
>>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")

Parameters:

Name Type Description Default
transformer_dim int

Channel dimension for the transformer module.

required
transformer Module

Transformer module used for mask prediction.

required
num_multimask_outputs int

Number of masks to predict for disambiguating masks.

3
activation Type[Module]

Type of activation to use when upscaling masks.

GELU
iou_head_depth int

Depth of the MLP used to predict mask quality.

3
iou_head_hidden_dim int

Hidden dimension of the MLP used to predict mask quality.

256

Examples:

>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
>>> print(decoder)
Source code in ultralytics/models/sam/modules/decoders.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    transformer_dim: int,
    transformer: nn.Module,
    num_multimask_outputs: int = 3,
    activation: Type[nn.Module] = nn.GELU,
    iou_head_depth: int = 3,
    iou_head_hidden_dim: int = 256,
) -> None:
    """
    Initialize the MaskDecoder module for generating masks and their associated quality scores.

    Args:
        transformer_dim (int): Channel dimension for the transformer module.
        transformer (nn.Module): Transformer module used for mask prediction.
        num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
        activation (Type[nn.Module]): Type of activation to use when upscaling masks.
        iou_head_depth (int): Depth of the MLP used to predict mask quality.
        iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.

    Examples:
        >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
        >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
        >>> print(decoder)
    """
    super().__init__()
    self.transformer_dim = transformer_dim
    self.transformer = transformer

    self.num_multimask_outputs = num_multimask_outputs

    self.iou_token = nn.Embedding(1, transformer_dim)
    self.num_mask_tokens = num_multimask_outputs + 1
    self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

    self.output_upscaling = nn.Sequential(
        nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
        LayerNorm2d(transformer_dim // 4),
        activation(),
        nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
        activation(),
    )
    self.output_hypernetworks_mlps = nn.ModuleList(
        [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
    )

    self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)

forward

forward(
    image_embeddings: Tensor,
    image_pe: Tensor,
    sparse_prompt_embeddings: Tensor,
    dense_prompt_embeddings: Tensor,
    multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]

Predict masks given image and prompt embeddings.

Parameters:

Name Type Description Default
image_embeddings Tensor

Embeddings from the image encoder.

required
image_pe Tensor

Positional encoding with the shape of image_embeddings.

required
sparse_prompt_embeddings Tensor

Embeddings of the points and boxes.

required
dense_prompt_embeddings Tensor

Embeddings of the mask inputs.

required
multimask_output bool

Whether to return multiple masks or a single mask.

required

Returns:

Name Type Description
masks Tensor

Batched predicted masks.

iou_pred Tensor

Batched predictions of mask quality.

Examples:

>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
>>> image_emb = torch.rand(1, 256, 64, 64)
>>> image_pe = torch.rand(1, 256, 64, 64)
>>> sparse_emb = torch.rand(1, 2, 256)
>>> dense_emb = torch.rand(1, 256, 64, 64)
>>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)
>>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
Source code in ultralytics/models/sam/modules/decoders.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def forward(
    self,
    image_embeddings: torch.Tensor,
    image_pe: torch.Tensor,
    sparse_prompt_embeddings: torch.Tensor,
    dense_prompt_embeddings: torch.Tensor,
    multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Predict masks given image and prompt embeddings.

    Args:
        image_embeddings (torch.Tensor): Embeddings from the image encoder.
        image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
        sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
        dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
        multimask_output (bool): Whether to return multiple masks or a single mask.

    Returns:
        masks (torch.Tensor): Batched predicted masks.
        iou_pred (torch.Tensor): Batched predictions of mask quality.

    Examples:
        >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
        >>> image_emb = torch.rand(1, 256, 64, 64)
        >>> image_pe = torch.rand(1, 256, 64, 64)
        >>> sparse_emb = torch.rand(1, 2, 256)
        >>> dense_emb = torch.rand(1, 256, 64, 64)
        >>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)
        >>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
    """
    masks, iou_pred = self.predict_masks(
        image_embeddings=image_embeddings,
        image_pe=image_pe,
        sparse_prompt_embeddings=sparse_prompt_embeddings,
        dense_prompt_embeddings=dense_prompt_embeddings,
    )

    # Select the correct mask or masks for output
    mask_slice = slice(1, None) if multimask_output else slice(0, 1)
    masks = masks[:, mask_slice, :, :]
    iou_pred = iou_pred[:, mask_slice]

    # Prepare output
    return masks, iou_pred

predict_masks

predict_masks(
    image_embeddings: Tensor,
    image_pe: Tensor,
    sparse_prompt_embeddings: Tensor,
    dense_prompt_embeddings: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]

Predict masks and quality scores using image and prompt embeddings via transformer architecture.

Source code in ultralytics/models/sam/modules/decoders.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def predict_masks(
    self,
    image_embeddings: torch.Tensor,
    image_pe: torch.Tensor,
    sparse_prompt_embeddings: torch.Tensor,
    dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Predict masks and quality scores using image and prompt embeddings via transformer architecture."""
    # Concatenate output tokens
    output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
    output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
    tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

    # Expand per-image data in batch direction to be per-mask
    src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
    src = src + dense_prompt_embeddings
    pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
    b, c, h, w = src.shape

    # Run the transformer
    hs, src = self.transformer(src, pos_src, tokens)
    iou_token_out = hs[:, 0, :]
    mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

    # Upscale mask embeddings and predict masks using the mask tokens
    src = src.transpose(1, 2).view(b, c, h, w)
    upscaled_embedding = self.output_upscaling(src)
    hyper_in_list: List[torch.Tensor] = [
        self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
    ]
    hyper_in = torch.stack(hyper_in_list, dim=1)
    b, c, h, w = upscaled_embedding.shape
    masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

    # Generate mask quality predictions
    iou_pred = self.iou_prediction_head(iou_token_out)

    return masks, iou_pred





ultralytics.models.sam.modules.decoders.SAM2MaskDecoder

SAM2MaskDecoder(
    transformer_dim: int,
    transformer: Module,
    num_multimask_outputs: int = 3,
    activation: Type[Module] = nn.GELU,
    iou_head_depth: int = 3,
    iou_head_hidden_dim: int = 256,
    use_high_res_features: bool = False,
    iou_prediction_use_sigmoid=False,
    dynamic_multimask_via_stability=False,
    dynamic_multimask_stability_delta=0.05,
    dynamic_multimask_stability_thresh=0.98,
    pred_obj_scores: bool = False,
    pred_obj_scores_mlp: bool = False,
    use_multimask_token_for_obj_ptr: bool = False,
)

Bases: Module

Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.

This class extends the functionality of the MaskDecoder, incorporating additional features such as high-resolution feature processing, dynamic multimask output, and object score prediction.

Attributes:

Name Type Description
transformer_dim int

Channel dimension of the transformer.

transformer Module

Transformer used to predict masks.

num_multimask_outputs int

Number of masks to predict when disambiguating masks.

iou_token Embedding

Embedding for IOU token.

num_mask_tokens int

Total number of mask tokens.

mask_tokens Embedding

Embedding for mask tokens.

pred_obj_scores bool

Whether to predict object scores.

obj_score_token Embedding

Embedding for object score token.

use_multimask_token_for_obj_ptr bool

Whether to use multimask token for object pointer.

output_upscaling Sequential

Upscaling layers for output.

use_high_res_features bool

Whether to use high-resolution features.

conv_s0 Conv2d

Convolutional layer for high-resolution features (s0).

conv_s1 Conv2d

Convolutional layer for high-resolution features (s1).

output_hypernetworks_mlps ModuleList

List of MLPs for output hypernetworks.

iou_prediction_head MLP

MLP for IOU prediction.

pred_obj_score_head Linear | MLP

Linear layer or MLP for object score prediction.

dynamic_multimask_via_stability bool

Whether to use dynamic multimask via stability.

dynamic_multimask_stability_delta float

Delta value for dynamic multimask stability.

dynamic_multimask_stability_thresh float

Threshold for dynamic multimask stability.

Methods:

Name Description
forward

Predicts masks given image and prompt embeddings.

predict_masks

Predicts instance segmentation masks from image and prompt embeddings.

_get_stability_scores

Computes mask stability scores based on IoU between thresholds.

_dynamic_multimask_via_stability

Dynamically selects the most stable mask output.

Examples:

>>> image_embeddings = torch.rand(1, 256, 64, 64)
>>> image_pe = torch.rand(1, 256, 64, 64)
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
>>> decoder = SAM2MaskDecoder(256, transformer)
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
...     image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
... )

This decoder extends the functionality of MaskDecoder, incorporating additional features such as high-resolution feature processing, dynamic multimask output, and object score prediction.

Parameters:

Name Type Description Default
transformer_dim int

Channel dimension of the transformer.

required
transformer Module

Transformer used to predict masks.

required
num_multimask_outputs int

Number of masks to predict when disambiguating masks.

3
activation Type[Module]

Type of activation to use when upscaling masks.

GELU
iou_head_depth int

Depth of the MLP used to predict mask quality.

3
iou_head_hidden_dim int

Hidden dimension of the MLP used to predict mask quality.

256
use_high_res_features bool

Whether to use high-resolution features.

False
iou_prediction_use_sigmoid bool

Whether to use sigmoid for IOU prediction.

False
dynamic_multimask_via_stability bool

Whether to use dynamic multimask via stability.

False
dynamic_multimask_stability_delta float

Delta value for dynamic multimask stability.

0.05
dynamic_multimask_stability_thresh float

Threshold for dynamic multimask stability.

0.98
pred_obj_scores bool

Whether to predict object scores.

False
pred_obj_scores_mlp bool

Whether to use MLP for object score prediction.

False
use_multimask_token_for_obj_ptr bool

Whether to use multimask token for object pointer.

False

Examples:

>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
>>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
>>> print(decoder)
Source code in ultralytics/models/sam/modules/decoders.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def __init__(
    self,
    transformer_dim: int,
    transformer: nn.Module,
    num_multimask_outputs: int = 3,
    activation: Type[nn.Module] = nn.GELU,
    iou_head_depth: int = 3,
    iou_head_hidden_dim: int = 256,
    use_high_res_features: bool = False,
    iou_prediction_use_sigmoid=False,
    dynamic_multimask_via_stability=False,
    dynamic_multimask_stability_delta=0.05,
    dynamic_multimask_stability_thresh=0.98,
    pred_obj_scores: bool = False,
    pred_obj_scores_mlp: bool = False,
    use_multimask_token_for_obj_ptr: bool = False,
) -> None:
    """
    Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.

    This decoder extends the functionality of MaskDecoder, incorporating additional features such as
    high-resolution feature processing, dynamic multimask output, and object score prediction.

    Args:
        transformer_dim (int): Channel dimension of the transformer.
        transformer (nn.Module): Transformer used to predict masks.
        num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
        activation (Type[nn.Module]): Type of activation to use when upscaling masks.
        iou_head_depth (int): Depth of the MLP used to predict mask quality.
        iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
        use_high_res_features (bool): Whether to use high-resolution features.
        iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
        dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
        dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
        dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
        pred_obj_scores (bool): Whether to predict object scores.
        pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
        use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.

    Examples:
        >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
        >>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
        >>> print(decoder)
    """
    super().__init__()
    self.transformer_dim = transformer_dim
    self.transformer = transformer

    self.num_multimask_outputs = num_multimask_outputs

    self.iou_token = nn.Embedding(1, transformer_dim)
    self.num_mask_tokens = num_multimask_outputs + 1
    self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

    self.pred_obj_scores = pred_obj_scores
    if self.pred_obj_scores:
        self.obj_score_token = nn.Embedding(1, transformer_dim)
    self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr

    self.output_upscaling = nn.Sequential(
        nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
        LayerNorm2d(transformer_dim // 4),
        activation(),
        nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
        activation(),
    )
    self.use_high_res_features = use_high_res_features
    if use_high_res_features:
        self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
        self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)

    self.output_hypernetworks_mlps = nn.ModuleList(
        [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
    )

    self.iou_prediction_head = MLP(
        transformer_dim,
        iou_head_hidden_dim,
        self.num_mask_tokens,
        iou_head_depth,
        sigmoid=iou_prediction_use_sigmoid,
    )
    if self.pred_obj_scores:
        self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
        if pred_obj_scores_mlp:
            self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)

    # When outputting a single mask, optionally we can dynamically fall back to the best
    # multimask output token if the single mask output token gives low stability scores.
    self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
    self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
    self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh

forward

forward(
    image_embeddings: Tensor,
    image_pe: Tensor,
    sparse_prompt_embeddings: Tensor,
    dense_prompt_embeddings: Tensor,
    multimask_output: bool,
    repeat_image: bool,
    high_res_features: Optional[List[Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Predict masks given image and prompt embeddings.

Parameters:

Name Type Description Default
image_embeddings Tensor

Embeddings from the image encoder with shape (B, C, H, W).

required
image_pe Tensor

Positional encoding with the shape of image_embeddings (B, C, H, W).

required
sparse_prompt_embeddings Tensor

Embeddings of the points and boxes with shape (B, N, C).

required
dense_prompt_embeddings Tensor

Embeddings of the mask inputs with shape (B, C, H, W).

required
multimask_output bool

Whether to return multiple masks or a single mask.

required
repeat_image bool

Flag to repeat the image embeddings.

required
high_res_features List[Tensor] | None

Optional high-resolution features.

None

Returns:

Name Type Description
masks Tensor

Batched predicted masks with shape (B, N, H, W).

iou_pred Tensor

Batched predictions of mask quality with shape (B, N).

sam_tokens_out Tensor

Batched SAM token for mask output with shape (B, N, C).

object_score_logits Tensor

Batched object score logits with shape (B, 1).

Examples:

>>> image_embeddings = torch.rand(1, 256, 64, 64)
>>> image_pe = torch.rand(1, 256, 64, 64)
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
>>> decoder = SAM2MaskDecoder(256, transformer)
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
...     image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
... )
Source code in ultralytics/models/sam/modules/decoders.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
def forward(
    self,
    image_embeddings: torch.Tensor,
    image_pe: torch.Tensor,
    sparse_prompt_embeddings: torch.Tensor,
    dense_prompt_embeddings: torch.Tensor,
    multimask_output: bool,
    repeat_image: bool,
    high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Predict masks given image and prompt embeddings.

    Args:
        image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
        image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).
        sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).
        dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
        multimask_output (bool): Whether to return multiple masks or a single mask.
        repeat_image (bool): Flag to repeat the image embeddings.
        high_res_features (List[torch.Tensor] | None): Optional high-resolution features.

    Returns:
        masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
        iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
        sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
        object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).

    Examples:
        >>> image_embeddings = torch.rand(1, 256, 64, 64)
        >>> image_pe = torch.rand(1, 256, 64, 64)
        >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
        >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
        >>> decoder = SAM2MaskDecoder(256, transformer)
        >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
        ...     image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
        ... )
    """
    masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
        image_embeddings=image_embeddings,
        image_pe=image_pe,
        sparse_prompt_embeddings=sparse_prompt_embeddings,
        dense_prompt_embeddings=dense_prompt_embeddings,
        repeat_image=repeat_image,
        high_res_features=high_res_features,
    )

    # Select the correct mask or masks for output
    if multimask_output:
        masks = masks[:, 1:, :, :]
        iou_pred = iou_pred[:, 1:]
    elif self.dynamic_multimask_via_stability and not self.training:
        masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
    else:
        masks = masks[:, 0:1, :, :]
        iou_pred = iou_pred[:, 0:1]

    if multimask_output and self.use_multimask_token_for_obj_ptr:
        sam_tokens_out = mask_tokens_out[:, 1:]  # [b, 3, c] shape
    else:
        # Take the mask output token. Here we *always* use the token for single mask output.
        # At test time, even if we track after 1-click (and using multimask_output=True),
        # we still take the single mask token here. The rationale is that we always track
        # after multiple clicks during training, so the past tokens seen during training
        # are always the single mask token (and we'll let it be the object-memory token).
        sam_tokens_out = mask_tokens_out[:, 0:1]  # [b, 1, c] shape

    # Prepare output
    return masks, iou_pred, sam_tokens_out, object_score_logits

predict_masks

predict_masks(
    image_embeddings: Tensor,
    image_pe: Tensor,
    sparse_prompt_embeddings: Tensor,
    dense_prompt_embeddings: Tensor,
    repeat_image: bool,
    high_res_features: Optional[List[Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Predict instance segmentation masks from image and prompt embeddings using a transformer.

Source code in ultralytics/models/sam/modules/decoders.py
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def predict_masks(
    self,
    image_embeddings: torch.Tensor,
    image_pe: torch.Tensor,
    sparse_prompt_embeddings: torch.Tensor,
    dense_prompt_embeddings: torch.Tensor,
    repeat_image: bool,
    high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Predict instance segmentation masks from image and prompt embeddings using a transformer."""
    # Concatenate output tokens
    s = 0
    if self.pred_obj_scores:
        output_tokens = torch.cat(
            [
                self.obj_score_token.weight,
                self.iou_token.weight,
                self.mask_tokens.weight,
            ],
            dim=0,
        )
        s = 1
    else:
        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
    output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
    tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

    # Expand per-image data in batch direction to be per-mask
    if repeat_image:
        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
    else:
        assert image_embeddings.shape[0] == tokens.shape[0]
        src = image_embeddings
    src = src + dense_prompt_embeddings
    assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
    pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
    b, c, h, w = src.shape

    # Run the transformer
    hs, src = self.transformer(src, pos_src, tokens)
    iou_token_out = hs[:, s, :]
    mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]

    # Upscale mask embeddings and predict masks using the mask tokens
    src = src.transpose(1, 2).view(b, c, h, w)
    if not self.use_high_res_features:
        upscaled_embedding = self.output_upscaling(src)
    else:
        dc1, ln1, act1, dc2, act2 = self.output_upscaling
        feat_s0, feat_s1 = high_res_features
        upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
        upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)

    hyper_in_list: List[torch.Tensor] = [
        self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
    ]
    hyper_in = torch.stack(hyper_in_list, dim=1)
    b, c, h, w = upscaled_embedding.shape
    masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

    # Generate mask quality predictions
    iou_pred = self.iou_prediction_head(iou_token_out)
    if self.pred_obj_scores:
        assert s == 1
        object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
    else:
        # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
        object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)

    return masks, iou_pred, mask_tokens_out, object_score_logits





📅 Created 1 year ago ✏️ Updated 8 months ago