跳至内容

参考资料 ultralytics/models/fastsam/predict.py

备注

该文件可在https://github.com/ultralytics/ultralytics/blob/main/ ultralytics/models/ fastsam/predict .py。如果您发现问题,请通过提交 Pull Request🛠️ 帮助修复。谢谢🙏!



ultralytics.models.fastsam.predict.FastSAMPredictor

垒球 DetectionPredictor

FastSAMPredictor 专门用于Ultralytics YOLO 框架中的快速SAM (Segment Anything Model)分割预测任务。

该类扩展了 DetectionPredictor,专门为快速SAM 定制了预测管道。 它调整了后处理步骤,纳入了掩码预测和非最大值抑制,同时优化了单类分割。 单类分割。

属性

名称 类型 说明
cfg dict

用于预测的配置参数。

overrides dict

用于自定义行为的可选参数重载。

_callbacks dict

预测时调用的回调函数的可选列表。

源代码 ultralytics/models/fastsam/predict.py
class FastSAMPredictor(DetectionPredictor):
    """
    FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
    YOLO framework.

    This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM.
    It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing
    for single-class segmentation.

    Attributes:
        cfg (dict): Configuration parameters for prediction.
        overrides (dict, optional): Optional parameter overrides for custom behavior.
        _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.

        Args:
            cfg (dict): Configuration parameters for prediction.
            overrides (dict, optional): Optional parameter overrides for custom behavior.
            _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
        """
        super().__init__(cfg, overrides, _callbacks)
        self.args.task = "segment"

    def postprocess(self, preds, img, orig_imgs):
        """
        Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
        size, and returns the final results.

        Args:
            preds (list): The raw output predictions from the model.
            img (torch.Tensor): The processed image tensor.
            orig_imgs (list | torch.Tensor): The original image or list of images.

        Returns:
            (list): A list of Results objects, each containing processed boxes, masks, and other metadata.
        """
        p = ops.non_max_suppression(
            preds[0],
            self.args.conf,
            self.args.iou,
            agnostic=self.args.agnostic_nms,
            max_det=self.args.max_det,
            nc=1,  # set to 1 class since SAM has no class predictions
            classes=self.args.classes,
        )
        full_box = torch.zeros(p[0].shape[1], device=p[0].device)
        full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
        full_box = full_box.view(1, -1)
        critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
        if critical_iou_index.numel() != 0:
            full_box[0][4] = p[0][critical_iou_index][:, 4]
            full_box[0][6:] = p[0][critical_iou_index][:, 6:]
            p[0][critical_iou_index] = full_box

        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported
        for i, pred in enumerate(p):
            orig_img = orig_imgs[i]
            img_path = self.batch[0][i]
            if not len(pred):  # save empty boxes
                masks = None
            elif self.args.retina_masks:
                pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
                masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
            else:
                masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
                pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
        return results

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

初始化 FastSAMPredictor 类,继承自 DetectionPredictor 并将任务设置为 "分段"。

参数

名称 类型 说明 默认值
cfg dict

用于预测的配置参数。

DEFAULT_CFG
overrides dict

用于自定义行为的可选参数重载。

None
_callbacks dict

预测时调用的回调函数的可选列表。

None
源代码 ultralytics/models/fastsam/predict.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.

    Args:
        cfg (dict): Configuration parameters for prediction.
        overrides (dict, optional): Optional parameter overrides for custom behavior.
        _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
    """
    super().__init__(cfg, overrides, _callbacks)
    self.args.task = "segment"

postprocess(preds, img, orig_imgs)

对预测结果执行后处理步骤,包括非最大值抑制和按原始图像大小缩放方框,并返回最终结果。 并返回最终结果。

参数

名称 类型 说明 默认值
preds list

模型预测的原始输出结果。

所需
img Tensor

经过处理的图像tensor 。

所需
orig_imgs list | Tensor

原始图像或图像列表。

所需

返回:

类型 说明
list

结果对象列表,每个结果对象都包含已处理的方框、遮罩和其他元数据。

源代码 ultralytics/models/fastsam/predict.py
def postprocess(self, preds, img, orig_imgs):
    """
    Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
    size, and returns the final results.

    Args:
        preds (list): The raw output predictions from the model.
        img (torch.Tensor): The processed image tensor.
        orig_imgs (list | torch.Tensor): The original image or list of images.

    Returns:
        (list): A list of Results objects, each containing processed boxes, masks, and other metadata.
    """
    p = ops.non_max_suppression(
        preds[0],
        self.args.conf,
        self.args.iou,
        agnostic=self.args.agnostic_nms,
        max_det=self.args.max_det,
        nc=1,  # set to 1 class since SAM has no class predictions
        classes=self.args.classes,
    )
    full_box = torch.zeros(p[0].shape[1], device=p[0].device)
    full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
    full_box = full_box.view(1, -1)
    critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
    if critical_iou_index.numel() != 0:
        full_box[0][4] = p[0][critical_iou_index][:, 4]
        full_box[0][6:] = p[0][critical_iou_index][:, 6:]
        p[0][critical_iou_index] = full_box

    if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
        orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

    results = []
    proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported
    for i, pred in enumerate(p):
        orig_img = orig_imgs[i]
        img_path = self.batch[0][i]
        if not len(pred):  # save empty boxes
            masks = None
        elif self.args.retina_masks:
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
        else:
            masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
        results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
    return results





创建于 2023-11-12,更新于 2023-11-25
作者:glenn-jocher(3)