高级自定义

Ultralytics YOLO 的命令行界面和 Python 界面都是构建在基础引擎执行器之上的高层抽象。本指南重点介绍 Trainer 引擎,讲解如何根据你的特定需求进行自定义。



Watch: Mastering Ultralytics YOLO: Advanced Customization
提示

有关常见训练器自定义(如自定义指标、类加权损失、模型保存、主干冻结和分层学习率)的实际示例,请参阅 自定义训练器 指南。

BaseTrainer

BaseTrainer 类提供了一个通用的训练流程,可适应各种任务。你可以通过重写特定函数或操作来进行自定义,同时需遵循所需的格式。例如,通过重写以下函数来集成你自己的自定义模型和数据加载器:

  • get_model(cfg, weights):构建待训练的模型。
  • get_dataloader():构建数据加载器。

有关更多详细信息和源代码,请参阅 BaseTrainer 参考

DetectionTrainer

以下是如何使用和自定义 Ultralytics YOLO DetectionTrainer 的方法:

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # Get the best model

自定义 DetectionTrainer

要训练一个未直接支持的自定义检测模型,请重载现有的 get_model 功能:

from ultralytics.models.yolo.detect import DetectionTrainer

class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg=None, weights=None, verbose=True):
        """Loads a custom detection model given configuration and weight files."""
        ...

trainer = CustomTrainer(overrides={...})
trainer.train()

通过修改 损失函数 或添加 回调函数 以每 10 个 epoch 将模型上传到 Google Drive,来进一步自定义训练器。以下是一个示例:

from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel

class MyCustomModel(DetectionModel):
    def init_criterion(self):
        """Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
        ...

class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg=None, weights=None, verbose=True):
        """Returns a customized detection model instance configured with specified config and weights."""
        return MyCustomModel(...)

# Callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)

trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callbacks
trainer.train()

有关回调触发事件和入口点的更多信息,请参阅 回调指南

其他引擎组件

你可以类似地自定义 ValidatorsPredictors 等其他组件。有关更多信息,请参考 验证器预测器 的文档。

将 YOLO 与自定义训练器结合使用

YOLO 模型类为训练器类提供了一个高层包装器。你可以利用这种架构在机器学习工作流中获得更大的灵活性:

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer

# Create a custom trainer
class MyCustomTrainer(DetectionTrainer):
    def get_model(self, cfg=None, weights=None, verbose=True):
        """Custom code implementation."""
        ...

# Initialize YOLO model
model = YOLO("yolo26n.pt")

# Train with custom trainer
results = model.train(trainer=MyCustomTrainer, data="coco8.yaml", epochs=3)

这种方法允许你在保持 YOLO 界面简洁性的同时,自定义底层的训练过程,以满足你的特定需求。

常见问题

我该如何针对特定任务自定义 Ultralytics YOLO DetectionTrainer?

通过重写方法来适配你的自定义模型和数据加载器,从而针对特定任务自定义 DetectionTrainer。首先继承 DetectionTrainer 并重定义如 get_model 等方法来实现自定义功能。以下是一个示例:

from ultralytics.models.yolo.detect import DetectionTrainer

class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg=None, weights=None, verbose=True):
        """Loads a custom detection model given configuration and weight files."""
        ...

trainer = CustomTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # Get the best model

如需进一步自定义(例如更改 损失函数 或添加 回调函数),请参考 回调指南

Ultralytics YOLO 中 BaseTrainer 的关键组件有哪些?

BaseTrainer 是训练流程的基础,可通过重写其通用方法来适应各种任务。其关键组件包括:

  • get_model(cfg, weights):构建待训练的模型。
  • get_dataloader():构建数据加载器。
  • preprocess_batch():在模型前向传播前处理批次数据。
  • set_model_attributes():根据数据集信息设置模型属性。
  • get_validator():返回用于模型评估的验证器。

有关自定义和源代码的更多详细信息,请参阅 BaseTrainer 参考

我该如何向 Ultralytics YOLO DetectionTrainer 添加回调函数?

DetectionTrainer 中添加回调函数以监控和修改训练过程。以下是如何添加一个在每个训练 epoch 后记录模型权重的回调函数的示例:

from ultralytics.models.yolo.detect import DetectionTrainer

# Callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)

trainer = DetectionTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callbacks
trainer.train()

有关回调事件和入口点的更多详细信息,请参考 回调指南

为什么我应该使用 Ultralytics YOLO 进行模型训练?

Ultralytics YOLO 为强大的引擎执行器提供了高层抽象,使其成为快速开发和自定义的理想选择。其主要优势包括:

  • 易用性:命令行和 Python 界面都简化了复杂任务。
  • 性能:针对实时 目标检测 和各种视觉 AI 应用进行了优化。
  • 自定义:可轻松扩展以支持自定义模型、损失函数 和数据加载器。
  • 模块化:组件可以独立修改,而不影响整个流水线。
  • 集成:可与机器学习生态系统中的常用框架和工具无缝协作。

浏览主要的 Ultralytics YOLO 页面,了解更多关于 YOLO 的功能。

我可以使用 Ultralytics YOLO DetectionTrainer 来处理非标准模型吗?

是的,DetectionTrainer 具有高度的灵活性,可以针对非标准模型进行自定义。继承 DetectionTrainer 并重载方法以支持你的特定模型需求。这是一个简单的示例:

from ultralytics.models.yolo.detect import DetectionTrainer

class CustomDetectionTrainer(DetectionTrainer):
    def get_model(self, cfg=None, weights=None, verbose=True):
        """Loads a custom detection model."""
        ...

trainer = CustomDetectionTrainer(overrides={...})
trainer.train()

有关详细说明和示例,请查看 DetectionTrainer 参考

评论