回调函数 (Callbacks)

Ultralytics 框架支持回调函数,它们作为 trainvalexportpredict 模式中关键阶段的入口点。每个回调函数都会接收一个 TrainerValidatorPredictor 对象,具体取决于操作类型。这些对象的所有属性都在文档的 参考章节 中有详细说明。



Watch: How to use Ultralytics Callbacks | Predict, Train, Validate and Export Callbacks | Ultralytics YOLO🚀

示例

在预测时返回额外信息

在此示例中,我们演示了如何随每个结果对象一起返回原始帧:

from ultralytics import YOLO

def on_predict_batch_end(predictor):
    """Combine prediction results with corresponding frames."""
    _, image, _, _ = predictor.batch

    # Ensure that image is a list
    image = image if isinstance(image, list) else [image]

    # Combine the prediction results with the corresponding frames
    predictor.results = zip(predictor.results, image)

# Create a YOLO model instance
model = YOLO("yolo26n.pt")

# Add the custom callback to the model
model.add_callback("on_predict_batch_end", on_predict_batch_end)

# Iterate through the results and frames
for result, frame in model.predict():  # or model.track()
    pass

使用 on_model_save 回调访问模型指标

此示例展示了如何在保存检查点后使用 on_model_save 回调来获取训练详情,例如 best_fitness 分数、total_loss 以及其他指标。

from ultralytics import YOLO

# Load a YOLO model
model = YOLO("yolo26n.pt")

def print_checkpoint_metrics(trainer):
    """Print trainer metrics and loss details after each checkpoint is saved."""
    print(
        f"Model details\n"
        f"Best fitness: {trainer.best_fitness}, "
        f"Loss names: {trainer.loss_names}, "  # List of loss names
        f"Metrics: {trainer.metrics}, "
        f"Total loss: {trainer.tloss}"  # Total loss value
    )

if __name__ == "__main__":
    # Add on_model_save callback.
    model.add_callback("on_model_save", print_checkpoint_metrics)

    # Run model training on custom dataset.
    results = model.train(data="coco8.yaml", epochs=3)

所有回调函数

以下是所有受支持的回调函数。有关更多详细信息,请参阅回调 源代码

训练器 (Trainer) 回调

回调函数描述
on_pretrain_routine_start在预训练例程开始时、数据加载和模型设置之前触发。
on_pretrain_routine_end在预训练例程结束时、数据加载和模型设置完成后触发。
on_train_start在训练开始时、第一个 epoch 开始之前触发。
on_train_epoch_start在每个训练 epoch 开始时、批次迭代之前触发。
on_train_batch_start在每个训练批次开始时、前向传播之前触发。
optimizer_step在优化器步骤期间触发。保留用于自定义集成;默认训练循环不会调用它。
on_before_zero_grad在梯度清零之前触发。保留用于自定义集成;默认训练循环不会调用它。
on_train_batch_end在每个训练批次结束时、反向传播之后触发。由于梯度累积,优化器步骤可能会被推迟。
on_train_epoch_end在每个训练 epoch 结束时、所有批次处理完但 验证之前触发。此时验证指标和 fitness 可能尚不可用。
on_model_save在保存模型检查点时、验证之后触发。
on_fit_epoch_end在每个拟合 epoch (train + val) 结束时、验证和任何检查点保存 之后 触发。此时验证指标可用,对于每个 epoch 的训练调用,fitness 也是可用的。此回调也会在最终最佳模型评估期间调用,此时不会保存检查点,且可能不存在 fitness。
on_train_end在训练过程结束时、最佳模型最终评估之后触发。
on_params_update在模型参数更新时触发。保留用于自定义集成;默认训练循环不会调用它。
teardown在训练过程进行清理时触发。

验证器 (Validator) 回调

回调函数描述
on_val_start在验证开始时触发。
on_val_batch_start在每个验证批次开始时触发。
on_val_batch_end在每个验证批次结束时触发。
on_val_end在验证结束时触发。

预测器 (Predictor) 回调

回调函数描述
on_predict_start在预测过程开始时触发。
on_predict_batch_start在每个预测批次开始时触发。
on_predict_postprocess_end在预测后处理结束时触发。
on_predict_batch_end在每个预测批次结束时触发。
on_predict_end在预测过程结束时触发。

导出器 (Exporter) 回调

回调函数描述
on_export_start在导出过程开始时触发。
on_export_end在导出过程结束时触发。

常见问题 (FAQ)

什么是 Ultralytics 回调函数,我该如何使用它们?

Ultralytics 回调函数是专门的入口点,在模型操作的关键阶段(如训练、验证、导出和预测)触发。这些回调函数可以在流程中的特定点启用自定义功能,从而允许对工作流程进行增强和修改。每个回调函数都会接收一个 TrainerValidatorPredictor 对象,具体取决于操作类型。有关这些对象的详细属性,请参阅 参考章节

要使用回调函数,请定义一个函数并使用 model.add_callback() 方法将其添加到模型中。以下是一个在预测期间返回额外信息的示例:

from ultralytics import YOLO

def on_predict_batch_end(predictor):
    """Handle prediction batch end by combining results with corresponding frames; modifies predictor results."""
    _, image, _, _ = predictor.batch
    image = image if isinstance(image, list) else [image]
    predictor.results = zip(predictor.results, image)

model = YOLO("yolo26n.pt")
model.add_callback("on_predict_batch_end", on_predict_batch_end)
for result, frame in model.predict():
    pass

如何使用回调函数自定义 Ultralytics 训练例程?

通过在训练过程的特定阶段注入逻辑,来自定义你的 Ultralytics 训练例程。Ultralytics YOLO 提供了各种训练回调函数,例如 on_train_starton_train_endon_train_batch_end,它们允许你添加自定义指标、处理流程或日志记录。

以下是如何在使用回调函数冻结层时冻结 BatchNorm 统计信息的方法:

from ultralytics import YOLO

# Add a callback to put the frozen layers in eval mode to prevent BN values from changing
def put_in_eval_mode(trainer):
    n_layers = trainer.args.freeze
    if not isinstance(n_layers, int):
        return

    for i, (name, module) in enumerate(trainer.model.named_modules()):
        if name.endswith("bn") and int(name.split(".")[1]) < n_layers:
            module.eval()
            module.track_running_stats = False

model = YOLO("yolo26n.pt")
model.add_callback("on_train_epoch_start", put_in_eval_mode)
model.train(data="coco.yaml", epochs=10)

有关有效使用训练回调的更多详细信息,请参阅 训练指南

为什么要在 Ultralytics YOLO 的验证过程中使用回调函数?

在 Ultralytics YOLO 的验证过程中使用回调函数,通过启用自定义处理、日志记录或指标计算来增强模型评估。诸如 on_val_starton_val_batch_endon_val_end 之类的回调函数提供了注入自定义逻辑的入口点,确保了详细且全面的验证过程。

例如,如果要绘制所有验证批次而不仅仅是前三个:

import inspect

from ultralytics import YOLO

def plot_samples(validator):
    frame = inspect.currentframe().f_back.f_back
    v = frame.f_locals
    validator.plot_val_samples(v["batch"], v["batch_i"])
    validator.plot_predictions(v["batch"], v["preds"], v["batch_i"])

model = YOLO("yolo26n.pt")
model.add_callback("on_val_batch_end", plot_samples)
model.val(data="coco.yaml")

有关将回调函数整合到验证过程中的更多见解,请参阅 验证指南

如何为 Ultralytics YOLO 中的预测模式附加自定义回调函数?

要为 Ultralytics YOLO 中的预测模式附加自定义回调函数,请定义一个回调函数并将其注册到预测流程中。常见的预测回调函数包括 on_predict_starton_predict_batch_endon_predict_end。它们允许修改预测输出并集成额外的功能,例如数据日志记录或结果转换。

以下是一个示例,其中自定义回调函数根据是否存在特定类的对象来保存预测结果:

from ultralytics import YOLO

model = YOLO("yolo26n.pt")

class_id = 2

def save_on_object(predictor):
    r = predictor.results[0]
    if class_id in r.boxes.cls:
        predictor.args.save = True
    else:
        predictor.args.save = False

model.add_callback("on_predict_postprocess_end", save_on_object)
results = model("pedestrians.mp4", stream=True, save=True)

for results in results:
    pass

有关更全面的使用说明,请参阅 预测指南,其中包含详细说明和额外的自定义选项。

使用 Ultralytics YOLO 中的回调函数有哪些实际示例?

Ultralytics YOLO 支持各种回调函数的实际实现,以增强和自定义训练、验证和预测等不同阶段。一些实际示例包括:

  • 记录自定义指标:在不同阶段记录额外的指标,例如在训练或验证 epochs 结束时。
  • 数据增强:在预测或训练批次期间实施自定义数据转换或增强。
  • 中间结果:保存中间结果(如预测或帧),以便进行进一步的分析或可视化。

示例:在预测期间使用 on_predict_batch_end 将帧与预测结果合并:

from ultralytics import YOLO

def on_predict_batch_end(predictor):
    """Combine prediction results with frames."""
    _, image, _, _ = predictor.batch
    image = image if isinstance(image, list) else [image]
    predictor.results = zip(predictor.results, image)

model = YOLO("yolo26n.pt")
model.add_callback("on_predict_batch_end", on_predict_batch_end)
for result, frame in model.predict():
    pass

浏览 回调源代码 以获取更多选项和示例。

评论