Skip to content

Reference for ultralytics/nn/modules/conv.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/modules/conv.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.nn.modules.conv.Conv

Conv(self, c1, c2, k = 1, s = 1, p = None, g = 1, d = 1, act = True)

Bases: nn.Module

Standard convolution module with batch normalization and activation.

Args

NameTypeDescriptionDefault
c1intNumber of input channels.required
c2intNumber of output channels.required
kintKernel size.1
sintStride.1
pint, optionalPadding.None
gintGroups.1
dintDilation.1
actbool | nn.ModuleActivation function.True

Attributes

NameTypeDescription
convnn.Conv2dConvolutional layer.
bnnn.BatchNorm2dBatch normalization layer.
actnn.ModuleActivation function layer.
default_actnn.ModuleDefault activation function (SiLU).

Methods

NameDescription
forwardApply convolution, batch normalization and activation to input tensor.
forward_fuseApply convolution and activation without batch normalization.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class Conv(nn.Module):
    """Standard convolution module with batch normalization and activation.

    Attributes:
        conv (nn.Conv2d): Convolutional layer.
        bn (nn.BatchNorm2d): Batch normalization layer.
        act (nn.Module): Activation function layer.
        default_act (nn.Module): Default activation function (SiLU).
    """

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given parameters.

        Args:
            c1 (int): Number of input channels.
            c2 (int): Number of output channels.
            k (int): Kernel size.
            s (int): Stride.
            p (int, optional): Padding.
            g (int): Groups.
            d (int): Dilation.
            act (bool | nn.Module): Activation function.
        """
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()


method ultralytics.nn.modules.conv.Conv.forward

def forward(self, x)

Apply convolution, batch normalization and activation to input tensor.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x):
    """Apply convolution, batch normalization and activation to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.bn(self.conv(x)))


method ultralytics.nn.modules.conv.Conv.forward_fuse

def forward_fuse(self, x)

Apply convolution and activation without batch normalization.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward_fuse(self, x):
    """Apply convolution and activation without batch normalization.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.conv(x))





class ultralytics.nn.modules.conv.Conv2

Conv2(self, c1, c2, k = 3, s = 1, p = None, g = 1, d = 1, act = True)

Bases: Conv

Simplified RepConv module with Conv fusing.

Args

NameTypeDescriptionDefault
c1intNumber of input channels.required
c2intNumber of output channels.required
kintKernel size.3
sintStride.1
pint, optionalPadding.None
gintGroups.1
dintDilation.1
actbool | nn.ModuleActivation function.True

Attributes

NameTypeDescription
convnn.Conv2dMain 3x3 convolutional layer.
cv2nn.Conv2dAdditional 1x1 convolutional layer.
bnnn.BatchNorm2dBatch normalization layer.
actnn.ModuleActivation function layer.

Methods

NameDescription
forwardApply convolution, batch normalization and activation to input tensor.
forward_fuseApply fused convolution, batch normalization and activation to input tensor.
fuse_convsFuse parallel convolutions.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class Conv2(Conv):
    """Simplified RepConv module with Conv fusing.

    Attributes:
        conv (nn.Conv2d): Main 3x3 convolutional layer.
        cv2 (nn.Conv2d): Additional 1x1 convolutional layer.
        bn (nn.BatchNorm2d): Batch normalization layer.
        act (nn.Module): Activation function layer.
    """

    def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv2 layer with given parameters.

        Args:
            c1 (int): Number of input channels.
            c2 (int): Number of output channels.
            k (int): Kernel size.
            s (int): Stride.
            p (int, optional): Padding.
            g (int): Groups.
            d (int): Dilation.
            act (bool | nn.Module): Activation function.
        """
        super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
        self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False)  # add 1x1 conv


method ultralytics.nn.modules.conv.Conv2.forward

def forward(self, x)

Apply convolution, batch normalization and activation to input tensor.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x):
    """Apply convolution, batch normalization and activation to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.bn(self.conv(x) + self.cv2(x)))


method ultralytics.nn.modules.conv.Conv2.forward_fuse

def forward_fuse(self, x)

Apply fused convolution, batch normalization and activation to input tensor.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward_fuse(self, x):
    """Apply fused convolution, batch normalization and activation to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.bn(self.conv(x)))


method ultralytics.nn.modules.conv.Conv2.fuse_convs

def fuse_convs(self)

Fuse parallel convolutions.

Source code in ultralytics/nn/modules/conv.pyView on GitHub
def fuse_convs(self):
    """Fuse parallel convolutions."""
    w = torch.zeros_like(self.conv.weight.data)
    i = [x // 2 for x in w.shape[2:]]
    w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone()
    self.conv.weight.data += w
    self.__delattr__("cv2")
    self.forward = self.forward_fuse





class ultralytics.nn.modules.conv.LightConv

LightConv(self, c1, c2, k = 1, act = nn.ReLU())

Bases: nn.Module

Light convolution module with 1x1 and depthwise convolutions.

This implementation is based on the PaddleDetection HGNetV2 backbone.

Args

NameTypeDescriptionDefault
c1intNumber of input channels.required
c2intNumber of output channels.required
kintKernel size for depthwise convolution.1
actnn.ModuleActivation function.nn.ReLU()

Attributes

NameTypeDescription
conv1Conv1x1 convolution layer.
conv2DWConvDepthwise convolution layer.

Methods

NameDescription
forwardApply 2 convolutions to input tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class LightConv(nn.Module):
    """Light convolution module with 1x1 and depthwise convolutions.

    This implementation is based on the PaddleDetection HGNetV2 backbone.

    Attributes:
        conv1 (Conv): 1x1 convolution layer.
        conv2 (DWConv): Depthwise convolution layer.
    """

    def __init__(self, c1, c2, k=1, act=nn.ReLU()):
        """Initialize LightConv layer with given parameters.

        Args:
            c1 (int): Number of input channels.
            c2 (int): Number of output channels.
            k (int): Kernel size for depthwise convolution.
            act (nn.Module): Activation function.
        """
        super().__init__()
        self.conv1 = Conv(c1, c2, 1, act=False)
        self.conv2 = DWConv(c2, c2, k, act=act)


method ultralytics.nn.modules.conv.LightConv.forward

def forward(self, x)

Apply 2 convolutions to input tensor.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x):
    """Apply 2 convolutions to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.conv2(self.conv1(x))





class ultralytics.nn.modules.conv.DWConv

DWConv(self, c1, c2, k = 1, s = 1, d = 1, act = True)

Bases: Conv

Depth-wise convolution module.

Args

NameTypeDescriptionDefault
c1intNumber of input channels.required
c2intNumber of output channels.required
kintKernel size.1
sintStride.1
dintDilation.1
actbool | nn.ModuleActivation function.True
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class DWConv(Conv):
    """Depth-wise convolution module."""

    def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
        """Initialize depth-wise convolution with given parameters.

        Args:
            c1 (int): Number of input channels.
            c2 (int): Number of output channels.
            k (int): Kernel size.
            s (int): Stride.
            d (int): Dilation.
            act (bool | nn.Module): Activation function.
        """
        super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)





class ultralytics.nn.modules.conv.DWConvTranspose2d

DWConvTranspose2d(self, c1, c2, k = 1, s = 1, p1 = 0, p2 = 0)

Bases: nn.ConvTranspose2d

Depth-wise transpose convolution module.

Args

NameTypeDescriptionDefault
c1intNumber of input channels.required
c2intNumber of output channels.required
kintKernel size.1
sintStride.1
p1intPadding.0
p2intOutput padding.0
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class DWConvTranspose2d(nn.ConvTranspose2d):
    """Depth-wise transpose convolution module."""

    def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0):
        """Initialize depth-wise transpose convolution with given parameters.

        Args:
            c1 (int): Number of input channels.
            c2 (int): Number of output channels.
            k (int): Kernel size.
            s (int): Stride.
            p1 (int): Padding.
            p2 (int): Output padding.
        """
        super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))





class ultralytics.nn.modules.conv.ConvTranspose

ConvTranspose(self, c1, c2, k = 2, s = 2, p = 0, bn = True, act = True)

Bases: nn.Module

Convolution transpose module with optional batch normalization and activation.

Args

NameTypeDescriptionDefault
c1intNumber of input channels.required
c2intNumber of output channels.required
kintKernel size.2
sintStride.2
pintPadding.0
bnboolUse batch normalization.True
actbool | nn.ModuleActivation function.True

Attributes

NameTypeDescription
conv_transposenn.ConvTranspose2dTransposed convolution layer.
bnnn.BatchNorm2d | nn.IdentityBatch normalization layer.
actnn.ModuleActivation function layer.
default_actnn.ModuleDefault activation function (SiLU).

Methods

NameDescription
forwardApply transposed convolution, batch normalization and activation to input.
forward_fuseApply activation and convolution transpose operation to input.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class ConvTranspose(nn.Module):
    """Convolution transpose module with optional batch normalization and activation.

    Attributes:
        conv_transpose (nn.ConvTranspose2d): Transposed convolution layer.
        bn (nn.BatchNorm2d | nn.Identity): Batch normalization layer.
        act (nn.Module): Activation function layer.
        default_act (nn.Module): Default activation function (SiLU).
    """

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
        """Initialize ConvTranspose layer with given parameters.

        Args:
            c1 (int): Number of input channels.
            c2 (int): Number of output channels.
            k (int): Kernel size.
            s (int): Stride.
            p (int): Padding.
            bn (bool): Use batch normalization.
            act (bool | nn.Module): Activation function.
        """
        super().__init__()
        self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
        self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()


method ultralytics.nn.modules.conv.ConvTranspose.forward

def forward(self, x)

Apply transposed convolution, batch normalization and activation to input.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x):
    """Apply transposed convolution, batch normalization and activation to input.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.bn(self.conv_transpose(x)))


method ultralytics.nn.modules.conv.ConvTranspose.forward_fuse

def forward_fuse(self, x)

Apply activation and convolution transpose operation to input.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward_fuse(self, x):
    """Apply activation and convolution transpose operation to input.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.conv_transpose(x))





class ultralytics.nn.modules.conv.Focus

Focus(self, c1, c2, k = 1, s = 1, p = None, g = 1, act = True)

Bases: nn.Module

Focus module for concentrating feature information.

Slices input tensor into 4 parts and concatenates them in the channel dimension.

Args

NameTypeDescriptionDefault
c1intNumber of input channels.required
c2intNumber of output channels.required
kintKernel size.1
sintStride.1
pint, optionalPadding.None
gintGroups.1
actbool | nn.ModuleActivation function.True

Attributes

NameTypeDescription
convConvConvolution layer.

Methods

NameDescription
forwardApply Focus operation and convolution to input tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class Focus(nn.Module):
    """Focus module for concentrating feature information.

    Slices input tensor into 4 parts and concatenates them in the channel dimension.

    Attributes:
        conv (Conv): Convolution layer.
    """

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
        """Initialize Focus module with given parameters.

        Args:
            c1 (int): Number of input channels.
            c2 (int): Number of output channels.
            k (int): Kernel size.
            s (int): Stride.
            p (int, optional): Padding.
            g (int): Groups.
            act (bool | nn.Module): Activation function.
        """
        super().__init__()
        self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)


method ultralytics.nn.modules.conv.Focus.forward

def forward(self, x)

Apply Focus operation and convolution to input tensor.

Input shape is (B, C, W, H) and output shape is (B, 4C, W/2, H/2).

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x):
    """Apply Focus operation and convolution to input tensor.

    Input shape is (B, C, W, H) and output shape is (B, 4C, W/2, H/2).

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))





class ultralytics.nn.modules.conv.GhostConv

GhostConv(self, c1, c2, k = 1, s = 1, g = 1, act = True)

Bases: nn.Module

Ghost Convolution module.

Generates more features with fewer parameters by using cheap operations.

Args

NameTypeDescriptionDefault
c1intNumber of input channels.required
c2intNumber of output channels.required
kintKernel size.1
sintStride.1
gintGroups.1
actbool | nn.ModuleActivation function.True

Attributes

NameTypeDescription
cv1ConvPrimary convolution.
cv2ConvCheap operation convolution.
References
https//github.com/huawei-noah/Efficient-AI-Backbones

Methods

NameDescription
forwardApply Ghost Convolution to input tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class GhostConv(nn.Module):
    """Ghost Convolution module.

    Generates more features with fewer parameters by using cheap operations.

    Attributes:
        cv1 (Conv): Primary convolution.
        cv2 (Conv): Cheap operation convolution.

    References:
        https://github.com/huawei-noah/Efficient-AI-Backbones
    """

    def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
        """Initialize Ghost Convolution module with given parameters.

        Args:
            c1 (int): Number of input channels.
            c2 (int): Number of output channels.
            k (int): Kernel size.
            s (int): Stride.
            g (int): Groups.
            act (bool | nn.Module): Activation function.
        """
        super().__init__()
        c_ = c2 // 2  # hidden channels
        self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
        self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)


method ultralytics.nn.modules.conv.GhostConv.forward

def forward(self, x)

Apply Ghost Convolution to input tensor.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor with concatenated features.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x):
    """Apply Ghost Convolution to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor with concatenated features.
    """
    y = self.cv1(x)
    return torch.cat((y, self.cv2(y)), 1)





class ultralytics.nn.modules.conv.RepConv

RepConv(self, c1, c2, k = 3, s = 1, p = 1, g = 1, d = 1, act = True, bn = False, deploy = False)

Bases: nn.Module

RepConv module with training and deploy modes.

This module is used in RT-DETR and can fuse convolutions during inference for efficiency.

Args

NameTypeDescriptionDefault
c1intNumber of input channels.required
c2intNumber of output channels.required
kintKernel size.3
sintStride.1
pintPadding.1
gintGroups.1
dintDilation.1
actbool | nn.ModuleActivation function.True
bnboolUse batch normalization for identity branch.False
deployboolDeploy mode for inference.False

Attributes

NameTypeDescription
conv1Conv3x3 convolution.
conv2Conv1x1 convolution.
bnnn.BatchNorm2d, optionalBatch normalization for identity branch.
actnn.ModuleActivation function.
default_actnn.ModuleDefault activation function (SiLU).
References
https//github.com/DingXiaoH/RepVGG/blob/main/repvgg.py

Methods

NameDescription
_fuse_bn_tensorFuse batch normalization with convolution weights.
_pad_1x1_to_3x3_tensorPad a 1x1 kernel to 3x3 size.
forwardForward pass for training mode.
forward_fuseForward pass for deploy mode.
fuse_convsFuse convolutions for inference by creating a single equivalent convolution.
get_equivalent_kernel_biasCalculate equivalent kernel and bias by fusing convolutions.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class RepConv(nn.Module):
    """RepConv module with training and deploy modes.

    This module is used in RT-DETR and can fuse convolutions during inference for efficiency.

    Attributes:
        conv1 (Conv): 3x3 convolution.
        conv2 (Conv): 1x1 convolution.
        bn (nn.BatchNorm2d, optional): Batch normalization for identity branch.
        act (nn.Module): Activation function.
        default_act (nn.Module): Default activation function (SiLU).

    References:
        https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
    """

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
        """Initialize RepConv module with given parameters.

        Args:
            c1 (int): Number of input channels.
            c2 (int): Number of output channels.
            k (int): Kernel size.
            s (int): Stride.
            p (int): Padding.
            g (int): Groups.
            d (int): Dilation.
            act (bool | nn.Module): Activation function.
            bn (bool): Use batch normalization for identity branch.
            deploy (bool): Deploy mode for inference.
        """
        super().__init__()
        assert k == 3 and p == 1
        self.g = g
        self.c1 = c1
        self.c2 = c2
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

        self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None
        self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
        self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)


method ultralytics.nn.modules.conv.RepConv._fuse_bn_tensor

def _fuse_bn_tensor(self, branch)

Fuse batch normalization with convolution weights.

Args

NameTypeDescriptionDefault
branchConv | nn.BatchNorm2d | NoneBranch to fuse.required

Returns

TypeDescription
kernel (torch.Tensor)Fused kernel.
bias (torch.Tensor)Fused bias.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def _fuse_bn_tensor(self, branch):
    """Fuse batch normalization with convolution weights.

    Args:
        branch (Conv | nn.BatchNorm2d | None): Branch to fuse.

    Returns:
        kernel (torch.Tensor): Fused kernel.
        bias (torch.Tensor): Fused bias.
    """
    if branch is None:
        return 0, 0
    if isinstance(branch, Conv):
        kernel = branch.conv.weight
        running_mean = branch.bn.running_mean
        running_var = branch.bn.running_var
        gamma = branch.bn.weight
        beta = branch.bn.bias
        eps = branch.bn.eps
    elif isinstance(branch, nn.BatchNorm2d):
        if not hasattr(self, "id_tensor"):
            input_dim = self.c1 // self.g
            kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
            for i in range(self.c1):
                kernel_value[i, i % input_dim, 1, 1] = 1
            self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
        kernel = self.id_tensor
        running_mean = branch.running_mean
        running_var = branch.running_var
        gamma = branch.weight
        beta = branch.bias
        eps = branch.eps
    std = (running_var + eps).sqrt()
    t = (gamma / std).reshape(-1, 1, 1, 1)
    return kernel * t, beta - running_mean * gamma / std


method ultralytics.nn.modules.conv.RepConv._pad_1x1_to_3x3_tensor

def _pad_1x1_to_3x3_tensor(kernel1x1)

Pad a 1x1 kernel to 3x3 size.

Args

NameTypeDescriptionDefault
kernel1x1torch.Tensor1x1 convolution kernel.required

Returns

TypeDescription
torch.TensorPadded 3x3 kernel.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
@staticmethod
def _pad_1x1_to_3x3_tensor(kernel1x1):
    """Pad a 1x1 kernel to 3x3 size.

    Args:
        kernel1x1 (torch.Tensor): 1x1 convolution kernel.

    Returns:
        (torch.Tensor): Padded 3x3 kernel.
    """
    if kernel1x1 is None:
        return 0
    else:
        return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])


method ultralytics.nn.modules.conv.RepConv.forward

def forward(self, x)

Forward pass for training mode.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x):
    """Forward pass for training mode.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    id_out = 0 if self.bn is None else self.bn(x)
    return self.act(self.conv1(x) + self.conv2(x) + id_out)


method ultralytics.nn.modules.conv.RepConv.forward_fuse

def forward_fuse(self, x)

Forward pass for deploy mode.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorOutput tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward_fuse(self, x):
    """Forward pass for deploy mode.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.conv(x))


method ultralytics.nn.modules.conv.RepConv.fuse_convs

def fuse_convs(self)

Fuse convolutions for inference by creating a single equivalent convolution.

Source code in ultralytics/nn/modules/conv.pyView on GitHub
def fuse_convs(self):
    """Fuse convolutions for inference by creating a single equivalent convolution."""
    if hasattr(self, "conv"):
        return
    kernel, bias = self.get_equivalent_kernel_bias()
    self.conv = nn.Conv2d(
        in_channels=self.conv1.conv.in_channels,
        out_channels=self.conv1.conv.out_channels,
        kernel_size=self.conv1.conv.kernel_size,
        stride=self.conv1.conv.stride,
        padding=self.conv1.conv.padding,
        dilation=self.conv1.conv.dilation,
        groups=self.conv1.conv.groups,
        bias=True,
    ).requires_grad_(False)
    self.conv.weight.data = kernel
    self.conv.bias.data = bias
    for para in self.parameters():
        para.detach_()
    self.__delattr__("conv1")
    self.__delattr__("conv2")
    if hasattr(self, "nm"):
        self.__delattr__("nm")
    if hasattr(self, "bn"):
        self.__delattr__("bn")
    if hasattr(self, "id_tensor"):
        self.__delattr__("id_tensor")


method ultralytics.nn.modules.conv.RepConv.get_equivalent_kernel_bias

def get_equivalent_kernel_bias(self)

Calculate equivalent kernel and bias by fusing convolutions.

Returns

TypeDescription
torch.TensorEquivalent kernel
torch.TensorEquivalent bias
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def get_equivalent_kernel_bias(self):
    """Calculate equivalent kernel and bias by fusing convolutions.

    Returns:
        (torch.Tensor): Equivalent kernel
        (torch.Tensor): Equivalent bias
    """
    kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
    kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
    kernelid, biasid = self._fuse_bn_tensor(self.bn)
    return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid





class ultralytics.nn.modules.conv.ChannelAttention

ChannelAttention(self, channels: int) -> None

Bases: nn.Module

Channel-attention module for feature recalibration.

Applies attention weights to channels based on global average pooling.

Args

NameTypeDescriptionDefault
channelsintNumber of input channels.required

Attributes

NameTypeDescription
poolnn.AdaptiveAvgPool2dGlobal average pooling.
fcnn.Conv2dFully connected layer implemented as 1x1 convolution.
actnn.SigmoidSigmoid activation for attention weights.
References
https//github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet

Methods

NameDescription
forwardApply channel attention to input tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class ChannelAttention(nn.Module):
    """Channel-attention module for feature recalibration.

    Applies attention weights to channels based on global average pooling.

    Attributes:
        pool (nn.AdaptiveAvgPool2d): Global average pooling.
        fc (nn.Conv2d): Fully connected layer implemented as 1x1 convolution.
        act (nn.Sigmoid): Sigmoid activation for attention weights.

    References:
        https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet
    """

    def __init__(self, channels: int) -> None:
        """Initialize Channel-attention module.

        Args:
            channels (int): Number of input channels.
        """
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()


method ultralytics.nn.modules.conv.ChannelAttention.forward

def forward(self, x: torch.Tensor) -> torch.Tensor

Apply channel attention to input tensor.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorChannel-attended output tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply channel attention to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Channel-attended output tensor.
    """
    return x * self.act(self.fc(self.pool(x)))





class ultralytics.nn.modules.conv.SpatialAttention

SpatialAttention(self, kernel_size = 7)

Bases: nn.Module

Spatial-attention module for feature recalibration.

Applies attention weights to spatial dimensions based on channel statistics.

Args

NameTypeDescriptionDefault
kernel_sizeintSize of the convolutional kernel (3 or 7).7

Attributes

NameTypeDescription
cv1nn.Conv2dConvolution layer for spatial attention.
actnn.SigmoidSigmoid activation for attention weights.

Methods

NameDescription
forwardApply spatial attention to input tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class SpatialAttention(nn.Module):
    """Spatial-attention module for feature recalibration.

    Applies attention weights to spatial dimensions based on channel statistics.

    Attributes:
        cv1 (nn.Conv2d): Convolution layer for spatial attention.
        act (nn.Sigmoid): Sigmoid activation for attention weights.
    """

    def __init__(self, kernel_size=7):
        """Initialize Spatial-attention module.

        Args:
            kernel_size (int): Size of the convolutional kernel (3 or 7).
        """
        super().__init__()
        assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.act = nn.Sigmoid()


method ultralytics.nn.modules.conv.SpatialAttention.forward

def forward(self, x)

Apply spatial attention to input tensor.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorSpatial-attended output tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x):
    """Apply spatial attention to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Spatial-attended output tensor.
    """
    return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))





class ultralytics.nn.modules.conv.CBAM

CBAM(self, c1, kernel_size = 7)

Bases: nn.Module

Convolutional Block Attention Module.

Combines channel and spatial attention mechanisms for comprehensive feature refinement.

Args

NameTypeDescriptionDefault
c1intNumber of input channels.required
kernel_sizeintSize of the convolutional kernel for spatial attention.7

Attributes

NameTypeDescription
channel_attentionChannelAttentionChannel attention module.
spatial_attentionSpatialAttentionSpatial attention module.

Methods

NameDescription
forwardApply channel and spatial attention sequentially to input tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class CBAM(nn.Module):
    """Convolutional Block Attention Module.

    Combines channel and spatial attention mechanisms for comprehensive feature refinement.

    Attributes:
        channel_attention (ChannelAttention): Channel attention module.
        spatial_attention (SpatialAttention): Spatial attention module.
    """

    def __init__(self, c1, kernel_size=7):
        """Initialize CBAM with given parameters.

        Args:
            c1 (int): Number of input channels.
            kernel_size (int): Size of the convolutional kernel for spatial attention.
        """
        super().__init__()
        self.channel_attention = ChannelAttention(c1)
        self.spatial_attention = SpatialAttention(kernel_size)


method ultralytics.nn.modules.conv.CBAM.forward

def forward(self, x)

Apply channel and spatial attention sequentially to input tensor.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required

Returns

TypeDescription
torch.TensorAttended output tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x):
    """Apply channel and spatial attention sequentially to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Attended output tensor.
    """
    return self.spatial_attention(self.channel_attention(x))





class ultralytics.nn.modules.conv.Concat

Concat(self, dimension = 1)

Bases: nn.Module

Concatenate a list of tensors along specified dimension.

Args

NameTypeDescriptionDefault
dimensionintDimension along which to concatenate tensors.1

Attributes

NameTypeDescription
dintDimension along which to concatenate tensors.

Methods

NameDescription
forwardConcatenate input tensors along specified dimension.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class Concat(nn.Module):
    """Concatenate a list of tensors along specified dimension.

    Attributes:
        d (int): Dimension along which to concatenate tensors.
    """

    def __init__(self, dimension=1):
        """Initialize Concat module.

        Args:
            dimension (int): Dimension along which to concatenate tensors.
        """
        super().__init__()
        self.d = dimension


method ultralytics.nn.modules.conv.Concat.forward

def forward(self, x: list[torch.Tensor])

Concatenate input tensors along specified dimension.

Args

NameTypeDescriptionDefault
xlist[torch.Tensor]List of input tensors.required

Returns

TypeDescription
torch.TensorConcatenated tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x: list[torch.Tensor]):
    """Concatenate input tensors along specified dimension.

    Args:
        x (list[torch.Tensor]): List of input tensors.

    Returns:
        (torch.Tensor): Concatenated tensor.
    """
    return torch.cat(x, self.d)





class ultralytics.nn.modules.conv.Index

Index(self, index = 0)

Bases: nn.Module

Returns a particular index of the input.

Args

NameTypeDescriptionDefault
indexintIndex to select from input.0

Attributes

NameTypeDescription
indexintIndex to select from input.

Methods

NameDescription
forwardSelect and return a particular index from input.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
class Index(nn.Module):
    """Returns a particular index of the input.

    Attributes:
        index (int): Index to select from input.
    """

    def __init__(self, index=0):
        """Initialize Index module.

        Args:
            index (int): Index to select from input.
        """
        super().__init__()
        self.index = index


method ultralytics.nn.modules.conv.Index.forward

def forward(self, x: list[torch.Tensor])

Select and return a particular index from input.

Args

NameTypeDescriptionDefault
xlist[torch.Tensor]List of input tensors.required

Returns

TypeDescription
torch.TensorSelected tensor.
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def forward(self, x: list[torch.Tensor]):
    """Select and return a particular index from input.

    Args:
        x (list[torch.Tensor]): List of input tensors.

    Returns:
        (torch.Tensor): Selected tensor.
    """
    return x[self.index]





function ultralytics.nn.modules.conv.autopad

def autopad(k, p = None, d = 1)

Pad to 'same' shape outputs.

Args

NameTypeDescriptionDefault
krequired
pNone
d1
Source code in ultralytics/nn/modules/conv.pyView on GitHub
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p





📅 Created 2 years ago ✏️ Updated 18 days ago
glenn-jocherY-T-Gjk4eBurhan-Q