Skip to content

Reference for ultralytics/models/yolo/world/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.world.train.WorldTrainer

WorldTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: DetectionTrainer

A class to fine-tune a world model on a close-set dataset.

This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual features for improved object detection and understanding.

Attributes:

Name Type Description
clip module

The CLIP module for text-image understanding.

text_model module

The text encoder model from CLIP.

model WorldModel

The YOLO World model being trained.

data dict

Dataset configuration containing class information.

args dict

Training arguments and configuration.

Examples:

>>> from ultralytics.models.yolo.world import WorldModel
>>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
>>> trainer = WorldTrainer(overrides=args)
>>> trainer.train()

Parameters:

Name Type Description Default
cfg dict

Configuration for the trainer.

DEFAULT_CFG
overrides dict

Configuration overrides.

None
_callbacks list

List of callback functions.

None
Source code in ultralytics/models/yolo/world/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize a WorldTrainer object with given arguments.

    Args:
        cfg (dict): Configuration for the trainer.
        overrides (dict, optional): Configuration overrides.
        _callbacks (list, optional): List of callback functions.
    """
    if overrides is None:
        overrides = {}
    super().__init__(cfg, overrides, _callbacks)

    # Import and assign clip
    try:
        import clip
    except ImportError:
        checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
        import clip
    self.clip = clip

build_dataset

build_dataset(img_path, mode='train', batch=None)

Build YOLO Dataset for training or validation.

Parameters:

Name Type Description Default
img_path str

Path to the folder containing images.

required
mode str

train mode or val mode, users are able to customize different augmentations for each mode.

'train'
batch int

Size of batches, this is for rect.

None

Returns:

Type Description
Dataset

YOLO dataset configured for training or validation.

Source code in ultralytics/models/yolo/world/train.py
def build_dataset(self, img_path, mode="train", batch=None):
    """
    Build YOLO Dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for `rect`.

    Returns:
        (Dataset): YOLO dataset configured for training or validation.
    """
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
    return build_yolo_dataset(
        self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
    )

get_model

get_model(cfg=None, weights=None, verbose=True)

Return WorldModel initialized with specified config and weights.

Parameters:

Name Type Description Default
cfg Dict | str

Model configuration.

None
weights str

Path to pretrained weights.

None
verbose bool

Whether to display model info.

True

Returns:

Type Description
WorldModel

Initialized WorldModel.

Source code in ultralytics/models/yolo/world/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """
    Return WorldModel initialized with specified config and weights.

    Args:
        cfg (Dict | str, optional): Model configuration.
        weights (str, optional): Path to pretrained weights.
        verbose (bool): Whether to display model info.

    Returns:
        (WorldModel): Initialized WorldModel.
    """
    # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
    # NOTE: Following the official config, nc hard-coded to 80 for now.
    model = WorldModel(
        cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
        ch=3,
        nc=min(self.data["nc"], 80),
        verbose=verbose and RANK == -1,
    )
    if weights:
        model.load(weights)
    self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)

    return model

preprocess_batch

preprocess_batch(batch)

Preprocess a batch of images and text for YOLOWorld training.

Source code in ultralytics/models/yolo/world/train.py
def preprocess_batch(self, batch):
    """Preprocess a batch of images and text for YOLOWorld training."""
    batch = super().preprocess_batch(batch)

    # Add text features
    texts = list(itertools.chain(*batch["texts"]))
    text_token = self.clip.tokenize(texts).to(batch["img"].device)
    txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype)  # torch.float32
    txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
    batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
    return batch





ultralytics.models.yolo.world.train.on_pretrain_routine_end

on_pretrain_routine_end(trainer)

Callback to set up model classes and text encoder at the end of the pretrain routine.

Source code in ultralytics/models/yolo/world/train.py
def on_pretrain_routine_end(trainer):
    """Callback to set up model classes and text encoder at the end of the pretrain routine."""
    if RANK in {-1, 0}:
        # Set class names for evaluation
        names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
        de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
    device = next(trainer.model.parameters()).device
    trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device)
    for p in trainer.text_model.parameters():
        p.requires_grad_(False)



📅 Created 12 months ago ✏️ Updated 6 months ago