Skip to content

Reference for ultralytics/nn/modules/conv.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/modules/conv.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.nn.modules.conv.Conv

Conv(c1, c2, k=1, s=1, p=None, g=1, d=1, act=True)

Bases: Module

Standard convolution module with batch normalization and activation.

Attributes:

Name Type Description
conv Conv2d

Convolutional layer.

bn BatchNorm2d

Batch normalization layer.

act Module

Activation function layer.

default_act Module

Default activation function (SiLU).

Parameters:

Name Type Description Default
c1 int

Number of input channels.

required
c2 int

Number of output channels.

required
k int

Kernel size.

1
s int

Stride.

1
p int

Padding.

None
g int

Groups.

1
d int

Dilation.

1
act bool | Module

Activation function.

True
Source code in ultralytics/nn/modules/conv.py
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
    """
    Initialize Conv layer with given parameters.

    Args:
        c1 (int): Number of input channels.
        c2 (int): Number of output channels.
        k (int): Kernel size.
        s (int): Stride.
        p (int, optional): Padding.
        g (int): Groups.
        d (int): Dilation.
        act (bool | nn.Module): Activation function.
    """
    super().__init__()
    self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
    self.bn = nn.BatchNorm2d(c2)
    self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

forward

forward(x)

Apply convolution, batch normalization and activation to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Apply convolution, batch normalization and activation to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.bn(self.conv(x)))

forward_fuse

forward_fuse(x)

Apply convolution and activation without batch normalization.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward_fuse(self, x):
    """
    Apply convolution and activation without batch normalization.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.conv(x))





ultralytics.nn.modules.conv.Conv2

Conv2(c1, c2, k=3, s=1, p=None, g=1, d=1, act=True)

Bases: Conv

Simplified RepConv module with Conv fusing.

Attributes:

Name Type Description
conv Conv2d

Main 3x3 convolutional layer.

cv2 Conv2d

Additional 1x1 convolutional layer.

bn BatchNorm2d

Batch normalization layer.

act Module

Activation function layer.

Parameters:

Name Type Description Default
c1 int

Number of input channels.

required
c2 int

Number of output channels.

required
k int

Kernel size.

3
s int

Stride.

1
p int

Padding.

None
g int

Groups.

1
d int

Dilation.

1
act bool | Module

Activation function.

True
Source code in ultralytics/nn/modules/conv.py
def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
    """
    Initialize Conv2 layer with given parameters.

    Args:
        c1 (int): Number of input channels.
        c2 (int): Number of output channels.
        k (int): Kernel size.
        s (int): Stride.
        p (int, optional): Padding.
        g (int): Groups.
        d (int): Dilation.
        act (bool | nn.Module): Activation function.
    """
    super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
    self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False)  # add 1x1 conv

forward

forward(x)

Apply convolution, batch normalization and activation to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Apply convolution, batch normalization and activation to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.bn(self.conv(x) + self.cv2(x)))

forward_fuse

forward_fuse(x)

Apply fused convolution, batch normalization and activation to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward_fuse(self, x):
    """
    Apply fused convolution, batch normalization and activation to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.bn(self.conv(x)))

fuse_convs

fuse_convs()

Fuse parallel convolutions.

Source code in ultralytics/nn/modules/conv.py
def fuse_convs(self):
    """Fuse parallel convolutions."""
    w = torch.zeros_like(self.conv.weight.data)
    i = [x // 2 for x in w.shape[2:]]
    w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone()
    self.conv.weight.data += w
    self.__delattr__("cv2")
    self.forward = self.forward_fuse





ultralytics.nn.modules.conv.LightConv

LightConv(c1, c2, k=1, act=nn.ReLU())

Bases: Module

Light convolution module with 1x1 and depthwise convolutions.

This implementation is based on the PaddleDetection HGNetV2 backbone.

Attributes:

Name Type Description
conv1 Conv

1x1 convolution layer.

conv2 DWConv

Depthwise convolution layer.

Parameters:

Name Type Description Default
c1 int

Number of input channels.

required
c2 int

Number of output channels.

required
k int

Kernel size for depthwise convolution.

1
act Module

Activation function.

ReLU()
Source code in ultralytics/nn/modules/conv.py
def __init__(self, c1, c2, k=1, act=nn.ReLU()):
    """
    Initialize LightConv layer with given parameters.

    Args:
        c1 (int): Number of input channels.
        c2 (int): Number of output channels.
        k (int): Kernel size for depthwise convolution.
        act (nn.Module): Activation function.
    """
    super().__init__()
    self.conv1 = Conv(c1, c2, 1, act=False)
    self.conv2 = DWConv(c2, c2, k, act=act)

forward

forward(x)

Apply 2 convolutions to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Apply 2 convolutions to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.conv2(self.conv1(x))





ultralytics.nn.modules.conv.DWConv

DWConv(c1, c2, k=1, s=1, d=1, act=True)

Bases: Conv

Depth-wise convolution module.

Parameters:

Name Type Description Default
c1 int

Number of input channels.

required
c2 int

Number of output channels.

required
k int

Kernel size.

1
s int

Stride.

1
d int

Dilation.

1
act bool | Module

Activation function.

True
Source code in ultralytics/nn/modules/conv.py
def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
    """
    Initialize depth-wise convolution with given parameters.

    Args:
        c1 (int): Number of input channels.
        c2 (int): Number of output channels.
        k (int): Kernel size.
        s (int): Stride.
        d (int): Dilation.
        act (bool | nn.Module): Activation function.
    """
    super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)





ultralytics.nn.modules.conv.DWConvTranspose2d

DWConvTranspose2d(c1, c2, k=1, s=1, p1=0, p2=0)

Bases: ConvTranspose2d

Depth-wise transpose convolution module.

Parameters:

Name Type Description Default
c1 int

Number of input channels.

required
c2 int

Number of output channels.

required
k int

Kernel size.

1
s int

Stride.

1
p1 int

Padding.

0
p2 int

Output padding.

0
Source code in ultralytics/nn/modules/conv.py
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0):
    """
    Initialize depth-wise transpose convolution with given parameters.

    Args:
        c1 (int): Number of input channels.
        c2 (int): Number of output channels.
        k (int): Kernel size.
        s (int): Stride.
        p1 (int): Padding.
        p2 (int): Output padding.
    """
    super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))





ultralytics.nn.modules.conv.ConvTranspose

ConvTranspose(c1, c2, k=2, s=2, p=0, bn=True, act=True)

Bases: Module

Convolution transpose module with optional batch normalization and activation.

Attributes:

Name Type Description
conv_transpose ConvTranspose2d

Transposed convolution layer.

bn BatchNorm2d | Identity

Batch normalization layer.

act Module

Activation function layer.

default_act Module

Default activation function (SiLU).

Parameters:

Name Type Description Default
c1 int

Number of input channels.

required
c2 int

Number of output channels.

required
k int

Kernel size.

2
s int

Stride.

2
p int

Padding.

0
bn bool

Use batch normalization.

True
act bool | Module

Activation function.

True
Source code in ultralytics/nn/modules/conv.py
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
    """
    Initialize ConvTranspose layer with given parameters.

    Args:
        c1 (int): Number of input channels.
        c2 (int): Number of output channels.
        k (int): Kernel size.
        s (int): Stride.
        p (int): Padding.
        bn (bool): Use batch normalization.
        act (bool | nn.Module): Activation function.
    """
    super().__init__()
    self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
    self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
    self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

forward

forward(x)

Apply transposed convolution, batch normalization and activation to input.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Apply transposed convolution, batch normalization and activation to input.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.bn(self.conv_transpose(x)))

forward_fuse

forward_fuse(x)

Apply activation and convolution transpose operation to input.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward_fuse(self, x):
    """
    Apply activation and convolution transpose operation to input.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.conv_transpose(x))





ultralytics.nn.modules.conv.Focus

Focus(c1, c2, k=1, s=1, p=None, g=1, act=True)

Bases: Module

Focus module for concentrating feature information.

Slices input tensor into 4 parts and concatenates them in the channel dimension.

Attributes:

Name Type Description
conv Conv

Convolution layer.

Parameters:

Name Type Description Default
c1 int

Number of input channels.

required
c2 int

Number of output channels.

required
k int

Kernel size.

1
s int

Stride.

1
p int

Padding.

None
g int

Groups.

1
act bool | Module

Activation function.

True
Source code in ultralytics/nn/modules/conv.py
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
    """
    Initialize Focus module with given parameters.

    Args:
        c1 (int): Number of input channels.
        c2 (int): Number of output channels.
        k (int): Kernel size.
        s (int): Stride.
        p (int, optional): Padding.
        g (int): Groups.
        act (bool | nn.Module): Activation function.
    """
    super().__init__()
    self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)

forward

forward(x)

Apply Focus operation and convolution to input tensor.

Input shape is (b,c,w,h) and output shape is (b,4c,w/2,h/2).

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Apply Focus operation and convolution to input tensor.

    Input shape is (b,c,w,h) and output shape is (b,4c,w/2,h/2).

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))





ultralytics.nn.modules.conv.GhostConv

GhostConv(c1, c2, k=1, s=1, g=1, act=True)

Bases: Module

Ghost Convolution module.

Generates more features with fewer parameters by using cheap operations.

Attributes:

Name Type Description
cv1 Conv

Primary convolution.

cv2 Conv

Cheap operation convolution.

References

https://github.com/huawei-noah/ghostnet

Parameters:

Name Type Description Default
c1 int

Number of input channels.

required
c2 int

Number of output channels.

required
k int

Kernel size.

1
s int

Stride.

1
g int

Groups.

1
act bool | Module

Activation function.

True
Source code in ultralytics/nn/modules/conv.py
def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
    """
    Initialize Ghost Convolution module with given parameters.

    Args:
        c1 (int): Number of input channels.
        c2 (int): Number of output channels.
        k (int): Kernel size.
        s (int): Stride.
        g (int): Groups.
        act (bool | nn.Module): Activation function.
    """
    super().__init__()
    c_ = c2 // 2  # hidden channels
    self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
    self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)

forward

forward(x)

Apply Ghost Convolution to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor with concatenated features.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Apply Ghost Convolution to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor with concatenated features.
    """
    y = self.cv1(x)
    return torch.cat((y, self.cv2(y)), 1)





ultralytics.nn.modules.conv.RepConv

RepConv(c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False)

Bases: Module

RepConv module with training and deploy modes.

This module is used in RT-DETR and can fuse convolutions during inference for efficiency.

Attributes:

Name Type Description
conv1 Conv

3x3 convolution.

conv2 Conv

1x1 convolution.

bn BatchNorm2d

Batch normalization for identity branch.

act Module

Activation function.

default_act Module

Default activation function (SiLU).

References

https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py

Parameters:

Name Type Description Default
c1 int

Number of input channels.

required
c2 int

Number of output channels.

required
k int

Kernel size.

3
s int

Stride.

1
p int

Padding.

1
g int

Groups.

1
d int

Dilation.

1
act bool | Module

Activation function.

True
bn bool

Use batch normalization for identity branch.

False
deploy bool

Deploy mode for inference.

False
Source code in ultralytics/nn/modules/conv.py
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
    """
    Initialize RepConv module with given parameters.

    Args:
        c1 (int): Number of input channels.
        c2 (int): Number of output channels.
        k (int): Kernel size.
        s (int): Stride.
        p (int): Padding.
        g (int): Groups.
        d (int): Dilation.
        act (bool | nn.Module): Activation function.
        bn (bool): Use batch normalization for identity branch.
        deploy (bool): Deploy mode for inference.
    """
    super().__init__()
    assert k == 3 and p == 1
    self.g = g
    self.c1 = c1
    self.c2 = c2
    self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None
    self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
    self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)

_fuse_bn_tensor

_fuse_bn_tensor(branch)

Fuse batch normalization with convolution weights.

Parameters:

Name Type Description Default
branch Conv | BatchNorm2d | None

Branch to fuse.

required

Returns:

Type Description
tuple

Tuple containing: - Fused kernel (torch.Tensor) - Fused bias (torch.Tensor)

Source code in ultralytics/nn/modules/conv.py
def _fuse_bn_tensor(self, branch):
    """
    Fuse batch normalization with convolution weights.

    Args:
        branch (Conv | nn.BatchNorm2d | None): Branch to fuse.

    Returns:
        (tuple): Tuple containing:
            - Fused kernel (torch.Tensor)
            - Fused bias (torch.Tensor)
    """
    if branch is None:
        return 0, 0
    if isinstance(branch, Conv):
        kernel = branch.conv.weight
        running_mean = branch.bn.running_mean
        running_var = branch.bn.running_var
        gamma = branch.bn.weight
        beta = branch.bn.bias
        eps = branch.bn.eps
    elif isinstance(branch, nn.BatchNorm2d):
        if not hasattr(self, "id_tensor"):
            input_dim = self.c1 // self.g
            kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
            for i in range(self.c1):
                kernel_value[i, i % input_dim, 1, 1] = 1
            self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
        kernel = self.id_tensor
        running_mean = branch.running_mean
        running_var = branch.running_var
        gamma = branch.weight
        beta = branch.bias
        eps = branch.eps
    std = (running_var + eps).sqrt()
    t = (gamma / std).reshape(-1, 1, 1, 1)
    return kernel * t, beta - running_mean * gamma / std

_pad_1x1_to_3x3_tensor staticmethod

_pad_1x1_to_3x3_tensor(kernel1x1)

Pad a 1x1 kernel to 3x3 size.

Parameters:

Name Type Description Default
kernel1x1 Tensor

1x1 convolution kernel.

required

Returns:

Type Description
Tensor

Padded 3x3 kernel.

Source code in ultralytics/nn/modules/conv.py
@staticmethod
def _pad_1x1_to_3x3_tensor(kernel1x1):
    """
    Pad a 1x1 kernel to 3x3 size.

    Args:
        kernel1x1 (torch.Tensor): 1x1 convolution kernel.

    Returns:
        (torch.Tensor): Padded 3x3 kernel.
    """
    if kernel1x1 is None:
        return 0
    else:
        return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])

forward

forward(x)

Forward pass for training mode.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Forward pass for training mode.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    id_out = 0 if self.bn is None else self.bn(x)
    return self.act(self.conv1(x) + self.conv2(x) + id_out)

forward_fuse

forward_fuse(x)

Forward pass for deploy mode.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward_fuse(self, x):
    """
    Forward pass for deploy mode.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Output tensor.
    """
    return self.act(self.conv(x))

fuse_convs

fuse_convs()

Fuse convolutions for inference by creating a single equivalent convolution.

Source code in ultralytics/nn/modules/conv.py
def fuse_convs(self):
    """Fuse convolutions for inference by creating a single equivalent convolution."""
    if hasattr(self, "conv"):
        return
    kernel, bias = self.get_equivalent_kernel_bias()
    self.conv = nn.Conv2d(
        in_channels=self.conv1.conv.in_channels,
        out_channels=self.conv1.conv.out_channels,
        kernel_size=self.conv1.conv.kernel_size,
        stride=self.conv1.conv.stride,
        padding=self.conv1.conv.padding,
        dilation=self.conv1.conv.dilation,
        groups=self.conv1.conv.groups,
        bias=True,
    ).requires_grad_(False)
    self.conv.weight.data = kernel
    self.conv.bias.data = bias
    for para in self.parameters():
        para.detach_()
    self.__delattr__("conv1")
    self.__delattr__("conv2")
    if hasattr(self, "nm"):
        self.__delattr__("nm")
    if hasattr(self, "bn"):
        self.__delattr__("bn")
    if hasattr(self, "id_tensor"):
        self.__delattr__("id_tensor")

get_equivalent_kernel_bias

get_equivalent_kernel_bias()

Calculate equivalent kernel and bias by fusing convolutions.

Returns:

Type Description
tuple

Tuple containing: - Equivalent kernel (torch.Tensor) - Equivalent bias (torch.Tensor)

Source code in ultralytics/nn/modules/conv.py
def get_equivalent_kernel_bias(self):
    """
    Calculate equivalent kernel and bias by fusing convolutions.

    Returns:
        (tuple): Tuple containing:
            - Equivalent kernel (torch.Tensor)
            - Equivalent bias (torch.Tensor)
    """
    kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
    kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
    kernelid, biasid = self._fuse_bn_tensor(self.bn)
    return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid





ultralytics.nn.modules.conv.ChannelAttention

ChannelAttention(channels: int)

Bases: Module

Channel-attention module for feature recalibration.

Applies attention weights to channels based on global average pooling.

Attributes:

Name Type Description
pool AdaptiveAvgPool2d

Global average pooling.

fc Conv2d

Fully connected layer implemented as 1x1 convolution.

act Sigmoid

Sigmoid activation for attention weights.

References

https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet

Parameters:

Name Type Description Default
channels int

Number of input channels.

required
Source code in ultralytics/nn/modules/conv.py
def __init__(self, channels: int) -> None:
    """
    Initialize Channel-attention module.

    Args:
        channels (int): Number of input channels.
    """
    super().__init__()
    self.pool = nn.AdaptiveAvgPool2d(1)
    self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
    self.act = nn.Sigmoid()

forward

forward(x: Tensor) -> torch.Tensor

Apply channel attention to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Channel-attended output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Apply channel attention to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Channel-attended output tensor.
    """
    return x * self.act(self.fc(self.pool(x)))





ultralytics.nn.modules.conv.SpatialAttention

SpatialAttention(kernel_size=7)

Bases: Module

Spatial-attention module for feature recalibration.

Applies attention weights to spatial dimensions based on channel statistics.

Attributes:

Name Type Description
cv1 Conv2d

Convolution layer for spatial attention.

act Sigmoid

Sigmoid activation for attention weights.

Parameters:

Name Type Description Default
kernel_size int

Size of the convolutional kernel (3 or 7).

7
Source code in ultralytics/nn/modules/conv.py
def __init__(self, kernel_size=7):
    """
    Initialize Spatial-attention module.

    Args:
        kernel_size (int): Size of the convolutional kernel (3 or 7).
    """
    super().__init__()
    assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
    padding = 3 if kernel_size == 7 else 1
    self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
    self.act = nn.Sigmoid()

forward

forward(x)

Apply spatial attention to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Spatial-attended output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Apply spatial attention to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Spatial-attended output tensor.
    """
    return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))





ultralytics.nn.modules.conv.CBAM

CBAM(c1, kernel_size=7)

Bases: Module

Convolutional Block Attention Module.

Combines channel and spatial attention mechanisms for comprehensive feature refinement.

Attributes:

Name Type Description
channel_attention ChannelAttention

Channel attention module.

spatial_attention SpatialAttention

Spatial attention module.

Parameters:

Name Type Description Default
c1 int

Number of input channels.

required
kernel_size int

Size of the convolutional kernel for spatial attention.

7
Source code in ultralytics/nn/modules/conv.py
def __init__(self, c1, kernel_size=7):
    """
    Initialize CBAM with given parameters.

    Args:
        c1 (int): Number of input channels.
        kernel_size (int): Size of the convolutional kernel for spatial attention.
    """
    super().__init__()
    self.channel_attention = ChannelAttention(c1)
    self.spatial_attention = SpatialAttention(kernel_size)

forward

forward(x)

Apply channel and spatial attention sequentially to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Attended output tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Apply channel and spatial attention sequentially to input tensor.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        (torch.Tensor): Attended output tensor.
    """
    return self.spatial_attention(self.channel_attention(x))





ultralytics.nn.modules.conv.Concat

Concat(dimension=1)

Bases: Module

Concatenate a list of tensors along specified dimension.

Attributes:

Name Type Description
d int

Dimension along which to concatenate tensors.

Parameters:

Name Type Description Default
dimension int

Dimension along which to concatenate tensors.

1
Source code in ultralytics/nn/modules/conv.py
def __init__(self, dimension=1):
    """
    Initialize Concat module.

    Args:
        dimension (int): Dimension along which to concatenate tensors.
    """
    super().__init__()
    self.d = dimension

forward

forward(x)

Concatenate input tensors along specified dimension.

Parameters:

Name Type Description Default
x List[Tensor]

List of input tensors.

required

Returns:

Type Description
Tensor

Concatenated tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Concatenate input tensors along specified dimension.

    Args:
        x (List[torch.Tensor]): List of input tensors.

    Returns:
        (torch.Tensor): Concatenated tensor.
    """
    return torch.cat(x, self.d)





ultralytics.nn.modules.conv.Index

Index(index=0)

Bases: Module

Returns a particular index of the input.

Attributes:

Name Type Description
index int

Index to select from input.

Parameters:

Name Type Description Default
index int

Index to select from input.

0
Source code in ultralytics/nn/modules/conv.py
def __init__(self, index=0):
    """
    Initialize Index module.

    Args:
        index (int): Index to select from input.
    """
    super().__init__()
    self.index = index

forward

forward(x)

Select and return a particular index from input.

Parameters:

Name Type Description Default
x List[Tensor]

List of input tensors.

required

Returns:

Type Description
Tensor

Selected tensor.

Source code in ultralytics/nn/modules/conv.py
def forward(self, x):
    """
    Select and return a particular index from input.

    Args:
        x (List[torch.Tensor]): List of input tensors.

    Returns:
        (torch.Tensor): Selected tensor.
    """
    return x[self.index]





ultralytics.nn.modules.conv.autopad

autopad(k, p=None, d=1)

Pad to 'same' shape outputs.

Source code in ultralytics/nn/modules/conv.py
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p



📅 Created 1 year ago ✏️ Updated 2 months ago