Skip to content

TensorRT Export for YOLOv8 Models

Deploying computer vision models in high-performance environments can require a format that maximizes speed and efficiency. This is especially true when you are deploying your model on NVIDIA GPUs.

By using the TensorRT export format, you can enhance your Ultralytics YOLOv8 models for swift and efficient inference on NVIDIA hardware. This guide will give you easy-to-follow steps for the conversion process and help you make the most of NVIDIA's advanced technology in your deep learning projects.

TensorRT

TensorRT Overview

TensorRT, developed by NVIDIA, is an advanced software development kit (SDK) designed for high-speed deep learning inference. Itโ€™s well-suited for real-time applications like object detection.

This toolkit optimizes deep learning models for NVIDIA GPUs and results in faster and more efficient operations. TensorRT models undergo TensorRT optimization, which includes techniques like layer fusion, precision calibration (INT8 and FP16), dynamic tensor memory management, and kernel auto-tuning. Converting deep learning models into the TensorRT format allows developers to realize the potential of NVIDIA GPUs fully.

TensorRT is known for its compatibility with various model formats, including TensorFlow, PyTorch, and ONNX, providing developers with a flexible solution for integrating and optimizing models from different frameworks. This versatility enables efficient model deployment across diverse hardware and software environments.

Key Features of TensorRT Models

TensorRT models offer a range of key features that contribute to their efficiency and effectiveness in high-speed deep learning inference:

  • Precision Calibration: TensorRT supports precision calibration, allowing models to be fine-tuned for specific accuracy requirements. This includes support for reduced precision formats like INT8 and FP16, which can further boost inference speed while maintaining acceptable accuracy levels.

  • Layer Fusion: The TensorRT optimization process includes layer fusion, where multiple layers of a neural network are combined into a single operation. This reduces computational overhead and improves inference speed by minimizing memory access and computation.

TensorRT Layer Fusion

  • Dynamic Tensor Memory Management: TensorRT efficiently manages tensor memory usage during inference, reducing memory overhead and optimizing memory allocation. This results in more efficient GPU memory utilization.

  • Automatic Kernel Tuning: TensorRT applies automatic kernel tuning to select the most optimized GPU kernel for each layer of the model. This adaptive approach ensures that the model takes full advantage of the GPU's computational power.

Deployment Options in TensorRT

Before we look at the code for exporting YOLOv8 models to the TensorRT format, letโ€™s understand where TensorRT models are normally used.

TensorRT offers several deployment options, and each option balances ease of integration, performance optimization, and flexibility differently:

  • Deploying within TensorFlow: This method integrates TensorRT into TensorFlow, allowing optimized models to run in a familiar TensorFlow environment. It's useful for models with a mix of supported and unsupported layers, as TF-TRT can handle these efficiently.

TensorRT Overview

  • Standalone TensorRT Runtime API: Offers granular control, ideal for performance-critical applications. It's more complex but allows for custom implementation of unsupported operators.

  • NVIDIA Triton Inference Server: An option that supports models from various frameworks. Particularly suited for cloud or edge inference, it provides features like concurrent model execution and model analysis.

Exporting YOLOv8 Models to TensorRT

You can improve execution efficiency and optimize performance by converting YOLOv8 models to TensorRT format.

Installation

To install the required package, run:

Installation

# Install the required package for YOLOv8
pip install ultralytics

For detailed instructions and best practices related to the installation process, check our YOLOv8 Installation guide. While installing the required packages for YOLOv8, if you encounter any difficulties, consult our Common Issues guide for solutions and tips.

Usage

Before diving into the usage instructions, be sure to check out the range of YOLOv8 models offered by Ultralytics. This will help you choose the most appropriate model for your project requirements.

Usage

from ultralytics import YOLO

# Load the YOLOv8 model
model = YOLO('yolov8n.pt')

# Export the model to TensorRT format
model.export(format='engine')  # creates 'yolov8n.engine'

# Load the exported TensorRT model
tensorrt_model = YOLO('yolov8n.engine')

# Run inference
results = tensorrt_model('https://ultralytics.com/images/bus.jpg')
# Export a YOLOv8n PyTorch model to TensorRT format
yolo export model=yolov8n.pt format=engine  # creates 'yolov8n.engine''

# Run inference with the exported model
yolo predict model=yolov8n.engine source='https://ultralytics.com/images/bus.jpg'

For more details about the export process, visit the Ultralytics documentation page on exporting.

Exporting TensorRT with INT8 Quantization

Exporting Ultralytics YOLO models using TensorRT with INT8 precision executes post-training quantization (PTQ). TensorRT uses calibration for PTQ, which measures the distribution of activations within each activation tensor as the YOLO model processes inference on representative input data, and then uses that distribution to estimate scale values for each tensor. Each activation tensor that is a candidate for quantization has an associated scale that is deduced by a calibration process.

When processing implicitly quantized networks TensorRT uses INT8 opportunistically to optimize layer execution time. If a layer runs faster in INT8 and has assigned quantization scales on its data inputs and outputs, then a kernel with INT8 precision is assigned to that layer, otherwise TensorRT selects a precision of either FP32 or FP16 for the kernel based on whichever results in faster execution time for that layer.

Tip

It is critical to ensure that the same device that will use the TensorRT model weights for deployment is used for exporting with INT8 precision, as the calibration results can vary across devices.

Configuring INT8 Export

The arguments provided when using export for an Ultralytics YOLO model will greatly influence the performance of the exported model. They will also need to be selected based on the device resources available, however the default arguments should work for most Ampere (or newer) NVIDIA discrete GPUs. The calibration algorithm used is "ENTROPY_CALIBRATION_2" and you can read more details about the options available in the TensorRT Developer Guide. Ultralytics tests found that "ENTROPY_CALIBRATION_2" was the best choice and exports are fixed to using this algorithm.

  • workspace : Controls the size (in GiB) of the device memory allocation while converting the model weights.

    • Aim to use the minimum workspace value required as this prevents testing algorithms that require more workspace from being considered by the TensorRT builder. Setting a higher value for workspace may take considerably longer to calibrate and export.

    • Default is workspace=4 (GiB), this value may need to be increased if calibration crashes (exits without warning).

    • TensorRT will report UNSUPPORTED_STATE during export if the value for workspace is larger than the memory available to the device, which means the value for workspace should be lowered.

    • If workspace is set to max value and calibration fails/crashes, consider reducing the values for imgsz and batch to reduce memory requirements.

    • Remember calibration for INT8 is specific to each device, borrowing a "high-end" GPU for calibration, might result in poor performance when inference is run on another device.

  • batch : The maximum batch-size that will be used for inference. During inference smaller batches can be used, but inference will not accept batches any larger than what is specified.

    Note

    During calibration, twice the batch size provided will be used. Using small batches can lead to inaccurate scaling during calibration. This is because the process adjusts based on the data it sees. Small batches might not capture the full range of values, leading to issues with the final calibration, so the batch size is doubled automatically. If no batch size is specified batch=1, calibration will be run at batch=1 * 2 to reduce calibration scaling errors.

Experimentation by NVIDIA led them to recommend using at least 500 calibration images that are representative of the data for your model, with INT8 quantization calibration. This is a guideline and not a hard requirement, and you will need to experiment with what is required to perform well for your dataset. Since the calibration data is required for INT8 calibration with TensorRT, make certain to use the data argument when int8=True for TensorRT and use data="my_dataset.yaml", which will use the images from validation to calibrate with. When no value is passed for data with export to TensorRT with INT8 quantization, the default will be to use one of the "small" example datasets based on the model task instead of throwing an error.

Example

from ultralytics import YOLO

model = YOLO("yolov8n.pt")
model.export(
    format="engine",
    dynamic=True, #(1)!
    batch=8, #(2)!
    workspace=4, #(3)!
    int8=True,
    data="coco.yaml", #(4)!
)

model = YOLO("yolov8n.engine", task="detect") # load the model
  1. Exports with dynamic axes, this will be enabled by default when exporting with int8=True even when not explicitly set. See export arguments for additional information.
  2. Sets max batch size of 8 for exported model, which calibrates with batch = 2 *ร—* 8 to avoid scaling errors during calibration.
  3. Allocates 4 GiB of memory instead of allocating the entire device for conversion process.
  4. Uses COCO dataset for calibration, specifically the images used for validation (5,000 total).
Calibration Cache

TensorRT will generate a calibration .cache which can be re-used to speed up export of future model weights using the same data, but this may result in poor calibration when the data is vastly different or if the batch value is changed drastically. In these circumstances, the existing .cache should be renamed and moved to a different directory or deleted entirely.

Advantages of using YOLO with TensorRT INT8

  • Reduced model size: Quantization from FP32 to INT8 can reduce the model size by 4x (on disk or in memory), leading to faster download times. lower storage requirements, and reduced memory footprint when deploying a model.

  • Lower power consumption: Reduced precision operations for INT8 exported YOLO models can consume less power compared to FP32 models, especially for battery-powered devices.

  • Improved inference speeds: TensorRT optimizes the model for the target hardware, potentially leading to faster inference speeds on GPUs, embedded devices, and accelerators.

Note on Inference Speeds

The first few inference calls with a model exported to TensorRT INT8 can be expected to have longer than usual preprocessing, inference, and/or postprocessing times. This may also occur when changing imgsz during inference, especially when imgsz is not the same as what was specified during export (export imgsz is set as TensorRT "optimal" profile).

Drawbacks of using YOLO with TensorRT INT8

  • Decreases in evaluation metrics: Using a lower precision will mean that mAP, Precision, Recall or any other metric used to evaluate model performance is likely to be somewhat worse. See the Performance results section to compare the differences in mAP50 and mAP50-95 when exporting with INT8 on small sample of various devices.

  • Increased development times: Finding the "optimal" settings for INT8 calibration for dataset and device can take a significant amount of testing.

  • Hardware dependency: Calibration and performance gains could be highly hardware dependent and model weights are less transferrable.

Ultralytics YOLO TensorRT Export Performance

NVIDIA A100

Performance

Tested with Ubuntu 22.04.3 LTS, python 3.10.12, ultralytics==8.2.4, tensorrt==8.6.1.post1

See Detection Docs for usage examples with these models trained on COCO, which include 80 pre-trained classes.

Note

Inference times shown for mean, min (fastest), and max (slowest) for each test using pre-trained weights yolov8n.engine

Precision Eval test mean
(ms)
min | max
(ms)
mAPval
50(B)
mAPval
50-95(B)
batch size
(pixels)
FP32 Predict 0.52 0.51 | 0.56 8 640
FP32 COCOval 0.52 0.52 0.37 1 640
FP16 Predict 0.34 0.34 | 0.41 8 640
FP16 COCOval 0.33 0.52 0.37 1 640
INT8 Predict 0.28 0.27 | 0.31 8 640
INT8 COCOval 0.29 0.47 0.33 1 640

See Segmentation Docs for usage examples with these models trained on COCO, which include 80 pre-trained classes.

Note

Inference times shown for mean, min (fastest), and max (slowest) for each test using pre-trained weights yolov8n-seg.engine

Precision Eval test mean
(ms)
min | max
(ms)
mAPval
50(B)
mAPval
50-95(B)
mAPval
50(M)
mAPval
50-95(M)
batch size
(pixels)
FP32 Predict 0.62 0.61 | 0.68 8 640
FP32 COCOval 0.63 0.52 0.36 0.49 0.31 1 640
FP16 Predict 0.40 0.39 | 0.44 8 640
FP16 COCOval 0.43 0.52 0.36 0.49 0.30 1 640
INT8 Predict 0.34 0.33 | 0.37 8 640
INT8 COCOval 0.36 0.46 0.32 0.43 0.27 1 640

See Classification Docs for usage examples with these models trained on ImageNet, which include 1000 pre-trained classes.

Note

Inference times shown for mean, min (fastest), and max (slowest) for each test using pre-trained weights yolov8n-cls.engine

Precision Eval test mean
(ms)
min | max
(ms)
top-1 top-5 batch size
(pixels)
FP32 Predict 0.26 0.25 | 0.28 0.35 0.61 8 640
FP32 ImageNetval 0.26 1 640
FP16 Predict 0.18 0.17 | 0.19 0.35 0.61 8 640
FP16 ImageNetval 0.18 1 640
INT8 Predict 0.16 0.15 | 0.57 0.32 0.59 8 640
INT8 ImageNetval 0.15 1 640

See Pose Estimation Docs for usage examples with these models trained on COCO, which include 1 pre-trained class, "person".

Note

Inference times shown for mean, min (fastest), and max (slowest) for each test using pre-trained weights yolov8n-pose.engine

Precision Eval test mean
(ms)
min | max
(ms)
mAPval
50(B)
mAPval
50-95(B)
mAPval
50(P)
mAPval
50-95(P)
batch size
(pixels)
FP32 Predict 0.54 0.53 | 0.58 8 640
FP32 COCOval 0.55 0.91 0.69 0.80 0.51 1 640
FP16 Predict 0.37 0.35 | 0.41 8 640
FP16 COCOval 0.36 0.91 0.69 0.80 0.51 1 640
INT8 Predict 0.29 0.28 | 0.33 8 640
INT8 COCOval 0.30 0.90 0.68 0.78 0.47 1 640

See Oriented Detection Docs for usage examples with these models trained on DOTAv1, which include 15 pre-trained classes.

Note

Inference times shown for mean, min (fastest), and max (slowest) for each test using pre-trained weights yolov8n-obb.engine

Precision Eval test mean
(ms)
min | max
(ms)
mAPval
50(B)
mAPval
50-95(B)
batch size
(pixels)
FP32 Predict 0.52 0.51 | 0.59 8 640
FP32 DOTAv1val 0.76 0.50 0.36 1 640
FP16 Predict 0.34 0.33 | 0.42 8 640
FP16 DOTAv1val 0.59 0.50 0.36 1 640
INT8 Predict 0.29 0.28 | 0.33 8 640
INT8 DOTAv1val 0.32 0.45 0.32 1 640

Consumer GPUs

Detection Performance (COCO)

Tested with Windows 10.0.19045, python 3.10.9, ultralytics==8.2.4, tensorrt==10.0.0b6

Note

Inference times shown for mean, min (fastest), and max (slowest) for each test using pre-trained weights yolov8n.engine

Precision Eval test mean
(ms)
min | max
(ms)
mAPval
50(B)
mAPval
50-95(B)
batch size
(pixels)
FP32 Predict 1.06 0.75 | 1.88 8 640
FP32 COCOval 1.37 0.52 0.37 1 640
FP16 Predict 0.62 0.75 | 1.13 8 640
FP16 COCOval 0.85 0.52 0.37 1 640
INT8 Predict 0.52 0.38 | 1.00 8 640
INT8 COCOval 0.74 0.47 0.33 1 640

Tested with Windows 10.0.22631, python 3.11.9, ultralytics==8.2.4, tensorrt==10.0.1

Note

Inference times shown for mean, min (fastest), and max (slowest) for each test using pre-trained weights yolov8n.engine

Precision Eval test mean
(ms)
min | max
(ms)
mAPval
50(B)
mAPval
50-95(B)
batch size
(pixels)
FP32 Predict 1.76 1.69 | 1.87 8 640
FP32 COCOval 1.94 0.52 0.37 1 640
FP16 Predict 0.86 0.75 | 1.00 8 640
FP16 COCOval 1.43 0.52 0.37 1 640
INT8 Predict 0.80 0.75 | 1.00 8 640
INT8 COCOval 1.35 0.47 0.33 1 640

Tested with Pop!_OS 22.04 LTS, python 3.10.12, ultralytics==8.2.4, tensorrt==8.6.1.post1

Note

Inference times shown for mean, min (fastest), and max (slowest) for each test using pre-trained weights yolov8n.engine

Precision Eval test mean
(ms)
min | max
(ms)
mAPval
50(B)
mAPval
50-95(B)
batch size
(pixels)
FP32 Predict 2.84 2.84 | 2.85 8 640
FP32 COCOval 2.94 0.52 0.37 1 640
FP16 Predict 1.09 1.09 | 1.10 8 640
FP16 COCOval 1.20 0.52 0.37 1 640
INT8 Predict 0.75 0.74 | 0.75 8 640
INT8 COCOval 0.76 0.47 0.33 1 640

Embedded Devices

Detection Performance (COCO)

Tested with JetPack 5.1.3 (L4T 35.5.0) Ubuntu 20.04.6, python 3.8.10, ultralytics==8.2.4, tensorrt==8.5.2.2

Note

Inference times shown for mean, min (fastest), and max (slowest) for each test using pre-trained weights yolov8n.engine

Precision Eval test mean
(ms)
min | max
(ms)
mAPval
50(B)
mAPval
50-95(B)
batch size
(pixels)
FP32 Predict 6.90 6.89 | 6.93 8 640
FP32 COCOval 6.97 0.52 0.37 1 640
FP16 Predict 3.36 3.35 | 3.39 8 640
FP16 COCOval 3.39 0.52 0.37 1 640
INT8 Predict 2.32 2.32 | 2.34 8 640
INT8 COCOval 2.33 0.47 0.33 1 640

Info

See our quickstart guide on NVIDIA Jetson with Ultralytics YOLO to learn more about setup and configuration.

Evaluation methods

Expand sections below for information on how these models were exported and tested.

Export configurations

See export mode for details regarding export configuration arguments.

from ultralytics import YOLO

model = YOLO("yolov8n.pt")

# TensorRT FP32
out = model.export(
    format="engine",
    imgsz:640,
    dynamic:True,
    verbose:False,
    batch:8,
    workspace:2
)

# TensorRT FP16
out = model.export(
    format="engine",
    imgsz:640,
    dynamic:True,
    verbose:False,
    batch:8,
    workspace:2,
    half=True
)

# TensorRT INT8
out = model.export(
    format="engine",
    imgsz:640,
    dynamic:True,
    verbose:False,
    batch:8,
    workspace:2,
    int8=True,
    data:"data.yaml"  # COCO, ImageNet, or DOTAv1 for appropriate model task
)
Predict loop

See predict mode for additional information.

import cv2
from ultralytics import YOLO

model = YOLO("yolov8n.engine")
img = cv2.imread("path/to/image.jpg")

for _ in range(100):
    result = model.predict(
        [img] * 8,  # batch=8 of the same image
        verbose=False,
        device="cuda"
    )
Validation configuration

See val mode to learn more about validation configuration arguments.

from ultralytics import YOLO

model = YOLO("yolov8n.engine")
results = model.val(
    data="data.yaml",  # COCO, ImageNet, or DOTAv1 for appropriate model task
    batch=1,
    imgsz=640,
    verbose=False,
    device="cuda"
)

Deploying Exported YOLOv8 TensorRT Models

Having successfully exported your Ultralytics YOLOv8 models to TensorRT format, you're now ready to deploy them. For in-depth instructions on deploying your TensorRT models in various settings, take a look at the following resources:

Summary

In this guide, we focused on converting Ultralytics YOLOv8 models to NVIDIA's TensorRT model format. This conversion step is crucial for improving the efficiency and speed of YOLOv8 models, making them more effective and suitable for diverse deployment environments.

For more information on usage details, take a look at the TensorRT official documentation.

If you're curious about additional Ultralytics YOLOv8 integrations, our integration guide page provides an extensive selection of informative resources and insights.



Created 2024-01-28, Updated 2024-05-08
Authors: Burhan-Q (1), glenn-jocher (1), abirami-vina (1)

Comments