跳至内容

参考资料 ultralytics/models/sam/model.py

备注

该文件可在https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/ sam/model .py。如果您发现问题,请通过提交 Pull Request🛠️ 帮助修复。谢谢🙏!



ultralytics.models.sam.model.SAM

垒球 Model

SAM (分段 Anything Model)接口类。

SAM 是专为可提示的实时图像分割而设计的。它可与各种提示一起使用,如 边框、点或标签。该模型具有零镜头性能,并在 SA-1B 数据集进行训练。

源代码 ultralytics/models/sam/model.py
class SAM(Model):
    """
    SAM (Segment Anything Model) interface class.

    SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
    bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
    dataset.
    """

    def __init__(self, model="sam_b.pt") -> None:
        """
        Initializes the SAM model with a pre-trained model file.

        Args:
            model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.

        Raises:
            NotImplementedError: If the model file extension is not .pt or .pth.
        """
        if model and Path(model).suffix not in (".pt", ".pth"):
            raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
        super().__init__(model=model, task="segment")

    def _load(self, weights: str, task=None):
        """
        Loads the specified weights into the SAM model.

        Args:
            weights (str): Path to the weights file.
            task (str, optional): Task name. Defaults to None.
        """
        self.model = build_sam(weights)

    def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
        """
        Performs segmentation prediction on the given image or video source.

        Args:
            source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
            stream (bool, optional): If True, enables real-time streaming. Defaults to False.
            bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
            points (list, optional): List of points for prompted segmentation. Defaults to None.
            labels (list, optional): List of labels for prompted segmentation. Defaults to None.

        Returns:
            (list): The model predictions.
        """
        overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
        kwargs.update(overrides)
        prompts = dict(bboxes=bboxes, points=points, labels=labels)
        return super().predict(source, stream, prompts=prompts, **kwargs)

    def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
        """
        Alias for the 'predict' method.

        Args:
            source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
            stream (bool, optional): If True, enables real-time streaming. Defaults to False.
            bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
            points (list, optional): List of points for prompted segmentation. Defaults to None.
            labels (list, optional): List of labels for prompted segmentation. Defaults to None.

        Returns:
            (list): The model predictions.
        """
        return self.predict(source, stream, bboxes, points, labels, **kwargs)

    def info(self, detailed=False, verbose=True):
        """
        Logs information about the SAM model.

        Args:
            detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
            verbose (bool, optional): If True, displays information on the console. Defaults to True.

        Returns:
            (tuple): A tuple containing the model's information.
        """
        return model_info(self.model, detailed=detailed, verbose=verbose)

    @property
    def task_map(self):
        """
        Provides a mapping from the 'segment' task to its corresponding 'Predictor'.

        Returns:
            (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
        """
        return {"segment": {"predictor": Predictor}}

task_map property

提供从 "分段 "任务到相应 "预测器 "的映射。

返回:

类型 说明
dict

将 "分段 "任务映射到相应 "预测器 "的字典。

__call__(source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs)

预测 "方法的别名。

参数

名称 类型 说明 默认值
source str

图像或视频文件的路径,或 PIL.Image 对象,或 numpy.ndarray 对象。

None
stream bool

如果为 True,则启用实时流媒体。默认为 "假"。

False
bboxes list

提示分割的边界框坐标列表。默认为 "无"。

None
points list

提示分割的点列表。默认为 "无"。

None
labels list

提示分割的标签列表。默认为 "无"。

None

返回:

类型 说明
list

模型预测

源代码 ultralytics/models/sam/model.py
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
    """
    Alias for the 'predict' method.

    Args:
        source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
        stream (bool, optional): If True, enables real-time streaming. Defaults to False.
        bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
        points (list, optional): List of points for prompted segmentation. Defaults to None.
        labels (list, optional): List of labels for prompted segmentation. Defaults to None.

    Returns:
        (list): The model predictions.
    """
    return self.predict(source, stream, bboxes, points, labels, **kwargs)

__init__(model='sam_b.pt')

使用预训练模型文件初始化SAM 模型。

参数

名称 类型 说明 默认值
model str

预训练SAM 模型文件的路径。文件扩展名应为 .pt 或 .pth。

'sam_b.pt'

加薪:

类型 说明
NotImplementedError

如果模型文件扩展名不是 .pt 或 .pth。

源代码 ultralytics/models/sam/model.py
def __init__(self, model="sam_b.pt") -> None:
    """
    Initializes the SAM model with a pre-trained model file.

    Args:
        model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.

    Raises:
        NotImplementedError: If the model file extension is not .pt or .pth.
    """
    if model and Path(model).suffix not in (".pt", ".pth"):
        raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
    super().__init__(model=model, task="segment")

info(detailed=False, verbose=True)

记录有关SAM 型号的信息。

参数

名称 类型 说明 默认值
detailed bool

如果为 True,则显示模型的详细信息。默认为 "假"。

False
verbose bool

如果为 True,则在控制台上显示信息。默认为 True。

True

返回:

类型 说明
tuple

包含模型信息的元组。

源代码 ultralytics/models/sam/model.py
def info(self, detailed=False, verbose=True):
    """
    Logs information about the SAM model.

    Args:
        detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
        verbose (bool, optional): If True, displays information on the console. Defaults to True.

    Returns:
        (tuple): A tuple containing the model's information.
    """
    return model_info(self.model, detailed=detailed, verbose=verbose)

predict(source, stream=False, bboxes=None, points=None, labels=None, **kwargs)

对给定的图像或视频源执行分割预测。

参数

名称 类型 说明 默认值
source str

图像或视频文件的路径,或 PIL.Image 对象,或 numpy.ndarray 对象。

所需
stream bool

如果为 True,则启用实时流媒体。默认为 "假"。

False
bboxes list

提示分割的边界框坐标列表。默认为 "无"。

None
points list

提示分割的点列表。默认为 "无"。

None
labels list

提示分割的标签列表。默认为 "无"。

None

返回:

类型 说明
list

模型预测

源代码 ultralytics/models/sam/model.py
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
    """
    Performs segmentation prediction on the given image or video source.

    Args:
        source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
        stream (bool, optional): If True, enables real-time streaming. Defaults to False.
        bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
        points (list, optional): List of points for prompted segmentation. Defaults to None.
        labels (list, optional): List of labels for prompted segmentation. Defaults to None.

    Returns:
        (list): The model predictions.
    """
    overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
    kwargs.update(overrides)
    prompts = dict(bboxes=bboxes, points=points, labels=labels)
    return super().predict(source, stream, prompts=prompts, **kwargs)





创建于 2023-11-12,更新于 2023-11-25
作者:glenn-jocher(3)