Skip to content

Reference for ultralytics/models/sam/modules/sam.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/sam.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.sam.modules.sam.SAMModel

SAMModel(
    image_encoder: ImageEncoderViT,
    prompt_encoder: PromptEncoder,
    mask_decoder: MaskDecoder,
    pixel_mean: List[float] = (123.675, 116.28, 103.53),
    pixel_std: List[float] = (58.395, 57.12, 57.375),
)

Bases: Module

Segment Anything Model (SAM) for object segmentation tasks.

This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input prompts.

Attributes:

Name Type Description
mask_threshold float

Threshold value for mask prediction.

image_encoder ImageEncoderViT

Backbone for encoding images into embeddings.

prompt_encoder PromptEncoder

Encoder for various types of input prompts.

mask_decoder MaskDecoder

Predicts object masks from image and prompt embeddings.

pixel_mean Tensor

Mean pixel values for image normalization, shape (3, 1, 1).

pixel_std Tensor

Standard deviation values for image normalization, shape (3, 1, 1).

Methods:

Name Description

Examples:

>>> image_encoder = ImageEncoderViT(...)
>>> prompt_encoder = PromptEncoder(...)
>>> mask_decoder = MaskDecoder(...)
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
>>> # Further usage depends on SAMPredictor class
Notes

All forward() operations are implemented in the SAMPredictor class.

Parameters:

Name Type Description Default
image_encoder ImageEncoderViT

The backbone used to encode the image into image embeddings.

required
prompt_encoder PromptEncoder

Encodes various types of input prompts.

required
mask_decoder MaskDecoder

Predicts masks from the image embeddings and encoded prompts.

required
pixel_mean List[float]

Mean values for normalizing pixels in the input image.

(123.675, 116.28, 103.53)
pixel_std List[float]

Std values for normalizing pixels in the input image.

(58.395, 57.12, 57.375)

Examples:

>>> image_encoder = ImageEncoderViT(...)
>>> prompt_encoder = PromptEncoder(...)
>>> mask_decoder = MaskDecoder(...)
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
>>> # Further usage depends on SAMPredictor class
Notes

All forward() operations moved to SAMPredictor.

Source code in ultralytics/models/sam/modules/sam.py
def __init__(
    self,
    image_encoder: ImageEncoderViT,
    prompt_encoder: PromptEncoder,
    mask_decoder: MaskDecoder,
    pixel_mean: List[float] = (123.675, 116.28, 103.53),
    pixel_std: List[float] = (58.395, 57.12, 57.375),
) -> None:
    """
    Initialize the SAMModel class to predict object masks from an image and input prompts.

    Args:
        image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
        prompt_encoder (PromptEncoder): Encodes various types of input prompts.
        mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
        pixel_mean (List[float]): Mean values for normalizing pixels in the input image.
        pixel_std (List[float]): Std values for normalizing pixels in the input image.

    Examples:
        >>> image_encoder = ImageEncoderViT(...)
        >>> prompt_encoder = PromptEncoder(...)
        >>> mask_decoder = MaskDecoder(...)
        >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
        >>> # Further usage depends on SAMPredictor class

    Notes:
        All forward() operations moved to SAMPredictor.
    """
    super().__init__()
    self.image_encoder = image_encoder
    self.prompt_encoder = prompt_encoder
    self.mask_decoder = mask_decoder
    self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
    self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

set_imgsz

set_imgsz(imgsz)

Set image size to make model compatible with different image sizes.

Parameters:

Name Type Description Default
imgsz Tuple[int, int]

The size of the input image.

required
Source code in ultralytics/models/sam/modules/sam.py
def set_imgsz(self, imgsz):
    """
    Set image size to make model compatible with different image sizes.

    Args:
        imgsz (Tuple[int, int]): The size of the input image.
    """
    if hasattr(self.image_encoder, "set_imgsz"):
        self.image_encoder.set_imgsz(imgsz)
    self.prompt_encoder.input_image_size = imgsz
    self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz]  # 16 is fixed as patch size of ViT model
    self.image_encoder.img_size = imgsz[0]





ultralytics.models.sam.modules.sam.SAM2Model

SAM2Model(
    image_encoder,
    memory_attention,
    memory_encoder,
    num_maskmem=7,
    image_size=512,
    backbone_stride=16,
    sigmoid_scale_for_mem_enc=1.0,
    sigmoid_bias_for_mem_enc=0.0,
    binarize_mask_from_pts_for_mem_enc=False,
    use_mask_input_as_output_without_sam=False,
    max_cond_frames_in_attn=-1,
    directly_add_no_mem_embed=False,
    use_high_res_features_in_sam=False,
    multimask_output_in_sam=False,
    multimask_min_pt_num=1,
    multimask_max_pt_num=1,
    multimask_output_for_tracking=False,
    use_multimask_token_for_obj_ptr: bool = False,
    iou_prediction_use_sigmoid=False,
    memory_temporal_stride_for_eval=1,
    add_all_frames_to_correct_as_cond=False,
    non_overlap_masks_for_mem_enc=False,
    use_obj_ptrs_in_encoder=False,
    max_obj_ptrs_in_encoder=16,
    add_tpos_enc_to_obj_ptrs=True,
    proj_tpos_enc_in_obj_ptrs=False,
    only_obj_ptrs_in_the_past_for_eval=False,
    pred_obj_scores: bool = False,
    pred_obj_scores_mlp: bool = False,
    fixed_no_obj_ptr: bool = False,
    soft_no_obj_ptr: bool = False,
    use_mlp_for_obj_ptr_proj: bool = False,
    sam_mask_decoder_extra_args=None,
    compile_image_encoder: bool = False,
)

Bases: Module

SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.

This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal consistency and efficient tracking of objects across frames.

Attributes:

Name Type Description
mask_threshold float

Threshold value for mask prediction.

image_encoder ImageEncoderViT

Visual encoder for extracting image features.

memory_attention Module

Module for attending to memory features.

memory_encoder Module

Encoder for generating memory representations.

num_maskmem int

Number of accessible memory frames.

image_size int

Size of input images.

backbone_stride int

Stride of the backbone network output.

sam_prompt_embed_dim int

Dimension of SAM prompt embeddings.

sam_image_embedding_size int

Size of SAM image embeddings.

sam_prompt_encoder PromptEncoder

Encoder for processing input prompts.

sam_mask_decoder SAM2MaskDecoder

Decoder for generating object masks.

obj_ptr_proj Module

Projection layer for object pointers.

obj_ptr_tpos_proj Module

Projection for temporal positional encoding in object pointers.

Methods:

Name Description
forward_image

Processes image batch through encoder to extract multi-level features.

track_step

Performs a single tracking step, updating object masks and memory features.

Examples:

>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
>>> image_batch = torch.rand(1, 3, 512, 512)
>>> features = model.forward_image(image_batch)
>>> track_results = model.track_step(0, True, features, None, None, None, {})

Parameters:

Name Type Description Default
image_encoder Module

Visual encoder for extracting image features.

required
memory_attention Module

Module for attending to memory features.

required
memory_encoder Module

Encoder for generating memory representations.

required
num_maskmem int

Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames).

7
image_size int

Size of input images.

512
backbone_stride int

Stride of the image backbone output.

16
sigmoid_scale_for_mem_enc float

Scale factor for mask sigmoid probability.

1.0
sigmoid_bias_for_mem_enc float

Bias factor for mask sigmoid probability.

0.0
binarize_mask_from_pts_for_mem_enc bool

Whether to binarize sigmoid mask logits on interacted frames with clicks during evaluation.

False
use_mask_input_as_output_without_sam bool

Whether to directly output the input mask without using SAM prompt encoder and mask decoder on frames with mask input.

False
max_cond_frames_in_attn int

Maximum number of conditioning frames to participate in memory attention. -1 means no limit.

-1
directly_add_no_mem_embed bool

Whether to directly add no-memory embedding to image feature on the first frame.

False
use_high_res_features_in_sam bool

Whether to use high-resolution feature maps in the SAM mask decoder.

False
multimask_output_in_sam bool

Whether to output multiple (3) masks for the first click on initial conditioning frames.

False
multimask_min_pt_num int

Minimum number of clicks to use multimask output in SAM.

1
multimask_max_pt_num int

Maximum number of clicks to use multimask output in SAM.

1
multimask_output_for_tracking bool

Whether to use multimask output for tracking.

False
use_multimask_token_for_obj_ptr bool

Whether to use multimask tokens for object pointers.

False
iou_prediction_use_sigmoid bool

Whether to use sigmoid to restrict IoU prediction to [0-1].

False
memory_temporal_stride_for_eval int

Memory bank's temporal stride during evaluation.

1
add_all_frames_to_correct_as_cond bool

Whether to append frames with correction clicks to conditioning frame list.

False
non_overlap_masks_for_mem_enc bool

Whether to apply non-overlapping constraints on object masks in memory encoder during evaluation.

False
use_obj_ptrs_in_encoder bool

Whether to cross-attend to object pointers from other frames in the encoder.

False
max_obj_ptrs_in_encoder int

Maximum number of object pointers from other frames in encoder cross-attention.

16
add_tpos_enc_to_obj_ptrs bool

Whether to add temporal positional encoding to object pointers in the encoder.

True
proj_tpos_enc_in_obj_ptrs bool

Whether to add an extra linear projection layer for temporal positional encoding in object pointers.

False
only_obj_ptrs_in_the_past_for_eval bool

Whether to only attend to object pointers in the past during evaluation.

False
pred_obj_scores bool

Whether to predict if there is an object in the frame.

False
pred_obj_scores_mlp bool

Whether to use an MLP to predict object scores.

False
fixed_no_obj_ptr bool

Whether to have a fixed no-object pointer when there is no object present.

False
soft_no_obj_ptr bool

Whether to mix in no-object pointer softly for easier recovery and error mitigation.

False
use_mlp_for_obj_ptr_proj bool

Whether to use MLP for object pointer projection.

False
sam_mask_decoder_extra_args Dict | None

Extra arguments for constructing the SAM mask decoder.

None
compile_image_encoder bool

Whether to compile the image encoder for faster inference.

False

Examples:

>>> image_encoder = ImageEncoderViT(...)
>>> memory_attention = SAM2TwoWayTransformer(...)
>>> memory_encoder = nn.Sequential(...)
>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
>>> image_batch = torch.rand(1, 3, 512, 512)
>>> features = model.forward_image(image_batch)
>>> track_results = model.track_step(0, True, features, None, None, None, {})
Source code in ultralytics/models/sam/modules/sam.py
def __init__(
    self,
    image_encoder,
    memory_attention,
    memory_encoder,
    num_maskmem=7,
    image_size=512,
    backbone_stride=16,
    sigmoid_scale_for_mem_enc=1.0,
    sigmoid_bias_for_mem_enc=0.0,
    binarize_mask_from_pts_for_mem_enc=False,
    use_mask_input_as_output_without_sam=False,
    max_cond_frames_in_attn=-1,
    directly_add_no_mem_embed=False,
    use_high_res_features_in_sam=False,
    multimask_output_in_sam=False,
    multimask_min_pt_num=1,
    multimask_max_pt_num=1,
    multimask_output_for_tracking=False,
    use_multimask_token_for_obj_ptr: bool = False,
    iou_prediction_use_sigmoid=False,
    memory_temporal_stride_for_eval=1,
    add_all_frames_to_correct_as_cond=False,
    non_overlap_masks_for_mem_enc=False,
    use_obj_ptrs_in_encoder=False,
    max_obj_ptrs_in_encoder=16,
    add_tpos_enc_to_obj_ptrs=True,
    proj_tpos_enc_in_obj_ptrs=False,
    only_obj_ptrs_in_the_past_for_eval=False,
    pred_obj_scores: bool = False,
    pred_obj_scores_mlp: bool = False,
    fixed_no_obj_ptr: bool = False,
    soft_no_obj_ptr: bool = False,
    use_mlp_for_obj_ptr_proj: bool = False,
    sam_mask_decoder_extra_args=None,
    compile_image_encoder: bool = False,
):
    """
    Initializes the SAM2Model for video object segmentation with memory-based tracking.

    Args:
        image_encoder (nn.Module): Visual encoder for extracting image features.
        memory_attention (nn.Module): Module for attending to memory features.
        memory_encoder (nn.Module): Encoder for generating memory representations.
        num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames).
        image_size (int): Size of input images.
        backbone_stride (int): Stride of the image backbone output.
        sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
        sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
        binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
            with clicks during evaluation.
        use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
            prompt encoder and mask decoder on frames with mask input.
        max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
            -1 means no limit.
        directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
            first frame.
        use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
        multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial
            conditioning frames.
        multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
        multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
        multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
        use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
        iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
        memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
        add_all_frames_to_correct_as_cond (bool): Whether to append frames with correction clicks to conditioning
            frame list.
        non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
            memory encoder during evaluation.
        use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
        max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
            cross-attention.
        add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
            the encoder.
        proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
            encoding in object pointers.
        only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
            during evaluation.
        pred_obj_scores (bool): Whether to predict if there is an object in the frame.
        pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
        fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
        soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
        use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
        sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder.
        compile_image_encoder (bool): Whether to compile the image encoder for faster inference.

    Examples:
        >>> image_encoder = ImageEncoderViT(...)
        >>> memory_attention = SAM2TwoWayTransformer(...)
        >>> memory_encoder = nn.Sequential(...)
        >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
        >>> image_batch = torch.rand(1, 3, 512, 512)
        >>> features = model.forward_image(image_batch)
        >>> track_results = model.track_step(0, True, features, None, None, None, {})
    """
    super().__init__()

    # Part 1: the image backbone
    self.image_encoder = image_encoder
    # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
    self.use_high_res_features_in_sam = use_high_res_features_in_sam
    self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
    self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
    self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
    if use_obj_ptrs_in_encoder:
        # A conv layer to downsample the mask prompt to stride 4 (the same stride as
        # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
        # so that it can be fed into the SAM mask decoder to generate a pointer.
        self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
    self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
    if proj_tpos_enc_in_obj_ptrs:
        assert add_tpos_enc_to_obj_ptrs  # these options need to be used together
    self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
    self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval

    # Part 2: memory attention to condition current frame's visual features
    # with memories (and obj ptrs) from past frames
    self.memory_attention = memory_attention
    self.hidden_dim = memory_attention.d_model

    # Part 3: memory encoder for the previous frame's outputs
    self.memory_encoder = memory_encoder
    self.mem_dim = self.hidden_dim
    if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
        # if there is compression of memories along channel dim
        self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
    self.num_maskmem = num_maskmem  # Number of memories accessible
    # Temporal encoding of the memories
    self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
    trunc_normal_(self.maskmem_tpos_enc, std=0.02)
    # a single token to indicate no memory embedding from previous frames
    self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
    self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
    trunc_normal_(self.no_mem_embed, std=0.02)
    trunc_normal_(self.no_mem_pos_enc, std=0.02)
    self.directly_add_no_mem_embed = directly_add_no_mem_embed
    # Apply sigmoid to the output raw mask logits (to turn them from
    # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
    self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
    self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
    self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
    self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
    self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
    # On frames with mask input, whether to directly output the input mask without
    # using a SAM prompt encoder + mask decoder
    self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
    self.multimask_output_in_sam = multimask_output_in_sam
    self.multimask_min_pt_num = multimask_min_pt_num
    self.multimask_max_pt_num = multimask_max_pt_num
    self.multimask_output_for_tracking = multimask_output_for_tracking
    self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
    self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid

    # Part 4: SAM-style prompt encoder (for both mask and point inputs)
    # and SAM-style mask decoder for the final mask output
    self.image_size = image_size
    self.backbone_stride = backbone_stride
    self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
    self.pred_obj_scores = pred_obj_scores
    self.pred_obj_scores_mlp = pred_obj_scores_mlp
    self.fixed_no_obj_ptr = fixed_no_obj_ptr
    self.soft_no_obj_ptr = soft_no_obj_ptr
    if self.fixed_no_obj_ptr:
        assert self.pred_obj_scores
        assert self.use_obj_ptrs_in_encoder
    if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
        self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
        trunc_normal_(self.no_obj_ptr, std=0.02)
    self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj

    self._build_sam_heads()
    self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
    self.max_cond_frames_in_attn = max_cond_frames_in_attn

    # Model compilation
    if compile_image_encoder:
        # Compile the forward function (not the full module) to allow loading checkpoints.
        print("Image encoder compilation is enabled. First forward pass will be slow.")
        self.image_encoder.forward = torch.compile(
            self.image_encoder.forward,
            mode="max-autotune",
            fullgraph=True,
            dynamic=False,
        )

device property

device

Returns the device on which the model's parameters are stored.

forward

forward(*args, **kwargs)

Processes image and prompt inputs to generate object masks and scores in video sequences.

Source code in ultralytics/models/sam/modules/sam.py
def forward(self, *args, **kwargs):
    """Processes image and prompt inputs to generate object masks and scores in video sequences."""
    raise NotImplementedError(
        "Please use the corresponding methods in SAM2VideoPredictor for inference."
        "See notebooks/video_predictor_example.ipynb for an example."
    )

forward_image

forward_image(img_batch: torch.Tensor)

Processes image batch through encoder to extract multi-level features for SAM model.

Source code in ultralytics/models/sam/modules/sam.py
def forward_image(self, img_batch: torch.Tensor):
    """Processes image batch through encoder to extract multi-level features for SAM model."""
    backbone_out = self.image_encoder(img_batch)
    if self.use_high_res_features_in_sam:
        # precompute projected level 0 and level 1 features in SAM decoder
        # to avoid running it again on every SAM click
        backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
        backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
    return backbone_out

set_imgsz

set_imgsz(imgsz)

Set image size to make model compatible with different image sizes.

Parameters:

Name Type Description Default
imgsz Tuple[int, int]

The size of the input image.

required
Source code in ultralytics/models/sam/modules/sam.py
def set_imgsz(self, imgsz):
    """
    Set image size to make model compatible with different image sizes.

    Args:
        imgsz (Tuple[int, int]): The size of the input image.
    """
    self.image_size = imgsz[0]
    self.sam_prompt_encoder.input_image_size = imgsz
    self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz]  # fixed ViT patch size of 16

track_step

track_step(
    frame_idx,
    is_init_cond_frame,
    current_vision_feats,
    current_vision_pos_embeds,
    feat_sizes,
    point_inputs,
    mask_inputs,
    output_dict,
    num_frames,
    track_in_reverse=False,
    run_mem_encoder=True,
    prev_sam_mask_logits=None,
)

Performs a single tracking step, updating object masks and memory features based on current frame inputs.

Source code in ultralytics/models/sam/modules/sam.py
def track_step(
    self,
    frame_idx,
    is_init_cond_frame,
    current_vision_feats,
    current_vision_pos_embeds,
    feat_sizes,
    point_inputs,
    mask_inputs,
    output_dict,
    num_frames,
    track_in_reverse=False,  # tracking in reverse time order (for demo usage)
    # Whether to run the memory encoder on the predicted masks. Sometimes we might want
    # to skip the memory encoder with `run_mem_encoder=False`. For example,
    # in demo we might call `track_step` multiple times for each user click,
    # and only encode the memory when the user finalizes their clicks. And in ablation
    # settings like SAM training on static images, we don't need the memory encoder.
    run_mem_encoder=True,
    # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
    prev_sam_mask_logits=None,
):
    """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
    current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
    # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
    if len(current_vision_feats) > 1:
        high_res_features = [
            x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
            for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
        ]
    else:
        high_res_features = None
    if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
        # When use_mask_input_as_output_without_sam=True, we directly output the mask input
        # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
        pix_feat = current_vision_feats[-1].permute(1, 2, 0)
        pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
        sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
    else:
        # fused the visual feature with previous memory features in the memory bank
        pix_feat_with_mem = self._prepare_memory_conditioned_features(
            frame_idx=frame_idx,
            is_init_cond_frame=is_init_cond_frame,
            current_vision_feats=current_vision_feats[-1:],
            current_vision_pos_embeds=current_vision_pos_embeds[-1:],
            feat_sizes=feat_sizes[-1:],
            output_dict=output_dict,
            num_frames=num_frames,
            track_in_reverse=track_in_reverse,
        )
        # apply SAM-style segmentation head
        # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
        # e.g. in demo where such logits come from earlier interaction instead of correction sampling
        # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
        if prev_sam_mask_logits is not None:
            assert point_inputs is not None and mask_inputs is None
            mask_inputs = prev_sam_mask_logits
        multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
        sam_outputs = self._forward_sam_heads(
            backbone_features=pix_feat_with_mem,
            point_inputs=point_inputs,
            mask_inputs=mask_inputs,
            high_res_features=high_res_features,
            multimask_output=multimask_output,
        )
    (
        _,
        _,
        _,
        low_res_masks,
        high_res_masks,
        obj_ptr,
        _,
    ) = sam_outputs

    current_out["pred_masks"] = low_res_masks
    current_out["pred_masks_high_res"] = high_res_masks
    current_out["obj_ptr"] = obj_ptr

    # Finally run the memory encoder on the predicted mask to encode
    # it into a new memory feature (that can be used in future frames)
    if run_mem_encoder and self.num_maskmem > 0:
        high_res_masks_for_mem_enc = high_res_masks
        maskmem_features, maskmem_pos_enc = self._encode_new_memory(
            current_vision_feats=current_vision_feats,
            feat_sizes=feat_sizes,
            pred_masks_high_res=high_res_masks_for_mem_enc,
            is_mask_from_pts=(point_inputs is not None),
        )
        current_out["maskmem_features"] = maskmem_features
        current_out["maskmem_pos_enc"] = maskmem_pos_enc
    else:
        current_out["maskmem_features"] = None
        current_out["maskmem_pos_enc"] = None

    return current_out




📅 Created 10 months ago ✏️ Updated 1 month ago