Skip to content

Reference for ultralytics/nn/modules/head.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/modules/head.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.nn.modules.head.Detect

Detect(nc=80, ch=())

Bases: Module

YOLO Detect head for detection models.

Source code in ultralytics/nn/modules/head.py
def __init__(self, nc=80, ch=()):
    """Initializes the YOLO detection layer with specified number of classes and channels."""
    super().__init__()
    self.nc = nc  # number of classes
    self.nl = len(ch)  # number of detection layers
    self.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
    self.no = nc + self.reg_max * 4  # number of outputs per anchor
    self.stride = torch.zeros(self.nl)  # strides computed during build
    c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channels
    self.cv2 = nn.ModuleList(
        nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
    )
    self.cv3 = (
        nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
        if self.legacy
        else nn.ModuleList(
            nn.Sequential(
                nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
                nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
                nn.Conv2d(c3, self.nc, 1),
            )
            for x in ch
        )
    )
    self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    if self.end2end:
        self.one2one_cv2 = copy.deepcopy(self.cv2)
        self.one2one_cv3 = copy.deepcopy(self.cv3)

bias_init

bias_init()

Initialize Detect() biases, WARNING: requires stride availability.

Source code in ultralytics/nn/modules/head.py
def bias_init(self):
    """Initialize Detect() biases, WARNING: requires stride availability."""
    m = self  # self.model[-1]  # Detect() module
    # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
    # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequency
    for a, b, s in zip(m.cv2, m.cv3, m.stride):  # from
        a[-1].bias.data[:] = 1.0  # box
        b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)
    if self.end2end:
        for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride):  # from
            a[-1].bias.data[:] = 1.0  # box
            b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)

decode_bboxes

decode_bboxes(bboxes, anchors)

Decode bounding boxes.

Source code in ultralytics/nn/modules/head.py
def decode_bboxes(self, bboxes, anchors):
    """Decode bounding boxes."""
    return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)

forward

forward(x)

Concatenates and returns predicted bounding boxes and class probabilities.

Source code in ultralytics/nn/modules/head.py
def forward(self, x):
    """Concatenates and returns predicted bounding boxes and class probabilities."""
    if self.end2end:
        return self.forward_end2end(x)

    for i in range(self.nl):
        x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
    if self.training:  # Training path
        return x
    y = self._inference(x)
    return y if self.export else (y, x)

forward_end2end

forward_end2end(x)

Performs forward pass of the v10Detect module.

Parameters:

NameTypeDescriptionDefault
xtensor

Input tensor.

required

Returns:

TypeDescription
(dict, tensor)

If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections. If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.

Source code in ultralytics/nn/modules/head.py
def forward_end2end(self, x):
    """
    Performs forward pass of the v10Detect module.

    Args:
        x (tensor): Input tensor.

    Returns:
        (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
                       If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
    """
    x_detach = [xi.detach() for xi in x]
    one2one = [
        torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
    ]
    for i in range(self.nl):
        x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
    if self.training:  # Training path
        return {"one2many": x, "one2one": one2one}

    y = self._inference(one2one)
    y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
    return y if self.export else (y, {"one2many": x, "one2one": one2one})

postprocess staticmethod

postprocess(preds: torch.Tensor, max_det: int, nc: int = 80)

Post-processes YOLO model predictions.

Parameters:

NameTypeDescriptionDefault
predsTensor

Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension format [x, y, w, h, class_probs].

required
max_detint

Maximum detections per image.

required
ncint

Number of classes. Default: 80.

80

Returns:

TypeDescription
Tensor

Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last dimension format [x, y, w, h, max_class_prob, class_index].

Source code in ultralytics/nn/modules/head.py
@staticmethod
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
    """
    Post-processes YOLO model predictions.

    Args:
        preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
            format [x, y, w, h, class_probs].
        max_det (int): Maximum detections per image.
        nc (int, optional): Number of classes. Default: 80.

    Returns:
        (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
            dimension format [x, y, w, h, max_class_prob, class_index].
    """
    batch_size, anchors, _ = preds.shape  # i.e. shape(16,8400,84)
    boxes, scores = preds.split([4, nc], dim=-1)
    index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
    boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
    scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
    scores, index = scores.flatten(1).topk(min(max_det, anchors))
    i = torch.arange(batch_size)[..., None]  # batch indices
    return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)





ultralytics.nn.modules.head.Segment

Segment(nc=80, nm=32, npr=256, ch=())

Bases: Detect

YOLO Segment head for segmentation models.

Source code in ultralytics/nn/modules/head.py
def __init__(self, nc=80, nm=32, npr=256, ch=()):
    """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
    super().__init__(nc, ch)
    self.nm = nm  # number of masks
    self.npr = npr  # number of protos
    self.proto = Proto(ch[0], self.npr, self.nm)  # protos

    c4 = max(ch[0] // 4, self.nm)
    self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)

forward

forward(x)

Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients.

Source code in ultralytics/nn/modules/head.py
def forward(self, x):
    """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
    p = self.proto(x[0])  # mask protos
    bs = p.shape[0]  # batch size

    mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)  # mask coefficients
    x = Detect.forward(self, x)
    if self.training:
        return x, mc, p
    return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))





ultralytics.nn.modules.head.OBB

OBB(nc=80, ne=1, ch=())

Bases: Detect

YOLO OBB detection head for detection with rotation models.

Source code in ultralytics/nn/modules/head.py
def __init__(self, nc=80, ne=1, ch=()):
    """Initialize OBB with number of classes `nc` and layer channels `ch`."""
    super().__init__(nc, ch)
    self.ne = ne  # number of extra parameters

    c4 = max(ch[0] // 4, self.ne)
    self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)

decode_bboxes

decode_bboxes(bboxes, anchors)

Decode rotated bounding boxes.

Source code in ultralytics/nn/modules/head.py
def decode_bboxes(self, bboxes, anchors):
    """Decode rotated bounding boxes."""
    return dist2rbox(bboxes, self.angle, anchors, dim=1)

forward

forward(x)

Concatenates and returns predicted bounding boxes and class probabilities.

Source code in ultralytics/nn/modules/head.py
def forward(self, x):
    """Concatenates and returns predicted bounding boxes and class probabilities."""
    bs = x[0].shape[0]  # batch size
    angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2)  # OBB theta logits
    # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
    angle = (angle.sigmoid() - 0.25) * math.pi  # [-pi/4, 3pi/4]
    # angle = angle.sigmoid() * math.pi / 2  # [0, pi/2]
    if not self.training:
        self.angle = angle
    x = Detect.forward(self, x)
    if self.training:
        return x, angle
    return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))





ultralytics.nn.modules.head.Pose

Pose(nc=80, kpt_shape=(17, 3), ch=())

Bases: Detect

YOLO Pose head for keypoints models.

Source code in ultralytics/nn/modules/head.py
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
    """Initialize YOLO network with default parameters and Convolutional Layers."""
    super().__init__(nc, ch)
    self.kpt_shape = kpt_shape  # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
    self.nk = kpt_shape[0] * kpt_shape[1]  # number of keypoints total

    c4 = max(ch[0] // 4, self.nk)
    self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)

forward

forward(x)

Perform forward pass through YOLO model and return predictions.

Source code in ultralytics/nn/modules/head.py
def forward(self, x):
    """Perform forward pass through YOLO model and return predictions."""
    bs = x[0].shape[0]  # batch size
    kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1)  # (bs, 17*3, h*w)
    x = Detect.forward(self, x)
    if self.training:
        return x, kpt
    pred_kpt = self.kpts_decode(bs, kpt)
    return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))

kpts_decode

kpts_decode(bs, kpts)

Decodes keypoints.

Source code in ultralytics/nn/modules/head.py
def kpts_decode(self, bs, kpts):
    """Decodes keypoints."""
    ndim = self.kpt_shape[1]
    if self.export:  # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
        y = kpts.view(bs, *self.kpt_shape, -1)
        a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
        if ndim == 3:
            a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
        return a.view(bs, self.nk, -1)
    else:
        y = kpts.clone()
        if ndim == 3:
            y[:, 2::3] = y[:, 2::3].sigmoid()  # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
        y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
        y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
        return y





ultralytics.nn.modules.head.Classify

Classify(c1, c2, k=1, s=1, p=None, g=1)

Bases: Module

YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).

Source code in ultralytics/nn/modules/head.py
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
    """Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
    super().__init__()
    c_ = 1280  # efficientnet_b0 size
    self.conv = Conv(c1, c_, k, s, p, g)
    self.pool = nn.AdaptiveAvgPool2d(1)  # to x(b,c_,1,1)
    self.drop = nn.Dropout(p=0.0, inplace=True)
    self.linear = nn.Linear(c_, c2)  # to x(b,c2)

forward

forward(x)

Performs a forward pass of the YOLO model on input image data.

Source code in ultralytics/nn/modules/head.py
def forward(self, x):
    """Performs a forward pass of the YOLO model on input image data."""
    if isinstance(x, list):
        x = torch.cat(x, 1)
    x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
    return x if self.training else x.softmax(1)





ultralytics.nn.modules.head.WorldDetect

WorldDetect(nc=80, embed=512, with_bn=False, ch=())

Bases: Detect

Head for integrating YOLO detection models with semantic understanding from text embeddings.

Source code in ultralytics/nn/modules/head.py
def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
    """Initialize YOLO detection layer with nc classes and layer channels ch."""
    super().__init__(nc, ch)
    c3 = max(ch[0], min(self.nc, 100))
    self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
    self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)

bias_init

bias_init()

Initialize Detect() biases, WARNING: requires stride availability.

Source code in ultralytics/nn/modules/head.py
def bias_init(self):
    """Initialize Detect() biases, WARNING: requires stride availability."""
    m = self  # self.model[-1]  # Detect() module
    # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
    # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequency
    for a, b, s in zip(m.cv2, m.cv3, m.stride):  # from
        a[-1].bias.data[:] = 1.0  # box

forward

forward(x, text)

Concatenates and returns predicted bounding boxes and class probabilities.

Source code in ultralytics/nn/modules/head.py
def forward(self, x, text):
    """Concatenates and returns predicted bounding boxes and class probabilities."""
    for i in range(self.nl):
        x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
    if self.training:
        return x

    # Inference path
    shape = x[0].shape  # BCHW
    x_cat = torch.cat([xi.view(shape[0], self.nc + self.reg_max * 4, -1) for xi in x], 2)
    if self.dynamic or self.shape != shape:
        self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
        self.shape = shape

    if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:  # avoid TF FlexSplitV ops
        box = x_cat[:, : self.reg_max * 4]
        cls = x_cat[:, self.reg_max * 4 :]
    else:
        box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)

    if self.export and self.format in {"tflite", "edgetpu"}:
        # Precompute normalization factor to increase numerical stability
        # See https://github.com/ultralytics/ultralytics/issues/7371
        grid_h = shape[2]
        grid_w = shape[3]
        grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
        norm = self.strides / (self.stride[0] * grid_size)
        dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
    else:
        dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides

    y = torch.cat((dbox, cls.sigmoid()), 1)
    return y if self.export else (y, x)





ultralytics.nn.modules.head.RTDETRDecoder

RTDETRDecoder(
    nc=80,
    ch=(512, 1024, 2048),
    hd=256,
    nq=300,
    ndp=4,
    nh=8,
    ndl=6,
    d_ffn=1024,
    dropout=0.0,
    act=nn.ReLU(),
    eval_idx=-1,
    nd=100,
    label_noise_ratio=0.5,
    box_noise_scale=1.0,
    learnt_init_query=False,
)

Bases: Module

Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.

This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes and class labels for objects in an image. It integrates features from multiple layers and runs through a series of Transformer decoder layers to output the final predictions.

Parameters:

NameTypeDescriptionDefault
ncint

Number of classes. Default is 80.

80
chtuple

Channels in the backbone feature maps. Default is (512, 1024, 2048).

(512, 1024, 2048)
hdint

Dimension of hidden layers. Default is 256.

256
nqint

Number of query points. Default is 300.

300
ndpint

Number of decoder points. Default is 4.

4
nhint

Number of heads in multi-head attention. Default is 8.

8
ndlint

Number of decoder layers. Default is 6.

6
d_ffnint

Dimension of the feed-forward networks. Default is 1024.

1024
dropoutfloat

Dropout rate. Default is 0.

0.0
actModule

Activation function. Default is nn.ReLU.

ReLU()
eval_idxint

Evaluation index. Default is -1.

-1
ndint

Number of denoising. Default is 100.

100
label_noise_ratiofloat

Label noise ratio. Default is 0.5.

0.5
box_noise_scalefloat

Box noise scale. Default is 1.0.

1.0
learnt_init_querybool

Whether to learn initial query embeddings. Default is False.

False
Source code in ultralytics/nn/modules/head.py
def __init__(
    self,
    nc=80,
    ch=(512, 1024, 2048),
    hd=256,  # hidden dim
    nq=300,  # num queries
    ndp=4,  # num decoder points
    nh=8,  # num head
    ndl=6,  # num decoder layers
    d_ffn=1024,  # dim of feedforward
    dropout=0.0,
    act=nn.ReLU(),
    eval_idx=-1,
    # Training args
    nd=100,  # num denoising
    label_noise_ratio=0.5,
    box_noise_scale=1.0,
    learnt_init_query=False,
):
    """
    Initializes the RTDETRDecoder module with the given parameters.

    Args:
        nc (int): Number of classes. Default is 80.
        ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048).
        hd (int): Dimension of hidden layers. Default is 256.
        nq (int): Number of query points. Default is 300.
        ndp (int): Number of decoder points. Default is 4.
        nh (int): Number of heads in multi-head attention. Default is 8.
        ndl (int): Number of decoder layers. Default is 6.
        d_ffn (int): Dimension of the feed-forward networks. Default is 1024.
        dropout (float): Dropout rate. Default is 0.
        act (nn.Module): Activation function. Default is nn.ReLU.
        eval_idx (int): Evaluation index. Default is -1.
        nd (int): Number of denoising. Default is 100.
        label_noise_ratio (float): Label noise ratio. Default is 0.5.
        box_noise_scale (float): Box noise scale. Default is 1.0.
        learnt_init_query (bool): Whether to learn initial query embeddings. Default is False.
    """
    super().__init__()
    self.hidden_dim = hd
    self.nhead = nh
    self.nl = len(ch)  # num level
    self.nc = nc
    self.num_queries = nq
    self.num_decoder_layers = ndl

    # Backbone feature projection
    self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
    # NOTE: simplified version but it's not consistent with .pt weights.
    # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)

    # Transformer module
    decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
    self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)

    # Denoising part
    self.denoising_class_embed = nn.Embedding(nc, hd)
    self.num_denoising = nd
    self.label_noise_ratio = label_noise_ratio
    self.box_noise_scale = box_noise_scale

    # Decoder embedding
    self.learnt_init_query = learnt_init_query
    if learnt_init_query:
        self.tgt_embed = nn.Embedding(nq, hd)
    self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)

    # Encoder head
    self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
    self.enc_score_head = nn.Linear(hd, nc)
    self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)

    # Decoder head
    self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
    self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])

    self._reset_parameters()

forward

forward(x, batch=None)

Runs the forward pass of the module, returning bounding box and classification scores for the input.

Source code in ultralytics/nn/modules/head.py
def forward(self, x, batch=None):
    """Runs the forward pass of the module, returning bounding box and classification scores for the input."""
    from ultralytics.models.utils.ops import get_cdn_group

    # Input projection and embedding
    feats, shapes = self._get_encoder_input(x)

    # Prepare denoising training
    dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
        batch,
        self.nc,
        self.num_queries,
        self.denoising_class_embed.weight,
        self.num_denoising,
        self.label_noise_ratio,
        self.box_noise_scale,
        self.training,
    )

    embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)

    # Decoder
    dec_bboxes, dec_scores = self.decoder(
        embed,
        refer_bbox,
        feats,
        shapes,
        self.dec_bbox_head,
        self.dec_score_head,
        self.query_pos_head,
        attn_mask=attn_mask,
    )
    x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
    if self.training:
        return x
    # (bs, 300, 4+nc)
    y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
    return y if self.export else (y, x)





ultralytics.nn.modules.head.v10Detect

v10Detect(nc=80, ch=())

Bases: Detect

v10 Detection head from https://arxiv.org/pdf/2405.14458.

Parameters:

NameTypeDescriptionDefault
ncint

Number of classes.

80
chtuple

Tuple of channel sizes.

()

Attributes:

NameTypeDescription
max_detint

Maximum number of detections.

Methods:

NameDescription
forward

Performs forward pass of the v10Detect module.

bias_init

Initializes biases of the Detect module.

Source code in ultralytics/nn/modules/head.py
def __init__(self, nc=80, ch=()):
    """Initializes the v10Detect object with the specified number of classes and input channels."""
    super().__init__(nc, ch)
    c3 = max(ch[0], min(self.nc, 100))  # channels
    # Light cls head
    self.cv3 = nn.ModuleList(
        nn.Sequential(
            nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)),
            nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)),
            nn.Conv2d(c3, self.nc, 1),
        )
        for x in ch
    )
    self.one2one_cv3 = copy.deepcopy(self.cv3)



📅 Created 11 months ago ✏️ Updated 1 month ago