Skip to content

Transfer Learning with Frozen Layers in YOLOv5

📚 This guide explains how to freeze YOLOv5 🚀 layers when implementing transfer learning. Transfer learning is a powerful technique that allows you to quickly retrain a model on new data without having to retrain the entire network. By freezing part of the initial weights and only updating the rest, you can significantly reduce computational resources and training time, though this approach may slightly impact final model accuracy.

Before You Start

Clone the repo and install requirements.txt in a Python>=3.8.0 environment, including PyTorch>=1.8. Models and datasets download automatically from the latest YOLOv5 release.

git clone https://github.com/ultralytics/yolov5 # clone
cd yolov5
pip install -r requirements.txt # install

How Layer Freezing Works

When you freeze layers in a neural network, you're essentially setting their parameters to be non-trainable. The gradients for these layers are set to zero, preventing any weight updates during backpropagation. This is implemented in YOLOv5's training process as follows:

# Freeze
freeze = [f"model.{x}." for x in range(freeze)]  # layers to freeze
for k, v in model.named_parameters():
    v.requires_grad = True  # train all layers
    if any(x in k for x in freeze):
        print(f"freezing {k}")
        v.requires_grad = False

Exploring Model Architecture

To effectively freeze specific parts of the model, it's helpful to understand the layer structure. You can view all module names with:

for k, v in model.named_parameters():
    print(k)

"""Output:
model.0.conv.conv.weight
model.0.conv.bn.weight
model.0.conv.bn.bias
model.1.conv.weight
model.1.bn.weight
model.1.bn.bias
model.2.cv1.conv.weight
model.2.cv1.bn.weight
...
"""

The YOLOv5 architecture consists of a backbone (layers 0-9) and a head (remaining layers):

# YOLOv5 v6.0 backbone
backbone:
    # [from, number, module, args]
    - [-1, 1, Conv, [64, 6, 2, 2]] # 0-P1/2
    - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
    - [-1, 3, C3, [128]]
    - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
    - [-1, 6, C3, [256]]
    - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
    - [-1, 9, C3, [512]]
    - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
    - [-1, 3, C3, [1024]]
    - [-1, 1, SPPF, [1024, 5]] # 9

# YOLOv5 v6.0 head
head:
    - [-1, 1, Conv, [512, 1, 1]]
    - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
    - [[-1, 6], 1, Concat, [1]] # cat backbone P4
    - [-1, 3, C3, [512, False]] # 13
    # ... remaining head layers

Freezing Options

Freeze Backbone Only

To freeze only the backbone (layers 0-9), which is useful for adapting the model to new classes while retaining learned feature extraction capabilities:

python train.py --freeze 10

This approach is particularly effective when your new dataset shares similar low-level features with the original training data but has different classes or objects.

Freeze All Except Detection Layers

To freeze the entire model except for the final output convolution layers in the Detect module:

python train.py --freeze 24

This approach is ideal when you want to maintain most of the model's learned features but need to adapt it to detect a different number of classes.

Performance Comparison

We trained YOLOv5m on the VOC dataset using different freezing strategies, starting from the official COCO pretrained weights:

python train.py --batch 48 --weights yolov5m.pt --data voc.yaml --epochs 50 --cache --img 512 --hyp hyp.finetune.yaml

Accuracy Results

The results demonstrate that freezing layers accelerates training but slightly reduces final accuracy:

Freezing training mAP50 results

Freezing training mAP50-95 results

Table results

Resource Utilization

Freezing more layers reduces GPU memory requirements and utilization, making this technique valuable for training larger models or using higher resolution images:

Training GPU memory allocated percent

Training GPU memory utilization percent

When to Use Layer Freezing

Layer freezing in transfer learning is particularly beneficial in scenarios such as:

  1. Limited computational resources: When GPU memory or processing power is constrained
  2. Small datasets: When your new dataset is too small to train a full model without overfitting
  3. Quick adaptation: When you need to rapidly adapt a model to a new domain
  4. Fine-tuning for specific tasks: When adapting a general model to a specialized application

For more information on transfer learning techniques and their applications, see the transfer learning glossary entry.

Supported Environments

Ultralytics provides a range of ready-to-use environments, each pre-installed with essential dependencies such as CUDA, CUDNN, Python, and PyTorch, to kickstart your projects.

Project Status

YOLOv5 CI

This badge indicates that all YOLOv5 GitHub Actions Continuous Integration (CI) tests are successfully passing. These CI tests rigorously check the functionality and performance of YOLOv5 across various key aspects: training, validation, inference, export, and benchmarks. They ensure consistent and reliable operation on macOS, Windows, and Ubuntu, with tests conducted every 24 hours and upon each new commit.

📅 Created 1 year ago ✏️ Updated 9 days ago

Comments