Skip to content

Reference for ultralytics/models/sam/model.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/model.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!



ultralytics.models.sam.model.SAM

Bases: Model

SAM (Segment Anything Model) interface class.

SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B dataset.

Source code in ultralytics/models/sam/model.py
class SAM(Model):
    """
    SAM (Segment Anything Model) interface class.

    SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
    bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
    dataset.
    """

    def __init__(self, model="sam_b.pt") -> None:
        """
        Initializes the SAM model with a pre-trained model file.

        Args:
            model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.

        Raises:
            NotImplementedError: If the model file extension is not .pt or .pth.
        """
        if model and Path(model).suffix not in {".pt", ".pth"}:
            raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
        super().__init__(model=model, task="segment")

    def _load(self, weights: str, task=None):
        """
        Loads the specified weights into the SAM model.

        Args:
            weights (str): Path to the weights file.
            task (str, optional): Task name. Defaults to None.
        """
        self.model = build_sam(weights)

    def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
        """
        Performs segmentation prediction on the given image or video source.

        Args:
            source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
            stream (bool, optional): If True, enables real-time streaming. Defaults to False.
            bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
            points (list, optional): List of points for prompted segmentation. Defaults to None.
            labels (list, optional): List of labels for prompted segmentation. Defaults to None.

        Returns:
            (list): The model predictions.
        """
        overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
        kwargs.update(overrides)
        prompts = dict(bboxes=bboxes, points=points, labels=labels)
        return super().predict(source, stream, prompts=prompts, **kwargs)

    def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
        """
        Alias for the 'predict' method.

        Args:
            source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
            stream (bool, optional): If True, enables real-time streaming. Defaults to False.
            bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
            points (list, optional): List of points for prompted segmentation. Defaults to None.
            labels (list, optional): List of labels for prompted segmentation. Defaults to None.

        Returns:
            (list): The model predictions.
        """
        return self.predict(source, stream, bboxes, points, labels, **kwargs)

    def info(self, detailed=False, verbose=True):
        """
        Logs information about the SAM model.

        Args:
            detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
            verbose (bool, optional): If True, displays information on the console. Defaults to True.

        Returns:
            (tuple): A tuple containing the model's information.
        """
        return model_info(self.model, detailed=detailed, verbose=verbose)

    @property
    def task_map(self):
        """
        Provides a mapping from the 'segment' task to its corresponding 'Predictor'.

        Returns:
            (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
        """
        return {"segment": {"predictor": Predictor}}

task_map property

Provides a mapping from the 'segment' task to its corresponding 'Predictor'.

Returns:

Type Description
dict

A dictionary mapping the 'segment' task to its corresponding 'Predictor'.

__call__(source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs)

Alias for the 'predict' method.

Parameters:

Name Type Description Default
source str

Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.

None
stream bool

If True, enables real-time streaming. Defaults to False.

False
bboxes list

List of bounding box coordinates for prompted segmentation. Defaults to None.

None
points list

List of points for prompted segmentation. Defaults to None.

None
labels list

List of labels for prompted segmentation. Defaults to None.

None

Returns:

Type Description
list

The model predictions.

Source code in ultralytics/models/sam/model.py
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
    """
    Alias for the 'predict' method.

    Args:
        source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
        stream (bool, optional): If True, enables real-time streaming. Defaults to False.
        bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
        points (list, optional): List of points for prompted segmentation. Defaults to None.
        labels (list, optional): List of labels for prompted segmentation. Defaults to None.

    Returns:
        (list): The model predictions.
    """
    return self.predict(source, stream, bboxes, points, labels, **kwargs)

__init__(model='sam_b.pt')

Initializes the SAM model with a pre-trained model file.

Parameters:

Name Type Description Default
model str

Path to the pre-trained SAM model file. File should have a .pt or .pth extension.

'sam_b.pt'

Raises:

Type Description
NotImplementedError

If the model file extension is not .pt or .pth.

Source code in ultralytics/models/sam/model.py
def __init__(self, model="sam_b.pt") -> None:
    """
    Initializes the SAM model with a pre-trained model file.

    Args:
        model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.

    Raises:
        NotImplementedError: If the model file extension is not .pt or .pth.
    """
    if model and Path(model).suffix not in {".pt", ".pth"}:
        raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
    super().__init__(model=model, task="segment")

info(detailed=False, verbose=True)

Logs information about the SAM model.

Parameters:

Name Type Description Default
detailed bool

If True, displays detailed information about the model. Defaults to False.

False
verbose bool

If True, displays information on the console. Defaults to True.

True

Returns:

Type Description
tuple

A tuple containing the model's information.

Source code in ultralytics/models/sam/model.py
def info(self, detailed=False, verbose=True):
    """
    Logs information about the SAM model.

    Args:
        detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
        verbose (bool, optional): If True, displays information on the console. Defaults to True.

    Returns:
        (tuple): A tuple containing the model's information.
    """
    return model_info(self.model, detailed=detailed, verbose=verbose)

predict(source, stream=False, bboxes=None, points=None, labels=None, **kwargs)

Performs segmentation prediction on the given image or video source.

Parameters:

Name Type Description Default
source str

Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.

required
stream bool

If True, enables real-time streaming. Defaults to False.

False
bboxes list

List of bounding box coordinates for prompted segmentation. Defaults to None.

None
points list

List of points for prompted segmentation. Defaults to None.

None
labels list

List of labels for prompted segmentation. Defaults to None.

None

Returns:

Type Description
list

The model predictions.

Source code in ultralytics/models/sam/model.py
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
    """
    Performs segmentation prediction on the given image or video source.

    Args:
        source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
        stream (bool, optional): If True, enables real-time streaming. Defaults to False.
        bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
        points (list, optional): List of points for prompted segmentation. Defaults to None.
        labels (list, optional): List of labels for prompted segmentation. Defaults to None.

    Returns:
        (list): The model predictions.
    """
    overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
    kwargs.update(overrides)
    prompts = dict(bboxes=bboxes, points=points, labels=labels)
    return super().predict(source, stream, prompts=prompts, **kwargs)





Created 2023-11-12, Updated 2024-05-08
Authors: Burhan-Q (1), glenn-jocher (3)