高级定制
Ultralytics YOLO 命令行接口和Python 接口都是建立在基础引擎执行器之上的高级抽象。本指南重点介绍 Trainer
引擎,并解释如何根据您的具体需求对其进行定制。
观看: 掌握Ultralytics YOLO :高级定制
基础培训师
"(《世界人权宣言》) BaseTrainer
类为各种任务提供了通用的训练例程。通过重载特定的函数或操作来定制它,同时遵守所需的格式。例如,通过重载这些函数,整合您自己的自定义模型和数据加载器:
get_model(cfg, weights)
:建立要训练的模型。get_dataloader()
:构建数据加载器。
有关详细信息和源代码,请参见 BaseTrainer
参考资料.
探测训练器
下面介绍如何使用和定制Ultralytics YOLO DetectionTrainer
:
from ultralytics.models.yolo.detect import DetectionTrainer
trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # Get the best model
自定义检测训练器
要训练不直接支持的自定义检测模型,请重载现有的 get_model
功能性:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
"""Loads a custom detection model given configuration and weight files."""
...
trainer = CustomTrainer(overrides={...})
trainer.train()
通过修改损失函数或添加回调,每 10 个周期将模型上传到Google Drive,进一步定制训练器。下面是一个例子:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
class MyCustomModel(DetectionModel):
def init_criterion(self):
"""Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
...
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
"""Returns a customized detection model instance configured with specified config and weights."""
return MyCustomModel(...)
# Callback to upload model weights
def log_model(trainer):
"""Logs the path of the last model weight used by the trainer."""
last_weight_path = trainer.last
print(last_weight_path)
trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callbacks
trainer.train()
有关回调触发事件和入口点的更多信息,请参阅回调指南。
其他发动机部件
自定义其他组件,如 Validators
和 Predictors
类似地。更多信息,请参阅 验证器 和 预测因素.
与定制培训师一起使用YOLO
"(《世界人权宣言》) YOLO
模型类为 Trainer 类提供了一个高级封装。您可以利用这一架构,在机器学习工作流中获得更大的灵活性:
from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
# Create a custom trainer
class MyCustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
"""Custom code implementation."""
...
# Initialize YOLO model
model = YOLO("yolo11n.pt")
# Train with custom trainer
results = model.train(trainer=MyCustomTrainer, data="coco8.yaml", epochs=3)
这种方法使您既能保持YOLO 界面的简洁性,又能根据您的具体要求定制基础培训流程。
常见问题
如何为特定任务定制Ultralytics YOLO DetectionTrainer?
自定义 DetectionTrainer
方法,以适应您的自定义模型和数据加载器。首先从 DetectionTrainer
等方法,并重新定义 get_model
来实现自定义功能。下面是一个例子:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
"""Loads a custom detection model given configuration and weight files."""
...
trainer = CustomTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # Get the best model
如需进一步定制,如更改损失函数或添加回调,请参阅回调 指南。
Ultralytics YOLO 中的 BaseTrainer 有哪些关键组件?
"(《世界人权宣言》) BaseTrainer
作为训练程序的基础,可通过覆盖其通用方法为各种任务进行定制。主要组件包括
get_model(cfg, weights)
:建立要训练的模型。get_dataloader()
:构建数据加载器。preprocess_batch()
:处理模型前向传递前的批量预处理。set_model_attributes()
:根据数据集信息设置模型属性。get_validator()
:返回模型评估的验证器。
有关定制和源代码的更多详情,请参阅 BaseTrainer
参考资料.
如何向Ultralytics YOLO DetectionTrainer 添加回调?
中添加回调,以监控和修改训练过程。 DetectionTrainer
.以下是如何在每次训练后添加回调以记录模型权重的方法 纪元:
from ultralytics.models.yolo.detect import DetectionTrainer
# Callback to upload model weights
def log_model(trainer):
"""Logs the path of the last model weight used by the trainer."""
last_weight_path = trainer.last
print(last_weight_path)
trainer = DetectionTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callbacks
trainer.train()
有关回调事件和入口点的更多详情,请参阅《回调指南》。
为什么要使用Ultralytics YOLO 进行模型训练?
Ultralytics YOLO 为强大的引擎执行器提供了高级抽象,是快速开发和定制的理想选择。主要优势包括
- 易用性:命令行和Python 界面均可简化复杂任务。
- 性能针对实时物体检测和各种视觉人工智能应用进行了优化。
- 自定义:可轻松扩展自定义模型、损失函数和数据加载器。
- 模块化:可独立修改组件,而不会影响整个管道。
- 集成:与 ML 生态系统中流行的框架和工具无缝协作。
了解有关YOLO 功能的更多信息,请浏览主页面 Ultralytics YOLO页面。
我可以将Ultralytics YOLO DetectionTrainer 用于非标准模型吗?
是的 DetectionTrainer
高度灵活,可为非标准模型定制。继承自 DetectionTrainer
和重载方法,以支持您的特定模型需求。下面是一个简单的例子:
from ultralytics.models.yolo.detect import DetectionTrainer
class CustomDetectionTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
"""Loads a custom detection model."""
...
trainer = CustomDetectionTrainer(overrides={...})
trainer.train()
有关全面的说明和示例,请查阅 DetectionTrainer
参考资料.