Skip to content

Reference for ultralytics/models/yolo/yoloe/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/yoloe/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.yoloe.train.YOLOETrainer

YOLOETrainer(cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None)

Bases: DetectionTrainer

A trainer class for YOLOE object detection models.

This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom model initialization, validation, and dataset building with multi-modal support.

Attributes:

Name Type Description
loss_names tuple

Names of loss components used during training.

Methods:

Name Description
get_model

Initialize and return a YOLOEModel with specified configuration.

get_validator

Return a YOLOEDetectValidator for model validation.

build_dataset

Build YOLO dataset with multi-modal support for training.

Parameters:

Name Type Description Default
cfg dict

Configuration dictionary with default training settings from DEFAULT_CFG.

DEFAULT_CFG
overrides dict

Dictionary of parameter overrides for the default configuration.

None
_callbacks list

List of callback functions to be applied during training.

None
Source code in ultralytics/models/yolo/yoloe/train.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
    """
    Initialize the YOLOE Trainer with specified configurations.

    Args:
        cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
        overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
        _callbacks (list, optional): List of callback functions to be applied during training.
    """
    if overrides is None:
        overrides = {}
    assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
    overrides["overlap_mask"] = False
    super().__init__(cfg, overrides, _callbacks)

build_dataset

build_dataset(img_path: str, mode: str = 'train', batch: int | None = None)

Build YOLO Dataset.

Parameters:

Name Type Description Default
img_path str

Path to the folder containing images.

required
mode str

'train' mode or 'val' mode, users are able to customize different augmentations for each mode.

'train'
batch int

Size of batches, this is for rectangular training.

None

Returns:

Type Description
Dataset

YOLO dataset configured for training or validation.

Source code in ultralytics/models/yolo/yoloe/train.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
    """
    Build YOLO Dataset.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for rectangular training.

    Returns:
        (Dataset): YOLO dataset configured for training or validation.
    """
    gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
    return build_yolo_dataset(
        self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
    )

get_model

get_model(cfg=None, weights=None, verbose: bool = True)

Return a YOLOEModel initialized with the specified configuration and weights.

Parameters:

Name Type Description Default
cfg dict | str

Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct path to a YAML file, or None to use default configuration.

None
weights str | Path

Path to pretrained weights file to load into the model.

None
verbose bool

Whether to display model information during initialization.

True

Returns:

Type Description
YOLOEModel

The initialized YOLOE model.

Notes
  • The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.
  • The nc parameter here represents the maximum number of different text samples in one image, rather than the actual number of classes.
Source code in ultralytics/models/yolo/yoloe/train.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def get_model(self, cfg=None, weights=None, verbose: bool = True):
    """
    Return a YOLOEModel initialized with the specified configuration and weights.

    Args:
        cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
            a direct path to a YAML file, or None to use default configuration.
        weights (str | Path, optional): Path to pretrained weights file to load into the model.
        verbose (bool): Whether to display model information during initialization.

    Returns:
        (YOLOEModel): The initialized YOLOE model.

    Notes:
        - The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.
        - The nc parameter here represents the maximum number of different text samples in one image,
          rather than the actual number of classes.
    """
    # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
    # NOTE: Following the official config, nc hard-coded to 80 for now.
    model = YOLOEModel(
        cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
        ch=self.data["channels"],
        nc=min(self.data["nc"], 80),
        verbose=verbose and RANK == -1,
    )
    if weights:
        model.load(weights)

    return model

get_validator

get_validator()

Return a YOLOEDetectValidator for YOLOE model validation.

Source code in ultralytics/models/yolo/yoloe/train.py
83
84
85
86
87
88
def get_validator(self):
    """Return a YOLOEDetectValidator for YOLOE model validation."""
    self.loss_names = "box", "cls", "dfl"
    return YOLOEDetectValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )





ultralytics.models.yolo.yoloe.train.YOLOEPETrainer

YOLOEPETrainer(
    cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None
)

Bases: DetectionTrainer

Fine-tune YOLOE model using linear probing approach.

This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new datasets while preserving pretrained features.

Methods:

Name Description
get_model

Initialize YOLOEModel with frozen layers except projection layers.

Source code in ultralytics/models/yolo/detect/train.py
56
57
58
59
60
61
62
63
64
65
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
    """
    Initialize a DetectionTrainer object for training YOLO object detection model training.

    Args:
        cfg (dict, optional): Default configuration dictionary containing training parameters.
        overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
        _callbacks (list, optional): List of callback functions to be executed during training.
    """
    super().__init__(cfg, overrides, _callbacks)

get_model

get_model(cfg=None, weights=None, verbose: bool = True)

Return YOLOEModel initialized with specified config and weights.

Parameters:

Name Type Description Default
cfg dict | str

Model configuration.

None
weights str

Path to pretrained weights.

None
verbose bool

Whether to display model information.

True

Returns:

Type Description
YOLOEModel

Initialized model with frozen layers except for specific projection layers.

Source code in ultralytics/models/yolo/yoloe/train.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def get_model(self, cfg=None, weights=None, verbose: bool = True):
    """
    Return YOLOEModel initialized with specified config and weights.

    Args:
        cfg (dict | str, optional): Model configuration.
        weights (str, optional): Path to pretrained weights.
        verbose (bool): Whether to display model information.

    Returns:
        (YOLOEModel): Initialized model with frozen layers except for specific projection layers.
    """
    # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
    # NOTE: Following the official config, nc hard-coded to 80 for now.
    model = YOLOEModel(
        cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
        ch=self.data["channels"],
        nc=self.data["nc"],
        verbose=verbose and RANK == -1,
    )

    del model.model[-1].savpe

    assert weights is not None, "Pretrained weights must be provided for linear probing."
    if weights:
        model.load(weights)

    model.eval()
    names = list(self.data["names"].values())
    # NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
    # it'd get correct results as long as loading proper pretrained weights.
    tpe = model.get_text_pe(names)
    model.set_classes(names, tpe)
    model.model[-1].fuse(model.pe)  # fuse text embeddings to classify head
    model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
    model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
    model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
    del model.pe
    model.train()

    return model





ultralytics.models.yolo.yoloe.train.YOLOETrainerFromScratch

YOLOETrainerFromScratch(
    cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None
)

Bases: YOLOETrainer, WorldTrainerFromScratch

Train YOLOE models from scratch with text embedding support.

This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with text embeddings and grounding datasets.

Methods:

Name Description
build_dataset

Build datasets for training with grounding support.

generate_text_embeddings

Generate and cache text embeddings for training.

Source code in ultralytics/models/yolo/yoloe/train.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
    """
    Initialize the YOLOE Trainer with specified configurations.

    Args:
        cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
        overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
        _callbacks (list, optional): List of callback functions to be applied during training.
    """
    if overrides is None:
        overrides = {}
    assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
    overrides["overlap_mask"] = False
    super().__init__(cfg, overrides, _callbacks)

build_dataset

build_dataset(
    img_path: list[str] | str, mode: str = "train", batch: int | None = None
)

Build YOLO Dataset for training or validation.

This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO datasets and grounding datasets with different formats.

Parameters:

Name Type Description Default
img_path list[str] | str

Path to the folder containing images or list of paths.

required
mode str

'train' mode or 'val' mode, allowing customized augmentations for each mode.

'train'
batch int

Size of batches, used for rectangular training/validation.

None

Returns:

Type Description
YOLOConcatDataset | Dataset

The constructed dataset for training or validation.

Source code in ultralytics/models/yolo/yoloe/train.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
    """
    Build YOLO Dataset for training or validation.

    This method constructs appropriate datasets based on the mode and input paths, handling both
    standard YOLO datasets and grounding datasets with different formats.

    Args:
        img_path (list[str] | str): Path to the folder containing images or list of paths.
        mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
        batch (int, optional): Size of batches, used for rectangular training/validation.

    Returns:
        (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
    """
    return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)

generate_text_embeddings

generate_text_embeddings(texts: list[str], batch: int, cache_dir: Path)

Generate text embeddings for a list of text samples.

Parameters:

Name Type Description Default
texts list[str]

List of text samples to encode.

required
batch int

Batch size for processing.

required
cache_dir Path

Directory to save/load cached embeddings.

required

Returns:

Type Description
dict

Dictionary mapping text samples to their embeddings.

Source code in ultralytics/models/yolo/yoloe/train.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
    """
    Generate text embeddings for a list of text samples.

    Args:
        texts (list[str]): List of text samples to encode.
        batch (int): Batch size for processing.
        cache_dir (Path): Directory to save/load cached embeddings.

    Returns:
        (dict): Dictionary mapping text samples to their embeddings.
    """
    model = "mobileclip:blt"
    cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
    if cache_path.exists():
        LOGGER.info(f"Reading existed cache from '{cache_path}'")
        txt_map = torch.load(cache_path, map_location=self.device)
        if sorted(txt_map.keys()) == sorted(texts):
            return txt_map
    LOGGER.info(f"Caching text embeddings to '{cache_path}'")
    assert self.model is not None
    txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
    txt_map = dict(zip(texts, txt_feats.squeeze(0)))
    torch.save(txt_map, cache_path)
    return txt_map





ultralytics.models.yolo.yoloe.train.YOLOEPEFreeTrainer

YOLOEPEFreeTrainer(
    cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None
)

Bases: YOLOEPETrainer, YOLOETrainerFromScratch

Train prompt-free YOLOE model.

This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't require text prompts during inference.

Methods:

Name Description
get_validator

Return standard DetectionValidator for validation.

preprocess_batch

Preprocess batches without text features.

set_text_embeddings

Set text embeddings for datasets (no-op for prompt-free).

Source code in ultralytics/models/yolo/yoloe/train.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
    """
    Initialize the YOLOE Trainer with specified configurations.

    Args:
        cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
        overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
        _callbacks (list, optional): List of callback functions to be applied during training.
    """
    if overrides is None:
        overrides = {}
    assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
    overrides["overlap_mask"] = False
    super().__init__(cfg, overrides, _callbacks)

get_validator

get_validator()

Return a DetectionValidator for YOLO model validation.

Source code in ultralytics/models/yolo/yoloe/train.py
231
232
233
234
235
236
def get_validator(self):
    """Return a DetectionValidator for YOLO model validation."""
    self.loss_names = "box", "cls", "dfl"
    return DetectionValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

preprocess_batch

preprocess_batch(batch)

Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed.

Source code in ultralytics/models/yolo/yoloe/train.py
238
239
240
def preprocess_batch(self, batch):
    """Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
    return DetectionTrainer.preprocess_batch(self, batch)

set_text_embeddings

set_text_embeddings(datasets, batch: int)

Set text embeddings for datasets to accelerate training by caching category names.

This method collects unique category names from all datasets, generates text embeddings for them, and caches these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of the first dataset's image path.

Parameters:

Name Type Description Default
datasets list[Dataset]

List of datasets containing category names to process.

required
batch int

Batch size for processing text embeddings.

required
Notes

The method creates a dictionary mapping text samples to their embeddings and stores it at the path specified by 'cache_path'. If the cache file already exists, it will be loaded instead of regenerating the embeddings.

Source code in ultralytics/models/yolo/yoloe/train.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def set_text_embeddings(self, datasets, batch: int):
    """
    Set text embeddings for datasets to accelerate training by caching category names.

    This method collects unique category names from all datasets, generates text embeddings for them,
    and caches these embeddings to improve training efficiency. The embeddings are stored in a file
    in the parent directory of the first dataset's image path.

    Args:
        datasets (list[Dataset]): List of datasets containing category names to process.
        batch (int): Batch size for processing text embeddings.

    Notes:
        The method creates a dictionary mapping text samples to their embeddings and stores it
        at the path specified by 'cache_path'. If the cache file already exists, it will be loaded
        instead of regenerating the embeddings.
    """
    pass





ultralytics.models.yolo.yoloe.train.YOLOEVPTrainer

YOLOEVPTrainer(cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None)

Bases: YOLOETrainerFromScratch

Train YOLOE model with visual prompts.

This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided alongside images to guide the detection process.

Methods:

Name Description
build_dataset

Build dataset with visual prompt loading transforms.

Source code in ultralytics/models/yolo/yoloe/train.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
    """
    Initialize the YOLOE Trainer with specified configurations.

    Args:
        cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
        overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
        _callbacks (list, optional): List of callback functions to be applied during training.
    """
    if overrides is None:
        overrides = {}
    assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
    overrides["overlap_mask"] = False
    super().__init__(cfg, overrides, _callbacks)

build_dataset

build_dataset(
    img_path: list[str] | str, mode: str = "train", batch: int | None = None
)

Build YOLO Dataset for training or validation with visual prompts.

Parameters:

Name Type Description Default
img_path list[str] | str

Path to the folder containing images or list of paths.

required
mode str

'train' mode or 'val' mode, allowing customized augmentations for each mode.

'train'
batch int

Size of batches, used for rectangular training/validation.

None

Returns:

Type Description
Dataset

YOLO dataset configured for training or validation, with visual prompts for training mode.

Source code in ultralytics/models/yolo/yoloe/train.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
    """
    Build YOLO Dataset for training or validation with visual prompts.

    Args:
        img_path (list[str] | str): Path to the folder containing images or list of paths.
        mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
        batch (int, optional): Size of batches, used for rectangular training/validation.

    Returns:
        (Dataset): YOLO dataset configured for training or validation, with visual prompts for training mode.
    """
    dataset = super().build_dataset(img_path, mode, batch)
    if isinstance(dataset, YOLOConcatDataset):
        for d in dataset.datasets:
            d.transforms.append(LoadVisualPrompt())
    else:
        dataset.transforms.append(LoadVisualPrompt())
    return dataset





📅 Created 5 months ago ✏️ Updated 5 months ago