高级自定义
Ultralytics YOLO 的命令行界面和 Python 界面都是构建在基础引擎执行器之上的高层抽象。本指南重点介绍 Trainer 引擎,讲解如何根据你的特定需求进行自定义。
Watch: Mastering Ultralytics YOLO: Advanced Customization
有关常见训练器自定义(如自定义指标、类加权损失、模型保存、主干冻结和分层学习率)的实际示例,请参阅 自定义训练器 指南。
BaseTrainer
BaseTrainer 类提供了一个通用的训练流程,可适应各种任务。你可以通过重写特定函数或操作来进行自定义,同时需遵循所需的格式。例如,通过重写以下函数来集成你自己的自定义模型和数据加载器:
get_model(cfg, weights):构建待训练的模型。get_dataloader():构建数据加载器。
有关更多详细信息和源代码,请参阅 BaseTrainer 参考。
DetectionTrainer
以下是如何使用和自定义 Ultralytics YOLO DetectionTrainer 的方法:
from ultralytics.models.yolo.detect import DetectionTrainer
trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # Get the best model自定义 DetectionTrainer
要训练一个未直接支持的自定义检测模型,请重载现有的 get_model 功能:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg=None, weights=None, verbose=True):
"""Loads a custom detection model given configuration and weight files."""
...
trainer = CustomTrainer(overrides={...})
trainer.train()通过修改 损失函数 或添加 回调函数 以每 10 个 epoch 将模型上传到 Google Drive,来进一步自定义训练器。以下是一个示例:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
class MyCustomModel(DetectionModel):
def init_criterion(self):
"""Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
...
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg=None, weights=None, verbose=True):
"""Returns a customized detection model instance configured with specified config and weights."""
return MyCustomModel(...)
# Callback to upload model weights
def log_model(trainer):
"""Logs the path of the last model weight used by the trainer."""
last_weight_path = trainer.last
print(last_weight_path)
trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callbacks
trainer.train()有关回调触发事件和入口点的更多信息,请参阅 回调指南。
其他引擎组件
你可以类似地自定义 Validators 和 Predictors 等其他组件。有关更多信息,请参考 验证器 和 预测器 的文档。
将 YOLO 与自定义训练器结合使用
YOLO 模型类为训练器类提供了一个高层包装器。你可以利用这种架构在机器学习工作流中获得更大的灵活性:
from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
# Create a custom trainer
class MyCustomTrainer(DetectionTrainer):
def get_model(self, cfg=None, weights=None, verbose=True):
"""Custom code implementation."""
...
# Initialize YOLO model
model = YOLO("yolo26n.pt")
# Train with custom trainer
results = model.train(trainer=MyCustomTrainer, data="coco8.yaml", epochs=3)这种方法允许你在保持 YOLO 界面简洁性的同时,自定义底层的训练过程,以满足你的特定需求。
常见问题
我该如何针对特定任务自定义 Ultralytics YOLO DetectionTrainer?
通过重写方法来适配你的自定义模型和数据加载器,从而针对特定任务自定义 DetectionTrainer。首先继承 DetectionTrainer 并重定义如 get_model 等方法来实现自定义功能。以下是一个示例:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg=None, weights=None, verbose=True):
"""Loads a custom detection model given configuration and weight files."""
...
trainer = CustomTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # Get the best model如需进一步自定义(例如更改 损失函数 或添加 回调函数),请参考 回调指南。
Ultralytics YOLO 中 BaseTrainer 的关键组件有哪些?
BaseTrainer 是训练流程的基础,可通过重写其通用方法来适应各种任务。其关键组件包括:
get_model(cfg, weights):构建待训练的模型。get_dataloader():构建数据加载器。preprocess_batch():在模型前向传播前处理批次数据。set_model_attributes():根据数据集信息设置模型属性。get_validator():返回用于模型评估的验证器。
有关自定义和源代码的更多详细信息,请参阅 BaseTrainer 参考。
我该如何向 Ultralytics YOLO DetectionTrainer 添加回调函数?
在 DetectionTrainer 中添加回调函数以监控和修改训练过程。以下是如何添加一个在每个训练 epoch 后记录模型权重的回调函数的示例:
from ultralytics.models.yolo.detect import DetectionTrainer
# Callback to upload model weights
def log_model(trainer):
"""Logs the path of the last model weight used by the trainer."""
last_weight_path = trainer.last
print(last_weight_path)
trainer = DetectionTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callbacks
trainer.train()有关回调事件和入口点的更多详细信息,请参考 回调指南。
为什么我应该使用 Ultralytics YOLO 进行模型训练?
Ultralytics YOLO 为强大的引擎执行器提供了高层抽象,使其成为快速开发和自定义的理想选择。其主要优势包括:
- 易用性:命令行和 Python 界面都简化了复杂任务。
- 性能:针对实时 目标检测 和各种视觉 AI 应用进行了优化。
- 自定义:可轻松扩展以支持自定义模型、损失函数 和数据加载器。
- 模块化:组件可以独立修改,而不影响整个流水线。
- 集成:可与机器学习生态系统中的常用框架和工具无缝协作。
浏览主要的 Ultralytics YOLO 页面,了解更多关于 YOLO 的功能。
我可以使用 Ultralytics YOLO DetectionTrainer 来处理非标准模型吗?
是的,DetectionTrainer 具有高度的灵活性,可以针对非标准模型进行自定义。继承 DetectionTrainer 并重载方法以支持你的特定模型需求。这是一个简单的示例:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomDetectionTrainer(DetectionTrainer):
def get_model(self, cfg=None, weights=None, verbose=True):
"""Loads a custom detection model."""
...
trainer = CustomDetectionTrainer(overrides={...})
trainer.train()有关详细说明和示例,请查看 DetectionTrainer 参考。