Skip to content

Reference for ultralytics/models/yolo/yoloe/train.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/yoloe/train.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.yolo.yoloe.train.YOLOETrainer

YOLOETrainer(self, cfg = DEFAULT_CFG, overrides: dict | None = None, _callbacks = None)

Bases: DetectionTrainer

A trainer class for YOLOE object detection models.

This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom model initialization, validation, and dataset building with multi-modal support.

Args

NameTypeDescriptionDefault
cfgdictConfiguration dictionary with default training settings from DEFAULT_CFG.DEFAULT_CFG
overridesdict, optionalDictionary of parameter overrides for the default configuration.None
_callbackslist, optionalList of callback functions to be applied during training.None

Attributes

NameTypeDescription
loss_namestupleNames of loss components used during training.

Methods

NameDescription
build_datasetBuild YOLO Dataset.
get_modelReturn a YOLOEModel initialized with the specified configuration and weights.
get_validatorReturn a YOLOEDetectValidator for YOLOE model validation.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
class YOLOETrainer(DetectionTrainer):
    """A trainer class for YOLOE object detection models.

    This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom
    model initialization, validation, and dataset building with multi-modal support.

    Attributes:
        loss_names (tuple): Names of loss components used during training.

    Methods:
        get_model: Initialize and return a YOLOEModel with specified configuration.
        get_validator: Return a YOLOEDetectValidator for model validation.
        build_dataset: Build YOLO dataset with multi-modal support for training.
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
        """Initialize the YOLOE Trainer with specified configurations.

        Args:
            cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
            overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
            _callbacks (list, optional): List of callback functions to be applied during training.
        """
        if overrides is None:
            overrides = {}
        assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
        overrides["overlap_mask"] = False
        super().__init__(cfg, overrides, _callbacks)


method ultralytics.models.yolo.yoloe.train.YOLOETrainer.build_dataset

def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None)

Build YOLO Dataset.

Args

NameTypeDescriptionDefault
img_pathstrPath to the folder containing images.required
modestr'train' mode or 'val' mode, users are able to customize different augmentations for each mode."train"
batchint, optionalSize of batches, this is for rectangular training.None

Returns

TypeDescription
DatasetYOLO dataset configured for training or validation.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
    """Build YOLO Dataset.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for rectangular training.

    Returns:
        (Dataset): YOLO dataset configured for training or validation.
    """
    gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
    return build_yolo_dataset(
        self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
    )


method ultralytics.models.yolo.yoloe.train.YOLOETrainer.get_model

def get_model(self, cfg = None, weights = None, verbose: bool = True)

Return a YOLOEModel initialized with the specified configuration and weights.

Args

NameTypeDescriptionDefault
cfgdict | str, optionalModel configuration. Can be a dictionary containing a 'yaml_file' key, a direct path to a YAML file, or None to use default configuration.None
weightsstr | Path, optionalPath to pretrained weights file to load into the model.None
verboseboolWhether to display model information during initialization.True

Returns

TypeDescription
YOLOEModelThe initialized YOLOE model.

Notes

  • The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.
  • The nc parameter here represents the maximum number of different text samples in one image, rather than the actual number of classes.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def get_model(self, cfg=None, weights=None, verbose: bool = True):
    """Return a YOLOEModel initialized with the specified configuration and weights.

    Args:
        cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct
            path to a YAML file, or None to use default configuration.
        weights (str | Path, optional): Path to pretrained weights file to load into the model.
        verbose (bool): Whether to display model information during initialization.

    Returns:
        (YOLOEModel): The initialized YOLOE model.

    Notes:
        - The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.
        - The nc parameter here represents the maximum number of different text samples in one image,
          rather than the actual number of classes.
    """
    # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
    # NOTE: Following the official config, nc hard-coded to 80 for now.
    model = YOLOEModel(
        cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
        ch=self.data["channels"],
        nc=min(self.data["nc"], 80),
        verbose=verbose and RANK == -1,
    )
    if weights:
        model.load(weights)

    return model


method ultralytics.models.yolo.yoloe.train.YOLOETrainer.get_validator

def get_validator(self)

Return a YOLOEDetectValidator for YOLOE model validation.

Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def get_validator(self):
    """Return a YOLOEDetectValidator for YOLOE model validation."""
    self.loss_names = "box", "cls", "dfl"
    return YOLOEDetectValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )





class ultralytics.models.yolo.yoloe.train.YOLOEPETrainer

YOLOEPETrainer()

Bases: DetectionTrainer

Fine-tune YOLOE model using linear probing approach.

This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new datasets while preserving pretrained features.

Methods

NameDescription
get_modelReturn YOLOEModel initialized with specified config and weights.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
class YOLOEPETrainer(DetectionTrainer):


method ultralytics.models.yolo.yoloe.train.YOLOEPETrainer.get_model

def get_model(self, cfg = None, weights = None, verbose: bool = True)

Return YOLOEModel initialized with specified config and weights.

Args

NameTypeDescriptionDefault
cfgdict | str, optionalModel configuration.None
weightsstr, optionalPath to pretrained weights.None
verboseboolWhether to display model information.True

Returns

TypeDescription
YOLOEModelInitialized model with frozen layers except for specific projection layers.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def get_model(self, cfg=None, weights=None, verbose: bool = True):
    """Return YOLOEModel initialized with specified config and weights.

    Args:
        cfg (dict | str, optional): Model configuration.
        weights (str, optional): Path to pretrained weights.
        verbose (bool): Whether to display model information.

    Returns:
        (YOLOEModel): Initialized model with frozen layers except for specific projection layers.
    """
    # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
    # NOTE: Following the official config, nc hard-coded to 80 for now.
    model = YOLOEModel(
        cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
        ch=self.data["channels"],
        nc=self.data["nc"],
        verbose=verbose and RANK == -1,
    )

    del model.model[-1].savpe

    assert weights is not None, "Pretrained weights must be provided for linear probing."
    if weights:
        model.load(weights)

    model.eval()
    names = list(self.data["names"].values())
    # NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
    # it'd get correct results as long as loading proper pretrained weights.
    tpe = model.get_text_pe(names)
    model.set_classes(names, tpe)
    model.model[-1].fuse(model.pe)  # fuse text embeddings to classify head
    model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
    model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
    model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
    del model.pe
    model.train()

    return model





class ultralytics.models.yolo.yoloe.train.YOLOETrainerFromScratch

YOLOETrainerFromScratch()

Bases: YOLOETrainer, WorldTrainerFromScratch

Train YOLOE models from scratch with text embedding support.

This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with text embeddings and grounding datasets.

Methods

NameDescription
build_datasetBuild YOLO Dataset for training or validation.
generate_text_embeddingsGenerate text embeddings for a list of text samples.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):


method ultralytics.models.yolo.yoloe.train.YOLOETrainerFromScratch.build_dataset

def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None)

Build YOLO Dataset for training or validation.

This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO datasets and grounding datasets with different formats.

Args

NameTypeDescriptionDefault
img_pathlist[str] | strPath to the folder containing images or list of paths.required
modestr'train' mode or 'val' mode, allowing customized augmentations for each mode."train"
batchint, optionalSize of batches, used for rectangular training/validation.None

Returns

TypeDescription
YOLOConcatDataset | DatasetThe constructed dataset for training or validation.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
    """Build YOLO Dataset for training or validation.

    This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
    datasets and grounding datasets with different formats.

    Args:
        img_path (list[str] | str): Path to the folder containing images or list of paths.
        mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
        batch (int, optional): Size of batches, used for rectangular training/validation.

    Returns:
        (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
    """
    return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)


method ultralytics.models.yolo.yoloe.train.YOLOETrainerFromScratch.generate_text_embeddings

def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path)

Generate text embeddings for a list of text samples.

Args

NameTypeDescriptionDefault
textslist[str]List of text samples to encode.required
batchintBatch size for processing.required
cache_dirPathDirectory to save/load cached embeddings.required

Returns

TypeDescription
dictDictionary mapping text samples to their embeddings.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
    """Generate text embeddings for a list of text samples.

    Args:
        texts (list[str]): List of text samples to encode.
        batch (int): Batch size for processing.
        cache_dir (Path): Directory to save/load cached embeddings.

    Returns:
        (dict): Dictionary mapping text samples to their embeddings.
    """
    model = "mobileclip:blt"
    cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
    if cache_path.exists():
        LOGGER.info(f"Reading existed cache from '{cache_path}'")
        txt_map = torch.load(cache_path, map_location=self.device)
        if sorted(txt_map.keys()) == sorted(texts):
            return txt_map
    LOGGER.info(f"Caching text embeddings to '{cache_path}'")
    assert self.model is not None
    txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
    txt_map = dict(zip(texts, txt_feats.squeeze(0)))
    torch.save(txt_map, cache_path)
    return txt_map





class ultralytics.models.yolo.yoloe.train.YOLOEPEFreeTrainer

YOLOEPEFreeTrainer()

Bases: YOLOEPETrainer, YOLOETrainerFromScratch

Train prompt-free YOLOE model.

This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't require text prompts during inference.

Methods

NameDescription
get_validatorReturn a DetectionValidator for YOLO model validation.
preprocess_batchPreprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed.
set_text_embeddingsSet text embeddings for datasets to accelerate training by caching category names.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):


method ultralytics.models.yolo.yoloe.train.YOLOEPEFreeTrainer.get_validator

def get_validator(self)

Return a DetectionValidator for YOLO model validation.

Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def get_validator(self):
    """Return a DetectionValidator for YOLO model validation."""
    self.loss_names = "box", "cls", "dfl"
    return DetectionValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )


method ultralytics.models.yolo.yoloe.train.YOLOEPEFreeTrainer.preprocess_batch

def preprocess_batch(self, batch)

Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed.

Args

NameTypeDescriptionDefault
batchrequired
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def preprocess_batch(self, batch):
    """Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
    return DetectionTrainer.preprocess_batch(self, batch)


method ultralytics.models.yolo.yoloe.train.YOLOEPEFreeTrainer.set_text_embeddings

def set_text_embeddings(self, datasets, batch: int)

Set text embeddings for datasets to accelerate training by caching category names.

This method collects unique category names from all datasets, generates text embeddings for them, and caches these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of the first dataset's image path.

Args

NameTypeDescriptionDefault
datasetslist[Dataset]List of datasets containing category names to process.required
batchintBatch size for processing text embeddings.required

Notes

The method creates a dictionary mapping text samples to their embeddings and stores it at the path specified by 'cache_path'. If the cache file already exists, it will be loaded instead of regenerating the embeddings.

Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def set_text_embeddings(self, datasets, batch: int):
    """Set text embeddings for datasets to accelerate training by caching category names.

    This method collects unique category names from all datasets, generates text embeddings for them, and caches
    these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of
    the first dataset's image path.

    Args:
        datasets (list[Dataset]): List of datasets containing category names to process.
        batch (int): Batch size for processing text embeddings.

    Notes:
        The method creates a dictionary mapping text samples to their embeddings and stores it
        at the path specified by 'cache_path'. If the cache file already exists, it will be loaded
        instead of regenerating the embeddings.
    """
    pass





class ultralytics.models.yolo.yoloe.train.YOLOEVPTrainer

YOLOEVPTrainer()

Bases: YOLOETrainerFromScratch

Train YOLOE model with visual prompts.

This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided alongside images to guide the detection process.

Methods

NameDescription
_close_dataloader_mosaicClose mosaic augmentation and add visual prompt loading to the training dataset.
build_datasetBuild YOLO Dataset for training or validation with visual prompts.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
class YOLOEVPTrainer(YOLOETrainerFromScratch):


method ultralytics.models.yolo.yoloe.train.YOLOEVPTrainer._close_dataloader_mosaic

def _close_dataloader_mosaic(self)

Close mosaic augmentation and add visual prompt loading to the training dataset.

Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def _close_dataloader_mosaic(self):
    """Close mosaic augmentation and add visual prompt loading to the training dataset."""
    super()._close_dataloader_mosaic()
    if isinstance(self.train_loader.dataset, YOLOConcatDataset):
        for d in self.train_loader.dataset.datasets:
            d.transforms.append(LoadVisualPrompt())
    else:
        self.train_loader.dataset.transforms.append(LoadVisualPrompt())


method ultralytics.models.yolo.yoloe.train.YOLOEVPTrainer.build_dataset

def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None)

Build YOLO Dataset for training or validation with visual prompts.

Args

NameTypeDescriptionDefault
img_pathlist[str] | strPath to the folder containing images or list of paths.required
modestr'train' mode or 'val' mode, allowing customized augmentations for each mode."train"
batchint, optionalSize of batches, used for rectangular training/validation.None

Returns

TypeDescription
DatasetYOLO dataset configured for training or validation, with visual prompts for training mode.
Source code in ultralytics/models/yolo/yoloe/train.pyView on GitHub
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
    """Build YOLO Dataset for training or validation with visual prompts.

    Args:
        img_path (list[str] | str): Path to the folder containing images or list of paths.
        mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
        batch (int, optional): Size of batches, used for rectangular training/validation.

    Returns:
        (Dataset): YOLO dataset configured for training or validation, with visual prompts for training mode.
    """
    dataset = super().build_dataset(img_path, mode, batch)
    if isinstance(dataset, YOLOConcatDataset):
        for d in dataset.datasets:
            d.transforms.append(LoadVisualPrompt())
    else:
        dataset.transforms.append(LoadVisualPrompt())
    return dataset





📅 Created 8 months ago ✏️ Updated 2 days ago
glenn-jocherRizwanMunawar