Skip to content

Reference for ultralytics/engine/predictor.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/predictor.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.engine.predictor.BasePredictor

BasePredictor(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

BasePredictor.

A base class for creating predictors.

Attributes:

Name Type Description
args SimpleNamespace

Configuration for the predictor.

save_dir Path

Directory to save results.

done_warmup bool

Whether the predictor has finished setup.

model Module

Model used for prediction.

data dict

Data configuration.

device device

Device used for prediction.

dataset Dataset

Dataset used for prediction.

vid_writer dict

Dictionary of {save_path: video_writer, ...} writer for saving video output.

Parameters:

Name Type Description Default
cfg str

Path to a configuration file. Defaults to DEFAULT_CFG.

DEFAULT_CFG
overrides dict

Configuration overrides. Defaults to None.

None
Source code in ultralytics/engine/predictor.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initializes the BasePredictor class.

    Args:
        cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
        overrides (dict, optional): Configuration overrides. Defaults to None.
    """
    self.args = get_cfg(cfg, overrides)
    self.save_dir = get_save_dir(self.args)
    if self.args.conf is None:
        self.args.conf = 0.25  # default conf=0.25
    self.done_warmup = False
    if self.args.show:
        self.args.show = check_imshow(warn=True)

    # Usable if setup is done
    self.model = None
    self.data = self.args.data  # data_dict
    self.imgsz = None
    self.device = None
    self.dataset = None
    self.vid_writer = {}  # dict of {save_path: video_writer, ...}
    self.plotted_img = None
    self.source_type = None
    self.seen = 0
    self.windows = []
    self.batch = None
    self.results = None
    self.transforms = None
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    self.txt_path = None
    self._lock = threading.Lock()  # for automatic thread-safe inference
    callbacks.add_integration_callbacks(self)

__call__

__call__(source=None, model=None, stream=False, *args, **kwargs)

Performs inference on an image or stream.

Source code in ultralytics/engine/predictor.py
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
    """Performs inference on an image or stream."""
    self.stream = stream
    if stream:
        return self.stream_inference(source, model, *args, **kwargs)
    else:
        return list(self.stream_inference(source, model, *args, **kwargs))  # merge list of Result into one

add_callback

add_callback(event: str, func)

Add callback.

Source code in ultralytics/engine/predictor.py
def add_callback(self, event: str, func):
    """Add callback."""
    self.callbacks[event].append(func)

inference

inference(im, *args, **kwargs)

Runs inference on a given image using the specified model and arguments.

Source code in ultralytics/engine/predictor.py
def inference(self, im, *args, **kwargs):
    """Runs inference on a given image using the specified model and arguments."""
    visualize = (
        increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
        if self.args.visualize and (not self.source_type.tensor)
        else False
    )
    return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)

postprocess

postprocess(preds, img, orig_imgs)

Post-processes predictions for an image and returns them.

Source code in ultralytics/engine/predictor.py
def postprocess(self, preds, img, orig_imgs):
    """Post-processes predictions for an image and returns them."""
    return preds

pre_transform

pre_transform(im)

Pre-transform input image before inference.

Parameters:

Name Type Description Default
im List(np.ndarray

(N, 3, h, w) for tensor, [(h, w, 3) x N] for list.

required

Returns:

Type Description
list

A list of transformed images.

Source code in ultralytics/engine/predictor.py
def pre_transform(self, im):
    """
    Pre-transform input image before inference.

    Args:
        im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.

    Returns:
        (list): A list of transformed images.
    """
    same_shapes = len({x.shape for x in im}) == 1
    letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
    return [letterbox(image=x) for x in im]

predict_cli

predict_cli(source=None, model=None)

Method used for Command Line Interface (CLI) prediction.

This function is designed to run predictions using the CLI. It sets up the source and model, then processes the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the generator without storing results.

Note

Do not modify this function or remove the generator. The generator ensures that no outputs are accumulated in memory, which is critical for preventing memory issues during long-running predictions.

Source code in ultralytics/engine/predictor.py
def predict_cli(self, source=None, model=None):
    """
    Method used for Command Line Interface (CLI) prediction.

    This function is designed to run predictions using the CLI. It sets up the source and model, then processes
    the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
    generator without storing results.

    Note:
        Do not modify this function or remove the generator. The generator ensures that no outputs are
        accumulated in memory, which is critical for preventing memory issues during long-running predictions.
    """
    gen = self.stream_inference(source, model)
    for _ in gen:  # sourcery skip: remove-empty-nested-block, noqa
        pass

preprocess

preprocess(im)

Prepares input image before inference.

Parameters:

Name Type Description Default
im torch.Tensor | List(np.ndarray

BCHW for tensor, [(HWC) x B] for list.

required
Source code in ultralytics/engine/predictor.py
def preprocess(self, im):
    """
    Prepares input image before inference.

    Args:
        im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
    """
    not_tensor = not isinstance(im, torch.Tensor)
    if not_tensor:
        im = np.stack(self.pre_transform(im))
        im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
        im = np.ascontiguousarray(im)  # contiguous
        im = torch.from_numpy(im)

    im = im.to(self.device)
    im = im.half() if self.model.fp16 else im.float()  # uint8 to fp16/32
    if not_tensor:
        im /= 255  # 0 - 255 to 0.0 - 1.0
    return im

run_callbacks

run_callbacks(event: str)

Runs all registered callbacks for a specific event.

Source code in ultralytics/engine/predictor.py
def run_callbacks(self, event: str):
    """Runs all registered callbacks for a specific event."""
    for callback in self.callbacks.get(event, []):
        callback(self)

save_predicted_images

save_predicted_images(save_path='', frame=0)

Save video predictions as mp4 at specified path.

Source code in ultralytics/engine/predictor.py
def save_predicted_images(self, save_path="", frame=0):
    """Save video predictions as mp4 at specified path."""
    im = self.plotted_img

    # Save videos and streams
    if self.dataset.mode in {"stream", "video"}:
        fps = self.dataset.fps if self.dataset.mode == "video" else 30
        frames_path = f'{save_path.split(".", 1)[0]}_frames/'
        if save_path not in self.vid_writer:  # new video
            if self.args.save_frames:
                Path(frames_path).mkdir(parents=True, exist_ok=True)
            suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
            self.vid_writer[save_path] = cv2.VideoWriter(
                filename=str(Path(save_path).with_suffix(suffix)),
                fourcc=cv2.VideoWriter_fourcc(*fourcc),
                fps=fps,  # integer required, floats produce error in MP4 codec
                frameSize=(im.shape[1], im.shape[0]),  # (width, height)
            )

        # Save video
        self.vid_writer[save_path].write(im)
        if self.args.save_frames:
            cv2.imwrite(f"{frames_path}{frame}.jpg", im)

    # Save images
    else:
        cv2.imwrite(save_path, im)

setup_model

setup_model(model, verbose=True)

Initialize YOLO model with given parameters and set it to evaluation mode.

Source code in ultralytics/engine/predictor.py
def setup_model(self, model, verbose=True):
    """Initialize YOLO model with given parameters and set it to evaluation mode."""
    self.model = AutoBackend(
        weights=model or self.args.model,
        device=select_device(self.args.device, verbose=verbose),
        dnn=self.args.dnn,
        data=self.args.data,
        fp16=self.args.half,
        batch=self.args.batch,
        fuse=True,
        verbose=verbose,
    )

    self.device = self.model.device  # update device
    self.args.half = self.model.fp16  # update half
    self.model.eval()

setup_source

setup_source(source)

Sets up source and inference mode.

Source code in ultralytics/engine/predictor.py
def setup_source(self, source):
    """Sets up source and inference mode."""
    self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2)  # check image size
    self.transforms = (
        getattr(
            self.model.model,
            "transforms",
            classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
        )
        if self.args.task == "classify"
        else None
    )
    self.dataset = load_inference_source(
        source=source,
        batch=self.args.batch,
        vid_stride=self.args.vid_stride,
        buffer=self.args.stream_buffer,
    )
    self.source_type = self.dataset.source_type
    if not getattr(self, "stream", True) and (
        self.source_type.stream
        or self.source_type.screenshot
        or len(self.dataset) > 1000  # many images
        or any(getattr(self.dataset, "video_flag", [False]))
    ):  # videos
        LOGGER.warning(STREAM_WARNING)
    self.vid_writer = {}

show

show(p='')

Display an image in a window using OpenCV imshow().

Source code in ultralytics/engine/predictor.py
def show(self, p=""):
    """Display an image in a window using OpenCV imshow()."""
    im = self.plotted_img
    if platform.system() == "Linux" and p not in self.windows:
        self.windows.append(p)
        cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux)
        cv2.resizeWindow(p, im.shape[1], im.shape[0])  # (width, height)
    cv2.imshow(p, im)
    cv2.waitKey(300 if self.dataset.mode == "image" else 1)  # 1 millisecond

stream_inference

stream_inference(source=None, model=None, *args, **kwargs)

Streams real-time inference on camera feed and saves results to file.

Source code in ultralytics/engine/predictor.py
@smart_inference_mode()
def stream_inference(self, source=None, model=None, *args, **kwargs):
    """Streams real-time inference on camera feed and saves results to file."""
    if self.args.verbose:
        LOGGER.info("")

    # Setup model
    if not self.model:
        self.setup_model(model)

    with self._lock:  # for thread-safe inference
        # Setup source every time predict is called
        self.setup_source(source if source is not None else self.args.source)

        # Check if save_dir/ label file exists
        if self.args.save or self.args.save_txt:
            (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)

        # Warmup model
        if not self.done_warmup:
            self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
            self.done_warmup = True

        self.seen, self.windows, self.batch = 0, [], None
        profilers = (
            ops.Profile(device=self.device),
            ops.Profile(device=self.device),
            ops.Profile(device=self.device),
        )
        self.run_callbacks("on_predict_start")
        for self.batch in self.dataset:
            self.run_callbacks("on_predict_batch_start")
            paths, im0s, s = self.batch

            # Preprocess
            with profilers[0]:
                im = self.preprocess(im0s)

            # Inference
            with profilers[1]:
                preds = self.inference(im, *args, **kwargs)
                if self.args.embed:
                    yield from [preds] if isinstance(preds, torch.Tensor) else preds  # yield embedding tensors
                    continue

            # Postprocess
            with profilers[2]:
                self.results = self.postprocess(preds, im, im0s)
            self.run_callbacks("on_predict_postprocess_end")

            # Visualize, save, write results
            n = len(im0s)
            for i in range(n):
                self.seen += 1
                self.results[i].speed = {
                    "preprocess": profilers[0].dt * 1e3 / n,
                    "inference": profilers[1].dt * 1e3 / n,
                    "postprocess": profilers[2].dt * 1e3 / n,
                }
                if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
                    s[i] += self.write_results(i, Path(paths[i]), im, s)

            # Print batch results
            if self.args.verbose:
                LOGGER.info("\n".join(s))

            self.run_callbacks("on_predict_batch_end")
            yield from self.results

    # Release assets
    for v in self.vid_writer.values():
        if isinstance(v, cv2.VideoWriter):
            v.release()

    # Print final results
    if self.args.verbose and self.seen:
        t = tuple(x.t / self.seen * 1e3 for x in profilers)  # speeds per image
        LOGGER.info(
            f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
            f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t
        )
    if self.args.save or self.args.save_txt or self.args.save_crop:
        nl = len(list(self.save_dir.glob("labels/*.txt")))  # number of labels
        s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
        LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
    self.run_callbacks("on_predict_end")

write_results

write_results(i, p, im, s)

Write inference results to a file or directory.

Source code in ultralytics/engine/predictor.py
def write_results(self, i, p, im, s):
    """Write inference results to a file or directory."""
    string = ""  # print string
    if len(im.shape) == 3:
        im = im[None]  # expand for batch dim
    if self.source_type.stream or self.source_type.from_img or self.source_type.tensor:  # batch_size >= 1
        string += f"{i}: "
        frame = self.dataset.count
    else:
        match = re.search(r"frame (\d+)/", s[i])
        frame = int(match[1]) if match else None  # 0 if frame undetermined

    self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
    string += "%gx%g " % im.shape[2:]
    result = self.results[i]
    result.save_dir = self.save_dir.__str__()  # used in other locations
    string += f"{result.verbose()}{result.speed['inference']:.1f}ms"

    # Add predictions to image
    if self.args.save or self.args.show:
        self.plotted_img = result.plot(
            line_width=self.args.line_width,
            boxes=self.args.show_boxes,
            conf=self.args.show_conf,
            labels=self.args.show_labels,
            im_gpu=None if self.args.retina_masks else im[i],
        )

    # Save results
    if self.args.save_txt:
        result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
    if self.args.save_crop:
        result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
    if self.args.show:
        self.show(str(p))
    if self.args.save:
        self.save_predicted_images(str(self.save_dir / p.name), frame)

    return string





Created 2023-11-12, Updated 2024-07-21
Authors: glenn-jocher (6), Burhan-Q (1)