Skip to content

YOLOE: Real-Time Seeing Anything

Introduction

YOLOE Prompting Options

YOLOE (Real-Time Seeing Anything) is a new advancement in zero-shot, promptable YOLO models, designed for open-vocabulary detection and segmentation. Unlike previous YOLO models limited to fixed categories, YOLOE uses text, image, or internal vocabulary prompts, enabling real-time detection of any object class. Built upon YOLOv10 and inspired by YOLO-World, YOLOE achieves state-of-the-art zero-shot performance with minimal impact on speed and accuracy.



Watch: How to use YOLOE with Ultralytics Python package: Open Vocabulary & Real-Time Seeing Anything 🚀

Compared to earlier YOLO models, YOLOE significantly boosts efficiency and accuracy. It improves by +3.5 AP over YOLO-Worldv2 on LVIS while using just a third of the training resources and achieving 1.4× faster inference speeds. Fine-tuned on COCO, YOLOE-v8-large surpasses YOLOv8-L by 0.1 mAP, using nearly 4× less training time. This demonstrates YOLOE's exceptional balance of accuracy, efficiency, and versatility. The sections below explore YOLOE's architecture, benchmark comparisons, and integration with the Ultralytics framework.

Architecture Overview

YOLOE Architecture

YOLOE retains the standard YOLO structure—a convolutional backbone (e.g., CSP-Darknet) for feature extraction, a neck (e.g., PAN-FPN) for multi-scale fusion, and an anchor-free, decoupled detection head (as in YOLOv8/YOLO11) predicting objectness, classes, and boxes independently. YOLOE introduces three novel modules enabling open-vocabulary detection:

  • Re-parameterizable Region-Text Alignment (RepRTA): Supports text-prompted detection by refining text embeddings (e.g., from CLIP) via a small auxiliary network. At inference, this network is folded into the main model, ensuring zero overhead. YOLOE thus detects arbitrary text-labeled objects (e.g., unseen "traffic light") without runtime penalties.

  • Semantic-Activated Visual Prompt Encoder (SAVPE): Enables visual-prompted detection via a lightweight embedding branch. Given a reference image, SAVPE encodes semantic and activation features, conditioning the model to detect visually similar objects—a one-shot detection capability useful for logos or specific parts.

  • Lazy Region-Prompt Contrast (LRPC): In prompt-free mode, YOLOE performs open-set recognition using internal embeddings trained on large vocabularies (1200+ categories from LVIS and Objects365). Without external prompts or encoders, YOLOE identifies objects via embedding similarity lookup, efficiently handling large label spaces at inference.

Additionally, YOLOE integrates real-time instance segmentation by extending the detection head with a mask prediction branch (similar to YOLACT or YOLOv8-Seg), adding minimal overhead.

Crucially, YOLOE's open-world modules introduce no inference cost when used as a regular closed-set YOLO. Post-training, YOLOE parameters can be re-parameterized into a standard YOLO head, preserving identical FLOPs and speed (e.g., matching YOLO11 exactly).

Available Models, Supported Tasks, and Operating Modes

This section details the models available with their specific pre-trained weights, the tasks they support, and their compatibility with various operating modes such as Inference, Validation, Training, and Export, denoted by ✅ for supported modes and ❌ for unsupported modes.

Text/Visual Prompt models

Model Type Pre-trained Weights Tasks Supported Inference Validation Training Export
YOLOE-11S yoloe-11s-seg.pt Instance Segmentation
YOLOE-11M yoloe-11m-seg.pt Instance Segmentation
YOLOE-11L yoloe-11l-seg.pt Instance Segmentation
YOLOE-v8S yoloe-v8s-seg.pt Instance Segmentation
YOLOE-v8M yoloe-v8m-seg.pt Instance Segmentation
YOLOE-v8L yoloe-v8l-seg.pt Instance Segmentation

Prompt Free models

Model Type Pre-trained Weights Tasks Supported Inference Validation Training Export
YOLOE-11S-PF yoloe-11s-seg-pf.pt Instance Segmentation
YOLOE-11M-PF yoloe-11m-seg-pf.pt Instance Segmentation
YOLOE-11L-PF yoloe-11l-seg-pf.pt Instance Segmentation
YOLOE-v8S-PF yoloe-v8s-seg-pf.pt Instance Segmentation
YOLOE-v8M-PF yoloe-v8m-seg-pf.pt Instance Segmentation
YOLOE-v8L-PF yoloe-v8l-seg-pf.pt Instance Segmentation

Usage Examples

The YOLOE models are easy to integrate into your Python applications. Ultralytics provides user-friendly Python API and CLI commands to streamline development.

Train Usage

Fine-Tuning on custom dataset

Example

from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer

model = YOLOE("yoloe-11s-seg.pt")

model.train(
    data="coco128-seg.yaml",
    epochs=80,
    close_mosaic=10,
    batch=128,
    optimizer="AdamW",
    lr0=1e-3,
    warmup_bias_lr=0.0,
    weight_decay=0.025,
    momentum=0.9,
    workers=4,
    device="0",
    trainer=YOLOEPESegTrainer,
)
from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer

model = YOLOE("yoloe-11s-seg.pt")
head_index = len(model.model.model) - 1
freeze = [str(f) for f in range(0, head_index)]
for name, child in model.model.model[-1].named_children():
    if "cv3" not in name:
        freeze.append(f"{head_index}.{name}")

freeze.extend(
    [
        f"{head_index}.cv3.0.0",
        f"{head_index}.cv3.0.1",
        f"{head_index}.cv3.1.0",
        f"{head_index}.cv3.1.1",
        f"{head_index}.cv3.2.0",
        f"{head_index}.cv3.2.1",
    ]
)

model.train(
    data="coco128-seg.yaml",
    epochs=2,
    close_mosaic=0,
    batch=16,
    optimizer="AdamW",
    lr0=1e-3,
    warmup_bias_lr=0.0,
    weight_decay=0.025,
    momentum=0.9,
    workers=4,
    device="0",
    trainer=YOLOEPESegTrainer,
    freeze=freeze,
)

Predict Usage

YOLOE supports both text-based and visual prompting. Using prompts is straightforward—just pass them through the predict method as shown below:

Example

Text prompts allow you to specify the classes that you wish to detect through textual descriptions. The following code shows how you can use YOLOE to detect people and buses in an image:

from ultralytics import YOLOE

# Initialize a YOLOE model
model = YOLOE("yoloe-11l-seg.pt")  # or select yoloe-11s/m-seg.pt for different sizes

# Set text prompt to detect person and bus. You only need to do this once after you load the model.
names = ["person", "bus"]
model.set_classes(names, model.get_text_pe(names))

# Run detection on the given image
results = model.predict("path/to/image.jpg")

# Show results
results[0].show()

Visual prompts allow you to guide the model by showing it visual examples of the target classes, rather than describing them in text.

The visual_prompts argument takes a dictionary with two keys: bboxes and cls. Each bounding box in bboxes should tightly enclose an example of the object you want the model to detect, and the corresponding entry in cls specifies the class label for that box. This pairing tells the model, "This is what class X looks like—now find more like it."

Class IDs (cls) in visual_prompts are used to associate each bounding box with a specific category within your prompt. They aren't fixed labels, but temporary identifiers you assign to each example. The only requirement is that class IDs must be sequential, starting from 0. This helps the model correctly associate each box with its respective class.

You can provide visual prompts directly within the same image you want to run inference on. For example:

import numpy as np

from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor

# Initialize a YOLOE model
model = YOLOE("yoloe-11l-seg.pt")

# Define visual prompts using bounding boxes and their corresponding class IDs.
# Each box highlights an example of the object you want the model to detect.
visual_prompts = dict(
    bboxes=np.array(
        [
            [221.52, 405.8, 344.98, 857.54],  # Box enclosing person
            [120, 425, 160, 445],  # Box enclosing glasses
        ],
    ),
    cls=np.array(
        [
            0,  # ID to be assigned for person
            1,  # ID to be assigned for glassses
        ]
    ),
)

# Run inference on an image, using the provided visual prompts as guidance
results = model.predict(
    "ultralytics/assets/bus.jpg",
    visual_prompts=visual_prompts,
    predictor=YOLOEVPSegPredictor,
)

# Show results
results[0].show()

Or you can provide examples from a separate reference image using the refer_image argument. In that case, the bboxes and cls in visual_prompts should describe objects in the reference image, not the target image you're making predictions on:

Note

If source is a video or stream, the model automatically uses the first frame as the refer_image. This means your visual_prompts are applied to that initial frame to help the model understand what to look for in the rest of the video. Alternatively, you can explicitly pass any specific frame as the refer_image to control which visual examples the model uses as reference.

import numpy as np

from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor

# Initialize a YOLOE model
model = YOLOE("yoloe-11l-seg.pt")

# Define visual prompts based on a separate reference image
visual_prompts = dict(
    bboxes=np.array([[221.52, 405.8, 344.98, 857.54]]),  # Box enclosing person
    cls=np.array([0]),  # ID to be assigned for person
)

# Run prediction on a different image, using reference image to guide what to look for
results = model.predict(
    "ultralytics/assets/zidane.jpg",  # Target image for detection
    refer_image="ultralytics/assets/bus.jpg",  # Reference image used to get visual prompts
    visual_prompts=visual_prompts,
    predictor=YOLOEVPSegPredictor,
)

# Show results
results[0].show()

You can also pass multiple target images to run prediction on:

import numpy as np

from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor

# Initialize a YOLOE model
model = YOLOE("yoloe-11l-seg.pt")

# Define visual prompts using bounding boxes and their corresponding class IDs.
# Each box highlights an example of the object you want the model to detect.
visual_prompts = dict(
    bboxes=[
        np.array(
            [
                [221.52, 405.8, 344.98, 857.54],  # Box enclosing person
                [120, 425, 160, 445],  # Box enclosing glasses
            ],
        ),
        np.array([[150, 200, 1150, 700]]),
    ],
    cls=[
        np.array(
            [
                0,  # ID to be assigned for person
                1,  # ID to be assigned for glasses
            ]
        ),
        np.array([0]),
    ],
)

# Run inference on multiple image, using the provided visual prompts as guidance
results = model.predict(
    ["ultralytics/assets/bus.jpg", "ultralytics/assets/zidane.jpg"],
    visual_prompts=visual_prompts,
    predictor=YOLOEVPSegPredictor,
)

# Show results
results[0].show()

YOLOE also includes prompt-free variants that come with a built-in vocabulary. These models don't require any prompts and work like traditional YOLO models. Instead of relying on user-provided labels or visual examples, they detect objects from a predefined list of 4,585 classes based on the tag set used by the Recognize Anything Model Plus (RAM++).

from ultralytics import YOLOE

# Initialize a YOLOE model
model = YOLOE("yoloe-11l-seg-pf.pt")

# Run prediction. No prompts required.
results = model.predict("path/to/image.jpg")

# Show results
results[0].show()

Val Usage

Example

from ultralytics import YOLOE

# Create a YOLOE model
model = YOLOE("yoloe-11l-seg.pt")  # or select yoloe-m/l-seg.pt for different sizes

# Conduct model validation on the COCO128-seg example dataset
metrics = model.val(data="coco128-seg.yaml")

Be default it's using the provided dataset to extract visual embeddings for each category.

from ultralytics import YOLOE

# Create a YOLOE model
model = YOLOE("yoloe-11l-seg.pt")  # or select yoloe-m/l-seg.pt for different sizes

# Conduct model validation on the COCO128-seg example dataset
metrics = model.val(data="coco128-seg.yaml", load_vp=True)

Alternatively we could use another dataset as a reference dataset to extract visual embeddings for each category. Note this reference dataset should have exactly the same categories as provided dataset.

from ultralytics import YOLOE

# Create a YOLOE model
model = YOLOE("yoloe-11l-seg.pt")  # or select yoloe-m/l-seg.pt for different sizes

# Conduct model validation on the COCO128-seg example dataset
metrics = model.val(data="coco128-seg.yaml", load_vp=True, refer_data="coco.yaml")
from ultralytics import YOLOE

# Create a YOLOE model
model = YOLOE("yoloe-11l-seg.pt")  # or select yoloe-m/l-seg.pt for different sizes

# Conduct model validation on the COCO128-seg example dataset
metrics = model.val(data="coco128-seg.yaml")

Model validation on a dataset is streamlined as follows:

Train Official Models

Prepare datasets

Note

Training official YOLOE models needs segment annotations for train data, here's the script provided by official team that converts datasets to segment annotations, powered by SAM2.1 models. Or you can directly download the provided Processed Segment Annotations in following table provided by official team.

  • Train data
Dataset Type Samples Boxes Raw Detection Annotations Processed Segment Annotations
Objects365v1 Detection 609k 9621k objects365_train.json objects365_train_segm.json
GQA Grounding 621k 3681k final_mixed_train_no_coco.json final_mixed_train_no_coco_segm.json
Flickr30k Grounding 149k 641k final_flickr_separateGT_train.json final_flickr_separateGT_train_segm.json
  • Val data
Dataset Type Annotation Files
LVIS minival Detection minival.txt

Launching training from scratch

Note

Visual Prompt models are fine-tuned based on trained-well Text Prompt models.

Example

from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOESegTrainerFromScratch

data = dict(
    train=dict(
        yolo_data=["Objects365.yaml"],
        grounding_data=[
            dict(
                img_path="../datasets/flickr/full_images/",
                json_file="../datasets/flickr/annotations/final_flickr_separateGT_train_segm.json",
            ),
            dict(
                img_path="../datasets/mixed_grounding/gqa/images",
                json_file="../datasets/mixed_grounding/annotations/final_mixed_train_no_coco_segm.json",
            ),
        ],
    ),
    val=dict(yolo_data=["lvis.yaml"]),
)

model = YOLOE("yoloe-11l-seg.yaml")
model.train(
    data=data,
    batch=128,
    epochs=30,
    close_mosaic=2,
    optimizer="AdamW",
    lr0=2e-3,
    warmup_bias_lr=0.0,
    weight_decay=0.025,
    momentum=0.9,
    workers=4,
    trainer=YOLOESegTrainerFromScratch,
    device="0,1,2,3,4,5,6,7",
)

Since only the SAVPE module needs to be updating during training. Converting trained-well Text-prompt model to detection model and adopt detection pipeline with less training cost. Note this step is optional, you can directly start from segmentation as well.

import torch

from ultralytics import YOLOE

det_model = YOLOE("yoloe-11l.yaml")
state = torch.load("yoloe-11l-seg.pt")
det_model.load(state["model"])
det_model.save("yoloe-11l-seg-det.pt")

Start training:

from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOEVPTrainer

data = dict(
    train=dict(
        yolo_data=["Objects365.yaml"],
        grounding_data=[
            dict(
                img_path="../datasets/flickr/full_images/",
                json_file="../datasets/flickr/annotations/final_flickr_separateGT_train_segm.json",
            ),
            dict(
                img_path="../datasets/mixed_grounding/gqa/images",
                json_file="../datasets/mixed_grounding/annotations/final_mixed_train_no_coco_segm.json",
            ),
        ],
    ),
    val=dict(yolo_data=["lvis.yaml"]),
)

model = YOLOE("yoloe-11l-seg.pt")
# replace to yoloe-11l-seg-det.pt if converted to detection model
# model = YOLOE("yoloe-11l-seg-det.pt")

# freeze every layer except of the savpe module.
head_index = len(model.model.model) - 1
freeze = list(range(0, head_index))
for name, child in model.model.model[-1].named_children():
    if "savpe" not in name:
        freeze.append(f"{head_index}.{name}")

model.train(
    data=data,
    batch=128,
    epochs=2,
    close_mosaic=2,
    optimizer="AdamW",
    lr0=16e-3,
    warmup_bias_lr=0.0,
    weight_decay=0.025,
    momentum=0.9,
    workers=4,
    trainer=YOLOEVPTrainer,
    device="0,1,2,3,4,5,6,7",
    freeze=freeze,
)

Convert back to segmentation model after training. Only needed if you converted segmentation model to detection model before training.

from copy import deepcopy

from ultralytics import YOLOE

model = YOLOE("yoloe-11l-seg.yaml")
model.load("yoloe-11l-seg.pt")

vp_model = YOLOE("yoloe-11l-vp.pt")
model.model.model[-1].savpe = deepcopy(vp_model.model.model[-1].savpe)
model.eval()
model.save("yoloe-11l-seg.pt")

Similar to visual prompt training, for prompt-free model there's only the specialized prompt embedding needs to be updating during training. Converting trained-well Text-prompt model to detection model and adopt detection pipeline with less training cost. Note this step is optional, you can directly start from segmentation as well.

import torch

from ultralytics import YOLOE

det_model = YOLOE("yoloe-11l.yaml")
state = torch.load("yoloe-11l-seg.pt")
det_model.load(state["model"])
det_model.save("yoloe-11l-seg-det.pt")
Start training:
from ultralytics import YOLOE

data = dict(
    train=dict(
        yolo_data=["Objects365.yaml"],
        grounding_data=[
            dict(
                img_path="../datasets/flickr/full_images/",
                json_file="../datasets/flickr/annotations/final_flickr_separateGT_train_segm.json",
            ),
            dict(
                img_path="../datasets/mixed_grounding/gqa/images",
                json_file="../datasets/mixed_grounding/annotations/final_mixed_train_no_coco_segm.json",
            ),
        ],
    ),
    val=dict(yolo_data=["lvis.yaml"]),
)

model = YOLOE("yoloe-11l-seg.pt")
# replace to yoloe-11l-seg-det.pt if converted to detection model
# model = YOLOE("yoloe-11l-seg-det.pt")

# freeze layers.
head_index = len(model.model.model) - 1
freeze = [str(f) for f in range(0, head_index)]
for name, child in model.model.model[-1].named_children():
    if "cv3" not in name:
        freeze.append(f"{head_index}.{name}")

freeze.extend(
    [
        f"{head_index}.cv3.0.0",
        f"{head_index}.cv3.0.1",
        f"{head_index}.cv3.1.0",
        f"{head_index}.cv3.1.1",
        f"{head_index}.cv3.2.0",
        f"{head_index}.cv3.2.1",
    ]
)

model.train(
    data=data,
    batch=128,
    epochs=1,
    close_mosaic=1,
    optimizer="AdamW",
    lr0=2e-3,
    warmup_bias_lr=0.0,
    weight_decay=0.025,
    momentum=0.9,
    workers=4,
    trainer=YOLOEPEFreeTrainer,
    device="0,1,2,3,4,5,6,7",
    freeze=freeze,
    single_cls=True,  # this is needed
)

Convert back to segmentation model after training. Only needed if you converted segmentation model to detection model before training.

from copy import deepcopy

from ultralytics import YOLOE

model = YOLOE("yoloe-11l-seg.pt")
model.eval()

pf_model = YOLOE("yoloe-11l-seg-pf.pt")
names = ["object"]
tpe = model.get_text_pe(names)
model.set_classes(names, tpe)
model.model.model[-1].fuse(model.model.pe)

model.model.model[-1].cv3[0][2] = deepcopy(pf_model.model.model[-1].cv3[0][2]).requires_grad_(True)
model.model.model[-1].cv3[1][2] = deepcopy(pf_model.model.model[-1].cv3[1][2]).requires_grad_(True)
model.model.model[-1].cv3[2][2] = deepcopy(pf_model.model.model[-1].cv3[2][2]).requires_grad_(True)
del model.model.pe
model.save("yoloe-11l-seg-pf.pt")

YOLOE Performance Comparison

YOLOE matches or exceeds the accuracy of closed-set YOLO models on standard benchmarks like COCO, without compromising speed or model size. The table below compares YOLOE-L (built on YOLO11) against corresponding YOLOv8 and YOLO11 models:

Model COCO mAP50-95 Inference Speed (T4) Parameters GFLOPs (640px)
YOLOv8-L (closed-set) 52.9% 9.06 ms (110 FPS) 43.7 M 165.2 B
YOLO11-L (closed-set) 53.5% 6.2 ms (130 FPS) 26.2 M 86.9 B
YOLOE-L (open-vocab) 52.6% 6.2 ms (130 FPS) 26.2 M 86.9 B

YOLO11-L and YOLOE-L have identical architectures (prompt modules disabled in YOLO11-L), resulting in identical inference speed and similar GFLOPs estimates.

YOLOE-L achieves 52.6% mAP, surpassing YOLOv8-L (52.9%) with roughly 40% fewer parameters (26M vs. 43.7M). It processes 640×640 images in 6.2 ms (161 FPS) compared to YOLOv8-L's 9.06 ms (110 FPS), highlighting YOLO11's efficiency. Crucially, YOLOE's open-vocabulary modules incur no inference cost, demonstrating a "no free lunch trade-off" design.

For zero-shot and transfer tasks, YOLOE excels: on LVIS, YOLOE-small improves over YOLO-Worldv2 by +3.5 AP using 3× less training resources. Fine-tuning YOLOE-L from LVIS to COCO also required 4× less training time than YOLOv8-L, underscoring its efficiency and adaptability. YOLOE further maintains YOLO's hallmark speed, achieving 300+ FPS on a T4 GPU and ~64 FPS on iPhone 12 via CoreML, ideal for edge and mobile deployments.

Note

Benchmark conditions: YOLOE results are from models pre-trained on Objects365, GoldG, and LVIS, then fine-tuned or evaluated on COCO. YOLOE's slight mAP advantage over YOLOv8 comes from extensive pre-training. Without this open-vocab training, YOLOE matches similar-sized YOLO models, affirming its SOTA accuracy and open-world flexibility without performance penalties.

Comparison with Previous Models

YOLOE introduces notable advancements over prior YOLO models and open-vocabulary detectors:

  • YOLOE vs YOLOv5:
    YOLOv5 offered good speed-accuracy balance but required retraining for new classes and used anchor-based heads. In contrast, YOLOE is anchor-free and dynamically detects new classes. YOLOE, building on YOLOv8's improvements, achieves higher accuracy (52.6% vs. YOLOv5's ~50% mAP on COCO) and integrates instance segmentation, unlike YOLOv5.

  • YOLOE vs YOLOv8:
    YOLOE extends YOLOv8's redesigned architecture, achieving similar or superior accuracy (52.6% mAP with ~26M parameters vs. YOLOv8-L's 52.9% with ~44M parameters). It significantly reduces training time due to stronger pre-training. The key advancement is YOLOE's open-world capability, detecting unseen objects (e.g., "bird scooter" or "peace symbol") via prompts, unlike YOLOv8's closed-set design.

  • YOLOE vs YOLO11:
    YOLO11 improves upon YOLOv8 with enhanced efficiency and fewer parameters (~22% reduction). YOLOE inherits these gains directly, matching YOLO11's inference speed and parameter count (~26M parameters), while adding open-vocabulary detection and segmentation. In closed-set scenarios, YOLOE is equivalent to YOLO11, but crucially adds adaptability to detect unseen classes, achieving YOLO11 + open-world capability without compromising speed.

  • YOLOE vs previous open-vocabulary detectors:
    Earlier open-vocab models (GLIP, OWL-ViT, YOLO-World) relied heavily on vision-language transformers, leading to slow inference. YOLOE surpasses these in zero-shot accuracy (e.g., +3.5 AP vs. YOLO-Worldv2) while running 1.4× faster with significantly lower training resources. Compared to transformer-based approaches (e.g., GLIP), YOLOE offers orders-of-magnitude faster inference, effectively bridging the accuracy-efficiency gap in open-set detection.

In summary, YOLOE maintains YOLO's renowned speed and efficiency, surpasses predecessors in accuracy, integrates segmentation, and introduces powerful open-world detection, making it uniquely versatile and practical.

Use Cases and Applications

YOLOE's open-vocabulary detection and segmentation enable diverse applications beyond traditional fixed-class models:

  • Open-World Object Detection:
    Ideal for dynamic scenarios like robotics, where robots recognize previously unseen objects using prompts, or security systems quickly adapting to new threats (e.g., hazardous items) without retraining.

  • Few-Shot and One-Shot Detection:
    Using visual prompts (SAVPE), YOLOE rapidly learns new objects from single reference images—perfect for industrial inspection (identifying parts or defects instantly) or custom surveillance, enabling visual searches with minimal setup.

  • Large-Vocabulary & Long-Tail Recognition:
    Equipped with a vocabulary of 1000+ classes, YOLOE excels in tasks like biodiversity monitoring (detecting rare species), museum collections, retail inventory, or e-commerce, reliably identifying many classes without extensive per-class training.

  • Interactive Detection and Segmentation:
    YOLOE supports real-time interactive applications such as searchable video/image retrieval, augmented reality (AR), and intuitive image editing, driven by natural inputs (text or visual prompts). Users can dynamically isolate, identify, or edit objects precisely using segmentation masks.

  • Automated Data Labeling and Bootstrapping:
    YOLOE facilitates rapid dataset creation by providing initial bounding box and segmentation annotations, significantly reducing human labeling efforts. Particularly valuable in analytics of large media collections, where it can auto-identify objects present, assisting in building specialized models faster.

  • Segmentation for Any Object:
    Extends segmentation capabilities to arbitrary objects through prompts—particularly beneficial for medical imaging, microscopy, or satellite imagery analysis, automatically identifying and precisely segmenting structures without specialized pre-trained models. Unlike models like SAM, YOLOE simultaneously recognizes and segments objects automatically, aiding in tasks like content creation or scene understanding.

Across all these use cases, YOLOE's core advantage is versatility, providing a unified model for detection, recognition, and segmentation across dynamic scenarios. Its efficiency ensures real-time performance on resource-constrained devices, ideal for robotics, autonomous driving, defense, and beyond.

Tip

Choose YOLOE's mode based on your needs:

  • Closed-set mode: For fixed-class tasks (max speed and accuracy).
  • Prompted mode: Add new objects quickly via text or visual prompts.
  • Prompt-free open-set mode: General detection across many categories (ideal for cataloging and discovery).

Often, combining modes—such as prompt-free discovery followed by targeted prompts—leverages YOLOE's full potential.

Training and Inference

YOLOE integrates seamlessly with the Ultralytics Python API and CLI, similar to other YOLO models (YOLOv8, YOLO-World). Here's how to quickly get started:

Training and inference with YOLOE

from ultralytics import YOLO

# Load pre-trained YOLOE model and train on custom data
model = YOLO("yoloe-11s.pt")
model.train(data="path/to/data.yaml", epochs=50, imgsz=640)

# Run inference using text prompts ("person", "bus")
model.set_classes(["person", "bus"])
results = model.predict(source="test_images/street.jpg")
results[0].save()  # save annotated output

Here, YOLOE behaves like a standard detector by default but easily switches to prompted detection by specifying classes (set_classes). Results contain bounding boxes, masks, and labels.

# Training YOLOE on custom dataset
yolo train model=yoloe-11s.pt data=path/to/data.yaml epochs=50 imgsz=640

# Inference with text prompts
yolo predict model=yoloe-11s.pt source="test_images/street.jpg" classes="person,bus"

CLI prompts (classes) guide YOLOE similarly to Python's set_classes. Visual prompting (image-based queries) currently requires the Python API.

Other Supported Tasks

  • Validation: Evaluate accuracy easily with model.val() or yolo val.
  • Export: Export YOLOE models (model.export()) to ONNX, TensorRT, etc., facilitating deployment.
  • Tracking: YOLOE supports object tracking (yolo track) when integrated, useful for tracking prompted classes in videos.

Note

YOLOE automatically includes segmentation masks in inference results (results[0].masks), simplifying pixel-precise tasks like object extraction or measurement without needing separate models.

Getting Started

Quickly set up YOLOE with Ultralytics by following these steps:

  1. Installation: Install or update the Ultralytics package:

    pip install -U ultralytics
    
  2. Download YOLOE Weights: Pre-trained YOLOE models (e.g., YOLOE-v8-S/L, YOLOE-11 variants) are available from the YOLOE GitHub releases. Simply download your desired .pt file to load into the Ultralytics YOLO class.

  3. Hardware Requirements:

    • Inference: Recommended GPU (NVIDIA with ≥4-8GB VRAM). Small models run efficiently on edge GPUs (e.g., Jetson) or CPUs at lower resolutions.
    • Training: Fine-tuning YOLOE on custom data typically requires just one GPU. Extensive open-vocabulary pre-training (LVIS/Objects365) used by authors required substantial compute (8× RTX 4090 GPUs).
  4. Configuration: YOLOE configurations use standard Ultralytics YAML files. Default configs (e.g., yoloe-s.yaml) typically suffice, but you can modify backbone, classes, or image size as needed.

  5. Running YOLOE:

    • Quick inference (prompt-free):
      yolo predict model=yoloe-s.pt source="image.jpg"
      
    • Prompted detection (text prompt example):

      yolo predict model=yoloe-s.pt source="kitchen.jpg" classes="bowl,apple"
      

      In Python:

      from ultralytics import YOLO
      
      model = YOLO("yoloe-s.pt")
      model.set_classes(["bowl", "apple"])
      results = model.predict("kitchen.jpg")
      results[0].save()
      
  6. Integration Tips:

    • Class names: Default YOLOE outputs use LVIS categories; use set_classes() to specify your own labels.
    • Speed: YOLOE has no overhead unless using prompts. Text prompts have minimal impact; visual prompts slightly more.
    • Batch inference: Supported directly (model.predict([img1, img2])). For image-specific prompts, run images individually.

The Ultralytics documentation provides further resources. YOLOE lets you easily explore powerful open-world capabilities within the familiar YOLO ecosystem.

Tip

Pro Tip: To maximize YOLOE's zero-shot accuracy, fine-tune from provided checkpoints rather than training from scratch. Use prompt words aligning with common training labels (see LVIS categories) to improve detection accuracy.

Citations and Acknowledgements

If YOLOE has contributed to your research or project, please cite the original paper by Ao Wang, Lihao Liu, Hui Chen, Zijia Lin, Jungong Han, and Guiguang Ding from Tsinghua University:

@misc{wang2025yoloerealtimeseeing,
      title={YOLOE: Real-Time Seeing Anything},
      author={Ao Wang and Lihao Liu and Hui Chen and Zijia Lin and Jungong Han and Guiguang Ding},
      year={2025},
      eprint={2503.07465},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2503.07465},
}

For further reading, the original YOLOE paper is available on arXiv. The project's source code and additional resources can be accessed via their GitHub repository.

FAQ

How does YOLOE differ from YOLO-World?

While both YOLOE and YOLO-World enable open-vocabulary detection, YOLOE offers several advantages. YOLOE achieves +3.5 AP higher accuracy on LVIS while using 3× less training resources and running 1.4× faster than YOLO-Worldv2. YOLOE also supports three prompting modes (text, visual, and internal vocabulary), whereas YOLO-World primarily focuses on text prompts. Additionally, YOLOE includes built-in instance segmentation capabilities, providing pixel-precise masks for detected objects without additional overhead.

Can I use YOLOE as a regular YOLO model?

Yes, YOLOE can function exactly like a standard YOLO model with no performance penalty. When used in closed-set mode (without prompts), YOLOE's open-vocabulary modules are re-parameterized into the standard detection head, resulting in identical speed and accuracy to equivalent YOLO11 models. This makes YOLOE extremely versatile—you can use it as a traditional detector for maximum speed and then switch to open-vocabulary mode only when needed.

What types of prompts can I use with YOLOE?

YOLOE supports three types of prompts:

  1. Text prompts: Specify object classes using natural language (e.g., "person", "traffic light", "bird scooter")
  2. Visual prompts: Provide reference images of objects you want to detect
  3. Internal vocabulary: Use YOLOE's built-in vocabulary of 1200+ categories without external prompts

This flexibility allows you to adapt YOLOE to various scenarios without retraining the model, making it particularly useful for dynamic environments where detection requirements change frequently.

How does YOLOE handle instance segmentation?

YOLOE integrates instance segmentation directly into its architecture by extending the detection head with a mask prediction branch. This approach is similar to YOLOv8-Seg but works for any prompted object class. Segmentation masks are automatically included in inference results and can be accessed via results[0].masks. This unified approach eliminates the need for separate detection and segmentation models, streamlining workflows for applications requiring pixel-precise object boundaries.

How does YOLOE handle inference with custom prompts?

Similar to YOLO-World, YOLOE supports a "prompt-then-detect" strategy that utilizes an offline vocabulary to enhance efficiency. Custom prompts like captions or specific object categories are pre-encoded and stored as offline vocabulary embeddings. This approach streamlines the detection process without requiring retraining. You can dynamically set these prompts within the model to tailor it to specific detection tasks:

from ultralytics import YOLO

# Initialize a YOLOE model
model = YOLO("yoloe-s.pt")

# Define custom classes
model.set_classes(["person", "bus"])

# Execute prediction on an image
results = model.predict("path/to/image.jpg")

# Show results
results[0].show()
📅 Created 1 month ago ✏️ Updated 1 day ago

Comments