Skip to content

Reference for ultralytics/models/sam/build.py

Note

Full source code for this file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/build.py. Help us fix any issues you see by submitting a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.sam.build.build_sam_vit_h(checkpoint=None)

Build and return a Segment Anything Model (SAM) h-size model.

Source code in ultralytics/models/sam/build.py
def build_sam_vit_h(checkpoint=None):
    """Build and return a Segment Anything Model (SAM) h-size model."""
    return _build_sam(
        encoder_embed_dim=1280,
        encoder_depth=32,
        encoder_num_heads=16,
        encoder_global_attn_indexes=[7, 15, 23, 31],
        checkpoint=checkpoint,
    )




ultralytics.models.sam.build.build_sam_vit_l(checkpoint=None)

Build and return a Segment Anything Model (SAM) l-size model.

Source code in ultralytics/models/sam/build.py
def build_sam_vit_l(checkpoint=None):
    """Build and return a Segment Anything Model (SAM) l-size model."""
    return _build_sam(
        encoder_embed_dim=1024,
        encoder_depth=24,
        encoder_num_heads=16,
        encoder_global_attn_indexes=[5, 11, 17, 23],
        checkpoint=checkpoint,
    )




ultralytics.models.sam.build.build_sam_vit_b(checkpoint=None)

Build and return a Segment Anything Model (SAM) b-size model.

Source code in ultralytics/models/sam/build.py
def build_sam_vit_b(checkpoint=None):
    """Build and return a Segment Anything Model (SAM) b-size model."""
    return _build_sam(
        encoder_embed_dim=768,
        encoder_depth=12,
        encoder_num_heads=12,
        encoder_global_attn_indexes=[2, 5, 8, 11],
        checkpoint=checkpoint,
    )




ultralytics.models.sam.build.build_mobile_sam(checkpoint=None)

Build and return Mobile Segment Anything Model (Mobile-SAM).

Source code in ultralytics/models/sam/build.py
def build_mobile_sam(checkpoint=None):
    """Build and return Mobile Segment Anything Model (Mobile-SAM)."""
    return _build_sam(
        encoder_embed_dim=[64, 128, 160, 320],
        encoder_depth=[2, 2, 6, 2],
        encoder_num_heads=[2, 4, 5, 10],
        encoder_global_attn_indexes=None,
        mobile_sam=True,
        checkpoint=checkpoint,
    )




ultralytics.models.sam.build._build_sam(encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False)

Builds the selected SAM model architecture.

Source code in ultralytics/models/sam/build.py
def _build_sam(encoder_embed_dim,
               encoder_depth,
               encoder_num_heads,
               encoder_global_attn_indexes,
               checkpoint=None,
               mobile_sam=False):
    """Builds the selected SAM model architecture."""
    prompt_embed_dim = 256
    image_size = 1024
    vit_patch_size = 16
    image_embedding_size = image_size // vit_patch_size
    image_encoder = (TinyViT(
        img_size=1024,
        in_chans=3,
        num_classes=1000,
        embed_dims=encoder_embed_dim,
        depths=encoder_depth,
        num_heads=encoder_num_heads,
        window_sizes=[7, 7, 14, 7],
        mlp_ratio=4.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        use_checkpoint=False,
        mbconv_expand_ratio=4.0,
        local_conv_size=3,
        layer_lr_decay=0.8,
    ) if mobile_sam else ImageEncoderViT(
        depth=encoder_depth,
        embed_dim=encoder_embed_dim,
        img_size=image_size,
        mlp_ratio=4,
        norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
        num_heads=encoder_num_heads,
        patch_size=vit_patch_size,
        qkv_bias=True,
        use_rel_pos=True,
        global_attn_indexes=encoder_global_attn_indexes,
        window_size=14,
        out_chans=prompt_embed_dim,
    ))
    sam = Sam(
        image_encoder=image_encoder,
        prompt_encoder=PromptEncoder(
            embed_dim=prompt_embed_dim,
            image_embedding_size=(image_embedding_size, image_embedding_size),
            input_image_size=(image_size, image_size),
            mask_in_chans=16,
        ),
        mask_decoder=MaskDecoder(
            num_multimask_outputs=3,
            transformer=TwoWayTransformer(
                depth=2,
                embedding_dim=prompt_embed_dim,
                mlp_dim=2048,
                num_heads=8,
            ),
            transformer_dim=prompt_embed_dim,
            iou_head_depth=3,
            iou_head_hidden_dim=256,
        ),
        pixel_mean=[123.675, 116.28, 103.53],
        pixel_std=[58.395, 57.12, 57.375],
    )
    if checkpoint is not None:
        checkpoint = attempt_download_asset(checkpoint)
        with open(checkpoint, 'rb') as f:
            state_dict = torch.load(f)
        sam.load_state_dict(state_dict)
    sam.eval()
    # sam.load_state_dict(torch.load(checkpoint), strict=True)
    # sam.eval()
    return sam




ultralytics.models.sam.build.build_sam(ckpt='sam_b.pt')

Build a SAM model specified by ckpt.

Source code in ultralytics/models/sam/build.py
def build_sam(ckpt='sam_b.pt'):
    """Build a SAM model specified by ckpt."""
    model_builder = None
    for k in sam_model_map.keys():
        if ckpt.endswith(k):
            model_builder = sam_model_map.get(k)

    if not model_builder:
        raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}')

    return model_builder(ckpt)




Created 2023-07-16, Updated 2023-08-07
Authors: glenn-jocher (5), Laughing-q (1)