Skip to content

Référence pour ultralytics/models/yolo/world/train.py

Note

Ce fichier est disponible à l'adresse https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/models/ yolo/world/train .py. Si tu repères un problème, aide à le corriger en contribuant à une Pull Request 🛠️. Merci 🙏 !



ultralytics.models.yolo.world.train.WorldTrainer

Bases : DetectionTrainer

Une classe pour affiner un modèle mondial sur un ensemble de données proches.

Exemple
from ultralytics.models.yolo.world import WorldModel

args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
trainer = WorldTrainer(overrides=args)
trainer.train()
Code source dans ultralytics/models/yolo/world/train.py
class WorldTrainer(yolo.detect.DetectionTrainer):
    """
    A class to fine-tune a world model on a close-set dataset.

    Example:
        ```python
        from ultralytics.models.yolo.world import WorldModel

        args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
        trainer = WorldTrainer(overrides=args)
        trainer.train()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a WorldTrainer object with given arguments."""
        if overrides is None:
            overrides = {}
        super().__init__(cfg, overrides, _callbacks)

        # Import and assign clip
        try:
            import clip
        except ImportError:
            checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
            import clip
        self.clip = clip

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Return WorldModel initialized with specified config and weights."""
        # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
        # NOTE: Following the official config, nc hard-coded to 80 for now.
        model = WorldModel(
            cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
            ch=3,
            nc=min(self.data["nc"], 80),
            verbose=verbose and RANK == -1,
        )
        if weights:
            model.load(weights)
        self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)

        return model

    def build_dataset(self, img_path, mode="train", batch=None):
        """
        Build YOLO Dataset.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
        """
        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
        return build_yolo_dataset(
            self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
        )

    def preprocess_batch(self, batch):
        """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
        batch = super().preprocess_batch(batch)

        # NOTE: add text features
        texts = list(itertools.chain(*batch["texts"]))
        text_token = self.clip.tokenize(texts).to(batch["img"].device)
        txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype)  # torch.float32
        txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
        batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
        return batch

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Initialise un objet WorldTrainer avec les arguments donnés.

Code source dans ultralytics/models/yolo/world/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a WorldTrainer object with given arguments."""
    if overrides is None:
        overrides = {}
    super().__init__(cfg, overrides, _callbacks)

    # Import and assign clip
    try:
        import clip
    except ImportError:
        checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
        import clip
    self.clip = clip

build_dataset(img_path, mode='train', batch=None)

Construis l'ensemble de données YOLO .

Paramètres :

Nom Type Description DĂ©faut
img_path str

Chemin d'accès au dossier contenant les images.

requis
mode str

train ou val les utilisateurs peuvent personnaliser différentes augmentations pour chaque mode.

'train'
batch int

Taille des lots, c'est pour rect. La valeur par défaut est Aucun.

None
Code source dans ultralytics/models/yolo/world/train.py
def build_dataset(self, img_path, mode="train", batch=None):
    """
    Build YOLO Dataset.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
    """
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
    return build_yolo_dataset(
        self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
    )

get_model(cfg=None, weights=None, verbose=True)

Retourne le modèle mondial initialisé avec la configuration et les poids spécifiés.

Code source dans ultralytics/models/yolo/world/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Return WorldModel initialized with specified config and weights."""
    # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
    # NOTE: Following the official config, nc hard-coded to 80 for now.
    model = WorldModel(
        cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
        ch=3,
        nc=min(self.data["nc"], 80),
        verbose=verbose and RANK == -1,
    )
    if weights:
        model.load(weights)
    self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)

    return model

preprocess_batch(batch)

Prétraite un lot d'images pour la formation YOLOWorld, en ajustant le formatage et les dimensions si nécessaire.

Code source dans ultralytics/models/yolo/world/train.py
def preprocess_batch(self, batch):
    """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
    batch = super().preprocess_batch(batch)

    # NOTE: add text features
    texts = list(itertools.chain(*batch["texts"]))
    text_token = self.clip.tokenize(texts).to(batch["img"].device)
    txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype)  # torch.float32
    txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
    batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
    return batch



ultralytics.models.yolo.world.train.on_pretrain_routine_end(trainer)

Rappelle-moi.

Code source dans ultralytics/models/yolo/world/train.py
def on_pretrain_routine_end(trainer):
    """Callback."""
    if RANK in {-1, 0}:
        # NOTE: for evaluation
        names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
        de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
    device = next(trainer.model.parameters()).device
    trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device)
    for p in trainer.text_model.parameters():
        p.requires_grad_(False)





Créé le 2024-03-31, Mis à jour le 2024-05-18
Auteurs : glenn-jocher (1), Burhan-Q (1), Laughing-q (1)