Reference for ultralytics/models/sam/model.py
Improvements
This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/model.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏
class ultralytics.models.sam.model.SAM
SAM(self, model: str = "sam_b.pt") -> None
Bases: Model
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for promptable segmentation with versatility in image analysis. It supports various prompts such as bounding boxes, points, or labels, and features zero-shot performance capabilities.
Args
| Name | Type | Description | Default |
|---|---|---|---|
model | str | Path to the pre-trained SAM model file. File should have a .pt or .pth extension. | "sam_b.pt" |
Attributes
| Name | Type | Description |
|---|---|---|
model | torch.nn.Module | The loaded SAM model. |
is_sam2 | bool | Indicates whether the model is SAM2 variant. |
task | str | The task type, set to "segment" for SAM models. |
Methods
| Name | Description |
|---|---|
task_map | Provide a mapping from the 'segment' task to its corresponding 'Predictor'. |
__call__ | Perform segmentation prediction on the given image or video source. |
_load | Load the specified weights into the SAM model. |
info | Log information about the SAM model. |
predict | Perform segmentation prediction on the given image or video source. |
Examples
>>> sam = SAM("sam_b.pt")
>>> results = sam.predict("image.jpg", points=[[500, 375]])
>>> for r in results:
>>> print(f"Detected {len(r.masks)} masks")
Raises
| Type | Description |
|---|---|
NotImplementedError | If the model file extension is not .pt or .pth. |
Source code in ultralytics/models/sam/model.py
View on GitHubclass SAM(Model):
"""SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for promptable
segmentation with versatility in image analysis. It supports various prompts such as bounding boxes, points, or
labels, and features zero-shot performance capabilities.
Attributes:
model (torch.nn.Module): The loaded SAM model.
is_sam2 (bool): Indicates whether the model is SAM2 variant.
task (str): The task type, set to "segment" for SAM models.
Methods:
predict: Perform segmentation prediction on the given image or video source.
info: Log information about the SAM model.
Examples:
>>> sam = SAM("sam_b.pt")
>>> results = sam.predict("image.jpg", points=[[500, 375]])
>>> for r in results:
>>> print(f"Detected {len(r.masks)} masks")
"""
def __init__(self, model: str = "sam_b.pt") -> None:
"""Initialize the SAM (Segment Anything Model) instance.
Args:
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
Raises:
NotImplementedError: If the model file extension is not .pt or .pth.
"""
if model and Path(model).suffix not in {".pt", ".pth"}:
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
self.is_sam2 = "sam2" in Path(model).stem
super().__init__(model=model, task="segment")
property ultralytics.models.sam.model.SAM.task_map
def task_map(self) -> dict[str, dict[str, type[Predictor]]]
Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
Returns
| Type | Description |
|---|---|
dict[str, dict[str, Type[Predictor]]] | A dictionary mapping the 'segment' task to its corresponding |
Examples
>>> sam = SAM("sam_b.pt")
>>> task_map = sam.task_map
>>> print(task_map)
{'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
Source code in ultralytics/models/sam/model.py
View on GitHub@property
def task_map(self) -> dict[str, dict[str, type[Predictor]]]:
"""Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
Returns:
(dict[str, dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding
Predictor class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
Examples:
>>> sam = SAM("sam_b.pt")
>>> task_map = sam.task_map
>>> print(task_map)
{'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
"""
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
method ultralytics.models.sam.model.SAM.__call__
def __call__(self, source = None, stream: bool = False, bboxes = None, points = None, labels = None, **kwargs)
Perform segmentation prediction on the given image or video source.
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model for segmentation tasks.
Args
| Name | Type | Description | Default |
|---|---|---|---|
source | str | PIL.Image | np.ndarray | None | Path to the image or video file, or a PIL.Image object, or a np.ndarray object. | None |
stream | bool | If True, enables real-time streaming. | False |
bboxes | list[list[float]] | None | List of bounding box coordinates for prompted segmentation. | None |
points | list[list[float]] | None | List of points for prompted segmentation. | None |
labels | list[int] | None | List of labels for prompted segmentation. | None |
**kwargs | Any | Additional keyword arguments to be passed to the predict method. | required |
Returns
| Type | Description |
|---|---|
list | The model predictions, typically containing segmentation masks and other relevant information. |
Examples
>>> sam = SAM("sam_b.pt")
>>> results = sam("image.jpg", points=[[500, 375]])
>>> print(f"Detected {len(results[0].masks)} masks")
Source code in ultralytics/models/sam/model.py
View on GitHubdef __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
"""Perform segmentation prediction on the given image or video source.
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model for
segmentation tasks.
Args:
source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image object, or a
np.ndarray object.
stream (bool): If True, enables real-time streaming.
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
points (list[list[float]] | None): List of points for prompted segmentation.
labels (list[int] | None): List of labels for prompted segmentation.
**kwargs (Any): Additional keyword arguments to be passed to the predict method.
Returns:
(list): The model predictions, typically containing segmentation masks and other relevant information.
Examples:
>>> sam = SAM("sam_b.pt")
>>> results = sam("image.jpg", points=[[500, 375]])
>>> print(f"Detected {len(results[0].masks)} masks")
"""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
method ultralytics.models.sam.model.SAM._load
def _load(self, weights: str, task = None)
Load the specified weights into the SAM model.
Args
| Name | Type | Description | Default |
|---|---|---|---|
weights | str | Path to the weights file. Should be a .pt or .pth file containing the model parameters. | required |
task | str | None | Task name. If provided, it specifies the particular task the model is being loaded for. | None |
Examples
>>> sam = SAM("sam_b.pt")
>>> sam._load("path/to/custom_weights.pt")
Source code in ultralytics/models/sam/model.py
View on GitHubdef _load(self, weights: str, task=None):
"""Load the specified weights into the SAM model.
Args:
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
Examples:
>>> sam = SAM("sam_b.pt")
>>> sam._load("path/to/custom_weights.pt")
"""
from .build import build_sam # slow import
self.model = build_sam(weights)
method ultralytics.models.sam.model.SAM.info
def info(self, detailed: bool = False, verbose: bool = True)
Log information about the SAM model.
Args
| Name | Type | Description | Default |
|---|---|---|---|
detailed | bool | If True, displays detailed information about the model layers and operations. | False |
verbose | bool | If True, prints the information to the console. | True |
Returns
| Type | Description |
|---|---|
tuple | A tuple containing the model's information (string representations of the model). |
Examples
>>> sam = SAM("sam_b.pt")
>>> info = sam.info()
>>> print(info[0]) # Print summary information
Source code in ultralytics/models/sam/model.py
View on GitHubdef info(self, detailed: bool = False, verbose: bool = True):
"""Log information about the SAM model.
Args:
detailed (bool): If True, displays detailed information about the model layers and operations.
verbose (bool): If True, prints the information to the console.
Returns:
(tuple): A tuple containing the model's information (string representations of the model).
Examples:
>>> sam = SAM("sam_b.pt")
>>> info = sam.info()
>>> print(info[0]) # Print summary information
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
method ultralytics.models.sam.model.SAM.predict
def predict(self, source, stream: bool = False, bboxes = None, points = None, labels = None, **kwargs)
Perform segmentation prediction on the given image or video source.
Args
| Name | Type | Description | Default |
|---|---|---|---|
source | str | PIL.Image | np.ndarray | Path to the image or video file, or a PIL.Image object, or a np.ndarray object. | required |
stream | bool | If True, enables real-time streaming. | False |
bboxes | list[list[float]] | None | List of bounding box coordinates for prompted segmentation. | None |
points | list[list[float]] | None | List of points for prompted segmentation. | None |
labels | list[int] | None | List of labels for prompted segmentation. | None |
**kwargs | Any | Additional keyword arguments for prediction. | required |
Returns
| Type | Description |
|---|---|
list | The model predictions. |
Examples
>>> sam = SAM("sam_b.pt")
>>> results = sam.predict("image.jpg", points=[[500, 375]])
>>> for r in results:
... print(f"Detected {len(r.masks)} masks")
Source code in ultralytics/models/sam/model.py
View on GitHubdef predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
"""Perform segmentation prediction on the given image or video source.
Args:
source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or a
np.ndarray object.
stream (bool): If True, enables real-time streaming.
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
points (list[list[float]] | None): List of points for prompted segmentation.
labels (list[int] | None): List of labels for prompted segmentation.
**kwargs (Any): Additional keyword arguments for prediction.
Returns:
(list): The model predictions.
Examples:
>>> sam = SAM("sam_b.pt")
>>> results = sam.predict("image.jpg", points=[[500, 375]])
>>> for r in results:
... print(f"Detected {len(r.masks)} masks")
"""
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
kwargs = {**overrides, **kwargs}
prompts = dict(bboxes=bboxes, points=points, labels=labels)
return super().predict(source, stream, prompts=prompts, **kwargs)