Meet YOLO26: next-gen vision AI.

Link to this section自定义 Trainer#

Ultralytics 训练流程围绕 BaseTrainer 和诸如 DetectionTrainer 之类的特定任务训练器构建。这些类开箱即用地处理训练循环、验证、检查点保存和日志记录。当你需要更多控制权(如跟踪自定义指标、调整损失加权或实现学习率调度)时,可以继承训练器并覆盖特定方法。

本指南介绍了七种常见的自定义操作:

  1. Logging custom metrics (F1 score) at the end of each epoch
  2. 添加类别权重以处理类别不平衡
  3. 根据不同指标保存最佳模型
  4. 冻结主干网络前 N 个 epoch,然后解冻
  5. 指定各层学习率
  6. 在多 GPU 训练中同步 BatchNorm
  7. 配置梯度裁剪以进行稳定性调优
前提条件

在阅读本指南之前,请确保你熟悉 YOLO 模型训练 的基础知识以及 高级自定义 页面,该页面涵盖了 BaseTrainer 架构。

Link to this section自定义训练器的工作原理#

The YOLO model class accepts a trainer parameter in the train() method. This allows you to pass your own trainer class that extends the default behavior:

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer

class CustomTrainer(DetectionTrainer):
    """A custom trainer that extends DetectionTrainer with additional functionality."""

    pass  # Add your customizations here

model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=10, trainer=CustomTrainer)

你的自定义训练器继承了 DetectionTrainer 的所有功能,因此你只需覆盖想要自定义的特定方法即可。

Link to this section记录自定义指标#

验证步骤会计算 精确率 (precision)召回率 (recall)mAP。如果你需要额外的指标(如各类的 F1 分数),请覆盖 validate()

import numpy as np

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils import LOGGER

class MetricsTrainer(DetectionTrainer):
    """Custom trainer that computes and logs F1 score at the end of each epoch."""

    def validate(self):
        """Run validation and compute per-class F1 scores."""
        metrics, fitness = super().validate()
        if metrics is None:
            return metrics, fitness

        if hasattr(self.validator, "metrics") and hasattr(self.validator.metrics, "box"):
            box = self.validator.metrics.box
            f1_per_class = box.f1
            class_indices = box.ap_class_index
            names = self.validator.names

            valid_f1 = f1_per_class[f1_per_class > 0]
            mean_f1 = np.mean(valid_f1) if len(valid_f1) > 0 else 0.0

            LOGGER.info(f"Mean F1 Score: {mean_f1:.4f}")
            per_class_str = [
                f"{names[i]}: {f1_per_class[j]:.3f}" for j, i in enumerate(class_indices) if f1_per_class[j] > 0
            ]
            LOGGER.info(f"Per-class F1: {per_class_str}")

        return metrics, fitness

model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=5, trainer=MetricsTrainer)

这将在每次验证运行后记录所有类别的平均 F1 分数以及各类的详细指标。

可用指标

验证器通过 self.validator.metrics.box 提供对许多指标的访问:

属性描述
f1各类别的 F1 分数
image_metrics包含精确率、召回率、F1、TP、FP 和 FN 的单图像指标字典
p各类别的精确率
r各类别的召回率
ap50IoU 为 0.5 时各类别的 AP
apIoU 为 0.5:0.95 时各类别的 AP
mp, mr平均精确率和召回率
map50, map平均 AP 指标

Link to this section添加类别权重#

如果你的数据集存在类别不平衡(例如,制造检查中的罕见缺陷),你可以提高 损失函数 中代表性不足类别的权重。这会使模型更严厉地惩罚罕见类别的分类错误。

要自定义损失函数,请继承损失类、模型和训练器:

import torch
from torch import nn

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import RANK
from ultralytics.utils.loss import E2ELoss, v8DetectionLoss

class WeightedDetectionLoss(v8DetectionLoss):
    """Detection loss with class weights applied to BCE classification loss."""

    def __init__(self, model, class_weights=None, tal_topk=10, tal_topk2=None):
        """Initialize loss with optional per-class weights for BCE."""
        super().__init__(model, tal_topk=tal_topk, tal_topk2=tal_topk2)
        if class_weights is not None:
            self.bce = nn.BCEWithLogitsLoss(
                pos_weight=class_weights.to(self.device),
                reduction="none",
            )

class WeightedE2ELoss(E2ELoss):
    """E2E Loss with class weights for YOLO26."""

    def __init__(self, model, class_weights=None):
        """Initialize E2E loss with weighted detection loss."""

        def weighted_loss_fn(model, tal_topk=10, tal_topk2=None):
            return WeightedDetectionLoss(model, class_weights=class_weights, tal_topk=tal_topk, tal_topk2=tal_topk2)

        super().__init__(model, loss_fn=weighted_loss_fn)

class WeightedDetectionModel(DetectionModel):
    """Detection model that uses class-weighted loss."""

    def init_criterion(self):
        """Initialize weighted loss criterion with per-class weights."""
        class_weights = torch.ones(self.nc)
        class_weights[0] = 2.0  # upweight class 0
        class_weights[1] = 3.0  # upweight rare class 1
        return WeightedE2ELoss(self, class_weights=class_weights)

class WeightedTrainer(DetectionTrainer):
    """Trainer that returns a WeightedDetectionModel."""

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Return a WeightedDetectionModel."""
        model = WeightedDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model

model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=10, trainer=WeightedTrainer)
从数据集计算权重

你可以根据数据集的标签分布自动计算类别权重。一种常见的方法是逆频率加权:

import numpy as np

# class_counts: number of instances per class
class_counts = np.array([5000, 200, 3000])
# Inverse frequency: rarer classes get higher weight
class_weights = max(class_counts) / class_counts
# Result: [1.0, 25.0, 1.67]
加载带有自定义类的模型

诸如 WeightedDetectionModel 之类的自定义类会通过引用存储在 checkpoint 中。当在训练脚本中定义时,它们属于 __main__ 模块,因此从不同的脚本加载 best.pt 会引发 AttributeError: Can't get attribute 'WeightedDetectionModel' on <module '__main__'> 错误。

请在专用模块中定义自定义类以确保它们可被导入,并确保该模块在加载时位于你的 PYTHONPATH 中。

# weighted_model.py
from ultralytics.nn.tasks import DetectionModel

class WeightedDetectionModel(DetectionModel):
    """Detection model that uses class-weighted loss."""

    ...
# inference script
from weighted_model import WeightedDetectionModel  # noqa: F401 - must be importable at checkpoint load time

from ultralytics import YOLO

model = YOLO("runs/detect/train/weights/best.pt")
metrics = model.val()

Link to this section通过自定义指标保存最佳模型#

trainer 会根据 fitness 保存 best.pt,对于检测任务,默认值为 mAP@0.5:0.95(对应的权重为 [0.0, 0.0, 0.0, 1.0],分别对应 [P, R, mAP@0.5, mAP@0.5:0.95])。要使用不同的指标(例如 mAP@0.5 或召回率),请重写 validate() 并返回你选择的指标作为 fitness 值。内置的 save_model() 将自动使用它:

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer

class CustomSaveTrainer(DetectionTrainer):
    """Trainer that saves the best model based on mAP@0.5 instead of default fitness."""

    def validate(self):
        """Override fitness to use mAP@0.5 for best model selection."""
        metrics, fitness = super().validate()
        if metrics:
            fitness = metrics.get("metrics/mAP50(B)", fitness)
            if self.best_fitness is None or fitness > self.best_fitness:
                self.best_fitness = fitness
        return metrics, fitness

model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=20, trainer=CustomSaveTrainer)
可用指标

验证后 self.metrics 中提供的常见指标包括:

描述
metrics/precision(B)精确率
metrics/recall(B)召回率
metrics/mAP50(B)IoU 为 0.5 时的 mAP
metrics/mAP50-95(B)IoU 为 0.5:0.95 时的 mAP

Link to this section冻结和解冻主干网络#

迁移学习 工作流程通常受益于在前 N 个 epoch 冻结预训练主干网络,从而允许检测头在对整个网络进行 微调 之前进行适应。Ultralytics 提供了一个 freeze 参数用于在训练开始时冻结层,你可以使用 回调 在 N 个 epoch 后解冻它们:

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils import LOGGER

FREEZE_EPOCHS = 5

def unfreeze_backbone(trainer):
    """Callback to unfreeze all layers after FREEZE_EPOCHS."""
    if trainer.epoch == FREEZE_EPOCHS:
        LOGGER.info(f"Epoch {trainer.epoch}: Unfreezing all layers for fine-tuning")
        for name, param in trainer.model.named_parameters():
            if not param.requires_grad:
                param.requires_grad = True
                LOGGER.info(f"  Unfroze: {name}")
        trainer.freeze_layer_names = [".dfl"]

class FreezingTrainer(DetectionTrainer):
    """Trainer with backbone freezing for first N epochs."""

    def __init__(self, *args, **kwargs):
        """Initialize and register the unfreeze callback."""
        super().__init__(*args, **kwargs)
        self.add_callback("on_train_epoch_start", unfreeze_backbone)

model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=20, freeze=10, trainer=FreezingTrainer)

freeze=10 参数在训练开始时冻结前 10 层(主干网络)。on_train_epoch_start 回调在每个 epoch 开始时触发,并在冻结期结束后解冻所有参数。

选择要冻结的内容
  • freeze=10 冻结前 10 层(通常是 YOLO 架构中的主干网络)
  • freeze=[0, 1, 2, 3] 按索引冻结特定层
  • 更高的 FREEZE_EPOCHS 值让检测头在主干网络更改之前有更多时间进行适应

Link to this section各层学习率#

网络的不同部分可以从不同的 学习率 中受益。一种常见的策略是为预训练主干网络使用较低的学习率以保留学到的特征,同时允许检测头以更高的速率更快地适应:

import torch

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils import LOGGER
from ultralytics.utils.torch_utils import unwrap_model

class PerLayerLRTrainer(DetectionTrainer):
    """Trainer with different learning rates for backbone and head."""

    def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
        """Build optimizer with separate learning rates for backbone and head."""
        backbone_params = []
        head_params = []

        for k, v in unwrap_model(model).named_parameters():
            if not v.requires_grad:
                continue
            is_backbone = any(k.startswith(f"model.{i}.") for i in range(10))
            if is_backbone:
                backbone_params.append(v)
            else:
                head_params.append(v)

        backbone_lr = lr * 0.1

        optimizer = torch.optim.AdamW(
            [
                {"params": backbone_params, "lr": backbone_lr, "weight_decay": decay},
                {"params": head_params, "lr": lr, "weight_decay": decay},
            ],
        )

        LOGGER.info(
            f"PerLayerLR optimizer: backbone ({len(backbone_params)} params, lr={backbone_lr}) "
            f"| head ({len(head_params)} params, lr={lr})"
        )
        return optimizer

model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=20, trainer=PerLayerLRTrainer)

Link to this sectionRT-DETR 变体#

对于 RT-DETR,模式相同但有两点改进。主干网络长度从 model.yaml["backbone"] 读取,因此同一个训练器无需硬编码层数即可跨 RT-DETR 变体(RT-DETR-L、RT-DETR-X、ResNet-50/101 主干网络)工作。参数在每个部分内也被拆分为权重、BatchNorm 和偏置组,因此权重衰减会从 BatchNorm 参数和偏置中排除,这与默认训练器的策略一致。这对于 RT-DETR 微调特别有用,因为解码器头通常是随机初始化的,而主干网络承载了受益于较低学习率的预训练特征:

import torch
from torch import nn

from ultralytics import RTDETR
from ultralytics.models.rtdetr.train import RTDETRTrainer
from ultralytics.utils import LOGGER, colorstr
from ultralytics.utils.torch_utils import unwrap_model

class RTDETRBackboneLRTrainer(RTDETRTrainer):
    """RT-DETR trainer with a lower learning rate for backbone parameters."""

    backbone_lr_ratio = 0.1  # backbone learning rate as a fraction of head learning rate

    def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
        """Build an AdamW optimizer with six param groups: head and backbone x {weight, bn, bias}."""
        # Resolve optimizer name; "auto" maps to AdamW with RT-DETR-style defaults
        canonical = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "auto"}
        name = {x.lower(): x for x in canonical}.get(name.lower(), name)
        if name == "auto":
            name, lr, momentum = "AdamW", 1e-4, 0.9
        self.args.warmup_bias_lr = 0.0  # RT-DETR warms biases from 0, unlike YOLO's 0.1
        if name not in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
            raise NotImplementedError(f"This trainer only supports AdamW-family optimizers; got {name}")

        # Identify backbone parameters from model.yaml and route each param into a (section, kind) group
        unwrapped = unwrap_model(model)
        backbone_len = len(unwrapped.yaml["backbone"])
        norm_types = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)
        groups = {f"{s}_{k}": [] for s in ("head", "backbone") for k in ("weight", "bn", "bias")}

        for module_name, module in unwrapped.named_modules():
            for param_name, param in module.named_parameters(recurse=False):
                if not param.requires_grad:
                    continue
                fullname = f"{module_name}.{param_name}" if module_name else param_name
                parts = fullname.split(".")
                section = (
                    "backbone"
                    if len(parts) > 1 and parts[0] == "model" and parts[1].isdigit() and int(parts[1]) < backbone_len
                    else "head"
                )
                if "bias" in param_name:
                    kind = "bias"
                elif isinstance(module, norm_types) or "logit_scale" in fullname:
                    kind = "bn"
                else:
                    kind = "weight"
                groups[f"{section}_{kind}"].append(param)

        # Build the optimizer with per-group lr and weight decay; backbone groups use lr * backbone_lr_ratio
        backbone_lr = lr * self.backbone_lr_ratio
        param_groups = [
            {"params": groups["head_weight"], "lr": lr, "weight_decay": decay, "param_group": "weight"},
            {"params": groups["head_bn"], "lr": lr, "weight_decay": 0.0, "param_group": "bn"},
            {"params": groups["head_bias"], "lr": lr, "weight_decay": 0.0, "param_group": "bias"},
            {"params": groups["backbone_weight"], "lr": backbone_lr, "weight_decay": decay, "param_group": "weight"},
            {"params": groups["backbone_bn"], "lr": backbone_lr, "weight_decay": 0.0, "param_group": "bn"},
            {"params": groups["backbone_bias"], "lr": backbone_lr, "weight_decay": 0.0, "param_group": "bias"},
        ]
        param_groups = [pg for pg in param_groups if pg["params"]]  # drop empty groups
        optimizer = getattr(torch.optim, name)(param_groups, betas=(momentum, 0.999))

        LOGGER.info(
            f"{colorstr('optimizer:')} {name}(lr={lr}, backbone_lr={backbone_lr}) with parameter groups\n"
            f"  Head:     {len(groups['head_bn'])} bn, {len(groups['head_weight'])} weight(decay={decay}), "
            f"{len(groups['head_bias'])} bias (lr={lr})\n"
            f"  Backbone: {len(groups['backbone_bn'])} bn, {len(groups['backbone_weight'])} weight(decay={decay}), "
            f"{len(groups['backbone_bias'])} bias (lr={backbone_lr})"
        )
        return optimizer

model = RTDETR("rtdetr-l.pt")
model.train(data="coco8.yaml", epochs=20, trainer=RTDETRBackboneLRTrainer)
选择 `backbone_lr_ratio`

常见的起始点是 backbone_lr_ratio = 0.1,这与原始 RT-DETR 及其 HGNetV2 主干网络的设置一致。文献建议根据主干网络大小和预训练数据规模对比例进行反向缩放:在超大规模数据集上预训练的大型主干网络(例如,使用 DINO、CLIP 或 MAE 在数亿张图像上训练的 ViT-L/H)通常使用 0.01 或更小的比例以保留学到的特征,而具有较轻预训练的小型主干网络则能容忍 0.5 或更高的比例。

学习率调度器

内置的学习率调度器(cosinelinear)仍然作用于各组基础学习率之上。主干网络和检测头的学习率将遵循相同的衰减调度,在整个训练过程中保持它们之间的比例。

组合技术

可以通过覆盖多个方法并根据需要添加回调,将这些自定义组合到单个训练器类中。

Link to this section多 GPU 训练的同步 BatchNorm#

当在多 GPU 上使用 DistributedDataParallel 进行训练时,默认的 BatchNorm2d 层会在每个 GPU 上独立计算统计信息。对于 RT-DETR 微调及其他使用较小单 GPU 批次大小的方案,单 GPU 的批次统计信息可能会有噪声。PyTorch 的 SyncBatchNorm 会跨所有进程同步均值和方差以获得单个全局批次统计信息,这通常以少量的 GPU 间通信开销为代价改善收敛。

转换必须在模型置于 GPU 之后、DDP 封装之前进行。最干净的钩子是 set_model_attributes()BaseTrainer 正是在该窗口期调用它:

from torch import nn

from ultralytics import RTDETR
from ultralytics.models.rtdetr.train import RTDETRTrainer

class SyncBNTrainer(RTDETRTrainer):
    """RT-DETR trainer that converts BatchNorm to SyncBatchNorm for multi-GPU training."""

    def set_model_attributes(self):
        """Run the parent setup, then convert BN to SyncBatchNorm when training on multiple GPUs."""
        super().set_model_attributes()
        if self.world_size > 1:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)

model = RTDETR("rtdetr-l.pt")
model.train(data="coco8.yaml", epochs=20, device=[0, 1], trainer=SyncBNTrainer)

world_size > 1 的保护确保了训练器在单 GPU 运行中也是安全的;在单 GPU 上,转换会被跳过,训练继续使用普通的 BatchNorm2d。相同的模式适用于通过将父类切换为 DetectionTrainer 来实现 YOLO 的功能。

何时使用 SyncBatchNorm
场景建议
多 GPU 训练,小单 GPU 批次 (≤ 16)启用
多 GPU 训练,大单 GPU 批次 (≥ 32)可选;收益较小
单 GPU 训练不适用(已跳过)

Link to this section可配置的梯度裁剪#

The default trainer clips gradients to max_norm=10.0 in optimizer_step(), a loose value tuned for YOLO models where gradients rarely exceed it. DETR-family detectors (RT-DETR, DEIM, DINO) typically use much tighter values such as 0.1 to stabilize the decoder's cross-attention layers, where gradient magnitudes can spike. To override the clip value, subclass the trainer and override optimizer_step():

import torch

from ultralytics import RTDETR
from ultralytics.models.rtdetr.train import RTDETRTrainer

class CustomClipTrainer(RTDETRTrainer):
    """RT-DETR trainer with configurable gradient clipping."""

    clip_grad_norm = 0.1  # max gradient norm; set to 0 to disable clipping

    def optimizer_step(self):
        """Run an optimizer step with a configurable gradient-norm clip."""
        self.scaler.unscale_(self.optimizer)
        if self.clip_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm)
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()
        if self.ema:
            self.ema.update(self.model)

model = RTDETR("rtdetr-l.pt")
model.train(data="coco8.yaml", epochs=20, trainer=CustomClipTrainer)

同一个训练器可以通过将父类切换为 DetectionTrainer (from ultralytics.models.yolo.detect import DetectionTrainer) 并使用 YOLO("yolo26n.pt") 加载 YOLO 检查点来应用于 YOLO。optimizer_step 主体保持不变。

典型的 `clip_grad_norm` 值
架构系列典型的 max_norm
RT-DETR / DEIM / DETR 系列0.1
YOLO (Ultralytics 默认)10.0
禁用裁剪0

Link to this section常见问题解答#

Link to this section如何将自定义训练器传递给 YOLO?#

Pass your custom trainer class (not an instance) to the trainer parameter in model.train():

from ultralytics import YOLO

model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", trainer=MyCustomTrainer)

YOLO 类在内部处理训练器实例化。有关训练器架构的更多详细信息,请参阅 高级自定义 页面。

Link to this section我可以覆盖哪些 BaseTrainer 方法?#

可用于自定义的关键方法:

方法用途
validate()运行验证并返回指标
build_optimizer()构建优化器
save_model()保存训练检查点
get_model()返回模型实例
get_validator()返回验证器实例
get_dataloader()构建数据加载器
preprocess_batch()预处理输入批次
label_loss_items()格式化损失项以进行记录

有关完整的 API 参考,请参阅 BaseTrainer 文档

Link to this section我可以使用回调而不是子类化训练器吗?#

可以,对于更简单的自定义需求,回调通常就足够了。可用的回调事件包括 on_train_starton_train_epoch_starton_train_epoch_endon_fit_epoch_endon_model_save。这些允许你在不进行子类化的情况下挂钩到训练循环中。上面的骨干网冻结示例演示了这种方法。

Link to this section如何在不子类化模型的情况下自定义损失函数?#

如果你的修改较简单(例如调整损失增益),你可以直接修改超参数

model.train(data="coco8.yaml", box=10.0, cls=1.5, dfl=2.0)

对于损失的结构性更改(例如添加类别权重),你需要按照类别权重部分所示对损失和模型进行子类化。

评论