跳至内容

高级定制

Ultralytics YOLO 命令行接口和Python 接口都是建立在基础引擎执行器之上的高级抽象。本指南重点介绍 Trainer 引擎,并解释如何根据您的具体需求对其进行定制。



观看: 掌握Ultralytics YOLO :高级定制

基础培训师

"(《世界人权宣言》) BaseTrainer 类为各种任务提供了通用的训练例程。通过重载特定的函数或操作来定制它,同时遵守所需的格式。例如,通过重载这些函数,整合您自己的自定义模型和数据加载器:

  • get_model(cfg, weights):建立要训练的模型。
  • get_dataloader():构建数据加载器。

有关详细信息和源代码,请参见 BaseTrainer 参考资料.

探测训练器

下面介绍如何使用和定制Ultralytics YOLO DetectionTrainer:

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # Get the best model

自定义检测训练器

要训练不直接支持的自定义检测模型,请重载现有的 get_model 功能性:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

通过修改损失函数或添加回调,每 10 个周期将模型上传到Google Drive,进一步定制训练器。下面是一个例子:

from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        """Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Returns a customized detection model instance configured with specified config and weights."""
        return MyCustomModel(...)


# Callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callbacks
trainer.train()

有关回调触发事件和入口点的更多信息,请参阅回调指南

其他发动机部件

自定义其他组件,如 ValidatorsPredictors 类似地。更多信息,请参阅 验证器预测因素.

与定制培训师一起使用YOLO

"(《世界人权宣言》) YOLO 模型类为 Trainer 类提供了一个高级封装。您可以利用这一架构,在机器学习工作流中获得更大的灵活性:

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer


# Create a custom trainer
class MyCustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Custom code implementation."""
        ...


# Initialize YOLO model
model = YOLO("yolo11n.pt")

# Train with custom trainer
results = model.train(trainer=MyCustomTrainer, data="coco8.yaml", epochs=3)

这种方法使您既能保持YOLO 界面的简洁性,又能根据您的具体要求定制基础培训流程。

常见问题

如何为特定任务定制Ultralytics YOLO DetectionTrainer?

自定义 DetectionTrainer 方法,以适应您的自定义模型和数据加载器。首先从 DetectionTrainer 等方法,并重新定义 get_model 来实现自定义功能。下面是一个例子:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model given configuration and weight files."""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # Get the best model

如需进一步定制,如更改损失函数或添加回调,请参阅回调 指南

Ultralytics YOLO 中的 BaseTrainer 有哪些关键组件?

"(《世界人权宣言》) BaseTrainer 作为训练程序的基础,可通过覆盖其通用方法为各种任务进行定制。主要组件包括

  • get_model(cfg, weights):建立要训练的模型。
  • get_dataloader():构建数据加载器。
  • preprocess_batch():处理模型前向传递前的批量预处理。
  • set_model_attributes():根据数据集信息设置模型属性。
  • get_validator():返回模型评估的验证器。

有关定制和源代码的更多详情,请参阅 BaseTrainer 参考资料.

如何向Ultralytics YOLO DetectionTrainer 添加回调?

中添加回调,以监控和修改训练过程。 DetectionTrainer.以下是如何在每次训练后添加回调以记录模型权重的方法 纪元:

from ultralytics.models.yolo.detect import DetectionTrainer


# Callback to upload model weights
def log_model(trainer):
    """Logs the path of the last model weight used by the trainer."""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = DetectionTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # Adds to existing callbacks
trainer.train()

有关回调事件和入口点的更多详情,请参阅《回调指南》

为什么要使用Ultralytics YOLO 进行模型训练?

Ultralytics YOLO 为强大的引擎执行器提供了高级抽象,是快速开发和定制的理想选择。主要优势包括

  • 易用性:命令行和Python 界面均可简化复杂任务。
  • 性能针对实时物体检测和各种视觉人工智能应用进行了优化。
  • 自定义:可轻松扩展自定义模型、损失函数和数据加载器。
  • 模块化:可独立修改组件,而不会影响整个管道。
  • 集成:与 ML 生态系统中流行的框架和工具无缝协作。

了解有关YOLO 功能的更多信息,请浏览主页面 Ultralytics YOLO页面。

我可以将Ultralytics YOLO DetectionTrainer 用于非标准模型吗?

是的 DetectionTrainer 高度灵活,可为非标准模型定制。继承自 DetectionTrainer 和重载方法,以支持您的特定模型需求。下面是一个简单的例子:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomDetectionTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """Loads a custom detection model."""
        ...


trainer = CustomDetectionTrainer(overrides={...})
trainer.train()

有关全面的说明和示例,请查阅 DetectionTrainer 参考资料.

📅创建于 1 年前 ✏️已更新 5 天前

评论