高级自定义
Ultralytics YOLO 命令行和 python 接口都是构建在基础引擎执行器之上的高级抽象。本指南侧重于 Trainer
引擎,解释如何根据您的特定需求进行定制。
观看: 掌握 Ultralytics YOLO:高级自定义
BaseTrainer
字段 BaseTrainer
类提供了一个通用的训练程序,适用于各种任务。通过覆盖特定的函数或操作来定制它,同时遵守所需的格式。例如,通过覆盖以下函数来集成您自己的自定义模型和数据加载器:
get_model(cfg, weights)
: 构建要训练的模型。get_dataloader()
: 构建 dataloader。
有关更多详细信息和源代码,请参见 BaseTrainer
参考.
DetectionTrainer
以下是如何使用和自定义 Ultralytics YOLO DetectionTrainer
:
from ultralytics.models.yolo.detect import DetectionTrainer
trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # Get the best model
自定义 DetectionTrainer
要训练未直接支持的自定义检测模型,请重载现有的 get_model
功能来实现:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
"""Loads a custom detection model given configuration and weight files."""
...
trainer = CustomTrainer(overrides={...})
trainer.train()
通过修改损失函数或添加回调以每 10 个epochs将模型上传到 Google Drive,从而进一步自定义训练器。 这是一个例子:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
class MyCustomModel(DetectionModel):
def init_criterion(self):
"""Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
...
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
"""Returns a customized detection model instance configured with specified config and weights."""
return MyCustomModel(...)
# Callback to upload model weights
def log_model(trainer):
"""Logs the path of the last model weight used by the trainer."""
last_weight_path = trainer.last
print(last_weight_path)
trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callbacks
trainer.train()
有关回调触发事件和入口点的更多信息,请参阅回调指南。
其他引擎组件
自定义其他组件,例如 Validators
和 Predictors
同样。有关更多信息,请参阅以下文档: 验证器 和 预测器.
将 YOLO 与自定义训练器结合使用
字段 YOLO
model 类为 Trainer 类提供了一个高级封装器。您可以利用此架构在机器学习工作流程中获得更大的灵活性:
from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
# Create a custom trainer
class MyCustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
"""Custom code implementation."""
...
# Initialize YOLO model
model = YOLO("yolo11n.pt")
# Train with custom trainer
results = model.train(trainer=MyCustomTrainer, data="coco8.yaml", epochs=3)
这种方法允许您保持 YOLO 接口的简单性,同时自定义底层训练过程以满足您的特定需求。
常见问题
如何为特定任务自定义 Ultralytics YOLO DetectionTrainer?
自定义 DetectionTrainer
针对特定任务,通过重写其方法来适应您的自定义模型和数据加载器。首先从以下继承: DetectionTrainer
并重新定义方法,例如 get_model
以实现自定义功能。这是一个例子:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
"""Loads a custom detection model given configuration and weight files."""
...
trainer = CustomTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # Get the best model
如需进一步的自定义,例如更改损失函数或添加回调,请参阅回调指南。
Ultralytics YOLO 中 BaseTrainer 的关键组件是什么?
字段 BaseTrainer
是训练程序的基础,通过覆盖其通用方法,可以针对各种任务进行定制。主要组件包括:
get_model(cfg, weights)
: 构建要训练的模型。get_dataloader()
: 构建 dataloader。preprocess_batch()
: 在模型正向传递之前处理批量预处理。set_model_attributes()
: 根据数据集信息设置模型属性。get_validator()
: 返回用于模型评估的验证器。
有关自定义和源代码的更多详细信息,请参见 BaseTrainer
参考.
如何向 Ultralytics YOLO DetectionTrainer 添加回调?
在以下位置添加回调以监控和修改训练过程 DetectionTrainer
。以下是如何添加回调以在每次训练后记录模型权重: epoch:
from ultralytics.models.yolo.detect import DetectionTrainer
# Callback to upload model weights
def log_model(trainer):
"""Logs the path of the last model weight used by the trainer."""
last_weight_path = trainer.last
print(last_weight_path)
trainer = DetectionTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callbacks
trainer.train()
有关回调事件和入口点的更多详细信息,请参阅回调指南。
为什么我应该使用 Ultralytics YOLO 进行模型训练?
Ultralytics YOLO 在强大的引擎执行器上提供了一个高级抽象,使其成为快速开发和自定义的理想选择。主要优势包括:
- 易于使用:命令行和 python 接口都简化了复杂的任务。
- 性能:针对实时对象检测和各种视觉 AI 应用进行了优化。
- 自定义:易于扩展,适用于自定义模型、损失函数和数据加载器。
- 模块化: 组件可以独立修改,而不会影响整个流程。
- 集成: 与 ML 生态系统中流行的框架和工具无缝协作。
通过浏览主要的 Ultralytics YOLO 页面,了解更多关于 YOLO 功能的信息。
我可以将 Ultralytics YOLO DetectionTrainer 用于非标准模型吗?
是的, DetectionTrainer
对于非标准模型,具有高度的灵活性和可定制性。继承自 DetectionTrainer
并重载方法以支持您的特定模型的需求。 这是一个简单的例子:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomDetectionTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
"""Loads a custom detection model."""
...
trainer = CustomDetectionTrainer(overrides={...})
trainer.train()
有关全面的说明和示例,请查看 DetectionTrainer
参考.