Skip to content

Reference for ultralytics/nn/tasks.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/tasks.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.nn.tasks.BaseModel

Bases: Module

The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.

forward

forward(x, *args, **kwargs)

Perform forward pass of the model for either training or inference.

If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.

Parameters:

Name Type Description Default
x Tensor | dict

Input tensor for inference, or dict with image tensor and labels for training.

required
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}

Returns:

Type Description
Tensor

Loss if x is a dict (training), or network predictions (inference).

Source code in ultralytics/nn/tasks.py
def forward(self, x, *args, **kwargs):
    """
    Perform forward pass of the model for either training or inference.

    If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.

    Args:
        x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
        *args (Any): Variable length argument list.
        **kwargs (Any): Arbitrary keyword arguments.

    Returns:
        (torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
    """
    if isinstance(x, dict):  # for cases of training and validating while training.
        return self.loss(x, *args, **kwargs)
    return self.predict(x, *args, **kwargs)

fuse

fuse(verbose=True)

Fuse the Conv2d() and BatchNorm2d() layers of the model into a single layer for improved computation efficiency.

Returns:

Type Description
Module

The fused model is returned.

Source code in ultralytics/nn/tasks.py
def fuse(self, verbose=True):
    """
    Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
    efficiency.

    Returns:
        (torch.nn.Module): The fused model is returned.
    """
    if not self.is_fused():
        for m in self.model.modules():
            if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
                if isinstance(m, Conv2):
                    m.fuse_convs()
                m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
                delattr(m, "bn")  # remove batchnorm
                m.forward = m.forward_fuse  # update forward
            if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
                m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
                delattr(m, "bn")  # remove batchnorm
                m.forward = m.forward_fuse  # update forward
            if isinstance(m, RepConv):
                m.fuse_convs()
                m.forward = m.forward_fuse  # update forward
            if isinstance(m, RepVGGDW):
                m.fuse()
                m.forward = m.forward_fuse
        self.info(verbose=verbose)

    return self

info

info(detailed=False, verbose=True, imgsz=640)

Print model information.

Parameters:

Name Type Description Default
detailed bool

If True, prints out detailed information about the model.

False
verbose bool

If True, prints out the model information.

True
imgsz int

The size of the image that the model will be trained on.

640
Source code in ultralytics/nn/tasks.py
def info(self, detailed=False, verbose=True, imgsz=640):
    """
    Print model information.

    Args:
        detailed (bool): If True, prints out detailed information about the model.
        verbose (bool): If True, prints out the model information.
        imgsz (int): The size of the image that the model will be trained on.
    """
    return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)

init_criterion

init_criterion()

Initialize the loss criterion for the BaseModel.

Source code in ultralytics/nn/tasks.py
def init_criterion(self):
    """Initialize the loss criterion for the BaseModel."""
    raise NotImplementedError("compute_loss() needs to be implemented by task heads")

is_fused

is_fused(thresh=10)

Check if the model has less than a certain threshold of BatchNorm layers.

Parameters:

Name Type Description Default
thresh int

The threshold number of BatchNorm layers.

10

Returns:

Type Description
bool

True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.

Source code in ultralytics/nn/tasks.py
def is_fused(self, thresh=10):
    """
    Check if the model has less than a certain threshold of BatchNorm layers.

    Args:
        thresh (int, optional): The threshold number of BatchNorm layers.

    Returns:
        (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
    """
    bn = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()
    return sum(isinstance(v, bn) for v in self.modules()) < thresh  # True if < 'thresh' BatchNorm layers in model

load

load(weights, verbose=True)

Load weights into the model.

Parameters:

Name Type Description Default
weights dict | Module

The pre-trained weights to be loaded.

required
verbose bool

Whether to log the transfer progress.

True
Source code in ultralytics/nn/tasks.py
def load(self, weights, verbose=True):
    """
    Load weights into the model.

    Args:
        weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
        verbose (bool, optional): Whether to log the transfer progress.
    """
    model = weights["model"] if isinstance(weights, dict) else weights  # torchvision models are not dicts
    csd = model.float().state_dict()  # checkpoint state_dict as FP32
    csd = intersect_dicts(csd, self.state_dict())  # intersect
    self.load_state_dict(csd, strict=False)  # load
    if verbose:
        LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")

loss

loss(batch, preds=None)

Compute loss.

Parameters:

Name Type Description Default
batch dict

Batch to compute loss on.

required
preds Tensor | List[Tensor]

Predictions.

None
Source code in ultralytics/nn/tasks.py
def loss(self, batch, preds=None):
    """
    Compute loss.

    Args:
        batch (dict): Batch to compute loss on.
        preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
    """
    if getattr(self, "criterion", None) is None:
        self.criterion = self.init_criterion()

    preds = self.forward(batch["img"]) if preds is None else preds
    return self.criterion(preds, batch)

predict

predict(x, profile=False, visualize=False, augment=False, embed=None)

Perform a forward pass through the network.

Parameters:

Name Type Description Default
x Tensor

The input tensor to the model.

required
profile bool

Print the computation time of each layer if True.

False
visualize bool

Save the feature maps of the model if True.

False
augment bool

Augment image during prediction.

False
embed list

A list of feature vectors/embeddings to return.

None

Returns:

Type Description
Tensor

The last output of the model.

Source code in ultralytics/nn/tasks.py
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
    """
    Perform a forward pass through the network.

    Args:
        x (torch.Tensor): The input tensor to the model.
        profile (bool): Print the computation time of each layer if True.
        visualize (bool): Save the feature maps of the model if True.
        augment (bool): Augment image during prediction.
        embed (list, optional): A list of feature vectors/embeddings to return.

    Returns:
        (torch.Tensor): The last output of the model.
    """
    if augment:
        return self._predict_augment(x)
    return self._predict_once(x, profile, visualize, embed)





ultralytics.nn.tasks.DetectionModel

DetectionModel(cfg='yolo11n.yaml', ch=3, nc=None, verbose=True)

Bases: BaseModel

YOLO detection model.

Parameters:

Name Type Description Default
cfg str | dict

Model configuration file path or dictionary.

'yolo11n.yaml'
ch int

Number of input channels.

3
nc int

Number of classes.

None
verbose bool

Whether to display model information.

True
Source code in ultralytics/nn/tasks.py
def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):  # model, input channels, number of classes
    """
    Initialize the YOLO detection model with the given config and parameters.

    Args:
        cfg (str | dict): Model configuration file path or dictionary.
        ch (int): Number of input channels.
        nc (int, optional): Number of classes.
        verbose (bool): Whether to display model information.
    """
    super().__init__()
    self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dict
    if self.yaml["backbone"][0][2] == "Silence":
        LOGGER.warning(
            "WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of torch.nn.Identity. "
            "Please delete local *.pt file and re-download the latest model checkpoint."
        )
        self.yaml["backbone"][0][2] = "nn.Identity"

    # Define model
    ch = self.yaml["ch"] = self.yaml.get("ch", ch)  # input channels
    if nc and nc != self.yaml["nc"]:
        LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
        self.yaml["nc"] = nc  # override YAML value
    self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelist
    self.names = {i: f"{i}" for i in range(self.yaml["nc"])}  # default names dict
    self.inplace = self.yaml.get("inplace", True)
    self.end2end = getattr(self.model[-1], "end2end", False)

    # Build strides
    m = self.model[-1]  # Detect()
    if isinstance(m, Detect):  # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
        s = 256  # 2x min stride
        m.inplace = self.inplace

        def _forward(x):
            """Perform a forward pass through the model, handling different Detect subclass types accordingly."""
            if self.end2end:
                return self.forward(x)["one2many"]
            return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)

        m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))])  # forward
        self.stride = m.stride
        m.bias_init()  # only run once
    else:
        self.stride = torch.Tensor([32])  # default stride for i.e. RTDETR

    # Init weights, biases
    initialize_weights(self)
    if verbose:
        self.info()
        LOGGER.info("")

init_criterion

init_criterion()

Initialize the loss criterion for the DetectionModel.

Source code in ultralytics/nn/tasks.py
def init_criterion(self):
    """Initialize the loss criterion for the DetectionModel."""
    return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)





ultralytics.nn.tasks.OBBModel

OBBModel(cfg='yolo11n-obb.yaml', ch=3, nc=None, verbose=True)

Bases: DetectionModel

YOLO Oriented Bounding Box (OBB) model.

Parameters:

Name Type Description Default
cfg str | dict

Model configuration file path or dictionary.

'yolo11n-obb.yaml'
ch int

Number of input channels.

3
nc int

Number of classes.

None
verbose bool

Whether to display model information.

True
Source code in ultralytics/nn/tasks.py
def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
    """
    Initialize YOLO OBB model with given config and parameters.

    Args:
        cfg (str | dict): Model configuration file path or dictionary.
        ch (int): Number of input channels.
        nc (int, optional): Number of classes.
        verbose (bool): Whether to display model information.
    """
    super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

init_criterion

init_criterion()

Initialize the loss criterion for the model.

Source code in ultralytics/nn/tasks.py
def init_criterion(self):
    """Initialize the loss criterion for the model."""
    return v8OBBLoss(self)





ultralytics.nn.tasks.SegmentationModel

SegmentationModel(cfg='yolo11n-seg.yaml', ch=3, nc=None, verbose=True)

Bases: DetectionModel

YOLO segmentation model.

Parameters:

Name Type Description Default
cfg str | dict

Model configuration file path or dictionary.

'yolo11n-seg.yaml'
ch int

Number of input channels.

3
nc int

Number of classes.

None
verbose bool

Whether to display model information.

True
Source code in ultralytics/nn/tasks.py
def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
    """
    Initialize YOLOv8 segmentation model with given config and parameters.

    Args:
        cfg (str | dict): Model configuration file path or dictionary.
        ch (int): Number of input channels.
        nc (int, optional): Number of classes.
        verbose (bool): Whether to display model information.
    """
    super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

init_criterion

init_criterion()

Initialize the loss criterion for the SegmentationModel.

Source code in ultralytics/nn/tasks.py
def init_criterion(self):
    """Initialize the loss criterion for the SegmentationModel."""
    return v8SegmentationLoss(self)





ultralytics.nn.tasks.PoseModel

PoseModel(
    cfg="yolo11n-pose.yaml",
    ch=3,
    nc=None,
    data_kpt_shape=(None, None),
    verbose=True,
)

Bases: DetectionModel

YOLO pose model.

Parameters:

Name Type Description Default
cfg str | dict

Model configuration file path or dictionary.

'yolo11n-pose.yaml'
ch int

Number of input channels.

3
nc int

Number of classes.

None
data_kpt_shape tuple

Shape of keypoints data.

(None, None)
verbose bool

Whether to display model information.

True
Source code in ultralytics/nn/tasks.py
def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
    """
    Initialize YOLOv8 Pose model.

    Args:
        cfg (str | dict): Model configuration file path or dictionary.
        ch (int): Number of input channels.
        nc (int, optional): Number of classes.
        data_kpt_shape (tuple): Shape of keypoints data.
        verbose (bool): Whether to display model information.
    """
    if not isinstance(cfg, dict):
        cfg = yaml_model_load(cfg)  # load model YAML
    if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
        LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
        cfg["kpt_shape"] = data_kpt_shape
    super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

init_criterion

init_criterion()

Initialize the loss criterion for the PoseModel.

Source code in ultralytics/nn/tasks.py
def init_criterion(self):
    """Initialize the loss criterion for the PoseModel."""
    return v8PoseLoss(self)





ultralytics.nn.tasks.ClassificationModel

ClassificationModel(cfg='yolo11n-cls.yaml', ch=3, nc=None, verbose=True)

Bases: BaseModel

YOLO classification model.

Parameters:

Name Type Description Default
cfg str | dict

Model configuration file path or dictionary.

'yolo11n-cls.yaml'
ch int

Number of input channels.

3
nc int

Number of classes.

None
verbose bool

Whether to display model information.

True
Source code in ultralytics/nn/tasks.py
def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
    """
    Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.

    Args:
        cfg (str | dict): Model configuration file path or dictionary.
        ch (int): Number of input channels.
        nc (int, optional): Number of classes.
        verbose (bool): Whether to display model information.
    """
    super().__init__()
    self._from_yaml(cfg, ch, nc, verbose)

init_criterion

init_criterion()

Initialize the loss criterion for the ClassificationModel.

Source code in ultralytics/nn/tasks.py
def init_criterion(self):
    """Initialize the loss criterion for the ClassificationModel."""
    return v8ClassificationLoss()

reshape_outputs staticmethod

reshape_outputs(model, nc)

Update a TorchVision classification model to class count 'n' if required.

Parameters:

Name Type Description Default
model Module

Model to update.

required
nc int

New number of classes.

required
Source code in ultralytics/nn/tasks.py
@staticmethod
def reshape_outputs(model, nc):
    """
    Update a TorchVision classification model to class count 'n' if required.

    Args:
        model (torch.nn.Module): Model to update.
        nc (int): New number of classes.
    """
    name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1]  # last module
    if isinstance(m, Classify):  # YOLO Classify() head
        if m.linear.out_features != nc:
            m.linear = torch.nn.Linear(m.linear.in_features, nc)
    elif isinstance(m, torch.nn.Linear):  # ResNet, EfficientNet
        if m.out_features != nc:
            setattr(model, name, torch.nn.Linear(m.in_features, nc))
    elif isinstance(m, torch.nn.Sequential):
        types = [type(x) for x in m]
        if torch.nn.Linear in types:
            i = len(types) - 1 - types[::-1].index(torch.nn.Linear)  # last torch.nn.Linear index
            if m[i].out_features != nc:
                m[i] = torch.nn.Linear(m[i].in_features, nc)
        elif torch.nn.Conv2d in types:
            i = len(types) - 1 - types[::-1].index(torch.nn.Conv2d)  # last torch.nn.Conv2d index
            if m[i].out_channels != nc:
                m[i] = torch.nn.Conv2d(
                    m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None
                )





ultralytics.nn.tasks.RTDETRDetectionModel

RTDETRDetectionModel(cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True)

Bases: DetectionModel

RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.

This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both the training and inference processes. RTDETR is an object detection and tracking model that extends from the DetectionModel base class.

Methods:

Name Description
init_criterion

Initializes the criterion used for loss calculation.

loss

Computes and returns the loss during training.

predict

Performs a forward pass through the network and returns the output.

Parameters:

Name Type Description Default
cfg str | dict

Configuration file name or path.

'rtdetr-l.yaml'
ch int

Number of input channels.

3
nc int

Number of classes.

None
verbose bool

Print additional information during initialization.

True
Source code in ultralytics/nn/tasks.py
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
    """
    Initialize the RTDETRDetectionModel.

    Args:
        cfg (str | dict): Configuration file name or path.
        ch (int): Number of input channels.
        nc (int, optional): Number of classes.
        verbose (bool): Print additional information during initialization.
    """
    super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

init_criterion

init_criterion()

Initialize the loss criterion for the RTDETRDetectionModel.

Source code in ultralytics/nn/tasks.py
def init_criterion(self):
    """Initialize the loss criterion for the RTDETRDetectionModel."""
    from ultralytics.models.utils.loss import RTDETRDetectionLoss

    return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)

loss

loss(batch, preds=None)

Compute the loss for the given batch of data.

Parameters:

Name Type Description Default
batch dict

Dictionary containing image and label data.

required
preds Tensor

Precomputed model predictions.

None

Returns:

Type Description
tuple

A tuple containing the total loss and main three losses in a tensor.

Source code in ultralytics/nn/tasks.py
def loss(self, batch, preds=None):
    """
    Compute the loss for the given batch of data.

    Args:
        batch (dict): Dictionary containing image and label data.
        preds (torch.Tensor, optional): Precomputed model predictions.

    Returns:
        (tuple): A tuple containing the total loss and main three losses in a tensor.
    """
    if not hasattr(self, "criterion"):
        self.criterion = self.init_criterion()

    img = batch["img"]
    # NOTE: preprocess gt_bbox and gt_labels to list.
    bs = len(img)
    batch_idx = batch["batch_idx"]
    gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
    targets = {
        "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
        "bboxes": batch["bboxes"].to(device=img.device),
        "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
        "gt_groups": gt_groups,
    }

    preds = self.predict(img, batch=targets) if preds is None else preds
    dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
    if dn_meta is None:
        dn_bboxes, dn_scores = None, None
    else:
        dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
        dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)

    dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes])  # (7, bs, 300, 4)
    dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])

    loss = self.criterion(
        (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
    )
    # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
    return sum(loss.values()), torch.as_tensor(
        [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
    )

predict

predict(
    x, profile=False, visualize=False, batch=None, augment=False, embed=None
)

Perform a forward pass through the model.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
profile bool

If True, profile the computation time for each layer.

False
visualize bool

If True, save feature maps for visualization.

False
batch dict

Ground truth data for evaluation.

None
augment bool

If True, perform data augmentation during inference.

False
embed list

A list of feature vectors/embeddings to return.

None

Returns:

Type Description
Tensor

Model's output tensor.

Source code in ultralytics/nn/tasks.py
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
    """
    Perform a forward pass through the model.

    Args:
        x (torch.Tensor): The input tensor.
        profile (bool): If True, profile the computation time for each layer.
        visualize (bool): If True, save feature maps for visualization.
        batch (dict, optional): Ground truth data for evaluation.
        augment (bool): If True, perform data augmentation during inference.
        embed (list, optional): A list of feature vectors/embeddings to return.

    Returns:
        (torch.Tensor): Model's output tensor.
    """
    y, dt, embeddings = [], [], []  # outputs
    for m in self.model[:-1]:  # except the head part
        if m.f != -1:  # if not from previous layer
            x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
        if profile:
            self._profile_one_layer(m, x, dt)
        x = m(x)  # run
        y.append(x if m.i in self.save else None)  # save output
        if visualize:
            feature_visualization(x, m.type, m.i, save_dir=visualize)
        if embed and m.i in embed:
            embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
            if m.i == max(embed):
                return torch.unbind(torch.cat(embeddings, 1), dim=0)
    head = self.model[-1]
    x = head([y[j] for j in head.f], batch)  # head inference
    return x





ultralytics.nn.tasks.WorldModel

WorldModel(cfg='yolov8s-world.yaml', ch=3, nc=None, verbose=True)

Bases: DetectionModel

YOLOv8 World Model.

Parameters:

Name Type Description Default
cfg str | dict

Model configuration file path or dictionary.

'yolov8s-world.yaml'
ch int

Number of input channels.

3
nc int

Number of classes.

None
verbose bool

Whether to display model information.

True
Source code in ultralytics/nn/tasks.py
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
    """
    Initialize YOLOv8 world model with given config and parameters.

    Args:
        cfg (str | dict): Model configuration file path or dictionary.
        ch (int): Number of input channels.
        nc (int, optional): Number of classes.
        verbose (bool): Whether to display model information.
    """
    self.txt_feats = torch.randn(1, nc or 80, 512)  # features placeholder
    self.clip_model = None  # CLIP model placeholder
    super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

loss

loss(batch, preds=None)

Compute loss.

Parameters:

Name Type Description Default
batch dict

Batch to compute loss on.

required
preds Tensor | List[Tensor]

Predictions.

None
Source code in ultralytics/nn/tasks.py
def loss(self, batch, preds=None):
    """
    Compute loss.

    Args:
        batch (dict): Batch to compute loss on.
        preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
    """
    if not hasattr(self, "criterion"):
        self.criterion = self.init_criterion()

    if preds is None:
        preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
    return self.criterion(preds, batch)

predict

predict(
    x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None
)

Perform a forward pass through the model.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
profile bool

If True, profile the computation time for each layer.

False
visualize bool

If True, save feature maps for visualization.

False
txt_feats Tensor

The text features, use it if it's given.

None
augment bool

If True, perform data augmentation during inference.

False
embed list

A list of feature vectors/embeddings to return.

None

Returns:

Type Description
Tensor

Model's output tensor.

Source code in ultralytics/nn/tasks.py
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
    """
    Perform a forward pass through the model.

    Args:
        x (torch.Tensor): The input tensor.
        profile (bool): If True, profile the computation time for each layer.
        visualize (bool): If True, save feature maps for visualization.
        txt_feats (torch.Tensor, optional): The text features, use it if it's given.
        augment (bool): If True, perform data augmentation during inference.
        embed (list, optional): A list of feature vectors/embeddings to return.

    Returns:
        (torch.Tensor): Model's output tensor.
    """
    txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
    if len(txt_feats) != len(x) or self.model[-1].export:
        txt_feats = txt_feats.expand(x.shape[0], -1, -1)
    ori_txt_feats = txt_feats.clone()
    y, dt, embeddings = [], [], []  # outputs
    for m in self.model:  # except the head part
        if m.f != -1:  # if not from previous layer
            x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
        if profile:
            self._profile_one_layer(m, x, dt)
        if isinstance(m, C2fAttn):
            x = m(x, txt_feats)
        elif isinstance(m, WorldDetect):
            x = m(x, ori_txt_feats)
        elif isinstance(m, ImagePoolingAttn):
            txt_feats = m(x, txt_feats)
        else:
            x = m(x)  # run

        y.append(x if m.i in self.save else None)  # save output
        if visualize:
            feature_visualization(x, m.type, m.i, save_dir=visualize)
        if embed and m.i in embed:
            embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
            if m.i == max(embed):
                return torch.unbind(torch.cat(embeddings, 1), dim=0)
    return x

set_classes

set_classes(text, batch=80, cache_clip_model=True)

Set classes in advance so that model could do offline-inference without clip model.

Parameters:

Name Type Description Default
text List[str]

List of class names.

required
batch int

Batch size for processing text tokens.

80
cache_clip_model bool

Whether to cache the CLIP model.

True
Source code in ultralytics/nn/tasks.py
def set_classes(self, text, batch=80, cache_clip_model=True):
    """
    Set classes in advance so that model could do offline-inference without clip model.

    Args:
        text (List[str]): List of class names.
        batch (int): Batch size for processing text tokens.
        cache_clip_model (bool): Whether to cache the CLIP model.
    """
    try:
        import clip
    except ImportError:
        check_requirements("git+https://github.com/ultralytics/CLIP.git")
        import clip

    if (
        not getattr(self, "clip_model", None) and cache_clip_model
    ):  # for backwards compatibility of models lacking clip_model attribute
        self.clip_model = clip.load("ViT-B/32")[0]
    model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
    device = next(model.parameters()).device
    text_token = clip.tokenize(text).to(device)
    txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
    txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
    txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
    self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
    self.model[-1].nc = len(text)





ultralytics.nn.tasks.Ensemble

Ensemble()

Bases: ModuleList

Ensemble of models.

Source code in ultralytics/nn/tasks.py
def __init__(self):
    """Initialize an ensemble of models."""
    super().__init__()

forward

forward(x, augment=False, profile=False, visualize=False)

Generate the YOLO network's final layer.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
augment bool

Whether to augment the input.

False
profile bool

Whether to profile the model.

False
visualize bool

Whether to visualize the features.

False

Returns:

Type Description
tuple

Tuple containing the concatenated predictions and None.

Source code in ultralytics/nn/tasks.py
def forward(self, x, augment=False, profile=False, visualize=False):
    """
    Generate the YOLO network's final layer.

    Args:
        x (torch.Tensor): Input tensor.
        augment (bool): Whether to augment the input.
        profile (bool): Whether to profile the model.
        visualize (bool): Whether to visualize the features.

    Returns:
        (tuple): Tuple containing the concatenated predictions and None.
    """
    y = [module(x, augment, profile, visualize)[0] for module in self]
    # y = torch.stack(y).max(0)[0]  # max ensemble
    # y = torch.stack(y).mean(0)  # mean ensemble
    y = torch.cat(y, 2)  # nms ensemble, y shape(B, HW, C)
    return y, None  # inference, train output





ultralytics.nn.tasks.SafeClass

SafeClass(*args, **kwargs)

A placeholder class to replace unknown classes during unpickling.

Source code in ultralytics/nn/tasks.py
def __init__(self, *args, **kwargs):
    """Initialize SafeClass instance, ignoring all arguments."""
    pass

__call__

__call__(*args, **kwargs)

Run SafeClass instance, ignoring all arguments.

Source code in ultralytics/nn/tasks.py
def __call__(self, *args, **kwargs):
    """Run SafeClass instance, ignoring all arguments."""
    pass





ultralytics.nn.tasks.SafeUnpickler

Bases: Unpickler

Custom Unpickler that replaces unknown classes with SafeClass.

find_class

find_class(module, name)

Attempt to find a class, returning SafeClass if not among safe modules.

Parameters:

Name Type Description Default
module str

Module name.

required
name str

Class name.

required

Returns:

Type Description
type

Found class or SafeClass.

Source code in ultralytics/nn/tasks.py
def find_class(self, module, name):
    """
    Attempt to find a class, returning SafeClass if not among safe modules.

    Args:
        module (str): Module name.
        name (str): Class name.

    Returns:
        (type): Found class or SafeClass.
    """
    safe_modules = (
        "torch",
        "collections",
        "collections.abc",
        "builtins",
        "math",
        "numpy",
        # Add other modules considered safe
    )
    if module in safe_modules:
        return super().find_class(module, name)
    else:
        return SafeClass





ultralytics.nn.tasks.temporary_modules

temporary_modules(modules=None, attributes=None)

Context manager for temporarily adding or modifying modules in Python's module cache (sys.modules).

This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've moved a module from one location to another, but you still want to support the old import paths for backwards compatibility.

Parameters:

Name Type Description Default
modules dict

A dictionary mapping old module paths to new module paths.

None
attributes dict

A dictionary mapping old module attributes to new module attributes.

None

Examples:

>>> with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
>>> import old.module  # this will now import new.module
>>> from old.module import attribute  # this will now import new.module.attribute
Note

The changes are only in effect inside the context manager and are undone once the context manager exits. Be aware that directly manipulating sys.modules can lead to unpredictable results, especially in larger applications or libraries. Use this function with caution.

Source code in ultralytics/nn/tasks.py
@contextlib.contextmanager
def temporary_modules(modules=None, attributes=None):
    """
    Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).

    This function can be used to change the module paths during runtime. It's useful when refactoring code,
    where you've moved a module from one location to another, but you still want to support the old import
    paths for backwards compatibility.

    Args:
        modules (dict, optional): A dictionary mapping old module paths to new module paths.
        attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.

    Examples:
        >>> with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
        >>> import old.module  # this will now import new.module
        >>> from old.module import attribute  # this will now import new.module.attribute

    Note:
        The changes are only in effect inside the context manager and are undone once the context manager exits.
        Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
        applications or libraries. Use this function with caution.
    """
    if modules is None:
        modules = {}
    if attributes is None:
        attributes = {}
    import sys
    from importlib import import_module

    try:
        # Set attributes in sys.modules under their old name
        for old, new in attributes.items():
            old_module, old_attr = old.rsplit(".", 1)
            new_module, new_attr = new.rsplit(".", 1)
            setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))

        # Set modules in sys.modules under their old name
        for old, new in modules.items():
            sys.modules[old] = import_module(new)

        yield
    finally:
        # Remove the temporary module paths
        for old in modules:
            if old in sys.modules:
                del sys.modules[old]





ultralytics.nn.tasks.torch_safe_load

torch_safe_load(weight, safe_only=False)

Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the error, logs a warning message, and attempts to install the missing module via the check_requirements() function. After installation, the function again attempts to load the model using torch.load().

Parameters:

Name Type Description Default
weight str

The file path of the PyTorch model.

required
safe_only bool

If True, replace unknown classes with SafeClass during loading.

False

Returns:

Name Type Description
ckpt dict

The loaded model checkpoint.

file str

The loaded filename.

Examples:

>>> from ultralytics.nn.tasks import torch_safe_load
>>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
Source code in ultralytics/nn/tasks.py
def torch_safe_load(weight, safe_only=False):
    """
    Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
    error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
    After installation, the function again attempts to load the model using torch.load().

    Args:
        weight (str): The file path of the PyTorch model.
        safe_only (bool): If True, replace unknown classes with SafeClass during loading.

    Returns:
        ckpt (dict): The loaded model checkpoint.
        file (str): The loaded filename.

    Examples:
        >>> from ultralytics.nn.tasks import torch_safe_load
        >>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
    """
    from ultralytics.utils.downloads import attempt_download_asset

    check_suffix(file=weight, suffix=".pt")
    file = attempt_download_asset(weight)  # search online if missing locally
    try:
        with temporary_modules(
            modules={
                "ultralytics.yolo.utils": "ultralytics.utils",
                "ultralytics.yolo.v8": "ultralytics.models.yolo",
                "ultralytics.yolo.data": "ultralytics.data",
            },
            attributes={
                "ultralytics.nn.modules.block.Silence": "torch.nn.Identity",  # YOLOv9e
                "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel",  # YOLOv10
                "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss",  # YOLOv10
            },
        ):
            if safe_only:
                # Load via custom pickle module
                safe_pickle = types.ModuleType("safe_pickle")
                safe_pickle.Unpickler = SafeUnpickler
                safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
                with open(file, "rb") as f:
                    ckpt = torch.load(f, pickle_module=safe_pickle)
            else:
                ckpt = torch.load(file, map_location="cpu")

    except ModuleNotFoundError as e:  # e.name is missing module name
        if e.name == "models":
            raise TypeError(
                emojis(
                    f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
                    f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
                    f"YOLOv8 at https://github.com/ultralytics/ultralytics."
                    f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
                    f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
                )
            ) from e
        LOGGER.warning(
            f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements."
            f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
            f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
            f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
        )
        check_requirements(e.name)  # install missing module
        ckpt = torch.load(file, map_location="cpu")

    if not isinstance(ckpt, dict):
        # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt")
        LOGGER.warning(
            f"WARNING ⚠️ The file '{weight}' appears to be improperly saved or formatted. "
            f"For optimal results, use model.save('filename.pt') to correctly save YOLO models."
        )
        ckpt = {"model": ckpt.model}

    return ckpt, file





ultralytics.nn.tasks.attempt_load_weights

attempt_load_weights(weights, device=None, inplace=True, fuse=False)

Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.

Parameters:

Name Type Description Default
weights str | List[str]

Model weights path(s).

required
device device

Device to load model to.

None
inplace bool

Whether to do inplace operations.

True
fuse bool

Whether to fuse model.

False

Returns:

Type Description
Module

Loaded model.

Source code in ultralytics/nn/tasks.py
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
    """
    Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.

    Args:
        weights (str | List[str]): Model weights path(s).
        device (torch.device, optional): Device to load model to.
        inplace (bool): Whether to do inplace operations.
        fuse (bool): Whether to fuse model.

    Returns:
        (torch.nn.Module): Loaded model.
    """
    ensemble = Ensemble()
    for w in weights if isinstance(weights, list) else [weights]:
        ckpt, w = torch_safe_load(w)  # load ckpt
        args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None  # combined args
        model = (ckpt.get("ema") or ckpt["model"]).to(device).float()  # FP32 model

        # Model compatibility updates
        model.args = args  # attach args to model
        model.pt_path = w  # attach *.pt file path to model
        model.task = guess_model_task(model)
        if not hasattr(model, "stride"):
            model.stride = torch.tensor([32.0])

        # Append
        ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval())  # model in eval mode

    # Module updates
    for m in ensemble.modules():
        if hasattr(m, "inplace"):
            m.inplace = inplace
        elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
            m.recompute_scale_factor = None  # torch 1.11.0 compatibility

    # Return model
    if len(ensemble) == 1:
        return ensemble[-1]

    # Return ensemble
    LOGGER.info(f"Ensemble created with {weights}\n")
    for k in "names", "nc", "yaml":
        setattr(ensemble, k, getattr(ensemble[0], k))
    ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride
    assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
    return ensemble





ultralytics.nn.tasks.attempt_load_one_weight

attempt_load_one_weight(weight, device=None, inplace=True, fuse=False)

Load a single model weights.

Parameters:

Name Type Description Default
weight str

Model weight path.

required
device device

Device to load model to.

None
inplace bool

Whether to do inplace operations.

True
fuse bool

Whether to fuse model.

False

Returns:

Type Description
tuple

Tuple containing the model and checkpoint.

Source code in ultralytics/nn/tasks.py
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
    """
    Load a single model weights.

    Args:
        weight (str): Model weight path.
        device (torch.device, optional): Device to load model to.
        inplace (bool): Whether to do inplace operations.
        fuse (bool): Whether to fuse model.

    Returns:
        (tuple): Tuple containing the model and checkpoint.
    """
    ckpt, weight = torch_safe_load(weight)  # load ckpt
    args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))}  # combine model and default args, preferring model args
    model = (ckpt.get("ema") or ckpt["model"]).to(device).float()  # FP32 model

    # Model compatibility updates
    model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # attach args to model
    model.pt_path = weight  # attach *.pt file path to model
    model.task = guess_model_task(model)
    if not hasattr(model, "stride"):
        model.stride = torch.tensor([32.0])

    model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()  # model in eval mode

    # Module updates
    for m in model.modules():
        if hasattr(m, "inplace"):
            m.inplace = inplace
        elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
            m.recompute_scale_factor = None  # torch 1.11.0 compatibility

    # Return model and ckpt
    return model, ckpt





ultralytics.nn.tasks.parse_model

parse_model(d, ch, verbose=True)

Parse a YOLO model.yaml dictionary into a PyTorch model.

Parameters:

Name Type Description Default
d dict

Model dictionary.

required
ch int

Input channels.

required
verbose bool

Whether to print model details.

True

Returns:

Type Description
tuple

Tuple containing the PyTorch model and sorted list of output layers.

Source code in ultralytics/nn/tasks.py
def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)
    """
    Parse a YOLO model.yaml dictionary into a PyTorch model.

    Args:
        d (dict): Model dictionary.
        ch (int): Input channels.
        verbose (bool): Whether to print model details.

    Returns:
        (tuple): Tuple containing the PyTorch model and sorted list of output layers.
    """
    import ast

    # Args
    legacy = True  # backward compatibility for v3/v5/v8/v9 models
    max_channels = float("inf")
    nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
    depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
    if scales:
        scale = d.get("scale")
        if not scale:
            scale = tuple(scales.keys())[0]
            LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
        depth, width, max_channels = scales[scale]

    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = torch.nn.SiLU()
        if verbose:
            LOGGER.info(f"{colorstr('activation:')} {act}")  # print

    if verbose:
        LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")
    ch = [ch]
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    base_modules = frozenset(
        {
            Classify,
            Conv,
            ConvTranspose,
            GhostConv,
            Bottleneck,
            GhostBottleneck,
            SPP,
            SPPF,
            C2fPSA,
            C2PSA,
            DWConv,
            Focus,
            BottleneckCSP,
            C1,
            C2,
            C2f,
            C3k2,
            RepNCSPELAN4,
            ELAN1,
            ADown,
            AConv,
            SPPELAN,
            C2fAttn,
            C3,
            C3TR,
            C3Ghost,
            torch.nn.ConvTranspose2d,
            DWConvTranspose2d,
            C3x,
            RepC3,
            PSA,
            SCDown,
            C2fCIB,
            A2C2f,
        }
    )
    repeat_modules = frozenset(  # modules with 'repeat' arguments
        {
            BottleneckCSP,
            C1,
            C2,
            C2f,
            C3k2,
            C2fAttn,
            C3,
            C3TR,
            C3Ghost,
            C3x,
            RepC3,
            C2fPSA,
            C2fCIB,
            C2PSA,
            A2C2f,
        }
    )
    for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # from, number, module, args
        m = (
            getattr(torch.nn, m[3:])
            if "nn." in m
            else getattr(__import__("torchvision").ops, m[16:])
            if "torchvision.ops." in m
            else globals()[m]
        )  # get module
        for j, a in enumerate(args):
            if isinstance(a, str):
                with contextlib.suppress(ValueError):
                    args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain
        if m in base_modules:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            if m is C2fAttn:  # set 1) embed channels and 2) num heads
                args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
                args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])

            args = [c1, c2, *args[1:]]
            if m in repeat_modules:
                args.insert(2, n)  # number of repeats
                n = 1
            if m is C3k2:  # for M/L/X sizes
                legacy = False
                if scale in "mlx":
                    args[3] = True
            if m is A2C2f:
                legacy = False
                if scale in "lx":  # for L/X sizes
                    args.extend((True, 1.2))
        elif m is AIFI:
            args = [ch[f], *args]
        elif m in frozenset({HGStem, HGBlock}):
            c1, cm, c2 = ch[f], args[0], args[1]
            args = [c1, cm, c2, *args[2:]]
            if m is HGBlock:
                args.insert(4, n)  # number of repeats
                n = 1
        elif m is ResNetLayer:
            c2 = args[1] if args[3] else args[1] * 4
        elif m is torch.nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        elif m in frozenset({Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}):
            args.append([ch[x] for x in f])
            if m is Segment:
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
            if m in {Detect, Segment, Pose, OBB}:
                m.legacy = legacy
        elif m is RTDETRDecoder:  # special case, channels arg must be passed in index 1
            args.insert(1, [ch[x] for x in f])
        elif m is CBLinear:
            c2 = args[0]
            c1 = ch[f]
            args = [c1, c2, *args[1:]]
        elif m is CBFuse:
            c2 = ch[f[-1]]
        elif m in frozenset({TorchVision, Index}):
            c2 = args[0]
            c1 = ch[f]
            args = [*args[1:]]
        else:
            c2 = ch[f]

        m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace("__main__.", "")  # module type
        m_.np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type = i, f, t  # attach index, 'from' index, type
        if verbose:
            LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f}  {t:<45}{str(args):<30}")  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return torch.nn.Sequential(*layers), sorted(save)





ultralytics.nn.tasks.yaml_model_load

yaml_model_load(path)

Load a YOLOv8 model from a YAML file.

Parameters:

Name Type Description Default
path str | Path

Path to the YAML file.

required

Returns:

Type Description
dict

Model dictionary.

Source code in ultralytics/nn/tasks.py
def yaml_model_load(path):
    """
    Load a YOLOv8 model from a YAML file.

    Args:
        path (str | Path): Path to the YAML file.

    Returns:
        (dict): Model dictionary.
    """
    path = Path(path)
    if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
        new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
        LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
        path = path.with_name(new_stem + path.suffix)

    unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path))  # i.e. yolov8x.yaml -> yolov8.yaml
    yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
    d = yaml_load(yaml_file)  # model dict
    d["scale"] = guess_model_scale(path)
    d["yaml_file"] = str(path)
    return d





ultralytics.nn.tasks.guess_model_scale

guess_model_scale(model_path)

Extract the size character n, s, m, l, or x of the model's scale from the model path.

Parameters:

Name Type Description Default
model_path str | Path

The path to the YOLO model's YAML file.

required

Returns:

Type Description
str

The size character of the model's scale (n, s, m, l, or x).

Source code in ultralytics/nn/tasks.py
def guess_model_scale(model_path):
    """
    Extract the size character n, s, m, l, or x of the model's scale from the model path.

    Args:
        model_path (str | Path): The path to the YOLO model's YAML file.

    Returns:
        (str): The size character of the model's scale (n, s, m, l, or x).
    """
    try:
        return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1)  # returns n, s, m, l, or x
    except AttributeError:
        return ""





ultralytics.nn.tasks.guess_model_task

guess_model_task(model)

Guess the task of a PyTorch model from its architecture or configuration.

Parameters:

Name Type Description Default
model Module | dict

PyTorch model or model configuration in YAML format.

required

Returns:

Type Description
str

Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').

Source code in ultralytics/nn/tasks.py
def guess_model_task(model):
    """
    Guess the task of a PyTorch model from its architecture or configuration.

    Args:
        model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.

    Returns:
        (str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').
    """

    def cfg2task(cfg):
        """Guess from YAML dictionary."""
        m = cfg["head"][-1][-2].lower()  # output module name
        if m in {"classify", "classifier", "cls", "fc"}:
            return "classify"
        if "detect" in m:
            return "detect"
        if m == "segment":
            return "segment"
        if m == "pose":
            return "pose"
        if m == "obb":
            return "obb"

    # Guess from model cfg
    if isinstance(model, dict):
        with contextlib.suppress(Exception):
            return cfg2task(model)
    # Guess from PyTorch model
    if isinstance(model, torch.nn.Module):  # PyTorch model
        for x in "model.args", "model.model.args", "model.model.model.args":
            with contextlib.suppress(Exception):
                return eval(x)["task"]
        for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
            with contextlib.suppress(Exception):
                return cfg2task(eval(x))
        for m in model.modules():
            if isinstance(m, Segment):
                return "segment"
            elif isinstance(m, Classify):
                return "classify"
            elif isinstance(m, Pose):
                return "pose"
            elif isinstance(m, OBB):
                return "obb"
            elif isinstance(m, (Detect, WorldDetect, v10Detect)):
                return "detect"

    # Guess from model filename
    if isinstance(model, (str, Path)):
        model = Path(model)
        if "-seg" in model.stem or "segment" in model.parts:
            return "segment"
        elif "-cls" in model.stem or "classify" in model.parts:
            return "classify"
        elif "-pose" in model.stem or "pose" in model.parts:
            return "pose"
        elif "-obb" in model.stem or "obb" in model.parts:
            return "obb"
        elif "detect" in model.parts:
            return "detect"

    # Unable to determine task from model
    LOGGER.warning(
        "WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
        "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
    )
    return "detect"  # assume detect



📅 Created 1 year ago ✏️ Updated 6 months ago