跳转至内容

高级自定义

Ultralytics YOLO 命令行和 python 接口都是构建在基础引擎执行器之上的高级抽象。本指南侧重于 Trainer 引擎,解释如何根据您的特定需求进行定制。



观看: 掌握 Ultralytics YOLO:高级自定义

BaseTrainer

字段 BaseTrainer 类提供了一个通用的训练程序,适用于各种任务。通过覆盖特定的函数或操作来定制它,同时遵守所需的格式。例如,通过覆盖以下函数来集成您自己的自定义模型和数据加载器:

  • get_model(cfg, weights): 构建要训练的模型。
  • get_dataloader(): 构建 dataloader。

有关更多详细信息和源代码,请参见 BaseTrainer 参考.

DetectionTrainer

以下是如何使用和自定义 Ultralytics YOLO DetectionTrainer:

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # Get the best model

自定义 DetectionTrainer

要训练未直接支持的自定义检测模型,请重载现有的 get_model 功能来实现:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

通过修改损失函数或添加回调以每 10 个epochs将模型上传到 Google Drive,从而进一步自定义训练器。 这是一个例子:

from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        """Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Returns a customized detection model instance configured with specified config and weights."""
        return MyCustomModel(...)


# Callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callbacks
trainer.train()

有关回调触发事件和入口点的更多信息,请参阅回调指南

其他引擎组件

自定义其他组件,例如 ValidatorsPredictors 同样。有关更多信息,请参阅以下文档: 验证器预测器.

将 YOLO 与自定义训练器结合使用

字段 YOLO model 类为 Trainer 类提供了一个高级封装器。您可以利用此架构在机器学习工作流程中获得更大的灵活性:

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer


# Create a custom trainer
class MyCustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Custom code implementation."""
        ...


# Initialize YOLO model
model = YOLO("yolo11n.pt")

# Train with custom trainer
results = model.train(trainer=MyCustomTrainer, data="coco8.yaml", epochs=3)

这种方法允许您保持 YOLO 接口的简单性,同时自定义底层训练过程以满足您的特定需求。

常见问题

如何为特定任务自定义 Ultralytics YOLO DetectionTrainer?

自定义 DetectionTrainer 针对特定任务,通过重写其方法来适应您的自定义模型和数据加载器。首先从以下继承: DetectionTrainer 并重新定义方法,例如 get_model 以实现自定义功能。这是一个例子:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # Get the best model

如需进一步的自定义,例如更改损失函数或添加回调,请参阅回调指南

Ultralytics YOLO 中 BaseTrainer 的关键组件是什么?

字段 BaseTrainer 是训练程序的基础,通过覆盖其通用方法,可以针对各种任务进行定制。主要组件包括:

  • get_model(cfg, weights): 构建要训练的模型。
  • get_dataloader(): 构建 dataloader。
  • preprocess_batch(): 在模型正向传递之前处理批量预处理。
  • set_model_attributes(): 根据数据集信息设置模型属性。
  • get_validator(): 返回用于模型评估的验证器。

有关自定义和源代码的更多详细信息,请参见 BaseTrainer 参考.

如何向 Ultralytics YOLO DetectionTrainer 添加回调?

在以下位置添加回调以监控和修改训练过程 DetectionTrainer。以下是如何添加回调以在每次训练后记录模型权重: epoch:

from ultralytics.models.yolo.detect import DetectionTrainer


# Callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = DetectionTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callbacks
trainer.train()

有关回调事件和入口点的更多详细信息,请参阅回调指南

为什么我应该使用 Ultralytics YOLO 进行模型训练?

Ultralytics YOLO 在强大的引擎执行器上提供了一个高级抽象,使其成为快速开发和自定义的理想选择。主要优势包括:

  • 易于使用:命令行和 python 接口都简化了复杂的任务。
  • 性能:针对实时对象检测和各种视觉 AI 应用进行了优化。
  • 自定义:易于扩展,适用于自定义模型、损失函数和数据加载器。
  • 模块化: 组件可以独立修改,而不会影响整个流程。
  • 集成: 与 ML 生态系统中流行的框架和工具无缝协作。

通过浏览主要的 Ultralytics YOLO 页面,了解更多关于 YOLO 功能的信息。

我可以将 Ultralytics YOLO DetectionTrainer 用于非标准模型吗?

是的, DetectionTrainer 对于非标准模型,具有高度的灵活性和可定制性。继承自 DetectionTrainer 并重载方法以支持您的特定模型的需求。 这是一个简单的例子:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomDetectionTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model."""
        ...


trainer = CustomDetectionTrainer(overrides={...})
trainer.train()

有关全面的说明和示例,请查看 DetectionTrainer 参考.



📅 创建于 1 年前 ✏️ 更新于 5 个月前

评论