Reference for ultralytics/nn/modules/conv.py
Improvements
This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/modules/conv.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏
Summary
Conv.forwardConv.forward_fuseConv2.forwardConv2.forward_fuseConv2.fuse_convsLightConv.forwardConvTranspose.forwardConvTranspose.forward_fuseFocus.forwardGhostConv.forwardRepConv.forward_fuseRepConv.forwardRepConv.get_equivalent_kernel_biasRepConv._pad_1x1_to_3x3_tensorRepConv._fuse_bn_tensorRepConv.fuse_convsChannelAttention.forwardSpatialAttention.forwardCBAM.forwardConcat.forwardIndex.forward
class ultralytics.nn.modules.conv.Conv
Conv(self, c1, c2, k = 1, s = 1, p = None, g = 1, d = 1, act = True)
Bases: nn.Module
Standard convolution module with batch normalization and activation.
Args
| Name | Type | Description | Default |
|---|---|---|---|
c1 | int | Number of input channels. | required |
c2 | int | Number of output channels. | required |
k | int | Kernel size. | 1 |
s | int | Stride. | 1 |
p | int, optional | Padding. | None |
g | int | Groups. | 1 |
d | int | Dilation. | 1 |
act | bool | nn.Module | Activation function. | True |
Attributes
| Name | Type | Description |
|---|---|---|
conv | nn.Conv2d | Convolutional layer. |
bn | nn.BatchNorm2d | Batch normalization layer. |
act | nn.Module | Activation function layer. |
default_act | nn.Module | Default activation function (SiLU). |
Methods
| Name | Description |
|---|---|
forward | Apply convolution, batch normalization and activation to input tensor. |
forward_fuse | Apply convolution and activation without batch normalization. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass Conv(nn.Module):
"""Standard convolution module with batch normalization and activation.
Attributes:
conv (nn.Conv2d): Convolutional layer.
bn (nn.BatchNorm2d): Batch normalization layer.
act (nn.Module): Activation function layer.
default_act (nn.Module): Default activation function (SiLU).
"""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given parameters.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size.
s (int): Stride.
p (int, optional): Padding.
g (int): Groups.
d (int): Dilation.
act (bool | nn.Module): Activation function.
"""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
method ultralytics.nn.modules.conv.Conv.forward
def forward(self, x)
Apply convolution, batch normalization and activation to input tensor.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return self.act(self.bn(self.conv(x)))
method ultralytics.nn.modules.conv.Conv.forward_fuse
def forward_fuse(self, x)
Apply convolution and activation without batch normalization.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward_fuse(self, x):
"""Apply convolution and activation without batch normalization.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return self.act(self.conv(x))
class ultralytics.nn.modules.conv.Conv2
Conv2(self, c1, c2, k = 3, s = 1, p = None, g = 1, d = 1, act = True)
Bases: Conv
Simplified RepConv module with Conv fusing.
Args
| Name | Type | Description | Default |
|---|---|---|---|
c1 | int | Number of input channels. | required |
c2 | int | Number of output channels. | required |
k | int | Kernel size. | 3 |
s | int | Stride. | 1 |
p | int, optional | Padding. | None |
g | int | Groups. | 1 |
d | int | Dilation. | 1 |
act | bool | nn.Module | Activation function. | True |
Attributes
| Name | Type | Description |
|---|---|---|
conv | nn.Conv2d | Main 3x3 convolutional layer. |
cv2 | nn.Conv2d | Additional 1x1 convolutional layer. |
bn | nn.BatchNorm2d | Batch normalization layer. |
act | nn.Module | Activation function layer. |
Methods
| Name | Description |
|---|---|
forward | Apply convolution, batch normalization and activation to input tensor. |
forward_fuse | Apply fused convolution, batch normalization and activation to input tensor. |
fuse_convs | Fuse parallel convolutions. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass Conv2(Conv):
"""Simplified RepConv module with Conv fusing.
Attributes:
conv (nn.Conv2d): Main 3x3 convolutional layer.
cv2 (nn.Conv2d): Additional 1x1 convolutional layer.
bn (nn.BatchNorm2d): Batch normalization layer.
act (nn.Module): Activation function layer.
"""
def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv2 layer with given parameters.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size.
s (int): Stride.
p (int, optional): Padding.
g (int): Groups.
d (int): Dilation.
act (bool | nn.Module): Activation function.
"""
super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
method ultralytics.nn.modules.conv.Conv2.forward
def forward(self, x)
Apply convolution, batch normalization and activation to input tensor.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return self.act(self.bn(self.conv(x) + self.cv2(x)))
method ultralytics.nn.modules.conv.Conv2.forward_fuse
def forward_fuse(self, x)
Apply fused convolution, batch normalization and activation to input tensor.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward_fuse(self, x):
"""Apply fused convolution, batch normalization and activation to input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return self.act(self.bn(self.conv(x)))
method ultralytics.nn.modules.conv.Conv2.fuse_convs
def fuse_convs(self)
Fuse parallel convolutions.
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef fuse_convs(self):
"""Fuse parallel convolutions."""
w = torch.zeros_like(self.conv.weight.data)
i = [x // 2 for x in w.shape[2:]]
w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone()
self.conv.weight.data += w
self.__delattr__("cv2")
self.forward = self.forward_fuse
class ultralytics.nn.modules.conv.LightConv
LightConv(self, c1, c2, k = 1, act = nn.ReLU())
Bases: nn.Module
Light convolution module with 1x1 and depthwise convolutions.
This implementation is based on the PaddleDetection HGNetV2 backbone.
Args
| Name | Type | Description | Default |
|---|---|---|---|
c1 | int | Number of input channels. | required |
c2 | int | Number of output channels. | required |
k | int | Kernel size for depthwise convolution. | 1 |
act | nn.Module | Activation function. | nn.ReLU() |
Attributes
| Name | Type | Description |
|---|---|---|
conv1 | Conv | 1x1 convolution layer. |
conv2 | DWConv | Depthwise convolution layer. |
Methods
| Name | Description |
|---|---|
forward | Apply 2 convolutions to input tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass LightConv(nn.Module):
"""Light convolution module with 1x1 and depthwise convolutions.
This implementation is based on the PaddleDetection HGNetV2 backbone.
Attributes:
conv1 (Conv): 1x1 convolution layer.
conv2 (DWConv): Depthwise convolution layer.
"""
def __init__(self, c1, c2, k=1, act=nn.ReLU()):
"""Initialize LightConv layer with given parameters.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size for depthwise convolution.
act (nn.Module): Activation function.
"""
super().__init__()
self.conv1 = Conv(c1, c2, 1, act=False)
self.conv2 = DWConv(c2, c2, k, act=act)
method ultralytics.nn.modules.conv.LightConv.forward
def forward(self, x)
Apply 2 convolutions to input tensor.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x):
"""Apply 2 convolutions to input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return self.conv2(self.conv1(x))
class ultralytics.nn.modules.conv.DWConv
DWConv(self, c1, c2, k = 1, s = 1, d = 1, act = True)
Bases: Conv
Depth-wise convolution module.
Args
| Name | Type | Description | Default |
|---|---|---|---|
c1 | int | Number of input channels. | required |
c2 | int | Number of output channels. | required |
k | int | Kernel size. | 1 |
s | int | Stride. | 1 |
d | int | Dilation. | 1 |
act | bool | nn.Module | Activation function. | True |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass DWConv(Conv):
"""Depth-wise convolution module."""
def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
"""Initialize depth-wise convolution with given parameters.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size.
s (int): Stride.
d (int): Dilation.
act (bool | nn.Module): Activation function.
"""
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
class ultralytics.nn.modules.conv.DWConvTranspose2d
DWConvTranspose2d(self, c1, c2, k = 1, s = 1, p1 = 0, p2 = 0)
Bases: nn.ConvTranspose2d
Depth-wise transpose convolution module.
Args
| Name | Type | Description | Default |
|---|---|---|---|
c1 | int | Number of input channels. | required |
c2 | int | Number of output channels. | required |
k | int | Kernel size. | 1 |
s | int | Stride. | 1 |
p1 | int | Padding. | 0 |
p2 | int | Output padding. | 0 |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass DWConvTranspose2d(nn.ConvTranspose2d):
"""Depth-wise transpose convolution module."""
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0):
"""Initialize depth-wise transpose convolution with given parameters.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size.
s (int): Stride.
p1 (int): Padding.
p2 (int): Output padding.
"""
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
class ultralytics.nn.modules.conv.ConvTranspose
ConvTranspose(self, c1, c2, k = 2, s = 2, p = 0, bn = True, act = True)
Bases: nn.Module
Convolution transpose module with optional batch normalization and activation.
Args
| Name | Type | Description | Default |
|---|---|---|---|
c1 | int | Number of input channels. | required |
c2 | int | Number of output channels. | required |
k | int | Kernel size. | 2 |
s | int | Stride. | 2 |
p | int | Padding. | 0 |
bn | bool | Use batch normalization. | True |
act | bool | nn.Module | Activation function. | True |
Attributes
| Name | Type | Description |
|---|---|---|
conv_transpose | nn.ConvTranspose2d | Transposed convolution layer. |
bn | nn.BatchNorm2d | nn.Identity | Batch normalization layer. |
act | nn.Module | Activation function layer. |
default_act | nn.Module | Default activation function (SiLU). |
Methods
| Name | Description |
|---|---|
forward | Apply transposed convolution, batch normalization and activation to input. |
forward_fuse | Apply activation and convolution transpose operation to input. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass ConvTranspose(nn.Module):
"""Convolution transpose module with optional batch normalization and activation.
Attributes:
conv_transpose (nn.ConvTranspose2d): Transposed convolution layer.
bn (nn.BatchNorm2d | nn.Identity): Batch normalization layer.
act (nn.Module): Activation function layer.
default_act (nn.Module): Default activation function (SiLU).
"""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
"""Initialize ConvTranspose layer with given parameters.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size.
s (int): Stride.
p (int): Padding.
bn (bool): Use batch normalization.
act (bool | nn.Module): Activation function.
"""
super().__init__()
self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
method ultralytics.nn.modules.conv.ConvTranspose.forward
def forward(self, x)
Apply transposed convolution, batch normalization and activation to input.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x):
"""Apply transposed convolution, batch normalization and activation to input.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return self.act(self.bn(self.conv_transpose(x)))
method ultralytics.nn.modules.conv.ConvTranspose.forward_fuse
def forward_fuse(self, x)
Apply activation and convolution transpose operation to input.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward_fuse(self, x):
"""Apply activation and convolution transpose operation to input.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return self.act(self.conv_transpose(x))
class ultralytics.nn.modules.conv.Focus
Focus(self, c1, c2, k = 1, s = 1, p = None, g = 1, act = True)
Bases: nn.Module
Focus module for concentrating feature information.
Slices input tensor into 4 parts and concatenates them in the channel dimension.
Args
| Name | Type | Description | Default |
|---|---|---|---|
c1 | int | Number of input channels. | required |
c2 | int | Number of output channels. | required |
k | int | Kernel size. | 1 |
s | int | Stride. | 1 |
p | int, optional | Padding. | None |
g | int | Groups. | 1 |
act | bool | nn.Module | Activation function. | True |
Attributes
| Name | Type | Description |
|---|---|---|
conv | Conv | Convolution layer. |
Methods
| Name | Description |
|---|---|
forward | Apply Focus operation and convolution to input tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass Focus(nn.Module):
"""Focus module for concentrating feature information.
Slices input tensor into 4 parts and concatenates them in the channel dimension.
Attributes:
conv (Conv): Convolution layer.
"""
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
"""Initialize Focus module with given parameters.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size.
s (int): Stride.
p (int, optional): Padding.
g (int): Groups.
act (bool | nn.Module): Activation function.
"""
super().__init__()
self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
method ultralytics.nn.modules.conv.Focus.forward
def forward(self, x)
Apply Focus operation and convolution to input tensor.
Input shape is (B, C, W, H) and output shape is (B, 4C, W/2, H/2).
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x):
"""Apply Focus operation and convolution to input tensor.
Input shape is (B, C, W, H) and output shape is (B, 4C, W/2, H/2).
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
class ultralytics.nn.modules.conv.GhostConv
GhostConv(self, c1, c2, k = 1, s = 1, g = 1, act = True)
Bases: nn.Module
Ghost Convolution module.
Generates more features with fewer parameters by using cheap operations.
Args
| Name | Type | Description | Default |
|---|---|---|---|
c1 | int | Number of input channels. | required |
c2 | int | Number of output channels. | required |
k | int | Kernel size. | 1 |
s | int | Stride. | 1 |
g | int | Groups. | 1 |
act | bool | nn.Module | Activation function. | True |
Attributes
| Name | Type | Description |
|---|---|---|
cv1 | Conv | Primary convolution. |
cv2 | Conv | Cheap operation convolution. |
References | ||
https | //github.com/huawei-noah/Efficient-AI-Backbones |
Methods
| Name | Description |
|---|---|
forward | Apply Ghost Convolution to input tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass GhostConv(nn.Module):
"""Ghost Convolution module.
Generates more features with fewer parameters by using cheap operations.
Attributes:
cv1 (Conv): Primary convolution.
cv2 (Conv): Cheap operation convolution.
References:
https://github.com/huawei-noah/Efficient-AI-Backbones
"""
def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
"""Initialize Ghost Convolution module with given parameters.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size.
s (int): Stride.
g (int): Groups.
act (bool | nn.Module): Activation function.
"""
super().__init__()
c_ = c2 // 2 # hidden channels
self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
method ultralytics.nn.modules.conv.GhostConv.forward
def forward(self, x)
Apply Ghost Convolution to input tensor.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor with concatenated features. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x):
"""Apply Ghost Convolution to input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor with concatenated features.
"""
y = self.cv1(x)
return torch.cat((y, self.cv2(y)), 1)
class ultralytics.nn.modules.conv.RepConv
RepConv(self, c1, c2, k = 3, s = 1, p = 1, g = 1, d = 1, act = True, bn = False, deploy = False)
Bases: nn.Module
RepConv module with training and deploy modes.
This module is used in RT-DETR and can fuse convolutions during inference for efficiency.
Args
| Name | Type | Description | Default |
|---|---|---|---|
c1 | int | Number of input channels. | required |
c2 | int | Number of output channels. | required |
k | int | Kernel size. | 3 |
s | int | Stride. | 1 |
p | int | Padding. | 1 |
g | int | Groups. | 1 |
d | int | Dilation. | 1 |
act | bool | nn.Module | Activation function. | True |
bn | bool | Use batch normalization for identity branch. | False |
deploy | bool | Deploy mode for inference. | False |
Attributes
| Name | Type | Description |
|---|---|---|
conv1 | Conv | 3x3 convolution. |
conv2 | Conv | 1x1 convolution. |
bn | nn.BatchNorm2d, optional | Batch normalization for identity branch. |
act | nn.Module | Activation function. |
default_act | nn.Module | Default activation function (SiLU). |
References | ||
https | //github.com/DingXiaoH/RepVGG/blob/main/repvgg.py |
Methods
| Name | Description |
|---|---|
_fuse_bn_tensor | Fuse batch normalization with convolution weights. |
_pad_1x1_to_3x3_tensor | Pad a 1x1 kernel to 3x3 size. |
forward | Forward pass for training mode. |
forward_fuse | Forward pass for deploy mode. |
fuse_convs | Fuse convolutions for inference by creating a single equivalent convolution. |
get_equivalent_kernel_bias | Calculate equivalent kernel and bias by fusing convolutions. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass RepConv(nn.Module):
"""RepConv module with training and deploy modes.
This module is used in RT-DETR and can fuse convolutions during inference for efficiency.
Attributes:
conv1 (Conv): 3x3 convolution.
conv2 (Conv): 1x1 convolution.
bn (nn.BatchNorm2d, optional): Batch normalization for identity branch.
act (nn.Module): Activation function.
default_act (nn.Module): Default activation function (SiLU).
References:
https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
"""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
"""Initialize RepConv module with given parameters.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size.
s (int): Stride.
p (int): Padding.
g (int): Groups.
d (int): Dilation.
act (bool | nn.Module): Activation function.
bn (bool): Use batch normalization for identity branch.
deploy (bool): Deploy mode for inference.
"""
super().__init__()
assert k == 3 and p == 1
self.g = g
self.c1 = c1
self.c2 = c2
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None
self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
method ultralytics.nn.modules.conv.RepConv._fuse_bn_tensor
def _fuse_bn_tensor(self, branch)
Fuse batch normalization with convolution weights.
Args
| Name | Type | Description | Default |
|---|---|---|---|
branch | Conv | nn.BatchNorm2d | None | Branch to fuse. | required |
Returns
| Type | Description |
|---|---|
kernel (torch.Tensor) | Fused kernel. |
bias (torch.Tensor) | Fused bias. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef _fuse_bn_tensor(self, branch):
"""Fuse batch normalization with convolution weights.
Args:
branch (Conv | nn.BatchNorm2d | None): Branch to fuse.
Returns:
kernel (torch.Tensor): Fused kernel.
bias (torch.Tensor): Fused bias.
"""
if branch is None:
return 0, 0
if isinstance(branch, Conv):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
elif isinstance(branch, nn.BatchNorm2d):
if not hasattr(self, "id_tensor"):
input_dim = self.c1 // self.g
kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
for i in range(self.c1):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
method ultralytics.nn.modules.conv.RepConv._pad_1x1_to_3x3_tensor
def _pad_1x1_to_3x3_tensor(kernel1x1)
Pad a 1x1 kernel to 3x3 size.
Args
| Name | Type | Description | Default |
|---|---|---|---|
kernel1x1 | torch.Tensor | 1x1 convolution kernel. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Padded 3x3 kernel. |
Source code in ultralytics/nn/modules/conv.py
View on GitHub@staticmethod
def _pad_1x1_to_3x3_tensor(kernel1x1):
"""Pad a 1x1 kernel to 3x3 size.
Args:
kernel1x1 (torch.Tensor): 1x1 convolution kernel.
Returns:
(torch.Tensor): Padded 3x3 kernel.
"""
if kernel1x1 is None:
return 0
else:
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
method ultralytics.nn.modules.conv.RepConv.forward
def forward(self, x)
Forward pass for training mode.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x):
"""Forward pass for training mode.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
id_out = 0 if self.bn is None else self.bn(x)
return self.act(self.conv1(x) + self.conv2(x) + id_out)
method ultralytics.nn.modules.conv.RepConv.forward_fuse
def forward_fuse(self, x)
Forward pass for deploy mode.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward_fuse(self, x):
"""Forward pass for deploy mode.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return self.act(self.conv(x))
method ultralytics.nn.modules.conv.RepConv.fuse_convs
def fuse_convs(self)
Fuse convolutions for inference by creating a single equivalent convolution.
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef fuse_convs(self):
"""Fuse convolutions for inference by creating a single equivalent convolution."""
if hasattr(self, "conv"):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.conv = nn.Conv2d(
in_channels=self.conv1.conv.in_channels,
out_channels=self.conv1.conv.out_channels,
kernel_size=self.conv1.conv.kernel_size,
stride=self.conv1.conv.stride,
padding=self.conv1.conv.padding,
dilation=self.conv1.conv.dilation,
groups=self.conv1.conv.groups,
bias=True,
).requires_grad_(False)
self.conv.weight.data = kernel
self.conv.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__("conv1")
self.__delattr__("conv2")
if hasattr(self, "nm"):
self.__delattr__("nm")
if hasattr(self, "bn"):
self.__delattr__("bn")
if hasattr(self, "id_tensor"):
self.__delattr__("id_tensor")
method ultralytics.nn.modules.conv.RepConv.get_equivalent_kernel_bias
def get_equivalent_kernel_bias(self)
Calculate equivalent kernel and bias by fusing convolutions.
Returns
| Type | Description |
|---|---|
torch.Tensor | Equivalent kernel |
torch.Tensor | Equivalent bias |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef get_equivalent_kernel_bias(self):
"""Calculate equivalent kernel and bias by fusing convolutions.
Returns:
(torch.Tensor): Equivalent kernel
(torch.Tensor): Equivalent bias
"""
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
kernelid, biasid = self._fuse_bn_tensor(self.bn)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
class ultralytics.nn.modules.conv.ChannelAttention
ChannelAttention(self, channels: int) -> None
Bases: nn.Module
Channel-attention module for feature recalibration.
Applies attention weights to channels based on global average pooling.
Args
| Name | Type | Description | Default |
|---|---|---|---|
channels | int | Number of input channels. | required |
Attributes
| Name | Type | Description |
|---|---|---|
pool | nn.AdaptiveAvgPool2d | Global average pooling. |
fc | nn.Conv2d | Fully connected layer implemented as 1x1 convolution. |
act | nn.Sigmoid | Sigmoid activation for attention weights. |
References | ||
https | //github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet |
Methods
| Name | Description |
|---|---|
forward | Apply channel attention to input tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass ChannelAttention(nn.Module):
"""Channel-attention module for feature recalibration.
Applies attention weights to channels based on global average pooling.
Attributes:
pool (nn.AdaptiveAvgPool2d): Global average pooling.
fc (nn.Conv2d): Fully connected layer implemented as 1x1 convolution.
act (nn.Sigmoid): Sigmoid activation for attention weights.
References:
https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet
"""
def __init__(self, channels: int) -> None:
"""Initialize Channel-attention module.
Args:
channels (int): Number of input channels.
"""
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
self.act = nn.Sigmoid()
method ultralytics.nn.modules.conv.ChannelAttention.forward
def forward(self, x: torch.Tensor) -> torch.Tensor
Apply channel attention to input tensor.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Channel-attended output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply channel attention to input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Channel-attended output tensor.
"""
return x * self.act(self.fc(self.pool(x)))
class ultralytics.nn.modules.conv.SpatialAttention
SpatialAttention(self, kernel_size = 7)
Bases: nn.Module
Spatial-attention module for feature recalibration.
Applies attention weights to spatial dimensions based on channel statistics.
Args
| Name | Type | Description | Default |
|---|---|---|---|
kernel_size | int | Size of the convolutional kernel (3 or 7). | 7 |
Attributes
| Name | Type | Description |
|---|---|---|
cv1 | nn.Conv2d | Convolution layer for spatial attention. |
act | nn.Sigmoid | Sigmoid activation for attention weights. |
Methods
| Name | Description |
|---|---|
forward | Apply spatial attention to input tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass SpatialAttention(nn.Module):
"""Spatial-attention module for feature recalibration.
Applies attention weights to spatial dimensions based on channel statistics.
Attributes:
cv1 (nn.Conv2d): Convolution layer for spatial attention.
act (nn.Sigmoid): Sigmoid activation for attention weights.
"""
def __init__(self, kernel_size=7):
"""Initialize Spatial-attention module.
Args:
kernel_size (int): Size of the convolutional kernel (3 or 7).
"""
super().__init__()
assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
padding = 3 if kernel_size == 7 else 1
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.act = nn.Sigmoid()
method ultralytics.nn.modules.conv.SpatialAttention.forward
def forward(self, x)
Apply spatial attention to input tensor.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Spatial-attended output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x):
"""Apply spatial attention to input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Spatial-attended output tensor.
"""
return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
class ultralytics.nn.modules.conv.CBAM
CBAM(self, c1, kernel_size = 7)
Bases: nn.Module
Convolutional Block Attention Module.
Combines channel and spatial attention mechanisms for comprehensive feature refinement.
Args
| Name | Type | Description | Default |
|---|---|---|---|
c1 | int | Number of input channels. | required |
kernel_size | int | Size of the convolutional kernel for spatial attention. | 7 |
Attributes
| Name | Type | Description |
|---|---|---|
channel_attention | ChannelAttention | Channel attention module. |
spatial_attention | SpatialAttention | Spatial attention module. |
Methods
| Name | Description |
|---|---|
forward | Apply channel and spatial attention sequentially to input tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass CBAM(nn.Module):
"""Convolutional Block Attention Module.
Combines channel and spatial attention mechanisms for comprehensive feature refinement.
Attributes:
channel_attention (ChannelAttention): Channel attention module.
spatial_attention (SpatialAttention): Spatial attention module.
"""
def __init__(self, c1, kernel_size=7):
"""Initialize CBAM with given parameters.
Args:
c1 (int): Number of input channels.
kernel_size (int): Size of the convolutional kernel for spatial attention.
"""
super().__init__()
self.channel_attention = ChannelAttention(c1)
self.spatial_attention = SpatialAttention(kernel_size)
method ultralytics.nn.modules.conv.CBAM.forward
def forward(self, x)
Apply channel and spatial attention sequentially to input tensor.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | torch.Tensor | Input tensor. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Attended output tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x):
"""Apply channel and spatial attention sequentially to input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Attended output tensor.
"""
return self.spatial_attention(self.channel_attention(x))
class ultralytics.nn.modules.conv.Concat
Concat(self, dimension = 1)
Bases: nn.Module
Concatenate a list of tensors along specified dimension.
Args
| Name | Type | Description | Default |
|---|---|---|---|
dimension | int | Dimension along which to concatenate tensors. | 1 |
Attributes
| Name | Type | Description |
|---|---|---|
d | int | Dimension along which to concatenate tensors. |
Methods
| Name | Description |
|---|---|
forward | Concatenate input tensors along specified dimension. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass Concat(nn.Module):
"""Concatenate a list of tensors along specified dimension.
Attributes:
d (int): Dimension along which to concatenate tensors.
"""
def __init__(self, dimension=1):
"""Initialize Concat module.
Args:
dimension (int): Dimension along which to concatenate tensors.
"""
super().__init__()
self.d = dimension
method ultralytics.nn.modules.conv.Concat.forward
def forward(self, x: list[torch.Tensor])
Concatenate input tensors along specified dimension.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | list[torch.Tensor] | List of input tensors. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Concatenated tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x: list[torch.Tensor]):
"""Concatenate input tensors along specified dimension.
Args:
x (list[torch.Tensor]): List of input tensors.
Returns:
(torch.Tensor): Concatenated tensor.
"""
return torch.cat(x, self.d)
class ultralytics.nn.modules.conv.Index
Index(self, index = 0)
Bases: nn.Module
Returns a particular index of the input.
Args
| Name | Type | Description | Default |
|---|---|---|---|
index | int | Index to select from input. | 0 |
Attributes
| Name | Type | Description |
|---|---|---|
index | int | Index to select from input. |
Methods
| Name | Description |
|---|---|
forward | Select and return a particular index from input. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubclass Index(nn.Module):
"""Returns a particular index of the input.
Attributes:
index (int): Index to select from input.
"""
def __init__(self, index=0):
"""Initialize Index module.
Args:
index (int): Index to select from input.
"""
super().__init__()
self.index = index
method ultralytics.nn.modules.conv.Index.forward
def forward(self, x: list[torch.Tensor])
Select and return a particular index from input.
Args
| Name | Type | Description | Default |
|---|---|---|---|
x | list[torch.Tensor] | List of input tensors. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Selected tensor. |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef forward(self, x: list[torch.Tensor]):
"""Select and return a particular index from input.
Args:
x (list[torch.Tensor]): List of input tensors.
Returns:
(torch.Tensor): Selected tensor.
"""
return x[self.index]
function ultralytics.nn.modules.conv.autopad
def autopad(k, p = None, d = 1)
Pad to 'same' shape outputs.
Args
| Name | Type | Description | Default |
|---|---|---|---|
k | required | ||
p | None | ||
d | 1 |
Source code in ultralytics/nn/modules/conv.py
View on GitHubdef autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p