Перейти к содержимому

Продвинутая настройка

И интерфейсы Ultralytics YOLO командной строки, и Python - это просто высокоуровневая абстракция на базовых исполнителях движка. Давай посмотрим на движок Trainer.



Смотри: Mastering Ultralytics YOLOv8 : Advanced Customization

BaseTrainer

BaseTrainer содержит общую шаблонную процедуру обучения. Его можно настроить под любую задачу, переопределив нужные функции или операции, при условии соблюдения правильных форматов. Например, ты можешь поддерживать свою собственную модель и dataloader, просто переопределив эти функции:

  • get_model(cfg, weights) - Функция, которая строит обучаемую модель
  • get_dataloader() - Функция, которая строит dataloader Более подробную информацию и исходный код можно найти в BaseTrainer Ссылка

DetectionTrainer

Вот как ты можешь использовать YOLOv8 DetectionTrainer и настроить его под себя.

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # get best model

Настройка тренажера DetectionTrainer

Давай настроим тренера чтобы обучить пользовательскую модель обнаружения которые не поддерживаются напрямую. Ты можешь сделать это, просто перегрузив существующий get_model функциональность:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

Теперь ты понимаешь, что тебе нужно еще больше настроить тренера под себя:

  • Настройте loss function.
  • Добавь callback который загружает модель на твой Google Drive после каждых 10 epochs Вот как ты можешь это сделать:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        return MyCustomModel(...)


# callback to upload model weights
def log_model(trainer):
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callback
trainer.train()

Чтобы узнать больше о событиях, вызывающих обратный вызов, и точке входа, ознакомься с нашим руководством по обратным вызовам.

Другие компоненты двигателя

Есть и другие компоненты, которые можно настроить аналогичным образом, например Validators и Predictors. Подробнее о них читай в разделе Reference.



Создано 2023-11-12, Обновлено 2024-02-03
Авторы: glenn-jocher (4), RizwanMunawar (1), AyushExel (1), Laughing-q (1)

Комментарии