Перейти к содержимому

Ссылка для ultralytics/models/yolo/world/train.py

Примечание

Этот файл доступен по адресу https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/models/ yolo/world/train .py. Если ты заметил проблему, пожалуйста, помоги исправить ее, отправив Pull Request 🛠️. Спасибо 🙏!



ultralytics.models.yolo.world.train.WorldTrainer

Базы: DetectionTrainer

Класс для тонкой настройки модели мира на близком наборе данных.

Пример
from ultralytics.models.yolo.world import WorldModel

args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
trainer = WorldTrainer(overrides=args)
trainer.train()
Исходный код в ultralytics/models/yolo/world/train.py
class WorldTrainer(yolo.detect.DetectionTrainer):
    """
    A class to fine-tune a world model on a close-set dataset.

    Example:
        ```python
        from ultralytics.models.yolo.world import WorldModel

        args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
        trainer = WorldTrainer(overrides=args)
        trainer.train()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a WorldTrainer object with given arguments."""
        if overrides is None:
            overrides = {}
        super().__init__(cfg, overrides, _callbacks)

        # Import and assign clip
        try:
            import clip
        except ImportError:
            checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
            import clip
        self.clip = clip

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Return WorldModel initialized with specified config and weights."""
        # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
        # NOTE: Following the official config, nc hard-coded to 80 for now.
        model = WorldModel(
            cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
            ch=3,
            nc=min(self.data["nc"], 80),
            verbose=verbose and RANK == -1,
        )
        if weights:
            model.load(weights)
        self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)

        return model

    def build_dataset(self, img_path, mode="train", batch=None):
        """
        Build YOLO Dataset.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
        """
        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
        return build_yolo_dataset(
            self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
        )

    def preprocess_batch(self, batch):
        """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
        batch = super().preprocess_batch(batch)

        # NOTE: add text features
        texts = list(itertools.chain(*batch["texts"]))
        text_token = self.clip.tokenize(texts).to(batch["img"].device)
        txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype)  # torch.float32
        txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
        batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
        return batch

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Инициализируй объект WorldTrainer с заданными аргументами.

Исходный код в ultralytics/models/yolo/world/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a WorldTrainer object with given arguments."""
    if overrides is None:
        overrides = {}
    super().__init__(cfg, overrides, _callbacks)

    # Import and assign clip
    try:
        import clip
    except ImportError:
        checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
        import clip
    self.clip = clip

build_dataset(img_path, mode='train', batch=None)

Построй YOLO Dataset.

Параметры:

Имя Тип Описание По умолчанию
img_path str

Путь к папке, содержащей изображения.

требуется
mode str

train режим или val Пользователи могут настраивать различные дополнения для каждого режима.

'train'
batch int

Размер партий, это для rect. По умолчанию установлено значение "Нет".

None
Исходный код в ultralytics/models/yolo/world/train.py
def build_dataset(self, img_path, mode="train", batch=None):
    """
    Build YOLO Dataset.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
    """
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
    return build_yolo_dataset(
        self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
    )

get_model(cfg=None, weights=None, verbose=True)

Возвращает WorldModel, инициализированную с указанным конфигом и весами.

Исходный код в ultralytics/models/yolo/world/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Return WorldModel initialized with specified config and weights."""
    # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
    # NOTE: Following the official config, nc hard-coded to 80 for now.
    model = WorldModel(
        cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
        ch=3,
        nc=min(self.data["nc"], 80),
        verbose=verbose and RANK == -1,
    )
    if weights:
        model.load(weights)
    self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)

    return model

preprocess_batch(batch)

Предварительно обрабатывает партию изображений для тренировки YOLOWorld, корректируя форматирование и размеры по мере необходимости.

Исходный код в ultralytics/models/yolo/world/train.py
def preprocess_batch(self, batch):
    """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
    batch = super().preprocess_batch(batch)

    # NOTE: add text features
    texts = list(itertools.chain(*batch["texts"]))
    text_token = self.clip.tokenize(texts).to(batch["img"].device)
    txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype)  # torch.float32
    txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
    batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
    return batch



ultralytics.models.yolo.world.train.on_pretrain_routine_end(trainer)

Обратный звонок.

Исходный код в ultralytics/models/yolo/world/train.py
def on_pretrain_routine_end(trainer):
    """Callback."""
    if RANK in {-1, 0}:
        # NOTE: for evaluation
        names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
        de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
    device = next(trainer.model.parameters()).device
    trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device)
    for p in trainer.text_model.parameters():
        p.requires_grad_(False)





Создано 2024-03-31, Обновлено 2024-03-31
Авторы: Laughing-q (1)