高级自定义

Ultralytics YOLO 的命令行界面和 Python 接口都是构建在基础引擎执行器之上的高层抽象。本指南重点介绍 Trainer 引擎,并说明如何根据你的特定需求对其进行自定义。



Watch: Mastering Ultralytics YOLO: Advanced Customization
提示

有关常见训练器自定义的实用示例——如自定义指标、类加权损失、模型保存、主干冻结和分层学习率——请参阅 自定义训练器 指南。

BaseTrainer

BaseTrainer 类提供了一个适用于各种任务的通用训练例程。通过重写特定的函数或操作,同时遵循要求的格式,即可对其进行自定义。例如,通过重写以下函数来集成你自己的自定义模型和数据加载器:

  • get_model(cfg, weights):构建要训练的模型。
  • get_dataloader():构建数据加载器。

有关更多详细信息和源代码,请参阅 BaseTrainer 参考文档

DetectionTrainer

以下是如何使用和自定义 Ultralytics YOLO DetectionTrainer 的方法:

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # Get the best model

自定义 DetectionTrainer

要训练一个未直接支持的自定义检测模型,请重载现有的 get_model 功能:

from ultralytics.models.yolo.detect import DetectionTrainer

class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...

trainer = CustomTrainer(overrides={...})
trainer.train()

通过修改 损失函数 或添加 回调函数 以便每 10 个 epoch 将模型上传到 Google Drive,来进一步自定义训练器。这是一个示例:

from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel

class MyCustomModel(DetectionModel):
    def init_criterion(self):
        """Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
        ...

class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Returns a customized detection model instance configured with specified config and weights."""
        return MyCustomModel(...)

# Callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)

trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callbacks
trainer.train()

有关回调触发事件和入口点的更多信息,请参阅 回调指南

其他引擎组件

你可以类似地自定义 ValidatorsPredictors 等其他组件。有关更多信息,请参考 ValidatorsPredictors 的文档。

将 YOLO 与自定义训练器结合使用

YOLO 模型类为训练器类提供了一个高层包装器。你可以利用这种架构在机器学习工作流中获得更大的灵活性:

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer

# Create a custom trainer
class MyCustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Custom code implementation."""
        ...

# Initialize YOLO model
model = YOLO("yolo26n.pt")

# Train with custom trainer
results = model.train(trainer=MyCustomTrainer, data="coco8.yaml", epochs=3)

这种方法允许你在保持 YOLO 接口简洁性的同时,自定义底层的训练过程,以满足你的特定需求。

常见问题 (FAQ)

如何为特定任务自定义 Ultralytics YOLO DetectionTrainer?

通过重写 DetectionTrainer 的方法以适配你的自定义模型和数据加载器,从而针对特定任务进行自定义。从继承 DetectionTrainer 开始,并重新定义诸如 get_model 等方法来实现自定义功能。这是一个示例:

from ultralytics.models.yolo.detect import DetectionTrainer

class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...

trainer = CustomTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # Get the best model

如需进一步自定义,例如更改 损失函数 或添加 回调函数,请参考 回调指南

Ultralytics YOLO 中 BaseTrainer 的关键组件是什么?

BaseTrainer 是训练例程的基础,通过重写其通用方法,可以为各种任务进行自定义。关键组件包括:

  • get_model(cfg, weights):构建要训练的模型。
  • get_dataloader():构建数据加载器。
  • preprocess_batch():处理模型前向传播之前的批次预处理。
  • set_model_attributes():根据数据集信息设置模型属性。
  • get_validator():返回一个用于模型评估的验证器。

有关自定义和源代码的更多详细信息,请参阅 BaseTrainer 参考文档

如何向 Ultralytics YOLO DetectionTrainer 添加回调函数?

添加回调函数以监控和修改 DetectionTrainer 中的训练过程。以下是如何添加回调函数,以便在每个训练 epoch 后记录模型权重的示例:

from ultralytics.models.yolo.detect import DetectionTrainer

# Callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)

trainer = DetectionTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callbacks
trainer.train()

有关回调事件和入口点的更多详细信息,请参考 回调指南

为什么要使用 Ultralytics YOLO 进行模型训练?

Ultralytics YOLO 对强大的引擎执行器提供了高层抽象,使其非常适合快速开发和自定义。主要优势包括:

  • 易用性:命令行界面和 Python 接口都简化了复杂的任务。
  • 性能:针对实时 目标检测 和各种视觉 AI 应用进行了优化。
  • 自定义:可轻松扩展以支持自定义模型、损失函数 和数据加载器。
  • 模块化:组件可以独立修改,不会影响整个管道。
  • 集成:与 ML 生态系统中的主流框架和工具无缝协作。

访问主要的 Ultralytics YOLO 页面,了解更多关于 YOLO 功能的信息。

我可以使用 Ultralytics YOLO DetectionTrainer 处理非标准模型吗?

可以,DetectionTrainer 具有高度的灵活性,可以针对非标准模型进行自定义。继承 DetectionTrainer 并重载方法以支持你特定模型的需求。这是一个简单的示例:

from ultralytics.models.yolo.detect import DetectionTrainer

class CustomDetectionTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model."""
        ...

trainer = CustomDetectionTrainer(overrides={...})
trainer.train()

如需详细的说明和示例,请查看 DetectionTrainer 参考文档

评论