Skip to content

Reference for ultralytics/models/yolo/detect/predict.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/detect/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.detect.predict.DetectionPredictor

DetectionPredictor(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: BasePredictor

A class extending the BasePredictor class for prediction based on a detection model.

This predictor specializes in object detection tasks, processing model outputs into meaningful detection results with bounding boxes and class predictions.

Attributes:

Name Type Description
args namespace

Configuration arguments for the predictor.

model Module

The detection model used for inference.

batch list

Batch of images and metadata for processing.

Methods:

Name Description
postprocess

Process raw model predictions into detection results.

construct_results

Build Results objects from processed predictions.

construct_result

Create a single Result object from a prediction.

Examples:

>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.detect import DetectionPredictor
>>> args = dict(model="yolo11n.pt", source=ASSETS)
>>> predictor = DetectionPredictor(overrides=args)
>>> predictor.predict_cli()
Source code in ultralytics/engine/predictor.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize the BasePredictor class.

    Args:
        cfg (str | dict): Path to a configuration file or a configuration dictionary.
        overrides (dict | None): Configuration overrides.
        _callbacks (dict | None): Dictionary of callback functions.
    """
    self.args = get_cfg(cfg, overrides)
    self.save_dir = get_save_dir(self.args)
    if self.args.conf is None:
        self.args.conf = 0.25  # default conf=0.25
    self.done_warmup = False
    if self.args.show:
        self.args.show = check_imshow(warn=True)

    # Usable if setup is done
    self.model = None
    self.data = self.args.data  # data_dict
    self.imgsz = None
    self.device = None
    self.dataset = None
    self.vid_writer = {}  # dict of {save_path: video_writer, ...}
    self.plotted_img = None
    self.source_type = None
    self.seen = 0
    self.windows = []
    self.batch = None
    self.results = None
    self.transforms = None
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    self.txt_path = None
    self._lock = threading.Lock()  # for automatic thread-safe inference
    callbacks.add_integration_callbacks(self)

construct_result

construct_result(pred, img, orig_img, img_path)

Construct a single Results object from one image prediction.

Parameters:

Name Type Description Default
pred Tensor

Predicted boxes and scores with shape (N, 6) where N is the number of detections.

required
img Tensor

Preprocessed image tensor used for inference.

required
orig_img ndarray

Original image before preprocessing.

required
img_path str

Path to the original image file.

required

Returns:

Type Description
Results

Results object containing the original image, image path, class names, and scaled bounding boxes.

Source code in ultralytics/models/yolo/detect/predict.py
def construct_result(self, pred, img, orig_img, img_path):
    """
    Construct a single Results object from one image prediction.

    Args:
        pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
        img (torch.Tensor): Preprocessed image tensor used for inference.
        orig_img (np.ndarray): Original image before preprocessing.
        img_path (str): Path to the original image file.

    Returns:
        (Results): Results object containing the original image, image path, class names, and scaled bounding boxes.
    """
    pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
    return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])

construct_results

construct_results(preds, img, orig_imgs)

Construct a list of Results objects from model predictions.

Parameters:

Name Type Description Default
preds List[Tensor]

List of predicted bounding boxes and scores for each image.

required
img Tensor

Batch of preprocessed images used for inference.

required
orig_imgs List[ndarray]

List of original images before preprocessing.

required

Returns:

Type Description
List[Results]

List of Results objects containing detection information for each image.

Source code in ultralytics/models/yolo/detect/predict.py
def construct_results(self, preds, img, orig_imgs):
    """
    Construct a list of Results objects from model predictions.

    Args:
        preds (List[torch.Tensor]): List of predicted bounding boxes and scores for each image.
        img (torch.Tensor): Batch of preprocessed images used for inference.
        orig_imgs (List[np.ndarray]): List of original images before preprocessing.

    Returns:
        (List[Results]): List of Results objects containing detection information for each image.
    """
    return [
        self.construct_result(pred, img, orig_img, img_path)
        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
    ]

postprocess

postprocess(preds, img, orig_imgs, **kwargs)

Post-processes predictions and returns a list of Results objects.

Source code in ultralytics/models/yolo/detect/predict.py
def postprocess(self, preds, img, orig_imgs, **kwargs):
    """Post-processes predictions and returns a list of Results objects."""
    preds = ops.non_max_suppression(
        preds,
        self.args.conf,
        self.args.iou,
        self.args.classes,
        self.args.agnostic_nms,
        max_det=self.args.max_det,
        nc=len(self.model.names),
        end2end=getattr(self.model, "end2end", False),
        rotated=self.args.task == "obb",
    )

    if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
        orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

    return self.construct_results(preds, img, orig_imgs, **kwargs)



📅 Created 1 year ago ✏️ Updated 6 months ago