Skip to content

Reference for ultralytics/models/yolo/segment/predict.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/segment/predict.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.yolo.segment.predict.SegmentationPredictor

SegmentationPredictor(self, cfg = DEFAULT_CFG, overrides = None, _callbacks = None)

Bases: DetectionPredictor

A class extending the DetectionPredictor class for prediction based on a segmentation model.

This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the prediction results.

This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the prediction results.

Args

NameTypeDescriptionDefault
cfgdictConfiguration for the predictor.DEFAULT_CFG
overridesdict, optionalConfiguration overrides that take precedence over cfg.None
_callbackslist, optionalList of callback functions to be invoked during prediction.None

Attributes

NameTypeDescription
argsdictConfiguration arguments for the predictor.
modeltorch.nn.ModuleThe loaded YOLO segmentation model.
batchlistCurrent batch of images being processed.

Methods

NameDescription
construct_resultConstruct a single result object from the prediction.
construct_resultsConstruct a list of result objects from the predictions.
postprocessApply non-max suppression and process segmentation detections for each image in the input batch.

Examples

>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.segment import SegmentationPredictor
>>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
>>> predictor = SegmentationPredictor(overrides=args)
>>> predictor.predict_cli()
Source code in ultralytics/models/yolo/segment/predict.pyView on GitHub
class SegmentationPredictor(DetectionPredictor):
    """A class extending the DetectionPredictor class for prediction based on a segmentation model.

    This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
    prediction results.

    Attributes:
        args (dict): Configuration arguments for the predictor.
        model (torch.nn.Module): The loaded YOLO segmentation model.
        batch (list): Current batch of images being processed.

    Methods:
        postprocess: Apply non-max suppression and process segmentation detections.
        construct_results: Construct a list of result objects from predictions.
        construct_result: Construct a single result object from a prediction.

    Examples:
        >>> from ultralytics.utils import ASSETS
        >>> from ultralytics.models.yolo.segment import SegmentationPredictor
        >>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
        >>> predictor = SegmentationPredictor(overrides=args)
        >>> predictor.predict_cli()
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize the SegmentationPredictor with configuration, overrides, and callbacks.

        This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
        prediction results.

        Args:
            cfg (dict): Configuration for the predictor.
            overrides (dict, optional): Configuration overrides that take precedence over cfg.
            _callbacks (list, optional): List of callback functions to be invoked during prediction.
        """
        super().__init__(cfg, overrides, _callbacks)
        self.args.task = "segment"


method ultralytics.models.yolo.segment.predict.SegmentationPredictor.construct_result

def construct_result(self, pred, img, orig_img, img_path, proto)

Construct a single result object from the prediction.

Args

NameTypeDescriptionDefault
predtorch.TensorThe predicted bounding boxes, scores, and masks.required
imgtorch.TensorThe image after preprocessing.required
orig_imgnp.ndarrayThe original image before preprocessing.required
img_pathstrThe path to the original image.required
prototorch.TensorThe prototype masks.required

Returns

TypeDescription
ResultsResult object containing the original image, image path, class names, bounding boxes, and masks.
Source code in ultralytics/models/yolo/segment/predict.pyView on GitHub
def construct_result(self, pred, img, orig_img, img_path, proto):
    """Construct a single result object from the prediction.

    Args:
        pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
        img (torch.Tensor): The image after preprocessing.
        orig_img (np.ndarray): The original image before preprocessing.
        img_path (str): The path to the original image.
        proto (torch.Tensor): The prototype masks.

    Returns:
        (Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
    """
    if pred.shape[0] == 0:  # save empty boxes
        masks = None
    elif self.args.retina_masks:
        pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
        masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
    else:
        masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
        pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
    if masks is not None:
        keep = masks.amax((-2, -1)) > 0  # only keep predictions with masks
        if not all(keep):  # most predictions have masks
            pred, masks = pred[keep], masks[keep]  # indexing is slow
    return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)


method ultralytics.models.yolo.segment.predict.SegmentationPredictor.construct_results

def construct_results(self, preds, img, orig_imgs, protos)

Construct a list of result objects from the predictions.

Args

NameTypeDescriptionDefault
predslist[torch.Tensor]List of predicted bounding boxes, scores, and masks.required
imgtorch.TensorThe image after preprocessing.required
orig_imgslist[np.ndarray]List of original images before preprocessing.required
protoslist[torch.Tensor]List of prototype masks.required

Returns

TypeDescription
list[Results]List of result objects containing the original images, image paths, class names, bounding
Source code in ultralytics/models/yolo/segment/predict.pyView on GitHub
def construct_results(self, preds, img, orig_imgs, protos):
    """Construct a list of result objects from the predictions.

    Args:
        preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
        img (torch.Tensor): The image after preprocessing.
        orig_imgs (list[np.ndarray]): List of original images before preprocessing.
        protos (list[torch.Tensor]): List of prototype masks.

    Returns:
        (list[Results]): List of result objects containing the original images, image paths, class names, bounding
            boxes, and masks.
    """
    return [
        self.construct_result(pred, img, orig_img, img_path, proto)
        for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
    ]


method ultralytics.models.yolo.segment.predict.SegmentationPredictor.postprocess

def postprocess(self, preds, img, orig_imgs)

Apply non-max suppression and process segmentation detections for each image in the input batch.

Args

NameTypeDescriptionDefault
predstupleModel predictions, containing bounding boxes, scores, classes, and mask coefficients.required
imgtorch.TensorInput image tensor in model format, with shape (B, C, H, W).required
orig_imgslist | torch.Tensor | np.ndarrayOriginal image or batch of images.required

Returns

TypeDescription
listList of Results objects containing the segmentation predictions for each image in the batch. Each

Examples

>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
>>> results = predictor.postprocess(preds, img, orig_img)
Source code in ultralytics/models/yolo/segment/predict.pyView on GitHub
def postprocess(self, preds, img, orig_imgs):
    """Apply non-max suppression and process segmentation detections for each image in the input batch.

    Args:
        preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
        img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
        orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.

    Returns:
        (list): List of Results objects containing the segmentation predictions for each image in the batch. Each
            Results object includes both bounding boxes and segmentation masks.

    Examples:
        >>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
        >>> results = predictor.postprocess(preds, img, orig_img)
    """
    # Extract protos - tuple if PyTorch model or array if exported
    protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
    return super().postprocess(preds[0], img, orig_imgs, protos=protos)





📅 Created 2 years ago ✏️ Updated 2 days ago
glenn-jocherjk4eBurhan-Q