Skip to content

Reference for ultralytics/models/sam/model.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/model.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.sam.model.SAM

SAM(model='sam_b.pt')

Bases: Model

SAM (Segment Anything Model) interface class for real-time image segmentation tasks.

This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for promptable segmentation with versatility in image analysis. It supports various prompts such as bounding boxes, points, or labels, and features zero-shot performance capabilities.

Attributes:

NameTypeDescription
modelModule

The loaded SAM model.

is_sam2bool

Indicates whether the model is SAM2 variant.

taskstr

The task type, set to "segment" for SAM models.

Methods:

NameDescription
predict

Performs segmentation prediction on the given image or video source.

info

Logs information about the SAM model.

Examples:

>>> sam = SAM("sam_b.pt")
>>> results = sam.predict("image.jpg", points=[[500, 375]])
>>> for r in results:
>>>     print(f"Detected {len(r.masks)} masks")

Parameters:

NameTypeDescriptionDefault
modelstr

Path to the pre-trained SAM model file. File should have a .pt or .pth extension.

'sam_b.pt'

Raises:

TypeDescription
NotImplementedError

If the model file extension is not .pt or .pth.

Examples:

>>> sam = SAM("sam_b.pt")
>>> print(sam.is_sam2)
Source code in ultralytics/models/sam/model.py
def __init__(self, model="sam_b.pt") -> None:
    """
    Initializes the SAM (Segment Anything Model) instance.

    Args:
        model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.

    Raises:
        NotImplementedError: If the model file extension is not .pt or .pth.

    Examples:
        >>> sam = SAM("sam_b.pt")
        >>> print(sam.is_sam2)
    """
    if model and Path(model).suffix not in {".pt", ".pth"}:
        raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
    self.is_sam2 = "sam2" in Path(model).stem
    super().__init__(model=model, task="segment")

task_map property

task_map

Provides a mapping from the 'segment' task to its corresponding 'Predictor'.

Returns:

TypeDescription
Dict[str, Type[Predictor]]

A dictionary mapping the 'segment' task to its corresponding Predictor class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.

Examples:

>>> sam = SAM("sam_b.pt")
>>> task_map = sam.task_map
>>> print(task_map)
{'segment': <class 'ultralytics.models.sam.predict.Predictor'>}

__call__

__call__(
    source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs
)

Performs segmentation prediction on the given image or video source.

This method is an alias for the 'predict' method, providing a convenient way to call the SAM model for segmentation tasks.

Parameters:

NameTypeDescriptionDefault
sourcestr | Image | ndarray | None

Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.

None
streambool

If True, enables real-time streaming.

False
bboxesList[List[float]] | None

List of bounding box coordinates for prompted segmentation.

None
pointsList[List[float]] | None

List of points for prompted segmentation.

None
labelsList[int] | None

List of labels for prompted segmentation.

None
**kwargsAny

Additional keyword arguments to be passed to the predict method.

{}

Returns:

TypeDescription
List

The model predictions, typically containing segmentation masks and other relevant information.

Examples:

>>> sam = SAM("sam_b.pt")
>>> results = sam("image.jpg", points=[[500, 375]])
>>> print(f"Detected {len(results[0].masks)} masks")
Source code in ultralytics/models/sam/model.py
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
    """
    Performs segmentation prediction on the given image or video source.

    This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
    for segmentation tasks.

    Args:
        source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
            object, or a numpy.ndarray object.
        stream (bool): If True, enables real-time streaming.
        bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
        points (List[List[float]] | None): List of points for prompted segmentation.
        labels (List[int] | None): List of labels for prompted segmentation.
        **kwargs (Any): Additional keyword arguments to be passed to the predict method.

    Returns:
        (List): The model predictions, typically containing segmentation masks and other relevant information.

    Examples:
        >>> sam = SAM("sam_b.pt")
        >>> results = sam("image.jpg", points=[[500, 375]])
        >>> print(f"Detected {len(results[0].masks)} masks")
    """
    return self.predict(source, stream, bboxes, points, labels, **kwargs)

info

info(detailed=False, verbose=True)

Logs information about the SAM model.

This method provides details about the Segment Anything Model (SAM), including its architecture, parameters, and computational requirements.

Parameters:

NameTypeDescriptionDefault
detailedbool

If True, displays detailed information about the model layers and operations.

False
verbosebool

If True, prints the information to the console.

True

Returns:

TypeDescription
Tuple

A tuple containing the model's information (string representations of the model).

Examples:

>>> sam = SAM("sam_b.pt")
>>> info = sam.info()
>>> print(info[0])  # Print summary information
Source code in ultralytics/models/sam/model.py
def info(self, detailed=False, verbose=True):
    """
    Logs information about the SAM model.

    This method provides details about the Segment Anything Model (SAM), including its architecture,
    parameters, and computational requirements.

    Args:
        detailed (bool): If True, displays detailed information about the model layers and operations.
        verbose (bool): If True, prints the information to the console.

    Returns:
        (Tuple): A tuple containing the model's information (string representations of the model).

    Examples:
        >>> sam = SAM("sam_b.pt")
        >>> info = sam.info()
        >>> print(info[0])  # Print summary information
    """
    return model_info(self.model, detailed=detailed, verbose=verbose)

predict

predict(source, stream=False, bboxes=None, points=None, labels=None, **kwargs)

Performs segmentation prediction on the given image or video source.

Parameters:

NameTypeDescriptionDefault
sourcestr | Image | ndarray

Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.

required
streambool

If True, enables real-time streaming.

False
bboxesList[List[float]] | None

List of bounding box coordinates for prompted segmentation.

None
pointsList[List[float]] | None

List of points for prompted segmentation.

None
labelsList[int] | None

List of labels for prompted segmentation.

None
**kwargsAny

Additional keyword arguments for prediction.

{}

Returns:

TypeDescription
List

The model predictions.

Examples:

>>> sam = SAM("sam_b.pt")
>>> results = sam.predict("image.jpg", points=[[500, 375]])
>>> for r in results:
...     print(f"Detected {len(r.masks)} masks")
Source code in ultralytics/models/sam/model.py
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
    """
    Performs segmentation prediction on the given image or video source.

    Args:
        source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
            a numpy.ndarray object.
        stream (bool): If True, enables real-time streaming.
        bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
        points (List[List[float]] | None): List of points for prompted segmentation.
        labels (List[int] | None): List of labels for prompted segmentation.
        **kwargs (Any): Additional keyword arguments for prediction.

    Returns:
        (List): The model predictions.

    Examples:
        >>> sam = SAM("sam_b.pt")
        >>> results = sam.predict("image.jpg", points=[[500, 375]])
        >>> for r in results:
        ...     print(f"Detected {len(r.masks)} masks")
    """
    overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
    kwargs = {**overrides, **kwargs}
    prompts = dict(bboxes=bboxes, points=points, labels=labels)
    return super().predict(source, stream, prompts=prompts, **kwargs)



📅 Created 11 months ago ✏️ Updated 1 month ago