Link to this sectionReference for ultralytics/nn/distill_model.py#
This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/distill_model.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏
Link to this section ultralytics.nn.distill_model.FeatureHook#
FeatureHook(self, feat_dict: dict, idx: int) -> NonePicklable forward hook that stores layer output into a shared dict.
Args
| Name | Type | Description | Default |
|---|---|---|---|
feat_dict | dict | required | |
idx | int | required |
Methods
| Name | Description |
|---|---|
__call__ | Store the layer's forward output into the shared feature dict under its index. |
Source code in ultralytics/nn/distill_model.py
class FeatureHook:
"""Picklable forward hook that stores layer output into a shared dict."""
def __init__(self, feat_dict: dict, idx: int) -> None:
"""Initialize the hook with the shared feature dict and the layer index to store outputs under."""
self.feat_dict = feat_dict
self.idx = idxLink to this section ultralytics.nn.distill_model.FeatureHook.__call__#
def __call__(self, module: nn.Module, inputs: tuple, output) -> NoneStore the layer's forward output into the shared feature dict under its index.
The output is a tensor for neck layers but a tuple/dict for the Detect head, so it is left untyped.
Args
| Name | Type | Description | Default |
|---|---|---|---|
module | nn.Module | required | |
inputs | tuple | required | |
output | required |
Source code in ultralytics/nn/distill_model.py
def __call__(self, module: nn.Module, inputs: tuple, output) -> None:
"""Store the layer's forward output into the shared feature dict under its index.
The output is a tensor for neck layers but a tuple/dict for the Detect head, so it is left untyped.
"""
self.feat_dict[self.idx] = outputLink to this section ultralytics.nn.distill_model.DistillationModel#
DistillationModel(self, teacher_model: str | Path | nn.Module, student_model: nn.Module)Bases: nn.Module
YOLO knowledge distillation model.
This class wraps a teacher-student pair for knowledge distillation training. Features are extracted from both models via forward hooks for distillation.
Args
| Name | Type | Description | Default |
|---|---|---|---|
teacher_model | `str | Path | nn.Module` |
student_model | nn.Module | Student model module to be trained. | required |
Attributes
| Name | Type | Description |
|---|---|---|
teacher_model | nn.Module | Frozen teacher model providing features. |
student_model | nn.Module | Trainable student model being distilled. |
feats_idx | list | Layer indices for feature extraction. |
projector | nn.ModuleList | MLP projector aligning student features to teacher dimensions. |
dis | float | Distillation loss weight factor. |
Methods
| Name | Description |
|---|---|
criterion | Get the criterion from the student model. |
end2end | Expose student end-to-end mode for validator/predictor control. |
__getstate__ | Return a copy of state for pickling without captured features or hook handles. |
__setstate__ | Clear stale features and hooks, and re-register forward hooks after unpickling. |
_clear_feature_hooks | Remove any FeatureHook instances from a module's forward hooks. |
_freeze_teacher | Keep teacher fixed for distillation. |
_register_feature_hooks | Register feature-capture hooks, removing stale FeatureHook instances first. |
_remove_feature_hooks | Remove any previously registered feature-capture hooks. |
criterion | Set value for student criterion. |
decouple_outputs | Decouple outputs for teacher/student models. |
end2end | Forward end-to-end mode update to the student model. |
forward | Forward pass through the student model. |
get_distill_layers | Auto-detect distillation feature layers from the model's Detect head. |
init_criterion | Initialize the loss criterion via the student model. |
loss | Compute loss. |
loss_sl2 | Compute score-weighted L2 distillation loss for a feature pair. |
set_head_attr | Forward head-attribute updates (e.g. max_det, agnostic_nms, end2end) to the student model. |
train | Set model train mode while keeping teacher frozen in eval mode. |
Examples
Train a student model with knowledge distillation from a larger teacher (the trainer builds the
DistillationModel internally when the ``distill_model`` argument is set)
>>> from ultralytics import YOLO
>>> model = YOLO("yolo26n.pt")
>>> model.train(data="coco8.yaml", distill_model="yolo26s.pt")Source code in ultralytics/nn/distill_model.py
class DistillationModel(nn.Module):
"""YOLO knowledge distillation model.
This class wraps a teacher-student pair for knowledge distillation training. Features are extracted from both models
via forward hooks for distillation.
Attributes:
teacher_model (nn.Module): Frozen teacher model providing features.
student_model (nn.Module): Trainable student model being distilled.
feats_idx (list): Layer indices for feature extraction.
projector (nn.ModuleList): MLP projector aligning student features to teacher dimensions.
dis (float): Distillation loss weight factor.
Methods:
get_distill_layers: Auto-detect distillation feature layers from the Detect head.
forward: Run the student model, or compute the combined loss when given a training batch.
loss: Compute combined detection and distillation loss.
loss_sl2: Compute score-weighted L2 distillation loss for a feature pair.
decouple_outputs: Normalize teacher/student head outputs across train/val formats.
train: Set training mode while keeping teacher frozen.
Examples:
Train a student model with knowledge distillation from a larger teacher (the trainer builds the
DistillationModel internally when the ``distill_model`` argument is set)
>>> from ultralytics import YOLO
>>> model = YOLO("yolo26n.pt")
>>> model.train(data="coco8.yaml", distill_model="yolo26s.pt")
"""
def __init__(self, teacher_model: str | Path | nn.Module, student_model: nn.Module):
"""Initialize the distillation model with teacher, student, and feature extraction hooks.
Args:
teacher_model (str | Path | nn.Module): Teacher model checkpoint path or module.
student_model (nn.Module): Student model module to be trained.
"""
super().__init__()
if isinstance(teacher_model, (str, Path)):
teacher_model = load_checkpoint(teacher_model)[0]
device = next(student_model.parameters()).device
self.teacher_model = teacher_model.to(device)
self._freeze_teacher()
self.student_model = student_model
self.feats_idx = self.get_distill_layers(student_model)
# Hook-based feature capture: identical for teacher and student
self._teacher_feats: dict[int, torch.Tensor] = {}
self._student_feats: dict[int, torch.Tensor] = {}
self._teacher_hooks: list = []
self._student_hooks: list = []
self._register_feature_hooks()
# Get feature dimensions via dummy forward pass (hooks capture outputs)
imgsz = student_model.args.imgsz
student_model.eval()
with torch.no_grad():
teacher_model(torch.zeros(2, 3, imgsz, imgsz).to(device))
student_model(torch.zeros(2, 3, imgsz, imgsz).to(device))
student_model.train()
teacher_output = [self._teacher_feats[idx] for idx in self.feats_idx]
student_output = [self._student_feats[idx] for idx in self.feats_idx]
copy_attr(self, student_model)
self.dis = self.student_model.args.dis
projectors = []
for student_out, teacher_out in zip(student_output[:-1], teacher_output[:-1]):
student_dim = self.decouple_outputs(student_out).shape[1]
teacher_dim = self.decouple_outputs(teacher_out).shape[1]
projectors.append(
nn.Sequential(
nn.Conv2d(student_dim, teacher_dim, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(teacher_dim, teacher_dim, kernel_size=1, stride=1, padding=0),
)
)
self.projector = nn.ModuleList(projectors).to(device)Link to this section ultralytics.nn.distill_model.DistillationModel.criterion#
def criterion(self)Get the criterion from the student model.
Source code in ultralytics/nn/distill_model.py
@property
def criterion(self):
"""Get the criterion from the student model."""
return self.student_model.criterionLink to this section ultralytics.nn.distill_model.DistillationModel.end2end#
def end2end(self)Expose student end-to-end mode for validator/predictor control.
Source code in ultralytics/nn/distill_model.py
@property
def end2end(self):
"""Expose student end-to-end mode for validator/predictor control."""
return getattr(self.student_model, "end2end", False)Link to this section ultralytics.nn.distill_model.DistillationModel.__getstate__#
def __getstate__(self)Return a copy of state for pickling without captured features or hook handles.
Clears the feature dicts in place (rather than replacing the attributes) because the registered FeatureHooks share these exact dict objects; otherwise a deepcopy/pickle of a mid-training model would still reach the hook-held tensors (which carry grad_fn and cannot be deep-copied).
Source code in ultralytics/nn/distill_model.py
def __getstate__(self):
"""Return a copy of state for pickling without captured features or hook handles.
Clears the feature dicts in place (rather than replacing the attributes) because the registered
FeatureHooks share these exact dict objects; otherwise a deepcopy/pickle of a mid-training model would
still reach the hook-held tensors (which carry grad_fn and cannot be deep-copied).
"""
self._teacher_feats.clear()
self._student_feats.clear()
state = self.__dict__.copy()
state["_teacher_hooks"] = []
state["_student_hooks"] = []
return stateLink to this section ultralytics.nn.distill_model.DistillationModel.__setstate__#
def __setstate__(self, state)Clear stale features and hooks, and re-register forward hooks after unpickling.
Args
| Name | Type | Description | Default |
|---|---|---|---|
state | required |
Source code in ultralytics/nn/distill_model.py
def __setstate__(self, state):
"""Clear stale features and hooks, and re-register forward hooks after unpickling."""
self.__dict__.update(state)
self._teacher_feats = {}
self._student_feats = {}
self._register_feature_hooks()Link to this section ultralytics.nn.distill_model.DistillationModel._clear_feature_hooks#
def _clear_feature_hooks(module: nn.Module) -> NoneRemove any FeatureHook instances from a module's forward hooks.
Args
| Name | Type | Description | Default |
|---|---|---|---|
module | nn.Module | required |
Source code in ultralytics/nn/distill_model.py
@staticmethod
def _clear_feature_hooks(module: nn.Module) -> None:
"""Remove any FeatureHook instances from a module's forward hooks."""
for handle_id, hook in list(module._forward_hooks.items()):
if isinstance(hook, FeatureHook):
del module._forward_hooks[handle_id]Link to this section ultralytics.nn.distill_model.DistillationModel._freeze_teacher#
def _freeze_teacher(self)Keep teacher fixed for distillation.
Source code in ultralytics/nn/distill_model.py
def _freeze_teacher(self):
"""Keep teacher fixed for distillation."""
if self.teacher_model is None:
return
self.teacher_model.eval()
for v in self.teacher_model.parameters():
if v.requires_grad:
v.requires_grad = FalseLink to this section ultralytics.nn.distill_model.DistillationModel._register_feature_hooks#
def _register_feature_hooks(self) -> NoneRegister feature-capture hooks, removing stale FeatureHook instances first.
Source code in ultralytics/nn/distill_model.py
def _register_feature_hooks(self) -> None:
"""Register feature-capture hooks, removing stale FeatureHook instances first."""
self._remove_feature_hooks()
for idx in self.feats_idx:
self._clear_feature_hooks(self.student_model.model[idx])
self._student_hooks.append(
self.student_model.model[idx].register_forward_hook(FeatureHook(self._student_feats, idx))
)
if self.teacher_model is not None:
self._clear_feature_hooks(self.teacher_model.model[idx])
self._teacher_hooks.append(
self.teacher_model.model[idx].register_forward_hook(FeatureHook(self._teacher_feats, idx))
)Link to this section ultralytics.nn.distill_model.DistillationModel._remove_feature_hooks#
def _remove_feature_hooks(self) -> NoneRemove any previously registered feature-capture hooks.
Source code in ultralytics/nn/distill_model.py
def _remove_feature_hooks(self) -> None:
"""Remove any previously registered feature-capture hooks."""
for handle in self._student_hooks:
handle.remove()
self._student_hooks.clear()
if self.teacher_model is not None:
for handle in self._teacher_hooks:
handle.remove()
self._teacher_hooks.clear()Link to this section ultralytics.nn.distill_model.DistillationModel.criterion#
def criterion(self, value) -> NoneSet value for student criterion.
Args
| Name | Type | Description | Default |
|---|---|---|---|
value | required |
Source code in ultralytics/nn/distill_model.py
@criterion.setter
def criterion(self, value) -> None:
"""Set value for student criterion."""
self.student_model.criterion = valueLink to this section ultralytics.nn.distill_model.DistillationModel.decouple_outputs#
def decouple_outputs(self, preds, branch: str = "one2one")Decouple outputs for teacher/student models.
This method handles different output formats from YOLO models, including tuple outputs (train/val mode), dict outputs with branches (one2one/one2many), and direct tensor outputs.
Args
| Name | Type | Description | Default |
|---|---|---|---|
preds | `torch.Tensor | tuple | dict` |
branch | str | Which branch to extract from dict outputs ("one2one" or "one2many"). | "one2one" |
Returns
| Type | Description |
|---|---|
| `torch.Tensor | dict` |
Source code in ultralytics/nn/distill_model.py
def decouple_outputs(self, preds, branch: str = "one2one"):
"""Decouple outputs for teacher/student models.
This method handles different output formats from YOLO models, including
tuple outputs (train/val mode), dict outputs with branches (one2one/one2many),
and direct tensor outputs.
Args:
preds (torch.Tensor | tuple | dict): Model predictions in various formats.
branch (str): Which branch to extract from dict outputs ("one2one" or "one2many").
Returns:
(torch.Tensor | dict): The decoupled predictions.
"""
if isinstance(preds, tuple): # decouple for val mode
preds = preds[1]
if isinstance(preds, dict):
if branch in preds:
preds = preds[branch]
return predsLink to this section ultralytics.nn.distill_model.DistillationModel.end2end#
def end2end(self, value)Forward end-to-end mode update to the student model.
Args
| Name | Type | Description | Default |
|---|---|---|---|
value | required |
Source code in ultralytics/nn/distill_model.py
@end2end.setter
def end2end(self, value):
"""Forward end-to-end mode update to the student model."""
self.student_model.end2end = valueLink to this section ultralytics.nn.distill_model.DistillationModel.forward#
def forward(self, x, *args, **kwargs)Forward pass through the student model.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | required | ||
*args | required | ||
**kwargs | required |
Source code in ultralytics/nn/distill_model.py
def forward(self, x, *args, **kwargs):
"""Forward pass through the student model."""
if isinstance(x, dict): # for cases of training and validating while training.
return self.loss(x, *args, **kwargs)
return self.student_model.predict(x, *args, **kwargs)Link to this section ultralytics.nn.distill_model.DistillationModel.get_distill_layers#
def get_distill_layers(model: nn.Module) -> list[int]Auto-detect distillation feature layers from the model's Detect head.
Returns the Detect head's input layer indices plus the head layer index itself. E.g. YOLO26 -> [16, 19, 22, 23], YOLOv8 -> [15, 18, 21, 22].
Args
| Name | Type | Description | Default |
|---|---|---|---|
model | nn.Module | required |
Source code in ultralytics/nn/distill_model.py
@staticmethod
def get_distill_layers(model: nn.Module) -> list[int]:
"""Auto-detect distillation feature layers from the model's Detect head.
Returns the Detect head's input layer indices plus the head layer index itself.
E.g. YOLO26 -> [16, 19, 22, 23], YOLOv8 -> [15, 18, 21, 22].
"""
for m in model.model:
if isinstance(m, Detect):
return [*list(m.f), m.i]
raise ValueError("No Detect head found in model")Link to this section ultralytics.nn.distill_model.DistillationModel.init_criterion#
def init_criterion(self)Initialize the loss criterion via the student model.
Source code in ultralytics/nn/distill_model.py
def init_criterion(self):
"""Initialize the loss criterion via the student model."""
return self.student_model.init_criterion()Link to this section ultralytics.nn.distill_model.DistillationModel.loss#
def loss(self, batch, preds = None)Compute loss.
Args
| Name | Type | Description | Default |
|---|---|---|---|
batch | dict | Batch to compute loss on. | required |
preds | `torch.Tensor | list[torch.Tensor], optional` | Predictions. |
Source code in ultralytics/nn/distill_model.py
def loss(self, batch, preds=None):
"""Compute loss.
Args:
batch (dict): Batch to compute loss on.
preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
"""
loss_distill = torch.zeros(1, device=batch["img"].device)
if not self.training: # for loss calculation during validation while training
if preds is None:
preds = self.student_model(batch["img"])
regular_loss, regular_loss_detach = self.student_model.loss(batch, preds)
return torch.cat([regular_loss, loss_distill]), torch.cat([regular_loss_detach, loss_distill])
# Clear feature dicts before forward passes
self._teacher_feats.clear()
self._student_feats.clear()
with torch.no_grad():
self.teacher_model(batch["img"]) # hooks capture teacher features
preds = self.student_model(batch["img"]) # hooks capture student features
regular_loss, regular_loss_detach = self.student_model.loss(batch, preds)
teacher_head_feat = self._teacher_feats[self.feats_idx[-1]]
teacher_scores = (
self.decouple_outputs(teacher_head_feat, branch="one2many")["scores"]
+ self.decouple_outputs(teacher_head_feat, branch="one2one")["scores"]
) / 2
# neck feature sizes vary per batch (e.g. multi_scale), so split scores by the live teacher feats
neck_feats = [self._teacher_feats[idx] for idx in self.feats_idx[:-1]]
parts = torch.split(teacher_scores, [f.shape[-2] * f.shape[-1] for f in neck_feats], dim=-1)
teacher_scores = tuple(p.sigmoid().max(dim=1, keepdim=True).values for p in parts)
for i, feat_idx in enumerate(self.feats_idx[:-1]):
teacher_feat = self.decouple_outputs(self._teacher_feats[feat_idx])
student_feat = self.projector[i](self.decouple_outputs(self._student_feats[feat_idx]))
loss_distill += (
self.loss_sl2(student_feat, teacher_feat, feat_idx=i, teacher_scores=teacher_scores) * self.dis
)
distill_loss_detach = loss_distill.detach()
loss_distill = loss_distill * batch["img"].shape[0]
return torch.cat([regular_loss, loss_distill]), torch.cat([regular_loss_detach, distill_loss_detach])Link to this section ultralytics.nn.distill_model.DistillationModel.loss_sl2#
def loss_sl2(
self, student_feat: torch.Tensor, teacher_feat: torch.Tensor, feat_idx: int, teacher_scores: tuple
) -> torch.TensorCompute score-weighted L2 distillation loss for a feature pair.
Args
| Name | Type | Description | Default |
|---|---|---|---|
student_feat | torch.Tensor | Student feature tensor of shape (N, C, H, W). | required |
teacher_feat | torch.Tensor | Teacher feature tensor of shape (N, C, H, W). | required |
feat_idx | int | Index of the feature level for selecting teacher scores. | required |
teacher_scores | tuple | Tuple of score tensors for each feature level. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | The computed score-weighted L2 loss. |
Source code in ultralytics/nn/distill_model.py
def loss_sl2(
self, student_feat: torch.Tensor, teacher_feat: torch.Tensor, feat_idx: int, teacher_scores: tuple
) -> torch.Tensor:
"""Compute score-weighted L2 distillation loss for a feature pair.
Args:
student_feat (torch.Tensor): Student feature tensor of shape (N, C, H, W).
teacher_feat (torch.Tensor): Teacher feature tensor of shape (N, C, H, W).
feat_idx (int): Index of the feature level for selecting teacher scores.
teacher_scores (tuple): Tuple of score tensors for each feature level.
Returns:
(torch.Tensor): The computed score-weighted L2 loss.
"""
teacher_score = teacher_scores[feat_idx]
n, c = student_feat.shape[:2]
student_feat = student_feat.view(n, c, -1)
teacher_feat = teacher_feat.view(n, c, -1)
mse = F.mse_loss(student_feat, teacher_feat, reduction="none")
weighted_mse = (mse * teacher_score).sum() / (teacher_score.sum() * c + 1e-9)
return weighted_mseLink to this section ultralytics.nn.distill_model.DistillationModel.set_head_attr#
def set_head_attr(self, **kwargs)Forward head-attribute updates (e.g. max_det, agnostic_nms, end2end) to the student model.
Args
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs | required |
Source code in ultralytics/nn/distill_model.py
def set_head_attr(self, **kwargs):
"""Forward head-attribute updates (e.g. max_det, agnostic_nms, end2end) to the student model."""
self.student_model.set_head_attr(**kwargs)Link to this section ultralytics.nn.distill_model.DistillationModel.train#
def train(self, mode: bool = True)Set model train mode while keeping teacher frozen in eval mode.
Args
| Name | Type | Description | Default |
|---|---|---|---|
mode | bool | True |
Source code in ultralytics/nn/distill_model.py
def train(self, mode: bool = True):
"""Set model train mode while keeping teacher frozen in eval mode."""
super().train(mode)
self._freeze_teacher()
return self