Skip to content

Reference for ultralytics/models/sam/modules/sam.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/sam.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.sam.modules.sam.SAMModel

SAMModel(
    image_encoder: ImageEncoderViT,
    prompt_encoder: PromptEncoder,
    mask_decoder: MaskDecoder,
    pixel_mean: List[float] = (123.675, 116.28, 103.53),
    pixel_std: List[float] = (58.395, 57.12, 57.375),
)

Bases: Module

Segment Anything Model (SAM) for object segmentation tasks.

This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input prompts.

Attributes:

Name Type Description
mask_threshold float

Threshold value for mask prediction.

image_encoder ImageEncoderViT

Backbone for encoding images into embeddings.

prompt_encoder PromptEncoder

Encoder for various types of input prompts.

mask_decoder MaskDecoder

Predicts object masks from image and prompt embeddings.

pixel_mean Tensor

Mean values for normalizing pixels in the input image.

pixel_std Tensor

Standard deviation values for normalizing pixels in the input image.

Methods:

Name Description
set_imgsz

Set image size to make model compatible with different image sizes.

Examples:

>>> image_encoder = ImageEncoderViT(...)
>>> prompt_encoder = PromptEncoder(...)
>>> mask_decoder = MaskDecoder(...)
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
>>> # Further usage depends on SAMPredictor class
Notes

All forward() operations are implemented in the SAMPredictor class.

Parameters:

Name Type Description Default
image_encoder ImageEncoderViT

The backbone used to encode the image into image embeddings.

required
prompt_encoder PromptEncoder

Encodes various types of input prompts.

required
mask_decoder MaskDecoder

Predicts masks from the image embeddings and encoded prompts.

required
pixel_mean List[float]

Mean values for normalizing pixels in the input image.

(123.675, 116.28, 103.53)
pixel_std List[float]

Standard deviation values for normalizing pixels in the input image.

(58.395, 57.12, 57.375)

Examples:

>>> image_encoder = ImageEncoderViT(...)
>>> prompt_encoder = PromptEncoder(...)
>>> mask_decoder = MaskDecoder(...)
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
>>> # Further usage depends on SAMPredictor class
Notes

All forward() operations moved to SAMPredictor.

Source code in ultralytics/models/sam/modules/sam.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    image_encoder: ImageEncoderViT,
    prompt_encoder: PromptEncoder,
    mask_decoder: MaskDecoder,
    pixel_mean: List[float] = (123.675, 116.28, 103.53),
    pixel_std: List[float] = (58.395, 57.12, 57.375),
) -> None:
    """
    Initialize the SAMModel class to predict object masks from an image and input prompts.

    Args:
        image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
        prompt_encoder (PromptEncoder): Encodes various types of input prompts.
        mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
        pixel_mean (List[float]): Mean values for normalizing pixels in the input image.
        pixel_std (List[float]): Standard deviation values for normalizing pixels in the input image.

    Examples:
        >>> image_encoder = ImageEncoderViT(...)
        >>> prompt_encoder = PromptEncoder(...)
        >>> mask_decoder = MaskDecoder(...)
        >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
        >>> # Further usage depends on SAMPredictor class

    Notes:
        All forward() operations moved to SAMPredictor.
    """
    super().__init__()
    self.image_encoder = image_encoder
    self.prompt_encoder = prompt_encoder
    self.mask_decoder = mask_decoder
    self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
    self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

set_imgsz

set_imgsz(imgsz)

Set image size to make model compatible with different image sizes.

Source code in ultralytics/models/sam/modules/sam.py
 94
 95
 96
 97
 98
 99
100
def set_imgsz(self, imgsz):
    """Set image size to make model compatible with different image sizes."""
    if hasattr(self.image_encoder, "set_imgsz"):
        self.image_encoder.set_imgsz(imgsz)
    self.prompt_encoder.input_image_size = imgsz
    self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz]  # 16 is fixed as patch size of ViT model
    self.image_encoder.img_size = imgsz[0]





ultralytics.models.sam.modules.sam.SAM2Model

SAM2Model(
    image_encoder,
    memory_attention,
    memory_encoder,
    num_maskmem=7,
    image_size=512,
    backbone_stride=16,
    sigmoid_scale_for_mem_enc=1.0,
    sigmoid_bias_for_mem_enc=0.0,
    binarize_mask_from_pts_for_mem_enc=False,
    use_mask_input_as_output_without_sam=False,
    max_cond_frames_in_attn=-1,
    directly_add_no_mem_embed=False,
    use_high_res_features_in_sam=False,
    multimask_output_in_sam=False,
    multimask_min_pt_num=1,
    multimask_max_pt_num=1,
    multimask_output_for_tracking=False,
    use_multimask_token_for_obj_ptr: bool = False,
    iou_prediction_use_sigmoid=False,
    memory_temporal_stride_for_eval=1,
    non_overlap_masks_for_mem_enc=False,
    use_obj_ptrs_in_encoder=False,
    max_obj_ptrs_in_encoder=16,
    add_tpos_enc_to_obj_ptrs=True,
    proj_tpos_enc_in_obj_ptrs=False,
    use_signed_tpos_enc_to_obj_ptrs=False,
    only_obj_ptrs_in_the_past_for_eval=False,
    pred_obj_scores: bool = False,
    pred_obj_scores_mlp: bool = False,
    fixed_no_obj_ptr: bool = False,
    soft_no_obj_ptr: bool = False,
    use_mlp_for_obj_ptr_proj: bool = False,
    no_obj_embed_spatial: bool = False,
    sam_mask_decoder_extra_args=None,
    compile_image_encoder: bool = False,
)

Bases: Module

SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.

This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal consistency and efficient tracking of objects across frames.

Attributes:

Name Type Description
mask_threshold float

Threshold value for mask prediction.

image_encoder ImageEncoderViT

Visual encoder for extracting image features.

memory_attention Module

Module for attending to memory features.

memory_encoder Module

Encoder for generating memory representations.

num_maskmem int

Number of accessible memory frames.

image_size int

Size of input images.

backbone_stride int

Stride of the backbone network output.

sam_prompt_embed_dim int

Dimension of SAM prompt embeddings.

sam_image_embedding_size int

Size of SAM image embeddings.

sam_prompt_encoder PromptEncoder

Encoder for processing input prompts.

sam_mask_decoder SAM2MaskDecoder

Decoder for generating object masks.

obj_ptr_proj Module

Projection layer for object pointers.

obj_ptr_tpos_proj Module

Projection for temporal positional encoding in object pointers.

hidden_dim int

Hidden dimension of the model.

mem_dim int

Memory dimension for encoding features.

use_high_res_features_in_sam bool

Whether to use high-resolution feature maps in the SAM mask decoder.

use_obj_ptrs_in_encoder bool

Whether to cross-attend to object pointers from other frames in the encoder.

max_obj_ptrs_in_encoder int

Maximum number of object pointers from other frames in encoder cross-attention.

add_tpos_enc_to_obj_ptrs bool

Whether to add temporal positional encoding to object pointers.

proj_tpos_enc_in_obj_ptrs bool

Whether to add an extra linear projection layer for temporal positional encoding in object pointers.

use_signed_tpos_enc_to_obj_ptrs bool

Whether to use signed distance in temporal positional encoding.

only_obj_ptrs_in_the_past_for_eval bool

Whether to only attend to object pointers in the past during evaluation.

pred_obj_scores bool

Whether to predict if there is an object in the frame.

pred_obj_scores_mlp bool

Whether to use an MLP to predict object scores.

fixed_no_obj_ptr bool

Whether to have a fixed no-object pointer when there is no object present.

soft_no_obj_ptr bool

Whether to mix in no-object pointer softly for easier recovery and error mitigation.

use_mlp_for_obj_ptr_proj bool

Whether to use MLP for object pointer projection.

no_obj_embed_spatial Tensor | None

No-object embedding for spatial frames.

max_cond_frames_in_attn int

Maximum number of conditioning frames to participate in memory attention.

directly_add_no_mem_embed bool

Whether to directly add no-memory embedding to image feature on the first frame.

multimask_output_in_sam bool

Whether to output multiple masks for the first click on initial conditioning frames.

multimask_min_pt_num int

Minimum number of clicks to use multimask output in SAM.

multimask_max_pt_num int

Maximum number of clicks to use multimask output in SAM.

multimask_output_for_tracking bool

Whether to use multimask output for tracking.

use_multimask_token_for_obj_ptr bool

Whether to use multimask tokens for object pointers.

iou_prediction_use_sigmoid bool

Whether to use sigmoid to restrict IoU prediction to [0-1].

memory_temporal_stride_for_eval int

Memory bank's temporal stride during evaluation.

non_overlap_masks_for_mem_enc bool

Whether to apply non-overlapping constraints on object masks in memory encoder during evaluation.

sigmoid_scale_for_mem_enc float

Scale factor for mask sigmoid probability.

sigmoid_bias_for_mem_enc float

Bias factor for mask sigmoid probability.

binarize_mask_from_pts_for_mem_enc bool

Whether to binarize sigmoid mask logits on interacted frames with clicks during evaluation.

use_mask_input_as_output_without_sam bool

Whether to directly output the input mask without using SAM prompt encoder and mask decoder on frames with mask input.

Methods:

Name Description
forward_image

Process image batch through encoder to extract multi-level features.

track_step

Perform a single tracking step, updating object masks and memory features.

set_binarize

Set binarize for VideoPredictor.

set_imgsz

Set image size to make model compatible with different image sizes.

Examples:

>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
>>> image_batch = torch.rand(1, 3, 512, 512)
>>> features = model.forward_image(image_batch)
>>> track_results = model.track_step(0, True, features, None, None, None, {})

Parameters:

Name Type Description Default
image_encoder Module

Visual encoder for extracting image features.

required
memory_attention Module

Module for attending to memory features.

required
memory_encoder Module

Encoder for generating memory representations.

required
num_maskmem int

Number of accessible memory frames.

7
image_size int

Size of input images.

512
backbone_stride int

Stride of the image backbone output.

16
sigmoid_scale_for_mem_enc float

Scale factor for mask sigmoid probability.

1.0
sigmoid_bias_for_mem_enc float

Bias factor for mask sigmoid probability.

0.0
binarize_mask_from_pts_for_mem_enc bool

Whether to binarize sigmoid mask logits on interacted frames with clicks during evaluation.

False
use_mask_input_as_output_without_sam bool

Whether to directly output the input mask without using SAM prompt encoder and mask decoder on frames with mask input.

False
max_cond_frames_in_attn int

Maximum number of conditioning frames to participate in memory attention.

-1
directly_add_no_mem_embed bool

Whether to directly add no-memory embedding to image feature on the first frame.

False
use_high_res_features_in_sam bool

Whether to use high-resolution feature maps in the SAM mask decoder.

False
multimask_output_in_sam bool

Whether to output multiple masks for the first click on initial conditioning frames.

False
multimask_min_pt_num int

Minimum number of clicks to use multimask output in SAM.

1
multimask_max_pt_num int

Maximum number of clicks to use multimask output in SAM.

1
multimask_output_for_tracking bool

Whether to use multimask output for tracking.

False
use_multimask_token_for_obj_ptr bool

Whether to use multimask tokens for object pointers.

False
iou_prediction_use_sigmoid bool

Whether to use sigmoid to restrict IoU prediction to [0-1].

False
memory_temporal_stride_for_eval int

Memory bank's temporal stride during evaluation.

1
non_overlap_masks_for_mem_enc bool

Whether to apply non-overlapping constraints on object masks in memory encoder during evaluation.

False
use_obj_ptrs_in_encoder bool

Whether to cross-attend to object pointers from other frames in the encoder.

False
max_obj_ptrs_in_encoder int

Maximum number of object pointers from other frames in encoder cross-attention.

16
add_tpos_enc_to_obj_ptrs bool

Whether to add temporal positional encoding to object pointers in the encoder.

True
proj_tpos_enc_in_obj_ptrs bool

Whether to add an extra linear projection layer for temporal positional encoding in object pointers.

False
use_signed_tpos_enc_to_obj_ptrs bool

Whether to use signed distance in the temporal positional encoding in the object pointers.

False
only_obj_ptrs_in_the_past_for_eval bool

Whether to only attend to object pointers in the past during evaluation.

False
pred_obj_scores bool

Whether to predict if there is an object in the frame.

False
pred_obj_scores_mlp bool

Whether to use an MLP to predict object scores.

False
fixed_no_obj_ptr bool

Whether to have a fixed no-object pointer when there is no object present.

False
soft_no_obj_ptr bool

Whether to mix in no-object pointer softly for easier recovery and error mitigation.

False
use_mlp_for_obj_ptr_proj bool

Whether to use MLP for object pointer projection.

False
no_obj_embed_spatial bool

Whether add no obj embedding to spatial frames.

False
sam_mask_decoder_extra_args dict | None

Extra arguments for constructing the SAM mask decoder.

None
compile_image_encoder bool

Whether to compile the image encoder for faster inference.

False

Examples:

>>> image_encoder = ImageEncoderViT(...)
>>> memory_attention = SAM2TwoWayTransformer(...)
>>> memory_encoder = nn.Sequential(...)
>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
>>> image_batch = torch.rand(1, 3, 512, 512)
>>> features = model.forward_image(image_batch)
>>> track_results = model.track_step(0, True, features, None, None, None, {})
Source code in ultralytics/models/sam/modules/sam.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def __init__(
    self,
    image_encoder,
    memory_attention,
    memory_encoder,
    num_maskmem=7,
    image_size=512,
    backbone_stride=16,
    sigmoid_scale_for_mem_enc=1.0,
    sigmoid_bias_for_mem_enc=0.0,
    binarize_mask_from_pts_for_mem_enc=False,
    use_mask_input_as_output_without_sam=False,
    max_cond_frames_in_attn=-1,
    directly_add_no_mem_embed=False,
    use_high_res_features_in_sam=False,
    multimask_output_in_sam=False,
    multimask_min_pt_num=1,
    multimask_max_pt_num=1,
    multimask_output_for_tracking=False,
    use_multimask_token_for_obj_ptr: bool = False,
    iou_prediction_use_sigmoid=False,
    memory_temporal_stride_for_eval=1,
    non_overlap_masks_for_mem_enc=False,
    use_obj_ptrs_in_encoder=False,
    max_obj_ptrs_in_encoder=16,
    add_tpos_enc_to_obj_ptrs=True,
    proj_tpos_enc_in_obj_ptrs=False,
    use_signed_tpos_enc_to_obj_ptrs=False,
    only_obj_ptrs_in_the_past_for_eval=False,
    pred_obj_scores: bool = False,
    pred_obj_scores_mlp: bool = False,
    fixed_no_obj_ptr: bool = False,
    soft_no_obj_ptr: bool = False,
    use_mlp_for_obj_ptr_proj: bool = False,
    no_obj_embed_spatial: bool = False,
    sam_mask_decoder_extra_args=None,
    compile_image_encoder: bool = False,
):
    """
    Initialize the SAM2Model for video object segmentation with memory-based tracking.

    Args:
        image_encoder (nn.Module): Visual encoder for extracting image features.
        memory_attention (nn.Module): Module for attending to memory features.
        memory_encoder (nn.Module): Encoder for generating memory representations.
        num_maskmem (int): Number of accessible memory frames.
        image_size (int): Size of input images.
        backbone_stride (int): Stride of the image backbone output.
        sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
        sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
        binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
            with clicks during evaluation.
        use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
            prompt encoder and mask decoder on frames with mask input.
        max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
        directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
            first frame.
        use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
        multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
            conditioning frames.
        multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
        multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
        multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
        use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
        iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
        memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
        non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
            memory encoder during evaluation.
        use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
        max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
            cross-attention.
        add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
            the encoder.
        proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
            encoding in object pointers.
        use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding
            in the object pointers.
        only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
            during evaluation.
        pred_obj_scores (bool): Whether to predict if there is an object in the frame.
        pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
        fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
        soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
        use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
        no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
        sam_mask_decoder_extra_args (dict | None): Extra arguments for constructing the SAM mask decoder.
        compile_image_encoder (bool): Whether to compile the image encoder for faster inference.

    Examples:
        >>> image_encoder = ImageEncoderViT(...)
        >>> memory_attention = SAM2TwoWayTransformer(...)
        >>> memory_encoder = nn.Sequential(...)
        >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
        >>> image_batch = torch.rand(1, 3, 512, 512)
        >>> features = model.forward_image(image_batch)
        >>> track_results = model.track_step(0, True, features, None, None, None, {})
    """
    super().__init__()

    # Part 1: the image backbone
    self.image_encoder = image_encoder
    # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
    self.use_high_res_features_in_sam = use_high_res_features_in_sam
    self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
    self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
    self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
    if use_obj_ptrs_in_encoder:
        # A conv layer to downsample the mask prompt to stride 4 (the same stride as
        # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
        # so that it can be fed into the SAM mask decoder to generate a pointer.
        self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
    self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
    if proj_tpos_enc_in_obj_ptrs:
        assert add_tpos_enc_to_obj_ptrs  # these options need to be used together
    self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
    self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
    self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval

    # Part 2: memory attention to condition current frame's visual features
    # with memories (and obj ptrs) from past frames
    self.memory_attention = memory_attention
    self.hidden_dim = memory_attention.d_model

    # Part 3: memory encoder for the previous frame's outputs
    self.memory_encoder = memory_encoder
    self.mem_dim = self.hidden_dim
    if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
        # if there is compression of memories along channel dim
        self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
    self.num_maskmem = num_maskmem  # Number of memories accessible
    # Temporal encoding of the memories
    self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
    trunc_normal_(self.maskmem_tpos_enc, std=0.02)
    # a single token to indicate no memory embedding from previous frames
    self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
    self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
    trunc_normal_(self.no_mem_embed, std=0.02)
    trunc_normal_(self.no_mem_pos_enc, std=0.02)
    self.directly_add_no_mem_embed = directly_add_no_mem_embed
    # Apply sigmoid to the output raw mask logits (to turn them from
    # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
    self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
    self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
    self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
    self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
    self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
    # On frames with mask input, whether to directly output the input mask without
    # using a SAM prompt encoder + mask decoder
    self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
    self.multimask_output_in_sam = multimask_output_in_sam
    self.multimask_min_pt_num = multimask_min_pt_num
    self.multimask_max_pt_num = multimask_max_pt_num
    self.multimask_output_for_tracking = multimask_output_for_tracking
    self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
    self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid

    # Part 4: SAM-style prompt encoder (for both mask and point inputs)
    # and SAM-style mask decoder for the final mask output
    self.image_size = image_size
    self.backbone_stride = backbone_stride
    self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
    self.pred_obj_scores = pred_obj_scores
    self.pred_obj_scores_mlp = pred_obj_scores_mlp
    self.fixed_no_obj_ptr = fixed_no_obj_ptr
    self.soft_no_obj_ptr = soft_no_obj_ptr
    if self.fixed_no_obj_ptr:
        assert self.pred_obj_scores
        assert self.use_obj_ptrs_in_encoder
    if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
        self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
        trunc_normal_(self.no_obj_ptr, std=0.02)
    self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
    self.no_obj_embed_spatial = None
    if no_obj_embed_spatial:
        self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
        trunc_normal_(self.no_obj_embed_spatial, std=0.02)

    self._build_sam_heads()
    self.max_cond_frames_in_attn = max_cond_frames_in_attn

    # Model compilation
    if compile_image_encoder:
        # Compile the forward function (not the full module) to allow loading checkpoints.
        LOGGER.info("Image encoder compilation is enabled. First forward pass will be slow.")
        self.image_encoder.forward = torch.compile(
            self.image_encoder.forward,
            mode="max-autotune",
            fullgraph=True,
            dynamic=False,
        )

device property

device

Return the device on which the model's parameters are stored.

forward

forward(*args, **kwargs)

Process image and prompt inputs to generate object masks and scores in video sequences.

Source code in ultralytics/models/sam/modules/sam.py
372
373
374
375
376
377
def forward(self, *args, **kwargs):
    """Process image and prompt inputs to generate object masks and scores in video sequences."""
    raise NotImplementedError(
        "Please use the corresponding methods in SAM2VideoPredictor for inference."
        "See notebooks/video_predictor_example.ipynb for an example."
    )

forward_image

forward_image(img_batch: Tensor)

Process image batch through encoder to extract multi-level features for SAM model.

Source code in ultralytics/models/sam/modules/sam.py
626
627
628
629
630
631
632
633
634
def forward_image(self, img_batch: torch.Tensor):
    """Process image batch through encoder to extract multi-level features for SAM model."""
    backbone_out = self.image_encoder(img_batch)
    if self.use_high_res_features_in_sam:
        # precompute projected level 0 and level 1 features in SAM decoder
        # to avoid running it again on every SAM click
        backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
        backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
    return backbone_out

set_binarize

set_binarize(binarize=False)

Set binarize for VideoPredictor.

Source code in ultralytics/models/sam/modules/sam.py
1029
1030
1031
def set_binarize(self, binarize=False):
    """Set binarize for VideoPredictor."""
    self.binarize_mask_from_pts_for_mem_enc = binarize

set_imgsz

set_imgsz(imgsz)

Set image size to make model compatible with different image sizes.

Source code in ultralytics/models/sam/modules/sam.py
1033
1034
1035
1036
1037
def set_imgsz(self, imgsz):
    """Set image size to make model compatible with different image sizes."""
    self.image_size = imgsz[0]
    self.sam_prompt_encoder.input_image_size = imgsz
    self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz]  # fixed ViT patch size of 16

track_step

track_step(
    frame_idx,
    is_init_cond_frame,
    current_vision_feats,
    current_vision_pos_embeds,
    feat_sizes,
    point_inputs,
    mask_inputs,
    output_dict,
    num_frames,
    track_in_reverse=False,
    run_mem_encoder=True,
    prev_sam_mask_logits=None,
)

Perform a single tracking step, updating object masks and memory features based on current frame inputs.

Source code in ultralytics/models/sam/modules/sam.py
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
def track_step(
    self,
    frame_idx,
    is_init_cond_frame,
    current_vision_feats,
    current_vision_pos_embeds,
    feat_sizes,
    point_inputs,
    mask_inputs,
    output_dict,
    num_frames,
    track_in_reverse=False,  # tracking in reverse time order (for demo usage)
    # Whether to run the memory encoder on the predicted masks. Sometimes we might want
    # to skip the memory encoder with `run_mem_encoder=False`. For example,
    # in demo we might call `track_step` multiple times for each user click,
    # and only encode the memory when the user finalizes their clicks. And in ablation
    # settings like SAM training on static images, we don't need the memory encoder.
    run_mem_encoder=True,
    # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
    prev_sam_mask_logits=None,
):
    """Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
    current_out, sam_outputs, _, _ = self._track_step(
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        feat_sizes,
        point_inputs,
        mask_inputs,
        output_dict,
        num_frames,
        track_in_reverse,
        prev_sam_mask_logits,
    )
    _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs

    current_out["pred_masks"] = low_res_masks
    current_out["pred_masks_high_res"] = high_res_masks
    current_out["obj_ptr"] = obj_ptr
    if not self.training:
        # Only add this in inference (to avoid unused param in activation checkpointing;
        # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
        current_out["object_score_logits"] = object_score_logits

    # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
    self._encode_memory_in_output(
        current_vision_feats,
        feat_sizes,
        point_inputs,
        run_mem_encoder,
        high_res_masks,
        object_score_logits,
        current_out,
    )

    return current_out





📅 Created 1 year ago ✏️ Updated 9 months ago