Skip to content

Quick Start Guide: NVIDIA DGX Spark with Ultralytics YOLO11

This comprehensive guide provides a detailed walkthrough for deploying Ultralytics YOLO11 on NVIDIA DGX Spark, NVIDIA's compact desktop AI supercomputer. Additionally, it showcases performance benchmarks to demonstrate the capabilities of YOLO11 on this powerful system.

NVIDIA DGX Spark

Note

This guide has been tested with NVIDIA DGX Spark Founders Edition running DGX OS based on Ubuntu. It is expected to work with the latest DGX OS releases.

What is NVIDIA DGX Spark?

NVIDIA DGX Spark is a compact desktop AI supercomputer powered by the NVIDIA GB10 Grace Blackwell Superchip. It delivers up to 1 petaFLOP of AI computing performance with FP4 precision, making it ideal for developers, researchers, and data scientists who need powerful AI capabilities in a desktop form factor.

Key Specifications

SpecificationDetails
AI PerformanceUp to 1 PFLOP (FP4)
GPUNVIDIA Blackwell Architecture with 5th Generation Tensor Cores, 4th Generation RT Cores
CPU20-core Arm processor (10 Cortex-X925 + 10 Cortex-A725)
Memory128 GB LPDDR5x unified system memory, 256-bit interface, 4266 MHz, 273 GB/s bandwidth
Storage1 TB or 4 TB NVMe M.2 with self-encryption
Network1x RJ-45 (10 GbE), ConnectX-7 Smart NIC, Wi-Fi 7, Bluetooth 5.4
Connectivity4x USB Type-C, 1x HDMI 2.1a, HDMI multichannel audio
Video Processing1x NVENC, 1x NVDEC

DGX OS

NVIDIA DGX OS is a customized Linux distribution that provides a stable, tested, and supported operating system foundation for running AI, machine learning, and analytics applications on DGX systems. It includes:

  • A robust Linux foundation optimized for AI workloads
  • Pre-configured drivers and system settings for NVIDIA hardware
  • Security updates and system maintenance capabilities
  • Compatibility with the broader NVIDIA software ecosystem

DGX OS follows a regular release schedule with updates typically provided twice per year (around February and August), with additional security patches provided between major releases.

DGX Dashboard

DGX Spark comes with a built-in DGX Dashboard that provides:

  • Real-time System Monitoring: Overview of the system's current operational metrics
  • System Updates: Ability to apply updates directly from the dashboard
  • System Settings: Change device name and other configurations
  • Integrated JupyterLab: Access local Jupyter Notebooks for development

NVIDIA DGX Dashboard

Accessing the Dashboard

Click the "Show Apps" button in the bottom left corner of the Ubuntu desktop, then select "DGX Dashboard" to open it in your browser.

# Open an SSH tunnel
ssh -L 11000:localhost:11000 <username>@<IP or spark-abcd.local>

# Then open in browser
# http://localhost:11000

After connecting with NVIDIA Sync, click the "DGX Dashboard" button to open the dashboard at http://localhost:11000.

Integrated JupyterLab

The dashboard includes an integrated JupyterLab instance that automatically creates a virtual environment and installs recommended packages when started. Each user account is assigned a dedicated port for JupyterLab access.

Quick Start with Docker

The fastest way to get started with Ultralytics YOLO11 on NVIDIA DGX Spark is to run with pre-built docker images. The same Docker image that supports Jetson AGX Thor (JetPack 7.0) works on DGX Spark with DGX OS.

t=ultralytics/ultralytics:latest-nvidia-arm64
sudo docker pull $t && sudo docker run -it --ipc=host --runtime=nvidia --gpus all $t

After this is done, skip to Use TensorRT on NVIDIA DGX Spark section.

Start with Native Installation

For a native installation without Docker, follow these steps.

Install Ultralytics Package

Here we will install Ultralytics package on DGX Spark with optional dependencies so that we can export the PyTorch models to other different formats. We will mainly focus on NVIDIA TensorRT exports because TensorRT will make sure we can get the maximum performance out of the DGX Spark.

  1. Update packages list, install pip and upgrade to latest

    sudo apt update
    sudo apt install python3-pip -y
    pip install -U pip
    
  2. Install ultralytics pip package with optional dependencies

    pip install ultralytics[export]
    
  3. Reboot the device

    sudo reboot
    

Install PyTorch and Torchvision

The above ultralytics installation will install Torch and Torchvision. However, these packages installed via pip may not be fully optimized for the DGX Spark's ARM64 architecture with CUDA 13. Therefore, we recommend installing the CUDA 13 compatible versions:

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu130

Info

When running PyTorch 2.9.1 on NVIDIA DGX Spark, you may encounter the following UserWarning when initializing CUDA (e.g. running yolo checks, yolo predict, etc.):

UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0)

This warning can be safely ignored. To address this permanently, a fix has been submitted in PyTorch PR #164590 which will be included in the PyTorch 2.10 release.

Install onnxruntime-gpu

The onnxruntime-gpu package hosted in PyPI does not have aarch64 binaries for ARM64 systems. So we need to manually install this package. This package is needed for some of the exports.

Here we will download and install onnxruntime-gpu 1.24.0 with Python3.12 support.

pip install https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.24.0-cp312-cp312-linux_aarch64.whl

Use TensorRT on NVIDIA DGX Spark

Among all the model export formats supported by Ultralytics, TensorRT offers the highest inference performance on NVIDIA DGX Spark, making it our top recommendation for deployments. For setup instructions and advanced usage, see our dedicated TensorRT integration guide.

Convert Model to TensorRT and Run Inference

The YOLO11n model in PyTorch format is converted to TensorRT to run inference with the exported model.

Example

from ultralytics import YOLO

# Load a YOLO11n PyTorch model
model = YOLO("yolo11n.pt")

# Export the model to TensorRT
model.export(format="engine")  # creates 'yolo11n.engine'

# Load the exported TensorRT model
trt_model = YOLO("yolo11n.engine")

# Run inference
results = trt_model("https://ultralytics.com/images/bus.jpg")
# Export a YOLO11n PyTorch model to TensorRT format
yolo export model=yolo11n.pt format=engine # creates 'yolo11n.engine'

# Run inference with the exported model
yolo predict model=yolo11n.engine source='https://ultralytics.com/images/bus.jpg'

Note

Visit the Export page to access additional arguments when exporting models to different model formats

NVIDIA DGX Spark YOLO11 Benchmarks

YOLO11 benchmarks were run by the Ultralytics team on multiple model formats measuring speed and accuracy: PyTorch, TorchScript, ONNX, OpenVINO, TensorRT, TF SavedModel, TF GraphDef, TF Lite, MNN, NCNN, ExecuTorch. Benchmarks were run on NVIDIA DGX Spark at FP32 precision with default input image size of 640.

Detailed Comparison Table

The below table represents the benchmark results for five different models (YOLO11n, YOLO11s, YOLO11m, YOLO11l, YOLO11x) across multiple formats, giving us the status, size, mAP50-95(B) metric, and inference time for each combination.

Performance

FormatStatusSize on disk (MB)mAP50-95(B)Inference time (ms/im)
PyTorchโœ…5.40.50712.67
TorchScriptโœ…10.50.50832.62
ONNXโœ…10.20.50745.92
OpenVINOโœ…10.40.505814.95
TensorRT (FP32)โœ…12.80.50851.95
TensorRT (FP16)โœ…7.00.50681.01
TensorRT (INT8)โœ…18.60.48801.62
TF SavedModelโœ…25.70.507636.39
TF GraphDefโœ…10.30.507641.06
TF Liteโœ…10.30.507564.36
MNNโœ…10.10.507512.14
NCNNโœ…10.20.504112.31
ExecuTorchโœ…10.20.507527.61
FormatStatusSize on disk (MB)mAP50-95(B)Inference time (ms/im)
PyTorchโœ…18.40.57675.38
TorchScriptโœ…36.50.57815.48
ONNXโœ…36.30.57848.17
OpenVINOโœ…36.40.580927.12
TensorRT (FP32)โœ…39.80.57833.59
TensorRT (FP16)โœ…20.10.58001.85
TensorRT (INT8)โœ…17.50.56641.88
TF SavedModelโœ…90.80.578266.63
TF GraphDefโœ…36.30.578271.67
TF Liteโœ…36.30.5782187.36
MNNโœ…36.20.577527.05
NCNNโœ…36.20.580626.26
ExecuTorchโœ…36.20.578254.73
FormatStatusSize on disk (MB)mAP50-95(B)Inference time (ms/im)
PyTorchโœ…38.80.625411.14
TorchScriptโœ…77.30.630412.00
ONNXโœ…76.90.630413.83
OpenVINOโœ…77.10.628462.44
TensorRT (FP32)โœ…79.90.63056.96
TensorRT (FP16)โœ…40.60.63133.14
TensorRT (INT8)โœ…26.60.62043.30
TF SavedModelโœ…192.40.6306139.85
TF GraphDefโœ…76.90.6306146.76
TF Liteโœ…76.90.6306568.18
MNNโœ…76.80.630667.67
NCNNโœ…76.80.630860.49
ExecuTorchโœ…76.90.6306120.37
FormatStatusSize on disk (MB)mAP50-95(B)Inference time (ms/im)
PyTorchโœ…49.00.636613.95
TorchScriptโœ…97.60.639915.67
ONNXโœ…97.00.639916.62
OpenVINOโœ…97.30.637778.80
TensorRT (FP32)โœ…99.20.64078.86
TensorRT (FP16)โœ…50.80.63503.85
TensorRT (INT8)โœ…32.50.62244.52
TF SavedModelโœ…242.70.6409187.45
TF GraphDefโœ…97.00.6409193.92
TF Liteโœ…97.00.6409728.61
MNNโœ…96.90.636985.21
NCNNโœ…96.90.637377.62
ExecuTorchโœ…97.00.6409153.56
FormatStatusSize on disk (MB)mAP50-95(B)Inference time (ms/im)
PyTorchโœ…109.30.699223.19
TorchScriptโœ…218.10.690025.75
ONNXโœ…217.50.690027.43
OpenVINOโœ…217.80.6872149.44
TensorRT (FP32)โœ…222.70.690213.87
TensorRT (FP16)โœ…111.10.68836.19
TensorRT (INT8)โœ…62.90.67936.62
TF SavedModelโœ…543.90.6900335.10
TF GraphDefโœ…217.50.6900348.86
TF Liteโœ…217.50.69001578.66
MNNโœ…217.30.6874168.95
NCNNโœ…217.40.6901132.13
ExecuTorchโœ…217.40.6900297.17

Benchmarked with Ultralytics 8.3.249

Reproduce Our Results

To reproduce the above Ultralytics benchmarks on all export formats run this code:

Example

from ultralytics import YOLO

# Load a YOLO11n PyTorch model
model = YOLO("yolo11n.pt")

# Benchmark YOLO11n speed and accuracy on the COCO128 dataset for all export formats
results = model.benchmark(data="coco128.yaml", imgsz=640)
# Benchmark YOLO11n speed and accuracy on the COCO128 dataset for all export formats
yolo benchmark model=yolo11n.pt data=coco128.yaml imgsz=640

Note that benchmarking results might vary based on the exact hardware and software configuration of a system, as well as the current workload of the system at the time the benchmarks are run. For the most reliable results, use a dataset with a large number of images, e.g., data='coco.yaml' (5000 val images).

Best Practices for NVIDIA DGX Spark

When using NVIDIA DGX Spark, there are a couple of best practices to follow in order to enable maximum performance running YOLO11.

  1. Monitor System Performance

    Use NVIDIA's monitoring tools to track GPU and CPU utilization:

    nvidia-smi
    
  2. Optimize Memory Usage

    With 128GB of unified memory, DGX Spark can handle large batch sizes and models. Consider increasing batch size for improved throughput:

    from ultralytics import YOLO
    
    model = YOLO("yolo11n.engine")
    results = model.predict(source="path/to/images", batch=16)
    
  3. Use TensorRT with FP16 or INT8

    For best performance, export models with FP16 or INT8 precision:

    yolo export model=yolo11n.pt format=engine half=True # FP16
    yolo export model=yolo11n.pt format=engine int8=True # INT8
    

System Updates (Founders Edition)

Keeping your DGX Spark Founders Edition up to date is crucial for performance and security. NVIDIA provides two primary methods for updating the system OS, drivers, and firmware.

The DGX Dashboard is the recommended way to perform system updates ensuring compatibility. It allows you to:

  • View available system updates
  • Install security patches and system updates
  • Manage NVIDIA driver and firmware updates

Manual System Updates

For advanced users, updates can be performed manually via terminal:

sudo apt update
sudo apt dist-upgrade
sudo fwupdmgr refresh
sudo fwupdmgr upgrade
sudo reboot

Warning

Ensure your system is connected to a stable power source and you have backed up critical data before performing updates.

Next Steps

For further learning and support, see the Ultralytics YOLO11 Docs.

FAQ

How do I deploy Ultralytics YOLO11 on NVIDIA DGX Spark?

Deploying Ultralytics YOLO11 on NVIDIA DGX Spark is straightforward. You can use the pre-built Docker image for quick setup or manually install the required packages. Detailed steps for each approach can be found in sections Quick Start with Docker and Start with Native Installation.

What performance can I expect from YOLO11 on NVIDIA DGX Spark?

YOLO11 models deliver excellent performance on DGX Spark thanks to the GB10 Grace Blackwell Superchip. The TensorRT format provides the best inference performance. Check the Detailed Comparison Table section for specific benchmark results across different model sizes and formats.

Why should I use TensorRT for YOLO11 on DGX Spark?

TensorRT is highly recommended for deploying YOLO11 models on DGX Spark due to its optimal performance. It accelerates inference by leveraging the Blackwell GPU capabilities, ensuring maximum efficiency and speed. Learn more in the Use TensorRT on NVIDIA DGX Spark section.

How does DGX Spark compare to Jetson devices for YOLO11?

DGX Spark offers significantly more compute power than Jetson devices with up to 1 PFLOP of AI performance and 128GB unified memory, compared to Jetson AGX Thor's 2070 TFLOPS and 128GB memory. DGX Spark is designed as a desktop AI supercomputer, while Jetson devices are embedded systems optimized for edge deployment.

Can I use the same Docker image for DGX Spark and Jetson AGX Thor?

Yes! The ultralytics/ultralytics:latest-nvidia-arm64 Docker image supports both NVIDIA DGX Spark (with DGX OS) and Jetson AGX Thor (with JetPack 7.0), as both use ARM64 architecture with CUDA 13 and similar software stacks.



๐Ÿ“… Created 0 days ago โœ๏ธ Updated 0 days ago
onuralpszr

Comments