Meet YOLO26: next-gen vision AI.

Link to this sectionReference for ultralytics/nn/distill_model.py#

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/distill_model.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


Summary

Link to this sectionClass ultralytics.nn.distill_model.FeatureHook#

FeatureHook(self, feat_dict: dict, idx: int) -> None

Picklable forward hook that stores layer output into a shared dict.

Args

NameTypeDescriptionDefault
feat_dictdictrequired
idxintrequired

Methods

NameDescription
__call__Store the layer's forward output into the shared feature dict under its index.
Source code in ultralytics/nn/distill_model.py

View on GitHub

class FeatureHook:
    """Picklable forward hook that stores layer output into a shared dict."""

    def __init__(self, feat_dict: dict, idx: int) -> None:
        """Initialize the hook with the shared feature dict and the layer index to store outputs under."""
        self.feat_dict = feat_dict
        self.idx = idx

Link to this sectionMethod ultralytics.nn.distill_model.FeatureHook.__call__#

def __call__(self, module: nn.Module, inputs: tuple, output) -> None

Store the layer's forward output into the shared feature dict under its index.

The output is a tensor for neck layers but a tuple/dict for the Detect head, so it is left untyped.

Args

NameTypeDescriptionDefault
modulenn.Modulerequired
inputstuplerequired
outputrequired
Source code in ultralytics/nn/distill_model.py

View on GitHub

def __call__(self, module: nn.Module, inputs: tuple, output) -> None:
    """Store the layer's forward output into the shared feature dict under its index.

    The output is a tensor for neck layers but a tuple/dict for the Detect head, so it is left untyped.
    """
    self.feat_dict[self.idx] = output





Link to this sectionClass ultralytics.nn.distill_model.DistillationModel#

DistillationModel(self, teacher_model: str | Path | nn.Module, student_model: nn.Module)

Bases: nn.Module

YOLO knowledge distillation model.

This class wraps a teacher-student pair for knowledge distillation training. Features are extracted from both models via forward hooks for distillation.

Args

NameTypeDescriptionDefault
teacher_model`strPathnn.Module`
student_modelnn.ModuleStudent model module to be trained.required

Attributes

NameTypeDescription
teacher_modelnn.ModuleFrozen teacher model providing features.
student_modelnn.ModuleTrainable student model being distilled.
feats_idxlistLayer indices for feature extraction.
projectornn.ModuleListMLP projector aligning student features to teacher dimensions.
disfloatDistillation loss weight factor.

Methods

NameDescription
criterionGet the criterion from the student model.
end2endExpose student end-to-end mode for validator/predictor control.
__getstate__Return a copy of state for pickling without captured features or hook handles.
__setstate__Clear stale features and hooks, and re-register forward hooks after unpickling.
_clear_feature_hooksRemove any FeatureHook instances from a module's forward hooks.
_freeze_teacherKeep teacher fixed for distillation.
_register_feature_hooksRegister feature-capture hooks, removing stale FeatureHook instances first.
_remove_feature_hooksRemove any previously registered feature-capture hooks.
criterionSet value for student criterion.
decouple_outputsDecouple outputs for teacher/student models.
end2endForward end-to-end mode update to the student model.
forwardForward pass through the student model.
get_distill_layersAuto-detect distillation feature layers from the model's Detect head.
init_criterionInitialize the loss criterion via the student model.
lossCompute loss.
loss_sl2Compute score-weighted L2 distillation loss for a feature pair.
set_head_attrForward head-attribute updates (e.g. max_det, agnostic_nms, end2end) to the student model.
trainSet model train mode while keeping teacher frozen in eval mode.

Examples

Train a student model with knowledge distillation from a larger teacher (the trainer builds the
DistillationModel internally when the ``distill_model`` argument is set)
>>> from ultralytics import YOLO
>>> model = YOLO("yolo26n.pt")
>>> model.train(data="coco8.yaml", distill_model="yolo26s.pt")
Source code in ultralytics/nn/distill_model.py

View on GitHub

class DistillationModel(nn.Module):
    """YOLO knowledge distillation model.

    This class wraps a teacher-student pair for knowledge distillation training. Features are extracted from both models
    via forward hooks for distillation.

    Attributes:
        teacher_model (nn.Module): Frozen teacher model providing features.
        student_model (nn.Module): Trainable student model being distilled.
        feats_idx (list): Layer indices for feature extraction.
        projector (nn.ModuleList): MLP projector aligning student features to teacher dimensions.
        dis (float): Distillation loss weight factor.

    Methods:
        get_distill_layers: Auto-detect distillation feature layers from the Detect head.
        forward: Run the student model, or compute the combined loss when given a training batch.
        loss: Compute combined detection and distillation loss.
        loss_sl2: Compute score-weighted L2 distillation loss for a feature pair.
        decouple_outputs: Normalize teacher/student head outputs across train/val formats.
        train: Set training mode while keeping teacher frozen.

    Examples:
        Train a student model with knowledge distillation from a larger teacher (the trainer builds the
        DistillationModel internally when the ``distill_model`` argument is set)
        >>> from ultralytics import YOLO
        >>> model = YOLO("yolo26n.pt")
        >>> model.train(data="coco8.yaml", distill_model="yolo26s.pt")
    """

    def __init__(self, teacher_model: str | Path | nn.Module, student_model: nn.Module):
        """Initialize the distillation model with teacher, student, and feature extraction hooks.

        Args:
            teacher_model (str | Path | nn.Module): Teacher model checkpoint path or module.
            student_model (nn.Module): Student model module to be trained.
        """
        super().__init__()
        if isinstance(teacher_model, (str, Path)):
            teacher_model = load_checkpoint(teacher_model)[0]
        device = next(student_model.parameters()).device
        self.teacher_model = teacher_model.to(device)
        self._freeze_teacher()
        self.student_model = student_model
        self.feats_idx = self.get_distill_layers(student_model)

        # Hook-based feature capture: identical for teacher and student
        self._teacher_feats: dict[int, torch.Tensor] = {}
        self._student_feats: dict[int, torch.Tensor] = {}
        self._teacher_hooks: list = []
        self._student_hooks: list = []
        self._register_feature_hooks()

        # Get feature dimensions via dummy forward pass (hooks capture outputs)
        imgsz = student_model.args.imgsz
        student_model.eval()
        with torch.no_grad():
            teacher_model(torch.zeros(2, 3, imgsz, imgsz).to(device))
            student_model(torch.zeros(2, 3, imgsz, imgsz).to(device))
        student_model.train()
        teacher_output = [self._teacher_feats[idx] for idx in self.feats_idx]
        student_output = [self._student_feats[idx] for idx in self.feats_idx]

        copy_attr(self, student_model)
        self.dis = self.student_model.args.dis
        projectors = []
        for student_out, teacher_out in zip(student_output[:-1], teacher_output[:-1]):
            student_dim = self.decouple_outputs(student_out).shape[1]
            teacher_dim = self.decouple_outputs(teacher_out).shape[1]
            projectors.append(
                nn.Sequential(
                    nn.Conv2d(student_dim, teacher_dim, kernel_size=1, stride=1, padding=0),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(teacher_dim, teacher_dim, kernel_size=1, stride=1, padding=0),
                )
            )
        self.projector = nn.ModuleList(projectors).to(device)

Link to this sectionProperty ultralytics.nn.distill_model.DistillationModel.criterion#

def criterion(self)

Get the criterion from the student model.

Source code in ultralytics/nn/distill_model.py

View on GitHub

@property
def criterion(self):
    """Get the criterion from the student model."""
    return self.student_model.criterion

Link to this sectionProperty ultralytics.nn.distill_model.DistillationModel.end2end#

def end2end(self)

Expose student end-to-end mode for validator/predictor control.

Source code in ultralytics/nn/distill_model.py

View on GitHub

@property
def end2end(self):
    """Expose student end-to-end mode for validator/predictor control."""
    return getattr(self.student_model, "end2end", False)

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.__getstate__#

def __getstate__(self)

Return a copy of state for pickling without captured features or hook handles.

Clears the feature dicts in place (rather than replacing the attributes) because the registered FeatureHooks share these exact dict objects; otherwise a deepcopy/pickle of a mid-training model would still reach the hook-held tensors (which carry grad_fn and cannot be deep-copied).

Source code in ultralytics/nn/distill_model.py

View on GitHub

def __getstate__(self):
    """Return a copy of state for pickling without captured features or hook handles.

    Clears the feature dicts in place (rather than replacing the attributes) because the registered
    FeatureHooks share these exact dict objects; otherwise a deepcopy/pickle of a mid-training model would
    still reach the hook-held tensors (which carry grad_fn and cannot be deep-copied).
    """
    self._teacher_feats.clear()
    self._student_feats.clear()
    state = self.__dict__.copy()
    state["_teacher_hooks"] = []
    state["_student_hooks"] = []
    return state

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.__setstate__#

def __setstate__(self, state)

Clear stale features and hooks, and re-register forward hooks after unpickling.

Args

NameTypeDescriptionDefault
staterequired
Source code in ultralytics/nn/distill_model.py

View on GitHub

def __setstate__(self, state):
    """Clear stale features and hooks, and re-register forward hooks after unpickling."""
    self.__dict__.update(state)
    self._teacher_feats = {}
    self._student_feats = {}
    self._register_feature_hooks()

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel._clear_feature_hooks#

def _clear_feature_hooks(module: nn.Module) -> None

Remove any FeatureHook instances from a module's forward hooks.

Args

NameTypeDescriptionDefault
modulenn.Modulerequired
Source code in ultralytics/nn/distill_model.py

View on GitHub

@staticmethod
def _clear_feature_hooks(module: nn.Module) -> None:
    """Remove any FeatureHook instances from a module's forward hooks."""
    for handle_id, hook in list(module._forward_hooks.items()):
        if isinstance(hook, FeatureHook):
            del module._forward_hooks[handle_id]

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel._freeze_teacher#

def _freeze_teacher(self)

Keep teacher fixed for distillation.

Source code in ultralytics/nn/distill_model.py

View on GitHub

def _freeze_teacher(self):
    """Keep teacher fixed for distillation."""
    if self.teacher_model is None:
        return
    self.teacher_model.eval()
    for v in self.teacher_model.parameters():
        if v.requires_grad:
            v.requires_grad = False

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel._register_feature_hooks#

def _register_feature_hooks(self) -> None

Register feature-capture hooks, removing stale FeatureHook instances first.

Source code in ultralytics/nn/distill_model.py

View on GitHub

def _register_feature_hooks(self) -> None:
    """Register feature-capture hooks, removing stale FeatureHook instances first."""
    self._remove_feature_hooks()
    for idx in self.feats_idx:
        self._clear_feature_hooks(self.student_model.model[idx])
        self._student_hooks.append(
            self.student_model.model[idx].register_forward_hook(FeatureHook(self._student_feats, idx))
        )
        if self.teacher_model is not None:
            self._clear_feature_hooks(self.teacher_model.model[idx])
            self._teacher_hooks.append(
                self.teacher_model.model[idx].register_forward_hook(FeatureHook(self._teacher_feats, idx))
            )

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel._remove_feature_hooks#

def _remove_feature_hooks(self) -> None

Remove any previously registered feature-capture hooks.

Source code in ultralytics/nn/distill_model.py

View on GitHub

def _remove_feature_hooks(self) -> None:
    """Remove any previously registered feature-capture hooks."""
    for handle in self._student_hooks:
        handle.remove()
    self._student_hooks.clear()
    if self.teacher_model is not None:
        for handle in self._teacher_hooks:
            handle.remove()
        self._teacher_hooks.clear()

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.criterion#

def criterion(self, value) -> None

Set value for student criterion.

Args

NameTypeDescriptionDefault
valuerequired
Source code in ultralytics/nn/distill_model.py

View on GitHub

@criterion.setter
def criterion(self, value) -> None:
    """Set value for student criterion."""
    self.student_model.criterion = value

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.decouple_outputs#

def decouple_outputs(self, preds, branch: str = "one2one")

Decouple outputs for teacher/student models.

This method handles different output formats from YOLO models, including tuple outputs (train/val mode), dict outputs with branches (one2one/one2many), and direct tensor outputs.

Args

NameTypeDescriptionDefault
preds`torch.Tensortupledict`
branchstrWhich branch to extract from dict outputs ("one2one" or "one2many")."one2one"

Returns

TypeDescription
`torch.Tensordict`
Source code in ultralytics/nn/distill_model.py

View on GitHub

def decouple_outputs(self, preds, branch: str = "one2one"):
    """Decouple outputs for teacher/student models.

    This method handles different output formats from YOLO models, including
    tuple outputs (train/val mode), dict outputs with branches (one2one/one2many),
    and direct tensor outputs.

    Args:
        preds (torch.Tensor | tuple | dict): Model predictions in various formats.
        branch (str): Which branch to extract from dict outputs ("one2one" or "one2many").

    Returns:
        (torch.Tensor | dict): The decoupled predictions.
    """
    if isinstance(preds, tuple):  # decouple for val mode
        preds = preds[1]
    if isinstance(preds, dict):
        if branch in preds:
            preds = preds[branch]
    return preds

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.end2end#

def end2end(self, value)

Forward end-to-end mode update to the student model.

Args

NameTypeDescriptionDefault
valuerequired
Source code in ultralytics/nn/distill_model.py

View on GitHub

@end2end.setter
def end2end(self, value):
    """Forward end-to-end mode update to the student model."""
    self.student_model.end2end = value

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.forward#

def forward(self, x, *args, **kwargs)

Forward pass through the student model.

Args

NameTypeDescriptionDefault
xrequired
*argsrequired
**kwargsrequired
Source code in ultralytics/nn/distill_model.py

View on GitHub

def forward(self, x, *args, **kwargs):
    """Forward pass through the student model."""
    if isinstance(x, dict):  # for cases of training and validating while training.
        return self.loss(x, *args, **kwargs)
    return self.student_model.predict(x, *args, **kwargs)

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.get_distill_layers#

def get_distill_layers(model: nn.Module) -> list[int]

Auto-detect distillation feature layers from the model's Detect head.

Returns the Detect head's input layer indices plus the head layer index itself. E.g. YOLO26 -> [16, 19, 22, 23], YOLOv8 -> [15, 18, 21, 22].

Args

NameTypeDescriptionDefault
modelnn.Modulerequired
Source code in ultralytics/nn/distill_model.py

View on GitHub

@staticmethod
def get_distill_layers(model: nn.Module) -> list[int]:
    """Auto-detect distillation feature layers from the model's Detect head.

    Returns the Detect head's input layer indices plus the head layer index itself.
    E.g. YOLO26 -> [16, 19, 22, 23], YOLOv8 -> [15, 18, 21, 22].
    """
    for m in model.model:
        if isinstance(m, Detect):
            return [*list(m.f), m.i]
    raise ValueError("No Detect head found in model")

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.init_criterion#

def init_criterion(self)

Initialize the loss criterion via the student model.

Source code in ultralytics/nn/distill_model.py

View on GitHub

def init_criterion(self):
    """Initialize the loss criterion via the student model."""
    return self.student_model.init_criterion()

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.loss#

def loss(self, batch, preds = None)

Compute loss.

Args

NameTypeDescriptionDefault
batchdictBatch to compute loss on.required
preds`torch.Tensorlist[torch.Tensor], optional`Predictions.
Source code in ultralytics/nn/distill_model.py

View on GitHub

def loss(self, batch, preds=None):
    """Compute loss.

    Args:
        batch (dict): Batch to compute loss on.
        preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
    """
    loss_distill = torch.zeros(1, device=batch["img"].device)
    if not self.training:  # for loss calculation during validation while training
        if preds is None:
            preds = self.student_model(batch["img"])
        regular_loss, regular_loss_detach = self.student_model.loss(batch, preds)
        return torch.cat([regular_loss, loss_distill]), torch.cat([regular_loss_detach, loss_distill])

    # Clear feature dicts before forward passes
    self._teacher_feats.clear()
    self._student_feats.clear()

    with torch.no_grad():
        self.teacher_model(batch["img"])  # hooks capture teacher features
    preds = self.student_model(batch["img"])  # hooks capture student features

    regular_loss, regular_loss_detach = self.student_model.loss(batch, preds)
    teacher_head_feat = self._teacher_feats[self.feats_idx[-1]]
    teacher_scores = (
        self.decouple_outputs(teacher_head_feat, branch="one2many")["scores"]
        + self.decouple_outputs(teacher_head_feat, branch="one2one")["scores"]
    ) / 2
    # neck feature sizes vary per batch (e.g. multi_scale), so split scores by the live teacher feats
    neck_feats = [self._teacher_feats[idx] for idx in self.feats_idx[:-1]]
    parts = torch.split(teacher_scores, [f.shape[-2] * f.shape[-1] for f in neck_feats], dim=-1)
    teacher_scores = tuple(p.sigmoid().max(dim=1, keepdim=True).values for p in parts)
    for i, feat_idx in enumerate(self.feats_idx[:-1]):
        teacher_feat = self.decouple_outputs(self._teacher_feats[feat_idx])
        student_feat = self.projector[i](self.decouple_outputs(self._student_feats[feat_idx]))
        loss_distill += (
            self.loss_sl2(student_feat, teacher_feat, feat_idx=i, teacher_scores=teacher_scores) * self.dis
        )

    distill_loss_detach = loss_distill.detach()
    loss_distill = loss_distill * batch["img"].shape[0]
    return torch.cat([regular_loss, loss_distill]), torch.cat([regular_loss_detach, distill_loss_detach])

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.loss_sl2#

def loss_sl2(
    self, student_feat: torch.Tensor, teacher_feat: torch.Tensor, feat_idx: int, teacher_scores: tuple
) -> torch.Tensor

Compute score-weighted L2 distillation loss for a feature pair.

Args

NameTypeDescriptionDefault
student_feattorch.TensorStudent feature tensor of shape (N, C, H, W).required
teacher_feattorch.TensorTeacher feature tensor of shape (N, C, H, W).required
feat_idxintIndex of the feature level for selecting teacher scores.required
teacher_scorestupleTuple of score tensors for each feature level.required

Returns

TypeDescription
torch.TensorThe computed score-weighted L2 loss.
Source code in ultralytics/nn/distill_model.py

View on GitHub

def loss_sl2(
    self, student_feat: torch.Tensor, teacher_feat: torch.Tensor, feat_idx: int, teacher_scores: tuple
) -> torch.Tensor:
    """Compute score-weighted L2 distillation loss for a feature pair.

    Args:
        student_feat (torch.Tensor): Student feature tensor of shape (N, C, H, W).
        teacher_feat (torch.Tensor): Teacher feature tensor of shape (N, C, H, W).
        feat_idx (int): Index of the feature level for selecting teacher scores.
        teacher_scores (tuple): Tuple of score tensors for each feature level.

    Returns:
        (torch.Tensor): The computed score-weighted L2 loss.
    """
    teacher_score = teacher_scores[feat_idx]
    n, c = student_feat.shape[:2]
    student_feat = student_feat.view(n, c, -1)
    teacher_feat = teacher_feat.view(n, c, -1)
    mse = F.mse_loss(student_feat, teacher_feat, reduction="none")
    weighted_mse = (mse * teacher_score).sum() / (teacher_score.sum() * c + 1e-9)
    return weighted_mse

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.set_head_attr#

def set_head_attr(self, **kwargs)

Forward head-attribute updates (e.g. max_det, agnostic_nms, end2end) to the student model.

Args

NameTypeDescriptionDefault
**kwargsrequired
Source code in ultralytics/nn/distill_model.py

View on GitHub

def set_head_attr(self, **kwargs):
    """Forward head-attribute updates (e.g. max_det, agnostic_nms, end2end) to the student model."""
    self.student_model.set_head_attr(**kwargs)

Link to this sectionMethod ultralytics.nn.distill_model.DistillationModel.train#

def train(self, mode: bool = True)

Set model train mode while keeping teacher frozen in eval mode.

Args

NameTypeDescriptionDefault
modeboolTrue
Source code in ultralytics/nn/distill_model.py

View on GitHub

def train(self, mode: bool = True):
    """Set model train mode while keeping teacher frozen in eval mode."""
    super().train(mode)
    self._freeze_teacher()
    return self



Contributors