Skip to content

Model Prediction with Ultralytics YOLO

Ultralytics YOLO ecosystem and integrations

Introduction

In the world of machine learning and computer vision, the process of making sense out of visual data is called 'inference' or 'prediction'. Ultralytics YOLO11 offers a powerful feature known as predict mode that is tailored for high-performance, real-time inference on a wide range of data sources.



Watch: How to Extract the Outputs from Ultralytics YOLO Model for Custom Projects.

Real-world Applications

ManufacturingSportsSafety
Vehicle Spare Parts DetectionFootball Player DetectionPeople Fall Detection
Vehicle Spare Parts DetectionFootball Player DetectionPeople Fall Detection

Why Use Ultralytics YOLO for Inference?

Here's why you should consider YOLO11's predict mode for your various inference needs:

  • Versatility: Capable of making inferences on images, videos, and even live streams.
  • Performance: Engineered for real-time, high-speed processing without sacrificing accuracy.
  • Ease of Use: Intuitive Python and CLI interfaces for rapid deployment and testing.
  • Highly Customizable: Various settings and parameters to tune the model's inference behavior according to your specific requirements.

Key Features of Predict Mode

YOLO11's predict mode is designed to be robust and versatile, featuring:

  • Multiple Data Source Compatibility: Whether your data is in the form of individual images, a collection of images, video files, or real-time video streams, predict mode has you covered.
  • Streaming Mode: Use the streaming feature to generate a memory-efficient generator of Results objects. Enable this by setting stream=True in the predictor's call method.
  • Batch Processing: The ability to process multiple images or video frames in a single batch, further speeding up inference time.
  • Integration Friendly: Easily integrate with existing data pipelines and other software components, thanks to its flexible API.

Ultralytics YOLO models return either a Python list of Results objects, or a memory-efficient Python generator of Results objects when stream=True is passed to the model during inference:

Predict

from ultralytics import YOLO

# Load a model
model = YOLO("yolo11n.pt")  # pretrained YOLO11n model

# Run batched inference on a list of images
results = model(["image1.jpg", "image2.jpg"])  # return a list of Results objects

# Process results list
for result in results:
    boxes = result.boxes  # Boxes object for bounding box outputs
    masks = result.masks  # Masks object for segmentation masks outputs
    keypoints = result.keypoints  # Keypoints object for pose outputs
    probs = result.probs  # Probs object for classification outputs
    obb = result.obb  # Oriented boxes object for OBB outputs
    result.show()  # display to screen
    result.save(filename="result.jpg")  # save to disk
from ultralytics import YOLO

# Load a model
model = YOLO("yolo11n.pt")  # pretrained YOLO11n model

# Run batched inference on a list of images
results = model(["image1.jpg", "image2.jpg"], stream=True)  # return a generator of Results objects

# Process results generator
for result in results:
    boxes = result.boxes  # Boxes object for bounding box outputs
    masks = result.masks  # Masks object for segmentation masks outputs
    keypoints = result.keypoints  # Keypoints object for pose outputs
    probs = result.probs  # Probs object for classification outputs
    obb = result.obb  # Oriented boxes object for OBB outputs
    result.show()  # display to screen
    result.save(filename="result.jpg")  # save to disk

Inference Sources

YOLO11 can process different types of input sources for inference, as shown in the table below. The sources include static images, video streams, and various data formats. The table also indicates whether each source can be used in streaming mode with the argument stream=True ✅. Streaming mode is beneficial for processing videos or live streams as it creates a generator of results instead of loading all frames into memory.

Tip

Use stream=True for processing long videos or large datasets to efficiently manage memory. When stream=False, the results for all frames or data points are stored in memory, which can quickly add up and cause out-of-memory errors for large inputs. In contrast, stream=True utilizes a generator, which only keeps the results of the current frame or data point in memory, significantly reducing memory consumption and preventing out-of-memory issues.

SourceExampleTypeNotes
image'image.jpg'str or PathSingle image file.
URL'https://ultralytics.com/images/bus.jpg'strURL to an image.
screenshot'screen'strCapture a screenshot.
PILImage.open('image.jpg')PIL.ImageHWC format with RGB channels.
OpenCVcv2.imread('image.jpg')np.ndarrayHWC format with BGR channels uint8 (0-255).
numpynp.zeros((640,1280,3))np.ndarrayHWC format with BGR channels uint8 (0-255).
torchtorch.zeros(16,3,320,640)torch.TensorBCHW format with RGB channels float32 (0.0-1.0).
CSV'sources.csv'str or PathCSV file containing paths to images, videos, or directories.
video ✅'video.mp4'str or PathVideo file in formats like MP4, AVI, etc.
directory ✅'path/'str or PathPath to a directory containing images or videos.
glob ✅'path/*.jpg'strGlob pattern to match multiple files. Use the * character as a wildcard.
YouTube ✅'https://youtu.be/LNwODJXcvt4'strURL to a YouTube video.
stream ✅'rtsp://example.com/media.mp4'strURL for streaming protocols such as RTSP, RTMP, TCP, or an IP address.
multi-stream ✅'list.streams'str or Path*.streams text file with one stream URL per row, i.e. 8 streams will run at batch-size 8.
webcam ✅0intIndex of the connected camera device to run inference on.

Below are code examples for using each source type:

Prediction sources

Run inference on an image file.

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Define path to the image file
source = "path/to/image.jpg"

# Run inference on the source
results = model(source)  # list of Results objects

Run inference on the current screen content as a screenshot.

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Define current screenshot as source
source = "screen"

# Run inference on the source
results = model(source)  # list of Results objects

Run inference on an image or video hosted remotely via URL.

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Define remote image or video URL
source = "https://ultralytics.com/images/bus.jpg"

# Run inference on the source
results = model(source)  # list of Results objects

Run inference on an image opened with Python Imaging Library (PIL).

from PIL import Image

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Open an image using PIL
source = Image.open("path/to/image.jpg")

# Run inference on the source
results = model(source)  # list of Results objects

Run inference on an image read with OpenCV.

import cv2

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Read an image using OpenCV
source = cv2.imread("path/to/image.jpg")

# Run inference on the source
results = model(source)  # list of Results objects

Run inference on an image represented as a numpy array.

import numpy as np

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Create a random numpy array of HWC shape (640, 640, 3) with values in range [0, 255] and type uint8
source = np.random.randint(low=0, high=255, size=(640, 640, 3), dtype="uint8")

# Run inference on the source
results = model(source)  # list of Results objects

Run inference on an image represented as a PyTorch tensor.

import torch

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Create a random torch tensor of BCHW shape (1, 3, 640, 640) with values in range [0, 1] and type float32
source = torch.rand(1, 3, 640, 640, dtype=torch.float32)

# Run inference on the source
results = model(source)  # list of Results objects

Run inference on a collection of images, URLs, videos and directories listed in a CSV file.

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Define a path to a CSV file with images, URLs, videos and directories
source = "path/to/file.csv"

# Run inference on the source
results = model(source)  # list of Results objects

Run inference on a video file. By using stream=True, you can create a generator of Results objects to reduce memory usage.

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Define path to video file
source = "path/to/video.mp4"

# Run inference on the source
results = model(source, stream=True)  # generator of Results objects

Run inference on all images and videos in a directory. To also capture images and videos in subdirectories use a glob pattern, i.e. path/to/dir/**/*.

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Define path to directory containing images and videos for inference
source = "path/to/dir"

# Run inference on the source
results = model(source, stream=True)  # generator of Results objects

Run inference on all images and videos that match a glob expression with * characters.

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Define a glob search for all JPG files in a directory
source = "path/to/dir/*.jpg"

# OR define a recursive glob search for all JPG files including subdirectories
source = "path/to/dir/**/*.jpg"

# Run inference on the source
results = model(source, stream=True)  # generator of Results objects

Run inference on a YouTube video. By using stream=True, you can create a generator of Results objects to reduce memory usage for long videos.

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Define source as YouTube video URL
source = "https://youtu.be/LNwODJXcvt4"

# Run inference on the source
results = model(source, stream=True)  # generator of Results objects

Use the stream mode to run inference on live video streams using RTSP, RTMP, TCP, or IP address protocols. If a single stream is provided, the model runs inference with a batch size of 1. For multiple streams, a .streams text file can be used to perform batched inference, where the batch size is determined by the number of streams provided (e.g., batch-size 8 for 8 streams).

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Single stream with batch-size 1 inference
source = "rtsp://example.com/media.mp4"  # RTSP, RTMP, TCP, or IP streaming address

# Run inference on the source
results = model(source, stream=True)  # generator of Results objects

For single stream usage, the batch size is set to 1 by default, allowing efficient real-time processing of the video feed.

To handle multiple video streams simultaneously, use a .streams text file containing the streaming sources. The model will run batched inference where the batch size equals the number of streams. This setup enables efficient processing of multiple feeds concurrently.

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Multiple streams with batched inference (e.g., batch-size 8 for 8 streams)
source = "path/to/list.streams"  # *.streams text file with one streaming address per line

# Run inference on the source
results = model(source, stream=True)  # generator of Results objects

Example .streams text file:

rtsp://example.com/media1.mp4
rtsp://example.com/media2.mp4
rtmp://example2.com/live
tcp://192.168.1.100:554
...

Each row in the file represents a streaming source, allowing you to monitor and perform inference on several video streams at once.

You can run inference on a connected camera device by passing the index of that particular camera to source.

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Run inference on the source
results = model(source=0, stream=True)  # generator of Results objects

Inference Arguments

model.predict() accepts multiple arguments that can be passed at inference time to override defaults:

Example

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Run inference on 'bus.jpg' with arguments
model.predict("bus.jpg", save=True, imgsz=320, conf=0.5)

Inference arguments:

ArgumentTypeDefaultDescription
sourcestr'ultralytics/assets'Specifies the data source for inference. Can be an image path, video file, directory, URL, or device ID for live feeds. Supports a wide range of formats and sources, enabling flexible application across different types of input.
conffloat0.25Sets the minimum confidence threshold for detections. Objects detected with confidence below this threshold will be disregarded. Adjusting this value can help reduce false positives.
ioufloat0.7Intersection Over Union (IoU) threshold for Non-Maximum Suppression (NMS). Lower values result in fewer detections by eliminating overlapping boxes, useful for reducing duplicates.
imgszint or tuple640Defines the image size for inference. Can be a single integer 640 for square resizing or a (height, width) tuple. Proper sizing can improve detection accuracy and processing speed.
halfboolFalseEnables half-precision (FP16) inference, which can speed up model inference on supported GPUs with minimal impact on accuracy.
devicestrNoneSpecifies the device for inference (e.g., cpu, cuda:0 or 0). Allows users to select between CPU, a specific GPU, or other compute devices for model execution.
max_detint300Maximum number of detections allowed per image. Limits the total number of objects the model can detect in a single inference, preventing excessive outputs in dense scenes.
vid_strideint1Frame stride for video inputs. Allows skipping frames in videos to speed up processing at the cost of temporal resolution. A value of 1 processes every frame, higher values skip frames.
stream_bufferboolFalseDetermines whether to queue incoming frames for video streams. If False, old frames get dropped to accomodate new frames (optimized for real-time applications). If `True', queues new frames in a buffer, ensuring no frames get skipped, but will cause latency if inference FPS is lower than stream FPS.
visualizeboolFalseActivates visualization of model features during inference, providing insights into what the model is "seeing". Useful for debugging and model interpretation.
augmentboolFalseEnables test-time augmentation (TTA) for predictions, potentially improving detection robustness at the cost of inference speed.
agnostic_nmsboolFalseEnables class-agnostic Non-Maximum Suppression (NMS), which merges overlapping boxes of different classes. Useful in multi-class detection scenarios where class overlap is common.
classeslist[int]NoneFilters predictions to a set of class IDs. Only detections belonging to the specified classes will be returned. Useful for focusing on relevant objects in multi-class detection tasks.
retina_masksboolFalseReturns high-resolution segmentation masks. The returned masks (masks.data) will match the original image size if enabled. If disabled, they have the image size used during inference.
embedlist[int]NoneSpecifies the layers from which to extract feature vectors or embeddings. Useful for downstream tasks like clustering or similarity search.
projectstrNoneName of the project directory where prediction outputs are saved if save is enabled.
namestrNoneName of the prediction run. Used for creating a subdirectory within the project folder, where prediction outputs are stored if save is enabled.

Visualization arguments:

ArgumentTypeDefaultDescription
showboolFalseIf True, displays the annotated images or videos in a window. Useful for immediate visual feedback during development or testing.
saveboolFalse or TrueEnables saving of the annotated images or videos to file. Useful for documentation, further analysis, or sharing results. Defaults to True when using CLI & False when used in Python.
save_framesboolFalseWhen processing videos, saves individual frames as images. Useful for extracting specific frames or for detailed frame-by-frame analysis.
save_txtboolFalseSaves detection results in a text file, following the format [class] [x_center] [y_center] [width] [height] [confidence]. Useful for integration with other analysis tools.
save_confboolFalseIncludes confidence scores in the saved text files. Enhances the detail available for post-processing and analysis.
save_cropboolFalseSaves cropped images of detections. Useful for dataset augmentation, analysis, or creating focused datasets for specific objects.
show_labelsboolTrueDisplays labels for each detection in the visual output. Provides immediate understanding of detected objects.
show_confboolTrueDisplays the confidence score for each detection alongside the label. Gives insight into the model's certainty for each detection.
show_boxesboolTrueDraws bounding boxes around detected objects. Essential for visual identification and location of objects in images or video frames.
line_widthNone or intNoneSpecifies the line width of bounding boxes. If None, the line width is automatically adjusted based on the image size. Provides visual customization for clarity.

Image and Video Formats

YOLO11 supports various image and video formats, as specified in ultralytics/data/utils.py. See the tables below for the valid suffixes and example predict commands.

Images

The below table contains valid Ultralytics image formats.

Note

HEIC images are supported for inference only, not for training.

Image SuffixesExample Predict CommandReference
.bmpyolo predict source=image.bmpMicrosoft BMP File Format
.dngyolo predict source=image.dngAdobe DNG
.jpegyolo predict source=image.jpegJPEG
.jpgyolo predict source=image.jpgJPEG
.mpoyolo predict source=image.mpoMulti Picture Object
.pngyolo predict source=image.pngPortable Network Graphics
.tifyolo predict source=image.tifTag Image File Format
.tiffyolo predict source=image.tiffTag Image File Format
.webpyolo predict source=image.webpWebP
.pfmyolo predict source=image.pfmPortable FloatMap
.HEICyolo predict source=image.HEICHigh Efficiency Image Format

Videos

The below table contains valid Ultralytics video formats.

Video SuffixesExample Predict CommandReference
.asfyolo predict source=video.asfAdvanced Systems Format
.aviyolo predict source=video.aviAudio Video Interleave
.gifyolo predict source=video.gifGraphics Interchange Format
.m4vyolo predict source=video.m4vMPEG-4 Part 14
.mkvyolo predict source=video.mkvMatroska
.movyolo predict source=video.movQuickTime File Format
.mp4yolo predict source=video.mp4MPEG-4 Part 14 - Wikipedia
.mpegyolo predict source=video.mpegMPEG-1 Part 2
.mpgyolo predict source=video.mpgMPEG-1 Part 2
.tsyolo predict source=video.tsMPEG Transport Stream
.wmvyolo predict source=video.wmvWindows Media Video
.webmyolo predict source=video.webmWebM Project

Working with Results

All Ultralytics predict() calls will return a list of Results objects:

Results

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Run inference on an image
results = model("bus.jpg")  # list of 1 Results object
results = model(["bus.jpg", "zidane.jpg"])  # list of 2 Results objects

Results objects have the following attributes:

AttributeTypeDescription
orig_imgnumpy.ndarrayThe original image as a numpy array.
orig_shapetupleThe original image shape in (height, width) format.
boxesBoxes, optionalA Boxes object containing the detection bounding boxes.
masksMasks, optionalA Masks object containing the detection masks.
probsProbs, optionalA Probs object containing probabilities of each class for classification task.
keypointsKeypoints, optionalA Keypoints object containing detected keypoints for each object.
obbOBB, optionalAn OBB object containing oriented bounding boxes.
speeddictA dictionary of preprocess, inference, and postprocess speeds in milliseconds per image.
namesdictA dictionary of class names.
pathstrThe path to the image file.

Results objects have the following methods:

MethodReturn TypeDescription
update()NoneUpdate the boxes, masks, and probs attributes of the Results object.
cpu()ResultsReturn a copy of the Results object with all tensors on CPU memory.
numpy()ResultsReturn a copy of the Results object with all tensors as numpy arrays.
cuda()ResultsReturn a copy of the Results object with all tensors on GPU memory.
to()ResultsReturn a copy of the Results object with tensors on the specified device and dtype.
new()ResultsReturn a new Results object with the same image, path, and names.
plot()numpy.ndarrayPlots the detection results. Returns a numpy array of the annotated image.
show()NoneShow annotated results to screen.
save()NoneSave annotated results to file.
verbose()strReturn log string for each task.
save_txt()NoneSave predictions into a txt file.
save_crop()NoneSave cropped predictions to save_dir/cls/file_name.jpg.
tojson()strConvert the object to JSON format.

For more details see the Results class documentation.

Boxes

Boxes object can be used to index, manipulate, and convert bounding boxes to different formats.

Boxes

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Run inference on an image
results = model("bus.jpg")  # results list

# View results
for r in results:
    print(r.boxes)  # print the Boxes object containing the detection bounding boxes

Here is a table for the Boxes class methods and properties, including their name, type, and description:

NameTypeDescription
cpu()MethodMove the object to CPU memory.
numpy()MethodConvert the object to a numpy array.
cuda()MethodMove the object to CUDA memory.
to()MethodMove the object to the specified device.
xyxyProperty (torch.Tensor)Return the boxes in xyxy format.
confProperty (torch.Tensor)Return the confidence values of the boxes.
clsProperty (torch.Tensor)Return the class values of the boxes.
idProperty (torch.Tensor)Return the track IDs of the boxes (if available).
xywhProperty (torch.Tensor)Return the boxes in xywh format.
xyxynProperty (torch.Tensor)Return the boxes in xyxy format normalized by original image size.
xywhnProperty (torch.Tensor)Return the boxes in xywh format normalized by original image size.

For more details see the Boxes class documentation.

Masks

Masks object can be used index, manipulate and convert masks to segments.

Masks

from ultralytics import YOLO

# Load a pretrained YOLO11n-seg Segment model
model = YOLO("yolo11n-seg.pt")

# Run inference on an image
results = model("bus.jpg")  # results list

# View results
for r in results:
    print(r.masks)  # print the Masks object containing the detected instance masks

Here is a table for the Masks class methods and properties, including their name, type, and description:

NameTypeDescription
cpu()MethodReturns the masks tensor on CPU memory.
numpy()MethodReturns the masks tensor as a numpy array.
cuda()MethodReturns the masks tensor on GPU memory.
to()MethodReturns the masks tensor with the specified device and dtype.
xynProperty (torch.Tensor)A list of normalized segments represented as tensors.
xyProperty (torch.Tensor)A list of segments in pixel coordinates represented as tensors.

For more details see the Masks class documentation.

Keypoints

Keypoints object can be used index, manipulate and normalize coordinates.

Keypoints

from ultralytics import YOLO

# Load a pretrained YOLO11n-pose Pose model
model = YOLO("yolo11n-pose.pt")

# Run inference on an image
results = model("bus.jpg")  # results list

# View results
for r in results:
    print(r.keypoints)  # print the Keypoints object containing the detected keypoints

Here is a table for the Keypoints class methods and properties, including their name, type, and description:

NameTypeDescription
cpu()MethodReturns the keypoints tensor on CPU memory.
numpy()MethodReturns the keypoints tensor as a numpy array.
cuda()MethodReturns the keypoints tensor on GPU memory.
to()MethodReturns the keypoints tensor with the specified device and dtype.
xynProperty (torch.Tensor)A list of normalized keypoints represented as tensors.
xyProperty (torch.Tensor)A list of keypoints in pixel coordinates represented as tensors.
confProperty (torch.Tensor)Returns confidence values of keypoints if available, else None.

For more details see the Keypoints class documentation.

Probs

Probs object can be used index, get top1 and top5 indices and scores of classification.

Probs

from ultralytics import YOLO

# Load a pretrained YOLO11n-cls Classify model
model = YOLO("yolo11n-cls.pt")

# Run inference on an image
results = model("bus.jpg")  # results list

# View results
for r in results:
    print(r.probs)  # print the Probs object containing the detected class probabilities

Here's a table summarizing the methods and properties for the Probs class:

NameTypeDescription
cpu()MethodReturns a copy of the probs tensor on CPU memory.
numpy()MethodReturns a copy of the probs tensor as a numpy array.
cuda()MethodReturns a copy of the probs tensor on GPU memory.
to()MethodReturns a copy of the probs tensor with the specified device and dtype.
top1Property (int)Index of the top 1 class.
top5Property (list[int])Indices of the top 5 classes.
top1confProperty (torch.Tensor)Confidence of the top 1 class.
top5confProperty (torch.Tensor)Confidences of the top 5 classes.

For more details see the Probs class documentation.

OBB

OBB object can be used to index, manipulate, and convert oriented bounding boxes to different formats.

OBB

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n-obb.pt")

# Run inference on an image
results = model("boats.jpg")  # results list

# View results
for r in results:
    print(r.obb)  # print the OBB object containing the oriented detection bounding boxes

Here is a table for the OBB class methods and properties, including their name, type, and description:

NameTypeDescription
cpu()MethodMove the object to CPU memory.
numpy()MethodConvert the object to a numpy array.
cuda()MethodMove the object to CUDA memory.
to()MethodMove the object to the specified device.
confProperty (torch.Tensor)Return the confidence values of the boxes.
clsProperty (torch.Tensor)Return the class values of the boxes.
idProperty (torch.Tensor)Return the track IDs of the boxes (if available).
xyxyProperty (torch.Tensor)Return the horizontal boxes in xyxy format.
xywhrProperty (torch.Tensor)Return the rotated boxes in xywhr format.
xyxyxyxyProperty (torch.Tensor)Return the rotated boxes in xyxyxyxy format.
xyxyxyxynProperty (torch.Tensor)Return the rotated boxes in xyxyxyxy format normalized by image size.

For more details see the OBB class documentation.

Plotting Results

The plot() method in Results objects facilitates visualization of predictions by overlaying detected objects (such as bounding boxes, masks, keypoints, and probabilities) onto the original image. This method returns the annotated image as a NumPy array, allowing for easy display or saving.

Plotting

from PIL import Image

from ultralytics import YOLO

# Load a pretrained YOLO11n model
model = YOLO("yolo11n.pt")

# Run inference on 'bus.jpg'
results = model(["bus.jpg", "zidane.jpg"])  # results list

# Visualize the results
for i, r in enumerate(results):
    # Plot results image
    im_bgr = r.plot()  # BGR-order numpy array
    im_rgb = Image.fromarray(im_bgr[..., ::-1])  # RGB-order PIL image

    # Show results to screen (in supported environments)
    r.show()

    # Save results to disk
    r.save(filename=f"results{i}.jpg")

plot() Method Parameters

The plot() method supports various arguments to customize the output:

ArgumentTypeDescriptionDefault
confboolInclude detection confidence scores.True
line_widthfloatLine width of bounding boxes. Scales with image size if None.None
font_sizefloatText font size. Scales with image size if None.None
fontstrFont name for text annotations.'Arial.ttf'
pilboolReturn image as a PIL Image object.False
imgnumpy.ndarrayAlternative image for plotting. Uses the original image if None.None
im_gputorch.TensorGPU-accelerated image for faster mask plotting. Shape: (1, 3, 640, 640).None
kpt_radiusintRadius for drawn keypoints.5
kpt_lineboolConnect keypoints with lines.True
labelsboolInclude class labels in annotations.True
boxesboolOverlay bounding boxes on the image.True
masksboolOverlay masks on the image.True
probsboolInclude classification probabilities.True
showboolDisplay the annotated image directly using the default image viewer.False
saveboolSave the annotated image to a file specified by filename.False
filenamestrPath and name of the file to save the annotated image if save is True.None
color_modestrSpecify the color mode, e.g., 'instance' or 'class'.'class'

Thread-Safe Inference

Ensuring thread safety during inference is crucial when you are running multiple YOLO models in parallel across different threads. Thread-safe inference guarantees that each thread's predictions are isolated and do not interfere with one another, avoiding race conditions and ensuring consistent and reliable outputs.

When using YOLO models in a multi-threaded application, it's important to instantiate separate model objects for each thread or employ thread-local storage to prevent conflicts:

Thread-Safe Inference

Instantiate a single model inside each thread for thread-safe inference:

from threading import Thread

from ultralytics import YOLO


def thread_safe_predict(model, image_path):
    """Performs thread-safe prediction on an image using a locally instantiated YOLO model."""
    model = YOLO(model)
    results = model.predict(image_path)
    # Process results


# Starting threads that each have their own model instance
Thread(target=thread_safe_predict, args=("yolo11n.pt", "image1.jpg")).start()
Thread(target=thread_safe_predict, args=("yolo11n.pt", "image2.jpg")).start()

For an in-depth look at thread-safe inference with YOLO models and step-by-step instructions, please refer to our YOLO Thread-Safe Inference Guide. This guide will provide you with all the necessary information to avoid common pitfalls and ensure that your multi-threaded inference runs smoothly.

Streaming Source for-loop

Here's a Python script using OpenCV (cv2) and YOLO to run inference on video frames. This script assumes you have already installed the necessary packages (opencv-python and ultralytics).

Streaming for-loop

import cv2

from ultralytics import YOLO

# Load the YOLO model
model = YOLO("yolo11n.pt")

# Open the video file
video_path = "path/to/your/video/file.mp4"
cap = cv2.VideoCapture(video_path)

# Loop through the video frames
while cap.isOpened():
    # Read a frame from the video
    success, frame = cap.read()

    if success:
        # Run YOLO inference on the frame
        results = model(frame)

        # Visualize the results on the frame
        annotated_frame = results[0].plot()

        # Display the annotated frame
        cv2.imshow("YOLO Inference", annotated_frame)

        # Break the loop if 'q' is pressed
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break
    else:
        # Break the loop if the end of the video is reached
        break

# Release the video capture object and close the display window
cap.release()
cv2.destroyAllWindows()

This script will run predictions on each frame of the video, visualize the results, and display them in a window. The loop can be exited by pressing 'q'.

FAQ

What is Ultralytics YOLO and its predict mode for real-time inference?

Ultralytics YOLO is a state-of-the-art model for real-time object detection, segmentation, and classification. Its predict mode allows users to perform high-speed inference on various data sources such as images, videos, and live streams. Designed for performance and versatility, it also offers batch processing and streaming modes. For more details on its features, check out the Ultralytics YOLO predict mode.

How can I run inference using Ultralytics YOLO on different data sources?

Ultralytics YOLO can process a wide range of data sources, including individual images, videos, directories, URLs, and streams. You can specify the data source in the model.predict() call. For example, use 'image.jpg' for a local image or 'https://ultralytics.com/images/bus.jpg' for a URL. Check out the detailed examples for various inference sources in the documentation.

How do I optimize YOLO inference speed and memory usage?

To optimize inference speed and manage memory efficiently, you can use the streaming mode by setting stream=True in the predictor's call method. The streaming mode generates a memory-efficient generator of Results objects instead of loading all frames into memory. For processing long videos or large datasets, streaming mode is particularly useful. Learn more about streaming mode.

What inference arguments does Ultralytics YOLO support?

The model.predict() method in YOLO supports various arguments such as conf, iou, imgsz, device, and more. These arguments allow you to customize the inference process, setting parameters like confidence thresholds, image size, and the device used for computation. Detailed descriptions of these arguments can be found in the inference arguments section.

How can I visualize and save the results of YOLO predictions?

After running inference with YOLO, the Results objects contain methods for displaying and saving annotated images. You can use methods like result.show() and result.save(filename="result.jpg") to visualize and save the results. For a comprehensive list of these methods, refer to the working with results section.

📅 Created 1 year ago ✏️ Updated 26 days ago

Comments