Skip to content

Reference for ultralytics/models/sam/build_sam3.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/build_sam3.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


function ultralytics.models.sam.build_sam3._create_vision_backbone

def _create_vision_backbone(compile_mode = None, enable_inst_interactivity = True) -> Sam3DualViTDetNeck

Create SAM3 visual backbone with ViT and neck.

Args

NameTypeDescriptionDefault
compile_modeNone
enable_inst_interactivityTrue
Source code in ultralytics/models/sam/build_sam3.pyView on GitHub
def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck:
    """Create SAM3 visual backbone with ViT and neck."""
    # Position encoding
    position_encoding = PositionEmbeddingSine(
        num_pos_feats=256,
        normalize=True,
        scale=None,
        temperature=10000,
    )

    # ViT backbone
    vit_backbone = ViT(
        img_size=1008,
        pretrain_img_size=336,
        patch_size=14,
        embed_dim=1024,
        depth=32,
        num_heads=16,
        mlp_ratio=4.625,
        norm_layer="LayerNorm",
        drop_path_rate=0.1,
        qkv_bias=True,
        use_abs_pos=True,
        tile_abs_pos=True,
        global_att_blocks=(7, 15, 23, 31),
        rel_pos_blocks=(),
        use_rope=True,
        use_interp_rope=True,
        window_size=24,
        pretrain_use_cls_token=True,
        retain_cls_token=False,
        ln_pre=True,
        ln_post=False,
        return_interm_layers=False,
        bias_patch_embed=False,
        compile_mode=compile_mode,
    )
    return Sam3DualViTDetNeck(
        position_encoding=position_encoding,
        d_model=256,
        scale_factors=[4.0, 2.0, 1.0, 0.5],
        trunk=vit_backbone,
        add_sam2_neck=enable_inst_interactivity,
    )





function ultralytics.models.sam.build_sam3._create_sam3_transformer

def _create_sam3_transformer() -> TransformerWrapper

Create SAM3 detector encoder and decoder.

Source code in ultralytics/models/sam/build_sam3.pyView on GitHub
def _create_sam3_transformer() -> TransformerWrapper:
    """Create SAM3 detector encoder and decoder."""
    encoder: TransformerEncoderFusion = TransformerEncoderFusion(
        layer=TransformerEncoderLayer(
            d_model=256,
            dim_feedforward=2048,
            dropout=0.1,
            pos_enc_at_attn=True,
            pos_enc_at_cross_attn_keys=False,
            pos_enc_at_cross_attn_queries=False,
            pre_norm=True,
            self_attention=nn.MultiheadAttention(
                num_heads=8,
                dropout=0.1,
                embed_dim=256,
                batch_first=True,
            ),
            cross_attention=nn.MultiheadAttention(
                num_heads=8,
                dropout=0.1,
                embed_dim=256,
                batch_first=True,
            ),
        ),
        num_layers=6,
        d_model=256,
        num_feature_levels=1,
        frozen=False,
        use_act_checkpoint=True,
        add_pooled_text_to_img_feat=False,
        pool_text_with_mask=True,
    )
    decoder: TransformerDecoder = TransformerDecoder(
        layer=TransformerDecoderLayer(
            d_model=256,
            dim_feedforward=2048,
            dropout=0.1,
            cross_attention=nn.MultiheadAttention(
                num_heads=8,
                dropout=0.1,
                embed_dim=256,
            ),
            n_heads=8,
            use_text_cross_attention=True,
        ),
        num_layers=6,
        num_queries=200,
        return_intermediate=True,
        box_refine=True,
        num_o2m_queries=0,
        dac=True,
        boxRPB="log",
        d_model=256,
        frozen=False,
        interaction_layer=None,
        dac_use_selfatt_ln=True,
        use_act_checkpoint=True,
        presence_token=True,
    )

    return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)





function ultralytics.models.sam.build_sam3.build_sam3_image_model

def build_sam3_image_model(checkpoint_path: str, bpe_path: str, enable_segmentation: bool = True, compile: bool = False)

Build SAM3 image model.

Args

NameTypeDescriptionDefault
checkpoint_pathstrOptional path to model checkpointrequired
bpe_pathstrPath to the BPE tokenizer vocabularyrequired
enable_segmentationboolWhether to enable segmentation headTrue
compileboolTo enable compilation, set to "default"False

Returns

TypeDescription
A SAM3 image model
Source code in ultralytics/models/sam/build_sam3.pyView on GitHub
def build_sam3_image_model(
    checkpoint_path: str, bpe_path: str, enable_segmentation: bool = True, compile: bool = False
):
    """Build SAM3 image model.

    Args:
        checkpoint_path: Optional path to model checkpoint
        bpe_path: Path to the BPE tokenizer vocabulary
        enable_segmentation: Whether to enable segmentation head
        compile: To enable compilation, set to "default"

    Returns:
        A SAM3 image model
    """
    # Create visual components
    compile_mode = "default" if compile else None
    vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)

    # Create text components
    text_encoder = VETextEncoder(
        tokenizer=SimpleTokenizer(bpe_path=bpe_path),
        d_model=256,
        width=1024,
        heads=16,
        layers=24,
    )

    # Create visual-language backbone
    backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1)

    # Create transformer components
    transformer = _create_sam3_transformer()

    # Create dot product scoring
    dot_prod_scoring = DotProductScoring(
        d_model=256,
        d_proj=256,
        prompt_mlp=MLP(
            input_dim=256,
            hidden_dim=2048,
            output_dim=256,
            num_layers=2,
            residual=True,
            out_norm=nn.LayerNorm(256),
        ),
    )

    # Create segmentation head if enabled
    segmentation_head = (
        UniversalSegmentationHead(
            hidden_dim=256,
            upsampling_stages=3,
            aux_masks=False,
            presence_head=False,
            dot_product_scorer=None,
            act_ckpt=True,
            cross_attend_prompt=nn.MultiheadAttention(
                num_heads=8,
                dropout=0,
                embed_dim=256,
            ),
            pixel_decoder=PixelDecoder(
                num_upsampling_stages=3,
                interpolation_mode="nearest",
                hidden_dim=256,
                compile_mode=compile_mode,
            ),
        )
        if enable_segmentation
        else None
    )

    # Create geometry encoder
    input_geometry_encoder = SequenceGeometryEncoder(
        pos_enc=PositionEmbeddingSine(
            num_pos_feats=256,
            normalize=True,
            scale=None,
            temperature=10000,
        ),
        encode_boxes_as_points=False,
        boxes_direct_project=True,
        boxes_pool=True,
        boxes_pos_enc=True,
        d_model=256,
        num_layers=3,
        layer=TransformerEncoderLayer(
            d_model=256,
            dim_feedforward=2048,
            dropout=0.1,
            pos_enc_at_attn=False,
            pre_norm=True,
            pos_enc_at_cross_attn_queries=False,
            pos_enc_at_cross_attn_keys=True,
        ),
        use_act_ckpt=True,
        add_cls=True,
        add_post_encode_proj=True,
    )

    # Create the SAM3SemanticModel model
    model = SAM3SemanticModel(
        backbone=backbone,
        transformer=transformer,
        input_geometry_encoder=input_geometry_encoder,
        segmentation_head=segmentation_head,
        num_feature_levels=1,
        o2m_mask_predict=True,
        dot_prod_scoring=dot_prod_scoring,
        use_instance_query=False,
        multimask_output=True,
    )

    # Load checkpoint
    model = _load_checkpoint(model, checkpoint_path)
    model.eval()
    return model





function ultralytics.models.sam.build_sam3.build_interactive_sam3

def build_interactive_sam3(checkpoint_path: str, compile = None, with_backbone = True) -> SAM3Model

Build the SAM3 Tracker module for video tracking.

Args

NameTypeDescriptionDefault
checkpoint_pathstrrequired
compileNone
with_backboneTrue

Returns

TypeDescription
Sam3TrackerPredictorWrapped SAM3 Tracker module
Source code in ultralytics/models/sam/build_sam3.pyView on GitHub
def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model:
    """Build the SAM3 Tracker module for video tracking.

    Returns:
        Sam3TrackerPredictor: Wrapped SAM3 Tracker module
    """
    # Create model components
    memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152])
    memory_attention = MemoryAttention(
        batch_first=True,
        d_model=256,
        pos_enc_at_input=True,
        layer=MemoryAttentionLayer(
            dim_feedforward=2048,
            dropout=0.1,
            pos_enc_at_attn=False,
            pos_enc_at_cross_attn_keys=True,
            pos_enc_at_cross_attn_queries=False,
            self_attn=RoPEAttention(
                embedding_dim=256,
                num_heads=1,
                downsample_rate=1,
                rope_theta=10000.0,
                feat_sizes=[72, 72],
            ),
            d_model=256,
            cross_attn=RoPEAttention(
                embedding_dim=256,
                num_heads=1,
                downsample_rate=1,
                kv_in_dim=64,
                rope_theta=10000.0,
                feat_sizes=[72, 72],
                rope_k_repeat=True,
            ),
        ),
        num_layers=4,
    )

    backbone = (
        SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None)
        if with_backbone
        else None
    )
    model = SAM3Model(
        image_size=1008,
        image_encoder=backbone,
        memory_attention=memory_attention,
        memory_encoder=memory_encoder,
        backbone_stride=14,
        num_maskmem=7,
        sigmoid_scale_for_mem_enc=20.0,
        sigmoid_bias_for_mem_enc=-10.0,
        use_mask_input_as_output_without_sam=True,
        directly_add_no_mem_embed=True,
        use_high_res_features_in_sam=True,
        multimask_output_in_sam=True,
        iou_prediction_use_sigmoid=True,
        use_obj_ptrs_in_encoder=True,
        add_tpos_enc_to_obj_ptrs=True,
        only_obj_ptrs_in_the_past_for_eval=True,
        pred_obj_scores=True,
        pred_obj_scores_mlp=True,
        fixed_no_obj_ptr=True,
        multimask_output_for_tracking=True,
        use_multimask_token_for_obj_ptr=True,
        multimask_min_pt_num=0,
        multimask_max_pt_num=1,
        use_mlp_for_obj_ptr_proj=True,
        compile_image_encoder=False,
        no_obj_embed_spatial=True,
        proj_tpos_enc_in_obj_ptrs=True,
        use_signed_tpos_enc_to_obj_ptrs=True,
        sam_mask_decoder_extra_args=dict(
            dynamic_multimask_via_stability=True,
            dynamic_multimask_stability_delta=0.05,
            dynamic_multimask_stability_thresh=0.98,
        ),
    )

    # Load checkpoint if provided
    model = _load_checkpoint(model, checkpoint_path, interactive=True)

    # Setup device and mode
    model.eval()
    return model





function ultralytics.models.sam.build_sam3._load_checkpoint

def _load_checkpoint(model, checkpoint, interactive = False)

Load SAM3 model checkpoint from file.

Args

NameTypeDescriptionDefault
modelrequired
checkpointrequired
interactiveFalse
Source code in ultralytics/models/sam/build_sam3.pyView on GitHub
def _load_checkpoint(model, checkpoint, interactive=False):
    """Load SAM3 model checkpoint from file."""
    with open(checkpoint, "rb") as f:
        ckpt = torch_load(f)
    if "model" in ckpt and isinstance(ckpt["model"], dict):
        ckpt = ckpt["model"]
    sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k}
    if interactive:
        sam3_image_ckpt.update(
            {
                k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v
                for k, v in sam3_image_ckpt.items()
                if "backbone.vision_backbone" in k
            }
        )
        sam3_image_ckpt.update(
            {
                k.replace("tracker.transformer.encoder", "memory_attention"): v
                for k, v in ckpt.items()
                if "tracker.transformer" in k
            }
        )
        sam3_image_ckpt.update(
            {
                k.replace("tracker.maskmem_backbone", "memory_encoder"): v
                for k, v in ckpt.items()
                if "tracker.maskmem_backbone" in k
            }
        )
        sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k})
    model.load_state_dict(sam3_image_ckpt, strict=False)
    return model





📅 Created 0 days ago ✏️ Updated 0 days ago
Laughing-q