Skip to content

Reference for ultralytics/nn/tasks.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/tasks.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.nn.tasks.BaseModel

BaseModel()

Bases: torch.nn.Module

Base class for all YOLO models in the Ultralytics family.

This class provides common functionality for YOLO models including forward pass handling, model fusion, information display, and weight loading capabilities.

Attributes

NameTypeDescription
modeltorch.nn.ModuleThe neural network model.
savelistList of layer indices to save outputs from.
stridetorch.TensorModel stride values.

Methods

NameDescription
_applyApply a function to all tensors in the model that are not parameters or registered buffers.
_predict_augmentPerform augmentations on input image x and return augmented inference.
_predict_oncePerform a forward pass through the network.
_profile_one_layerProfile the computation time and FLOPs of a single layer of the model on a given input.
forwardPerform forward pass of the model for either training or inference.
fuseFuse the Conv2d() and BatchNorm2d() layers of the model into a single layer for improved computation
infoPrint model information.
init_criterionInitialize the loss criterion for the BaseModel.
is_fusedCheck if the model has less than a certain threshold of BatchNorm layers.
loadLoad weights into the model.
lossCompute loss.
predictPerform a forward pass through the network.

Examples

Create a BaseModel instance
>>> model = BaseModel()
>>> model.info()  # Display model information
Source code in ultralytics/nn/tasks.pyView on GitHub
class BaseModel(torch.nn.Module):


method ultralytics.nn.tasks.BaseModel._apply

def _apply(self, fn)

Apply a function to all tensors in the model that are not parameters or registered buffers.

Args

NameTypeDescriptionDefault
fnfunctionThe function to apply to the model.required

Returns

TypeDescription
BaseModelAn updated BaseModel object.
Source code in ultralytics/nn/tasks.pyView on GitHub
def _apply(self, fn):
    """Apply a function to all tensors in the model that are not parameters or registered buffers.

    Args:
        fn (function): The function to apply to the model.

    Returns:
        (BaseModel): An updated BaseModel object.
    """
    self = super()._apply(fn)
    m = self.model[-1]  # Detect()
    if isinstance(
        m, Detect
    ):  # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect, YOLOESegment
        m.stride = fn(m.stride)
        m.anchors = fn(m.anchors)
        m.strides = fn(m.strides)
    return self


method ultralytics.nn.tasks.BaseModel._predict_augment

def _predict_augment(self, x)

Perform augmentations on input image x and return augmented inference.

Args

NameTypeDescriptionDefault
xrequired
Source code in ultralytics/nn/tasks.pyView on GitHub
def _predict_augment(self, x):
    """Perform augmentations on input image x and return augmented inference."""
    LOGGER.warning(
        f"{self.__class__.__name__} does not support 'augment=True' prediction. "
        f"Reverting to single-scale prediction."
    )
    return self._predict_once(x)


method ultralytics.nn.tasks.BaseModel._predict_once

def _predict_once(self, x, profile = False, visualize = False, embed = None)

Perform a forward pass through the network.

Args

NameTypeDescriptionDefault
xtorch.TensorThe input tensor to the model.required
profileboolPrint the computation time of each layer if True.False
visualizeboolSave the feature maps of the model if True.False
embedlist, optionalA list of feature vectors/embeddings to return.None

Returns

TypeDescription
torch.TensorThe last output of the model.
Source code in ultralytics/nn/tasks.pyView on GitHub
def _predict_once(self, x, profile=False, visualize=False, embed=None):
    """Perform a forward pass through the network.

    Args:
        x (torch.Tensor): The input tensor to the model.
        profile (bool): Print the computation time of each layer if True.
        visualize (bool): Save the feature maps of the model if True.
        embed (list, optional): A list of feature vectors/embeddings to return.

    Returns:
        (torch.Tensor): The last output of the model.
    """
    y, dt, embeddings = [], [], []  # outputs
    embed = frozenset(embed) if embed is not None else {-1}
    max_idx = max(embed)
    for m in self.model:
        if m.f != -1:  # if not from previous layer
            x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
        if profile:
            self._profile_one_layer(m, x, dt)
        x = m(x)  # run
        y.append(x if m.i in self.save else None)  # save output
        if visualize:
            feature_visualization(x, m.type, m.i, save_dir=visualize)
        if m.i in embed:
            embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
            if m.i == max_idx:
                return torch.unbind(torch.cat(embeddings, 1), dim=0)
    return x


method ultralytics.nn.tasks.BaseModel._profile_one_layer

def _profile_one_layer(self, m, x, dt)

Profile the computation time and FLOPs of a single layer of the model on a given input.

Args

NameTypeDescriptionDefault
mtorch.nn.ModuleThe layer to be profiled.required
xtorch.TensorThe input data to the layer.required
dtlistA list to store the computation time of the layer.required
Source code in ultralytics/nn/tasks.pyView on GitHub
def _profile_one_layer(self, m, x, dt):
    """Profile the computation time and FLOPs of a single layer of the model on a given input.

    Args:
        m (torch.nn.Module): The layer to be profiled.
        x (torch.Tensor): The input data to the layer.
        dt (list): A list to store the computation time of the layer.
    """
    try:
        import thop
    except ImportError:
        thop = None  # conda support without 'ultralytics-thop' installed

    c = m == self.model[-1] and isinstance(x, list)  # is final layer list, copy input as inplace fix
    flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0  # GFLOPs
    t = time_sync()
    for _ in range(10):
        m(x.copy() if c else x)
    dt.append((time_sync() - t) * 100)
    if m == self.model[0]:
        LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s}  module")
    LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f}  {m.type}")
    if c:
        LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s}  Total")


method ultralytics.nn.tasks.BaseModel.forward

def forward(self, x, *args, **kwargs)

Perform forward pass of the model for either training or inference.

If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.

Args

NameTypeDescriptionDefault
xtorch.Tensor | dictInput tensor for inference, or dict with image tensor and labels for training.required
*argsAnyVariable length argument list.required
**kwargsAnyArbitrary keyword arguments.required

Returns

TypeDescription
torch.TensorLoss if x is a dict (training), or network predictions (inference).
Source code in ultralytics/nn/tasks.pyView on GitHub
def forward(self, x, *args, **kwargs):
    """Perform forward pass of the model for either training or inference.

    If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.

    Args:
        x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
        *args (Any): Variable length argument list.
        **kwargs (Any): Arbitrary keyword arguments.

    Returns:
        (torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
    """
    if isinstance(x, dict):  # for cases of training and validating while training.
        return self.loss(x, *args, **kwargs)
    return self.predict(x, *args, **kwargs)


method ultralytics.nn.tasks.BaseModel.fuse

def fuse(self, verbose = True)

Fuse the Conv2d() and BatchNorm2d() layers of the model into a single layer for improved computation

efficiency.

Args

NameTypeDescriptionDefault
verboseTrue

Returns

TypeDescription
torch.nn.ModuleThe fused model is returned.
Source code in ultralytics/nn/tasks.pyView on GitHub
def fuse(self, verbose=True):
    """Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
    efficiency.

    Returns:
        (torch.nn.Module): The fused model is returned.
    """
    if not self.is_fused():
        for m in self.model.modules():
            if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
                if isinstance(m, Conv2):
                    m.fuse_convs()
                m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
                delattr(m, "bn")  # remove batchnorm
                m.forward = m.forward_fuse  # update forward
            if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
                m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
                delattr(m, "bn")  # remove batchnorm
                m.forward = m.forward_fuse  # update forward
            if isinstance(m, RepConv):
                m.fuse_convs()
                m.forward = m.forward_fuse  # update forward
            if isinstance(m, RepVGGDW):
                m.fuse()
                m.forward = m.forward_fuse
            if isinstance(m, v10Detect):
                m.fuse()  # remove one2many head
        self.info(verbose=verbose)

    return self


method ultralytics.nn.tasks.BaseModel.info

def info(self, detailed = False, verbose = True, imgsz = 640)

Print model information.

Args

NameTypeDescriptionDefault
detailedboolIf True, prints out detailed information about the model.False
verboseboolIf True, prints out the model information.True
imgszintThe size of the image that the model will be trained on.640
Source code in ultralytics/nn/tasks.pyView on GitHub
def info(self, detailed=False, verbose=True, imgsz=640):
    """Print model information.

    Args:
        detailed (bool): If True, prints out detailed information about the model.
        verbose (bool): If True, prints out the model information.
        imgsz (int): The size of the image that the model will be trained on.
    """
    return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)


method ultralytics.nn.tasks.BaseModel.init_criterion

def init_criterion(self)

Initialize the loss criterion for the BaseModel.

Source code in ultralytics/nn/tasks.pyView on GitHub
def init_criterion(self):
    """Initialize the loss criterion for the BaseModel."""
    raise NotImplementedError("compute_loss() needs to be implemented by task heads")


method ultralytics.nn.tasks.BaseModel.is_fused

def is_fused(self, thresh = 10)

Check if the model has less than a certain threshold of BatchNorm layers.

Args

NameTypeDescriptionDefault
threshint, optionalThe threshold number of BatchNorm layers.10

Returns

TypeDescription
boolTrue if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
Source code in ultralytics/nn/tasks.pyView on GitHub
def is_fused(self, thresh=10):
    """Check if the model has less than a certain threshold of BatchNorm layers.

    Args:
        thresh (int, optional): The threshold number of BatchNorm layers.

    Returns:
        (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
    """
    bn = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()
    return sum(isinstance(v, bn) for v in self.modules()) < thresh  # True if < 'thresh' BatchNorm layers in model


method ultralytics.nn.tasks.BaseModel.load

def load(self, weights, verbose = True)

Load weights into the model.

Args

NameTypeDescriptionDefault
weightsdict | torch.nn.ModuleThe pre-trained weights to be loaded.required
verbosebool, optionalWhether to log the transfer progress.True
Source code in ultralytics/nn/tasks.pyView on GitHub
def load(self, weights, verbose=True):
    """Load weights into the model.

    Args:
        weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
        verbose (bool, optional): Whether to log the transfer progress.
    """
    model = weights["model"] if isinstance(weights, dict) else weights  # torchvision models are not dicts
    csd = model.float().state_dict()  # checkpoint state_dict as FP32
    updated_csd = intersect_dicts(csd, self.state_dict())  # intersect
    self.load_state_dict(updated_csd, strict=False)  # load
    len_updated_csd = len(updated_csd)
    first_conv = "model.0.conv.weight"  # hard-coded to yolo models for now
    # mostly used to boost multi-channel training
    state_dict = self.state_dict()
    if first_conv not in updated_csd and first_conv in state_dict:
        c1, c2, h, w = state_dict[first_conv].shape
        cc1, cc2, ch, cw = csd[first_conv].shape
        if ch == h and cw == w:
            c1, c2 = min(c1, cc1), min(c2, cc2)
            state_dict[first_conv][:c1, :c2] = csd[first_conv][:c1, :c2]
            len_updated_csd += 1
    if verbose:
        LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")


method ultralytics.nn.tasks.BaseModel.loss

def loss(self, batch, preds = None)

Compute loss.

Args

NameTypeDescriptionDefault
batchdictBatch to compute loss on.required
predstorch.Tensor | list[torch.Tensor], optionalPredictions.None
Source code in ultralytics/nn/tasks.pyView on GitHub
def loss(self, batch, preds=None):
    """Compute loss.

    Args:
        batch (dict): Batch to compute loss on.
        preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
    """
    if getattr(self, "criterion", None) is None:
        self.criterion = self.init_criterion()

    if preds is None:
        preds = self.forward(batch["img"])
    return self.criterion(preds, batch)


method ultralytics.nn.tasks.BaseModel.predict

def predict(self, x, profile = False, visualize = False, augment = False, embed = None)

Perform a forward pass through the network.

Args

NameTypeDescriptionDefault
xtorch.TensorThe input tensor to the model.required
profileboolPrint the computation time of each layer if True.False
visualizeboolSave the feature maps of the model if True.False
augmentboolAugment image during prediction.False
embedlist, optionalA list of feature vectors/embeddings to return.None

Returns

TypeDescription
torch.TensorThe last output of the model.
Source code in ultralytics/nn/tasks.pyView on GitHub
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
    """Perform a forward pass through the network.

    Args:
        x (torch.Tensor): The input tensor to the model.
        profile (bool): Print the computation time of each layer if True.
        visualize (bool): Save the feature maps of the model if True.
        augment (bool): Augment image during prediction.
        embed (list, optional): A list of feature vectors/embeddings to return.

    Returns:
        (torch.Tensor): The last output of the model.
    """
    if augment:
        return self._predict_augment(x)
    return self._predict_once(x, profile, visualize, embed)





class ultralytics.nn.tasks.DetectionModel

DetectionModel(self, cfg = "yolo11n.yaml", ch = 3, nc = None, verbose = True)

Bases: BaseModel

YOLO detection model.

This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented inference, and loss computation for object detection tasks.

Args

NameTypeDescriptionDefault
cfgstr | dictModel configuration file path or dictionary."yolo11n.yaml"
chintNumber of input channels.3
ncint, optionalNumber of classes.None
verboseboolWhether to display model information.True

Attributes

NameTypeDescription
yamldictModel configuration dictionary.
modeltorch.nn.SequentialThe neural network model.
savelistList of layer indices to save outputs from.
namesdictClass names dictionary.
inplaceboolWhether to use inplace operations.
end2endboolWhether the model uses end-to-end detection.
stridetorch.TensorModel stride values.

Methods

NameDescription
_clip_augmentedClip YOLO augmented inference tails.
_descale_predDe-scale predictions following augmented inference (inverse operation).
_predict_augmentPerform augmentations on input image x and return augmented inference and train outputs.
init_criterionInitialize the loss criterion for the DetectionModel.

Examples

Initialize a detection model
>>> model = DetectionModel("yolo11n.yaml", ch=3, nc=80)
>>> results = model.predict(image_tensor)
Source code in ultralytics/nn/tasks.pyView on GitHub
class DetectionModel(BaseModel):
    """YOLO detection model.

    This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
    inference, and loss computation for object detection tasks.

    Attributes:
        yaml (dict): Model configuration dictionary.
        model (torch.nn.Sequential): The neural network model.
        save (list): List of layer indices to save outputs from.
        names (dict): Class names dictionary.
        inplace (bool): Whether to use inplace operations.
        end2end (bool): Whether the model uses end-to-end detection.
        stride (torch.Tensor): Model stride values.

    Methods:
        __init__: Initialize the YOLO detection model.
        _predict_augment: Perform augmented inference.
        _descale_pred: De-scale predictions following augmented inference.
        _clip_augmented: Clip YOLO augmented inference tails.
        init_criterion: Initialize the loss criterion.

    Examples:
        Initialize a detection model
        >>> model = DetectionModel("yolo11n.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
        """Initialize the YOLO detection model with the given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__()
        self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dict
        if self.yaml["backbone"][0][2] == "Silence":
            LOGGER.warning(
                "YOLOv9 `Silence` module is deprecated in favor of torch.nn.Identity. "
                "Please delete local *.pt file and re-download the latest model checkpoint."
            )
            self.yaml["backbone"][0][2] = "nn.Identity"

        # Define model
        self.yaml["channels"] = ch  # save channels
        if nc and nc != self.yaml["nc"]:
            LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
            self.yaml["nc"] = nc  # override YAML value
        self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelist
        self.names = {i: f"{i}" for i in range(self.yaml["nc"])}  # default names dict
        self.inplace = self.yaml.get("inplace", True)
        self.end2end = getattr(self.model[-1], "end2end", False)

        # Build strides
        m = self.model[-1]  # Detect()
        if isinstance(m, Detect):  # includes all Detect subclasses like Segment, Pose, OBB, YOLOEDetect, YOLOESegment
            s = 256  # 2x min stride
            m.inplace = self.inplace

            def _forward(x):
                """Perform a forward pass through the model, handling different Detect subclass types accordingly."""
                if self.end2end:
                    return self.forward(x)["one2many"]
                return self.forward(x)[0] if isinstance(m, (Segment, YOLOESegment, Pose, OBB)) else self.forward(x)

            self.model.eval()  # Avoid changing batch statistics until training begins
            m.training = True  # Setting it to True to properly return strides
            m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))])  # forward
            self.stride = m.stride
            self.model.train()  # Set model back to training(default) mode
            m.bias_init()  # only run once
        else:
            self.stride = torch.Tensor([32])  # default stride for i.e. RTDETR

        # Init weights, biases
        initialize_weights(self)
        if verbose:
            self.info()
            LOGGER.info("")


method ultralytics.nn.tasks.DetectionModel._clip_augmented

def _clip_augmented(self, y)

Clip YOLO augmented inference tails.

Args

NameTypeDescriptionDefault
ylist[torch.Tensor]List of detection tensors.required

Returns

TypeDescription
list[torch.Tensor]Clipped detection tensors.
Source code in ultralytics/nn/tasks.pyView on GitHub
def _clip_augmented(self, y):
    """Clip YOLO augmented inference tails.

    Args:
        y (list[torch.Tensor]): List of detection tensors.

    Returns:
        (list[torch.Tensor]): Clipped detection tensors.
    """
    nl = self.model[-1].nl  # number of detection layers (P3-P5)
    g = sum(4**x for x in range(nl))  # grid points
    e = 1  # exclude layer count
    i = (y[0].shape[-1] // g) * sum(4**x for x in range(e))  # indices
    y[0] = y[0][..., :-i]  # large
    i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e))  # indices
    y[-1] = y[-1][..., i:]  # small
    return y


method ultralytics.nn.tasks.DetectionModel._descale_pred

def _descale_pred(p, flips, scale, img_size, dim = 1)

De-scale predictions following augmented inference (inverse operation).

Args

NameTypeDescriptionDefault
ptorch.TensorPredictions tensor.required
flipsintFlip type (0=none, 2=ud, 3=lr).required
scalefloatScale factor.required
img_sizetupleOriginal image size (height, width).required
dimintDimension to split at.1

Returns

TypeDescription
torch.TensorDe-scaled predictions.
Source code in ultralytics/nn/tasks.pyView on GitHub
@staticmethod
def _descale_pred(p, flips, scale, img_size, dim=1):
    """De-scale predictions following augmented inference (inverse operation).

    Args:
        p (torch.Tensor): Predictions tensor.
        flips (int): Flip type (0=none, 2=ud, 3=lr).
        scale (float): Scale factor.
        img_size (tuple): Original image size (height, width).
        dim (int): Dimension to split at.

    Returns:
        (torch.Tensor): De-scaled predictions.
    """
    p[:, :4] /= scale  # de-scale
    x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
    if flips == 2:
        y = img_size[0] - y  # de-flip ud
    elif flips == 3:
        x = img_size[1] - x  # de-flip lr
    return torch.cat((x, y, wh, cls), dim)


method ultralytics.nn.tasks.DetectionModel._predict_augment

def _predict_augment(self, x)

Perform augmentations on input image x and return augmented inference and train outputs.

Args

NameTypeDescriptionDefault
xtorch.TensorInput image tensor.required

Returns

TypeDescription
torch.TensorAugmented inference output.
Source code in ultralytics/nn/tasks.pyView on GitHub
def _predict_augment(self, x):
    """Perform augmentations on input image x and return augmented inference and train outputs.

    Args:
        x (torch.Tensor): Input image tensor.

    Returns:
        (torch.Tensor): Augmented inference output.
    """
    if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
        LOGGER.warning("Model does not support 'augment=True', reverting to single-scale prediction.")
        return self._predict_once(x)
    img_size = x.shape[-2:]  # height, width
    s = [1, 0.83, 0.67]  # scales
    f = [None, 3, None]  # flips (2-ud, 3-lr)
    y = []  # outputs
    for si, fi in zip(s, f):
        xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
        yi = super().predict(xi)[0]  # forward
        yi = self._descale_pred(yi, fi, si, img_size)
        y.append(yi)
    y = self._clip_augmented(y)  # clip augmented tails
    return torch.cat(y, -1), None  # augmented inference, train


method ultralytics.nn.tasks.DetectionModel.init_criterion

def init_criterion(self)

Initialize the loss criterion for the DetectionModel.

Source code in ultralytics/nn/tasks.pyView on GitHub
def init_criterion(self):
    """Initialize the loss criterion for the DetectionModel."""
    return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)





class ultralytics.nn.tasks.OBBModel

OBBModel(self, cfg = "yolo11n-obb.yaml", ch = 3, nc = None, verbose = True)

Bases: DetectionModel

YOLO Oriented Bounding Box (OBB) model.

This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss computation for rotated object detection.

Args

NameTypeDescriptionDefault
cfgstr | dictModel configuration file path or dictionary."yolo11n-obb.yaml"
chintNumber of input channels.3
ncint, optionalNumber of classes.None
verboseboolWhether to display model information.True

Methods

NameDescription
init_criterionInitialize the loss criterion for the model.

Examples

Initialize an OBB model
>>> model = OBBModel("yolo11n-obb.yaml", ch=3, nc=80)
>>> results = model.predict(image_tensor)
Source code in ultralytics/nn/tasks.pyView on GitHub
class OBBModel(DetectionModel):
    """YOLO Oriented Bounding Box (OBB) model.

    This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
    computation for rotated object detection.

    Methods:
        __init__: Initialize YOLO OBB model.
        init_criterion: Initialize the loss criterion for OBB detection.

    Examples:
        Initialize an OBB model
        >>> model = OBBModel("yolo11n-obb.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
        """Initialize YOLO OBB model with given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)


method ultralytics.nn.tasks.OBBModel.init_criterion

def init_criterion(self)

Initialize the loss criterion for the model.

Source code in ultralytics/nn/tasks.pyView on GitHub
def init_criterion(self):
    """Initialize the loss criterion for the model."""
    return v8OBBLoss(self)





class ultralytics.nn.tasks.SegmentationModel

SegmentationModel(self, cfg = "yolo11n-seg.yaml", ch = 3, nc = None, verbose = True)

Bases: DetectionModel

YOLO segmentation model.

This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for pixel-level object detection and segmentation.

Args

NameTypeDescriptionDefault
cfgstr | dictModel configuration file path or dictionary."yolo11n-seg.yaml"
chintNumber of input channels.3
ncint, optionalNumber of classes.None
verboseboolWhether to display model information.True

Methods

NameDescription
init_criterionInitialize the loss criterion for the SegmentationModel.

Examples

Initialize a segmentation model
>>> model = SegmentationModel("yolo11n-seg.yaml", ch=3, nc=80)
>>> results = model.predict(image_tensor)
Source code in ultralytics/nn/tasks.pyView on GitHub
class SegmentationModel(DetectionModel):
    """YOLO segmentation model.

    This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
    pixel-level object detection and segmentation.

    Methods:
        __init__: Initialize YOLO segmentation model.
        init_criterion: Initialize the loss criterion for segmentation.

    Examples:
        Initialize a segmentation model
        >>> model = SegmentationModel("yolo11n-seg.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
        """Initialize Ultralytics YOLO segmentation model with given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)


method ultralytics.nn.tasks.SegmentationModel.init_criterion

def init_criterion(self)

Initialize the loss criterion for the SegmentationModel.

Source code in ultralytics/nn/tasks.pyView on GitHub
def init_criterion(self):
    """Initialize the loss criterion for the SegmentationModel."""
    return v8SegmentationLoss(self)





class ultralytics.nn.tasks.PoseModel

PoseModel(self, cfg = "yolo11n-pose.yaml", ch = 3, nc = None, data_kpt_shape = (None, None), verbose = True)

Bases: DetectionModel

YOLO pose model.

This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for keypoint detection and pose estimation.

Args

NameTypeDescriptionDefault
cfgstr | dictModel configuration file path or dictionary."yolo11n-pose.yaml"
chintNumber of input channels.3
ncint, optionalNumber of classes.None
data_kpt_shapetupleShape of keypoints data.(None, None)
verboseboolWhether to display model information.True

Attributes

NameTypeDescription
kpt_shapetupleShape of keypoints data (num_keypoints, num_dimensions).

Methods

NameDescription
init_criterionInitialize the loss criterion for the PoseModel.

Examples

Initialize a pose model
>>> model = PoseModel("yolo11n-pose.yaml", ch=3, nc=1, data_kpt_shape=(17, 3))
>>> results = model.predict(image_tensor)
Source code in ultralytics/nn/tasks.pyView on GitHub
class PoseModel(DetectionModel):
    """YOLO pose model.

    This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
    keypoint detection and pose estimation.

    Attributes:
        kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).

    Methods:
        __init__: Initialize YOLO pose model.
        init_criterion: Initialize the loss criterion for pose estimation.

    Examples:
        Initialize a pose model
        >>> model = PoseModel("yolo11n-pose.yaml", ch=3, nc=1, data_kpt_shape=(17, 3))
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
        """Initialize Ultralytics YOLO Pose model.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            data_kpt_shape (tuple): Shape of keypoints data.
            verbose (bool): Whether to display model information.
        """
        if not isinstance(cfg, dict):
            cfg = yaml_model_load(cfg)  # load model YAML
        if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
            LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
            cfg["kpt_shape"] = data_kpt_shape
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)


method ultralytics.nn.tasks.PoseModel.init_criterion

def init_criterion(self)

Initialize the loss criterion for the PoseModel.

Source code in ultralytics/nn/tasks.pyView on GitHub
def init_criterion(self):
    """Initialize the loss criterion for the PoseModel."""
    return v8PoseLoss(self)





class ultralytics.nn.tasks.ClassificationModel

ClassificationModel(self, cfg = "yolo11n-cls.yaml", ch = 3, nc = None, verbose = True)

Bases: BaseModel

YOLO classification model.

This class implements the YOLO classification architecture for image classification tasks, providing model initialization, configuration, and output reshaping capabilities.

Args

NameTypeDescriptionDefault
cfgstr | dictModel configuration file path or dictionary."yolo11n-cls.yaml"
chintNumber of input channels.3
ncint, optionalNumber of classes.None
verboseboolWhether to display model information.True

Attributes

NameTypeDescription
yamldictModel configuration dictionary.
modeltorch.nn.SequentialThe neural network model.
stridetorch.TensorModel stride values.
namesdictClass names dictionary.

Methods

NameDescription
_from_yamlSet Ultralytics YOLO model configurations and define the model architecture.
init_criterionInitialize the loss criterion for the ClassificationModel.
reshape_outputsUpdate a TorchVision classification model to class count 'n' if required.

Examples

Initialize a classification model
>>> model = ClassificationModel("yolo11n-cls.yaml", ch=3, nc=1000)
>>> results = model.predict(image_tensor)
Source code in ultralytics/nn/tasks.pyView on GitHub
class ClassificationModel(BaseModel):
    """YOLO classification model.

    This class implements the YOLO classification architecture for image classification tasks, providing model
    initialization, configuration, and output reshaping capabilities.

    Attributes:
        yaml (dict): Model configuration dictionary.
        model (torch.nn.Sequential): The neural network model.
        stride (torch.Tensor): Model stride values.
        names (dict): Class names dictionary.

    Methods:
        __init__: Initialize ClassificationModel.
        _from_yaml: Set model configurations and define architecture.
        reshape_outputs: Update model to specified class count.
        init_criterion: Initialize the loss criterion.

    Examples:
        Initialize a classification model
        >>> model = ClassificationModel("yolo11n-cls.yaml", ch=3, nc=1000)
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
        """Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__()
        self._from_yaml(cfg, ch, nc, verbose)


method ultralytics.nn.tasks.ClassificationModel._from_yaml

def _from_yaml(self, cfg, ch, nc, verbose)

Set Ultralytics YOLO model configurations and define the model architecture.

Args

NameTypeDescriptionDefault
cfgstr | dictModel configuration file path or dictionary.required
chintNumber of input channels.required
ncint, optionalNumber of classes.required
verboseboolWhether to display model information.required
Source code in ultralytics/nn/tasks.pyView on GitHub
def _from_yaml(self, cfg, ch, nc, verbose):
    """Set Ultralytics YOLO model configurations and define the model architecture.

    Args:
        cfg (str | dict): Model configuration file path or dictionary.
        ch (int): Number of input channels.
        nc (int, optional): Number of classes.
        verbose (bool): Whether to display model information.
    """
    self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dict

    # Define model
    ch = self.yaml["channels"] = self.yaml.get("channels", ch)  # input channels
    if nc and nc != self.yaml["nc"]:
        LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
        self.yaml["nc"] = nc  # override YAML value
    elif not nc and not self.yaml.get("nc", None):
        raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
    self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelist
    self.stride = torch.Tensor([1])  # no stride constraints
    self.names = {i: f"{i}" for i in range(self.yaml["nc"])}  # default names dict
    self.info()


method ultralytics.nn.tasks.ClassificationModel.init_criterion

def init_criterion(self)

Initialize the loss criterion for the ClassificationModel.

Source code in ultralytics/nn/tasks.pyView on GitHub
def init_criterion(self):
    """Initialize the loss criterion for the ClassificationModel."""
    return v8ClassificationLoss()


method ultralytics.nn.tasks.ClassificationModel.reshape_outputs

def reshape_outputs(model, nc)

Update a TorchVision classification model to class count 'n' if required.

Args

NameTypeDescriptionDefault
modeltorch.nn.ModuleModel to update.required
ncintNew number of classes.required
Source code in ultralytics/nn/tasks.pyView on GitHub
@staticmethod
def reshape_outputs(model, nc):
    """Update a TorchVision classification model to class count 'n' if required.

    Args:
        model (torch.nn.Module): Model to update.
        nc (int): New number of classes.
    """
    name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1]  # last module
    if isinstance(m, Classify):  # YOLO Classify() head
        if m.linear.out_features != nc:
            m.linear = torch.nn.Linear(m.linear.in_features, nc)
    elif isinstance(m, torch.nn.Linear):  # ResNet, EfficientNet
        if m.out_features != nc:
            setattr(model, name, torch.nn.Linear(m.in_features, nc))
    elif isinstance(m, torch.nn.Sequential):
        types = [type(x) for x in m]
        if torch.nn.Linear in types:
            i = len(types) - 1 - types[::-1].index(torch.nn.Linear)  # last torch.nn.Linear index
            if m[i].out_features != nc:
                m[i] = torch.nn.Linear(m[i].in_features, nc)
        elif torch.nn.Conv2d in types:
            i = len(types) - 1 - types[::-1].index(torch.nn.Conv2d)  # last torch.nn.Conv2d index
            if m[i].out_channels != nc:
                m[i] = torch.nn.Conv2d(
                    m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None
                )





class ultralytics.nn.tasks.RTDETRDetectionModel

RTDETRDetectionModel(self, cfg = "rtdetr-l.yaml", ch = 3, nc = None, verbose = True)

Bases: DetectionModel

RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.

This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both the training and inference processes. RTDETR is an object detection and tracking model that extends from the DetectionModel base class.

Args

NameTypeDescriptionDefault
cfgstr | dictConfiguration file name or path."rtdetr-l.yaml"
chintNumber of input channels.3
ncint, optionalNumber of classes.None
verboseboolPrint additional information during initialization.True

Attributes

NameTypeDescription
ncintNumber of classes for detection.
criterionRTDETRDetectionLossLoss function for training.

Methods

NameDescription
_applyApply a function to all tensors in the model that are not parameters or registered buffers.
init_criterionInitialize the loss criterion for the RTDETRDetectionModel.
lossCompute the loss for the given batch of data.
predictPerform a forward pass through the model.

Examples

Initialize an RTDETR model
>>> model = RTDETRDetectionModel("rtdetr-l.yaml", ch=3, nc=80)
>>> results = model.predict(image_tensor)
Source code in ultralytics/nn/tasks.pyView on GitHub
class RTDETRDetectionModel(DetectionModel):
    """RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.

    This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
    the training and inference processes. RTDETR is an object detection and tracking model that extends from the
    DetectionModel base class.

    Attributes:
        nc (int): Number of classes for detection.
        criterion (RTDETRDetectionLoss): Loss function for training.

    Methods:
        __init__: Initialize the RTDETRDetectionModel.
        init_criterion: Initialize the loss criterion.
        loss: Compute loss for training.
        predict: Perform forward pass through the model.

    Examples:
        Initialize an RTDETR model
        >>> model = RTDETRDetectionModel("rtdetr-l.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
        """Initialize the RTDETRDetectionModel.

        Args:
            cfg (str | dict): Configuration file name or path.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Print additional information during initialization.
        """
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)


method ultralytics.nn.tasks.RTDETRDetectionModel._apply

def _apply(self, fn)

Apply a function to all tensors in the model that are not parameters or registered buffers.

Args

NameTypeDescriptionDefault
fnfunctionThe function to apply to the model.required

Returns

TypeDescription
RTDETRDetectionModelAn updated BaseModel object.
Source code in ultralytics/nn/tasks.pyView on GitHub
def _apply(self, fn):
    """Apply a function to all tensors in the model that are not parameters or registered buffers.

    Args:
        fn (function): The function to apply to the model.

    Returns:
        (RTDETRDetectionModel): An updated BaseModel object.
    """
    self = super()._apply(fn)
    m = self.model[-1]
    m.anchors = fn(m.anchors)
    m.valid_mask = fn(m.valid_mask)
    return self


method ultralytics.nn.tasks.RTDETRDetectionModel.init_criterion

def init_criterion(self)

Initialize the loss criterion for the RTDETRDetectionModel.

Source code in ultralytics/nn/tasks.pyView on GitHub
def init_criterion(self):
    """Initialize the loss criterion for the RTDETRDetectionModel."""
    from ultralytics.models.utils.loss import RTDETRDetectionLoss

    return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)


method ultralytics.nn.tasks.RTDETRDetectionModel.loss

def loss(self, batch, preds = None)

Compute the loss for the given batch of data.

Args

NameTypeDescriptionDefault
batchdictDictionary containing image and label data.required
predstorch.Tensor, optionalPrecomputed model predictions.None

Returns

TypeDescription
loss_sum (torch.Tensor)Total loss value.
loss_items (torch.Tensor)Main three losses in a tensor.
Source code in ultralytics/nn/tasks.pyView on GitHub
def loss(self, batch, preds=None):
    """Compute the loss for the given batch of data.

    Args:
        batch (dict): Dictionary containing image and label data.
        preds (torch.Tensor, optional): Precomputed model predictions.

    Returns:
        loss_sum (torch.Tensor): Total loss value.
        loss_items (torch.Tensor): Main three losses in a tensor.
    """
    if not hasattr(self, "criterion"):
        self.criterion = self.init_criterion()

    img = batch["img"]
    # NOTE: preprocess gt_bbox and gt_labels to list.
    bs = img.shape[0]
    batch_idx = batch["batch_idx"]
    gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
    targets = {
        "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
        "bboxes": batch["bboxes"].to(device=img.device),
        "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
        "gt_groups": gt_groups,
    }

    if preds is None:
        preds = self.predict(img, batch=targets)
    dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
    if dn_meta is None:
        dn_bboxes, dn_scores = None, None
    else:
        dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
        dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)

    dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes])  # (7, bs, 300, 4)
    dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])

    loss = self.criterion(
        (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
    )
    # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
    return sum(loss.values()), torch.as_tensor(
        [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
    )


method ultralytics.nn.tasks.RTDETRDetectionModel.predict

def predict(self, x, profile = False, visualize = False, batch = None, augment = False, embed = None)

Perform a forward pass through the model.

Args

NameTypeDescriptionDefault
xtorch.TensorThe input tensor.required
profileboolIf True, profile the computation time for each layer.False
visualizeboolIf True, save feature maps for visualization.False
batchdict, optionalGround truth data for evaluation.None
augmentboolIf True, perform data augmentation during inference.False
embedlist, optionalA list of feature vectors/embeddings to return.None

Returns

TypeDescription
torch.TensorModel's output tensor.
Source code in ultralytics/nn/tasks.pyView on GitHub
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
    """Perform a forward pass through the model.

    Args:
        x (torch.Tensor): The input tensor.
        profile (bool): If True, profile the computation time for each layer.
        visualize (bool): If True, save feature maps for visualization.
        batch (dict, optional): Ground truth data for evaluation.
        augment (bool): If True, perform data augmentation during inference.
        embed (list, optional): A list of feature vectors/embeddings to return.

    Returns:
        (torch.Tensor): Model's output tensor.
    """
    y, dt, embeddings = [], [], []  # outputs
    embed = frozenset(embed) if embed is not None else {-1}
    max_idx = max(embed)
    for m in self.model[:-1]:  # except the head part
        if m.f != -1:  # if not from previous layer
            x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
        if profile:
            self._profile_one_layer(m, x, dt)
        x = m(x)  # run
        y.append(x if m.i in self.save else None)  # save output
        if visualize:
            feature_visualization(x, m.type, m.i, save_dir=visualize)
        if m.i in embed:
            embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
            if m.i == max_idx:
                return torch.unbind(torch.cat(embeddings, 1), dim=0)
    head = self.model[-1]
    x = head([y[j] for j in head.f], batch)  # head inference
    return x





class ultralytics.nn.tasks.WorldModel

WorldModel(self, cfg = "yolov8s-world.yaml", ch = 3, nc = None, verbose = True)

Bases: DetectionModel

YOLOv8 World Model.

This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class specification and CLIP model integration for zero-shot detection capabilities.

Args

NameTypeDescriptionDefault
cfgstr | dictModel configuration file path or dictionary."yolov8s-world.yaml"
chintNumber of input channels.3
ncint, optionalNumber of classes.None
verboseboolWhether to display model information.True

Attributes

NameTypeDescription
txt_featstorch.TensorText feature embeddings for classes.
clip_modeltorch.nn.ModuleCLIP model for text encoding.

Methods

NameDescription
get_text_peSet classes in advance so that model could do offline-inference without clip model.
lossCompute loss.
predictPerform a forward pass through the model.
set_classesSet classes in advance so that model could do offline-inference without clip model.

Examples

Initialize a world model
>>> model = WorldModel("yolov8s-world.yaml", ch=3, nc=80)
>>> model.set_classes(["person", "car", "bicycle"])
>>> results = model.predict(image_tensor)
Source code in ultralytics/nn/tasks.pyView on GitHub
class WorldModel(DetectionModel):
    """YOLOv8 World Model.

    This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
    specification and CLIP model integration for zero-shot detection capabilities.

    Attributes:
        txt_feats (torch.Tensor): Text feature embeddings for classes.
        clip_model (torch.nn.Module): CLIP model for text encoding.

    Methods:
        __init__: Initialize YOLOv8 world model.
        set_classes: Set classes for offline inference.
        get_text_pe: Get text positional embeddings.
        predict: Perform forward pass with text features.
        loss: Compute loss with text features.

    Examples:
        Initialize a world model
        >>> model = WorldModel("yolov8s-world.yaml", ch=3, nc=80)
        >>> model.set_classes(["person", "car", "bicycle"])
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
        """Initialize YOLOv8 world model with given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        self.txt_feats = torch.randn(1, nc or 80, 512)  # features placeholder
        self.clip_model = None  # CLIP model placeholder
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)


method ultralytics.nn.tasks.WorldModel.get_text_pe

def get_text_pe(self, text, batch = 80, cache_clip_model = True)

Set classes in advance so that model could do offline-inference without clip model.

Args

NameTypeDescriptionDefault
textlist[str]List of class names.required
batchintBatch size for processing text tokens.80
cache_clip_modelboolWhether to cache the CLIP model.True

Returns

TypeDescription
torch.TensorText positional embeddings.
Source code in ultralytics/nn/tasks.pyView on GitHub
def get_text_pe(self, text, batch=80, cache_clip_model=True):
    """Set classes in advance so that model could do offline-inference without clip model.

    Args:
        text (list[str]): List of class names.
        batch (int): Batch size for processing text tokens.
        cache_clip_model (bool): Whether to cache the CLIP model.

    Returns:
        (torch.Tensor): Text positional embeddings.
    """
    from ultralytics.nn.text_model import build_text_model

    device = next(self.model.parameters()).device
    if not getattr(self, "clip_model", None) and cache_clip_model:
        # For backwards compatibility of models lacking clip_model attribute
        self.clip_model = build_text_model("clip:ViT-B/32", device=device)
    model = self.clip_model if cache_clip_model else build_text_model("clip:ViT-B/32", device=device)
    text_token = model.tokenize(text)
    txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
    txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
    return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])


method ultralytics.nn.tasks.WorldModel.loss

def loss(self, batch, preds = None)

Compute loss.

Args

NameTypeDescriptionDefault
batchdictBatch to compute loss on.required
predstorch.Tensor | list[torch.Tensor], optionalPredictions.None
Source code in ultralytics/nn/tasks.pyView on GitHub
def loss(self, batch, preds=None):
    """Compute loss.

    Args:
        batch (dict): Batch to compute loss on.
        preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
    """
    if not hasattr(self, "criterion"):
        self.criterion = self.init_criterion()

    if preds is None:
        preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
    return self.criterion(preds, batch)


method ultralytics.nn.tasks.WorldModel.predict

def predict(self, x, profile = False, visualize = False, txt_feats = None, augment = False, embed = None)

Perform a forward pass through the model.

Args

NameTypeDescriptionDefault
xtorch.TensorThe input tensor.required
profileboolIf True, profile the computation time for each layer.False
visualizeboolIf True, save feature maps for visualization.False
txt_featstorch.Tensor, optionalThe text features, use it if it's given.None
augmentboolIf True, perform data augmentation during inference.False
embedlist, optionalA list of feature vectors/embeddings to return.None

Returns

TypeDescription
torch.TensorModel's output tensor.
Source code in ultralytics/nn/tasks.pyView on GitHub
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
    """Perform a forward pass through the model.

    Args:
        x (torch.Tensor): The input tensor.
        profile (bool): If True, profile the computation time for each layer.
        visualize (bool): If True, save feature maps for visualization.
        txt_feats (torch.Tensor, optional): The text features, use it if it's given.
        augment (bool): If True, perform data augmentation during inference.
        embed (list, optional): A list of feature vectors/embeddings to return.

    Returns:
        (torch.Tensor): Model's output tensor.
    """
    txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
    if txt_feats.shape[0] != x.shape[0] or self.model[-1].export:
        txt_feats = txt_feats.expand(x.shape[0], -1, -1)
    ori_txt_feats = txt_feats.clone()
    y, dt, embeddings = [], [], []  # outputs
    embed = frozenset(embed) if embed is not None else {-1}
    max_idx = max(embed)
    for m in self.model:  # except the head part
        if m.f != -1:  # if not from previous layer
            x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
        if profile:
            self._profile_one_layer(m, x, dt)
        if isinstance(m, C2fAttn):
            x = m(x, txt_feats)
        elif isinstance(m, WorldDetect):
            x = m(x, ori_txt_feats)
        elif isinstance(m, ImagePoolingAttn):
            txt_feats = m(x, txt_feats)
        else:
            x = m(x)  # run

        y.append(x if m.i in self.save else None)  # save output
        if visualize:
            feature_visualization(x, m.type, m.i, save_dir=visualize)
        if m.i in embed:
            embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
            if m.i == max_idx:
                return torch.unbind(torch.cat(embeddings, 1), dim=0)
    return x


method ultralytics.nn.tasks.WorldModel.set_classes

def set_classes(self, text, batch = 80, cache_clip_model = True)

Set classes in advance so that model could do offline-inference without clip model.

Args

NameTypeDescriptionDefault
textlist[str]List of class names.required
batchintBatch size for processing text tokens.80
cache_clip_modelboolWhether to cache the CLIP model.True
Source code in ultralytics/nn/tasks.pyView on GitHub
def set_classes(self, text, batch=80, cache_clip_model=True):
    """Set classes in advance so that model could do offline-inference without clip model.

    Args:
        text (list[str]): List of class names.
        batch (int): Batch size for processing text tokens.
        cache_clip_model (bool): Whether to cache the CLIP model.
    """
    self.txt_feats = self.get_text_pe(text, batch=batch, cache_clip_model=cache_clip_model)
    self.model[-1].nc = len(text)





class ultralytics.nn.tasks.YOLOEModel

YOLOEModel(self, cfg = "yoloe-v8s.yaml", ch = 3, nc = None, verbose = True)

Bases: DetectionModel

YOLOE detection model.

This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting both prompt-based and prompt-free inference modes.

Args

NameTypeDescriptionDefault
cfgstr | dictModel configuration file path or dictionary."yoloe-v8s.yaml"
chintNumber of input channels.3
ncint, optionalNumber of classes.None
verboseboolWhether to display model information.True

Attributes

NameTypeDescription
petorch.TensorPrompt embeddings for classes.
clip_modeltorch.nn.ModuleCLIP model for text encoding.

Methods

NameDescription
get_cls_peGet class positional embeddings.
get_text_peSet classes in advance so that model could do offline-inference without clip model.
get_visual_peGet visual embeddings.
get_vocabGet fused vocabulary layer from the model.
lossCompute loss.
predictPerform a forward pass through the model.
set_classesSet classes in advance so that model could do offline-inference without clip model.
set_vocabSet vocabulary for the prompt-free model.

Examples

Initialize a YOLOE model
>>> model = YOLOEModel("yoloe-v8s.yaml", ch=3, nc=80)
>>> results = model.predict(image_tensor, tpe=text_embeddings)
Source code in ultralytics/nn/tasks.pyView on GitHub
class YOLOEModel(DetectionModel):
    """YOLOE detection model.

    This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
    both prompt-based and prompt-free inference modes.

    Attributes:
        pe (torch.Tensor): Prompt embeddings for classes.
        clip_model (torch.nn.Module): CLIP model for text encoding.

    Methods:
        __init__: Initialize YOLOE model.
        get_text_pe: Get text positional embeddings.
        get_visual_pe: Get visual embeddings.
        set_vocab: Set vocabulary for prompt-free model.
        get_vocab: Get fused vocabulary layer.
        set_classes: Set classes for offline inference.
        get_cls_pe: Get class positional embeddings.
        predict: Perform forward pass with prompts.
        loss: Compute loss with prompts.

    Examples:
        Initialize a YOLOE model
        >>> model = YOLOEModel("yoloe-v8s.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor, tpe=text_embeddings)
    """

    def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
        """Initialize YOLOE model with given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)


method ultralytics.nn.tasks.YOLOEModel.get_cls_pe

def get_cls_pe(self, tpe, vpe)

Get class positional embeddings.

Args

NameTypeDescriptionDefault
tpetorch.Tensor, optionalText positional embeddings.required
vpetorch.Tensor, optionalVisual positional embeddings.required

Returns

TypeDescription
torch.TensorClass positional embeddings.
Source code in ultralytics/nn/tasks.pyView on GitHub
def get_cls_pe(self, tpe, vpe):
    """Get class positional embeddings.

    Args:
        tpe (torch.Tensor, optional): Text positional embeddings.
        vpe (torch.Tensor, optional): Visual positional embeddings.

    Returns:
        (torch.Tensor): Class positional embeddings.
    """
    all_pe = []
    if tpe is not None:
        assert tpe.ndim == 3
        all_pe.append(tpe)
    if vpe is not None:
        assert vpe.ndim == 3
        all_pe.append(vpe)
    if not all_pe:
        all_pe.append(getattr(self, "pe", torch.zeros(1, 80, 512)))
    return torch.cat(all_pe, dim=1)


method ultralytics.nn.tasks.YOLOEModel.get_text_pe

def get_text_pe(self, text, batch = 80, cache_clip_model = False, without_reprta = False)

Set classes in advance so that model could do offline-inference without clip model.

Args

NameTypeDescriptionDefault
textlist[str]List of class names.required
batchintBatch size for processing text tokens.80
cache_clip_modelboolWhether to cache the CLIP model.False
without_reprtaboolWhether to return text embeddings cooperated with reprta module.False

Returns

TypeDescription
torch.TensorText positional embeddings.
Source code in ultralytics/nn/tasks.pyView on GitHub
@smart_inference_mode()
def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
    """Set classes in advance so that model could do offline-inference without clip model.

    Args:
        text (list[str]): List of class names.
        batch (int): Batch size for processing text tokens.
        cache_clip_model (bool): Whether to cache the CLIP model.
        without_reprta (bool): Whether to return text embeddings cooperated with reprta module.

    Returns:
        (torch.Tensor): Text positional embeddings.
    """
    from ultralytics.nn.text_model import build_text_model

    device = next(self.model.parameters()).device
    if not getattr(self, "clip_model", None) and cache_clip_model:
        # For backwards compatibility of models lacking clip_model attribute
        self.clip_model = build_text_model("mobileclip:blt", device=device)

    model = self.clip_model if cache_clip_model else build_text_model("mobileclip:blt", device=device)
    text_token = model.tokenize(text)
    txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
    txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
    txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
    if without_reprta:
        return txt_feats

    head = self.model[-1]
    assert isinstance(head, YOLOEDetect)
    return head.get_tpe(txt_feats)  # run auxiliary text head


method ultralytics.nn.tasks.YOLOEModel.get_visual_pe

def get_visual_pe(self, img, visual)

Get visual embeddings.

Args

NameTypeDescriptionDefault
imgtorch.TensorInput image tensor.required
visualtorch.TensorVisual features.required

Returns

TypeDescription
torch.TensorVisual positional embeddings.
Source code in ultralytics/nn/tasks.pyView on GitHub
@smart_inference_mode()
def get_visual_pe(self, img, visual):
    """Get visual embeddings.

    Args:
        img (torch.Tensor): Input image tensor.
        visual (torch.Tensor): Visual features.

    Returns:
        (torch.Tensor): Visual positional embeddings.
    """
    return self(img, vpe=visual, return_vpe=True)


method ultralytics.nn.tasks.YOLOEModel.get_vocab

def get_vocab(self, names)

Get fused vocabulary layer from the model.

Args

NameTypeDescriptionDefault
nameslistList of class names.required

Returns

TypeDescription
nn.ModuleListList of vocabulary modules.
Source code in ultralytics/nn/tasks.pyView on GitHub
def get_vocab(self, names):
    """Get fused vocabulary layer from the model.

    Args:
        names (list): List of class names.

    Returns:
        (nn.ModuleList): List of vocabulary modules.
    """
    assert not self.training
    head = self.model[-1]
    assert isinstance(head, YOLOEDetect)
    assert not head.is_fused

    tpe = self.get_text_pe(names)
    self.set_classes(names, tpe)
    device = next(self.model.parameters()).device
    head.fuse(self.pe.to(device))  # fuse prompt embeddings to classify head

    vocab = nn.ModuleList()
    for cls_head in head.cv3:
        assert isinstance(cls_head, nn.Sequential)
        vocab.append(cls_head[-1])
    return vocab


method ultralytics.nn.tasks.YOLOEModel.loss

def loss(self, batch, preds = None)

Compute loss.

Args

NameTypeDescriptionDefault
batchdictBatch to compute loss on.required
predstorch.Tensor | list[torch.Tensor], optionalPredictions.None
Source code in ultralytics/nn/tasks.pyView on GitHub
def loss(self, batch, preds=None):
    """Compute loss.

    Args:
        batch (dict): Batch to compute loss on.
        preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
    """
    if not hasattr(self, "criterion"):
        from ultralytics.utils.loss import TVPDetectLoss

        visual_prompt = batch.get("visuals", None) is not None  # TODO
        self.criterion = TVPDetectLoss(self) if visual_prompt else self.init_criterion()

    if preds is None:
        preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
    return self.criterion(preds, batch)


method ultralytics.nn.tasks.YOLOEModel.predict

def predict(
    self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
)

Perform a forward pass through the model.

Args

NameTypeDescriptionDefault
xtorch.TensorThe input tensor.required
profileboolIf True, profile the computation time for each layer.False
visualizeboolIf True, save feature maps for visualization.False
tpetorch.Tensor, optionalText positional embeddings.None
augmentboolIf True, perform data augmentation during inference.False
embedlist, optionalA list of feature vectors/embeddings to return.None
vpetorch.Tensor, optionalVisual positional embeddings.None
return_vpeboolIf True, return visual positional embeddings.False

Returns

TypeDescription
torch.TensorModel's output tensor.
Source code in ultralytics/nn/tasks.pyView on GitHub
def predict(
    self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
):
    """Perform a forward pass through the model.

    Args:
        x (torch.Tensor): The input tensor.
        profile (bool): If True, profile the computation time for each layer.
        visualize (bool): If True, save feature maps for visualization.
        tpe (torch.Tensor, optional): Text positional embeddings.
        augment (bool): If True, perform data augmentation during inference.
        embed (list, optional): A list of feature vectors/embeddings to return.
        vpe (torch.Tensor, optional): Visual positional embeddings.
        return_vpe (bool): If True, return visual positional embeddings.

    Returns:
        (torch.Tensor): Model's output tensor.
    """
    y, dt, embeddings = [], [], []  # outputs
    b = x.shape[0]
    embed = frozenset(embed) if embed is not None else {-1}
    max_idx = max(embed)
    for m in self.model:  # except the head part
        if m.f != -1:  # if not from previous layer
            x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
        if profile:
            self._profile_one_layer(m, x, dt)
        if isinstance(m, YOLOEDetect):
            vpe = m.get_vpe(x, vpe) if vpe is not None else None
            if return_vpe:
                assert vpe is not None
                assert not self.training
                return vpe
            cls_pe = self.get_cls_pe(m.get_tpe(tpe), vpe).to(device=x[0].device, dtype=x[0].dtype)
            if cls_pe.shape[0] != b or m.export:
                cls_pe = cls_pe.expand(b, -1, -1)
            x = m(x, cls_pe)
        else:
            x = m(x)  # run

        y.append(x if m.i in self.save else None)  # save output
        if visualize:
            feature_visualization(x, m.type, m.i, save_dir=visualize)
        if m.i in embed:
            embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
            if m.i == max_idx:
                return torch.unbind(torch.cat(embeddings, 1), dim=0)
    return x


method ultralytics.nn.tasks.YOLOEModel.set_classes

def set_classes(self, names, embeddings)

Set classes in advance so that model could do offline-inference without clip model.

Args

NameTypeDescriptionDefault
nameslist[str]List of class names.required
embeddingstorch.TensorEmbeddings tensor.required
Source code in ultralytics/nn/tasks.pyView on GitHub
def set_classes(self, names, embeddings):
    """Set classes in advance so that model could do offline-inference without clip model.

    Args:
        names (list[str]): List of class names.
        embeddings (torch.Tensor): Embeddings tensor.
    """
    assert not hasattr(self.model[-1], "lrpc"), (
        "Prompt-free model does not support setting classes. Please try with Text/Visual prompt models."
    )
    assert embeddings.ndim == 3
    self.pe = embeddings
    self.model[-1].nc = len(names)
    self.names = check_class_names(names)


method ultralytics.nn.tasks.YOLOEModel.set_vocab

def set_vocab(self, vocab, names)

Set vocabulary for the prompt-free model.

Args

NameTypeDescriptionDefault
vocabnn.ModuleListList of vocabulary items.required
nameslist[str]List of class names.required
Source code in ultralytics/nn/tasks.pyView on GitHub
def set_vocab(self, vocab, names):
    """Set vocabulary for the prompt-free model.

    Args:
        vocab (nn.ModuleList): List of vocabulary items.
        names (list[str]): List of class names.
    """
    assert not self.training
    head = self.model[-1]
    assert isinstance(head, YOLOEDetect)

    # Cache anchors for head
    device = next(self.parameters()).device
    self(torch.empty(1, 3, self.args["imgsz"], self.args["imgsz"]).to(device))  # warmup

    # re-parameterization for prompt-free model
    self.model[-1].lrpc = nn.ModuleList(
        LRPCHead(cls, pf[-1], loc[-1], enabled=i != 2)
        for i, (cls, pf, loc) in enumerate(zip(vocab, head.cv3, head.cv2))
    )
    for loc_head, cls_head in zip(head.cv2, head.cv3):
        assert isinstance(loc_head, nn.Sequential)
        assert isinstance(cls_head, nn.Sequential)
        del loc_head[-1]
        del cls_head[-1]
    self.model[-1].nc = len(names)
    self.names = check_class_names(names)





class ultralytics.nn.tasks.YOLOESegModel

YOLOESegModel(self, cfg = "yoloe-v8s-seg.yaml", ch = 3, nc = None, verbose = True)

Bases: YOLOEModel, SegmentationModel

YOLOE segmentation model.

This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing specialized loss computation for pixel-level object detection and segmentation.

Args

NameTypeDescriptionDefault
cfgstr | dictModel configuration file path or dictionary."yoloe-v8s-seg.yaml"
chintNumber of input channels.3
ncint, optionalNumber of classes.None
verboseboolWhether to display model information.True

Methods

NameDescription
lossCompute loss.

Examples

Initialize a YOLOE segmentation model
>>> model = YOLOESegModel("yoloe-v8s-seg.yaml", ch=3, nc=80)
>>> results = model.predict(image_tensor, tpe=text_embeddings)
Source code in ultralytics/nn/tasks.pyView on GitHub
class YOLOESegModel(YOLOEModel, SegmentationModel):
    """YOLOE segmentation model.

    This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
    specialized loss computation for pixel-level object detection and segmentation.

    Methods:
        __init__: Initialize YOLOE segmentation model.
        loss: Compute loss with prompts for segmentation.

    Examples:
        Initialize a YOLOE segmentation model
        >>> model = YOLOESegModel("yoloe-v8s-seg.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor, tpe=text_embeddings)
    """

    def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
        """Initialize YOLOE segmentation model with given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)


method ultralytics.nn.tasks.YOLOESegModel.loss

def loss(self, batch, preds = None)

Compute loss.

Args

NameTypeDescriptionDefault
batchdictBatch to compute loss on.required
predstorch.Tensor | list[torch.Tensor], optionalPredictions.None
Source code in ultralytics/nn/tasks.pyView on GitHub
def loss(self, batch, preds=None):
    """Compute loss.

    Args:
        batch (dict): Batch to compute loss on.
        preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
    """
    if not hasattr(self, "criterion"):
        from ultralytics.utils.loss import TVPSegmentLoss

        visual_prompt = batch.get("visuals", None) is not None  # TODO
        self.criterion = TVPSegmentLoss(self) if visual_prompt else self.init_criterion()

    if preds is None:
        preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
    return self.criterion(preds, batch)





class ultralytics.nn.tasks.Ensemble

Ensemble(self)

Bases: torch.nn.ModuleList

Ensemble of models.

This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging or other ensemble techniques.

Methods

NameDescription
forwardGenerate the YOLO network's final layer.

Examples

Create an ensemble of models
>>> ensemble = Ensemble()
>>> ensemble.append(model1)
>>> ensemble.append(model2)
>>> results = ensemble(image_tensor)
Source code in ultralytics/nn/tasks.pyView on GitHub
class Ensemble(torch.nn.ModuleList):
    """Ensemble of models.

    This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
    or other ensemble techniques.

    Methods:
        __init__: Initialize an ensemble of models.
        forward: Generate predictions from all models in the ensemble.

    Examples:
        Create an ensemble of models
        >>> ensemble = Ensemble()
        >>> ensemble.append(model1)
        >>> ensemble.append(model2)
        >>> results = ensemble(image_tensor)
    """

    def __init__(self):
        """Initialize an ensemble of models."""
        super().__init__()


method ultralytics.nn.tasks.Ensemble.forward

def forward(self, x, augment = False, profile = False, visualize = False)

Generate the YOLO network's final layer.

Args

NameTypeDescriptionDefault
xtorch.TensorInput tensor.required
augmentboolWhether to augment the input.False
profileboolWhether to profile the model.False
visualizeboolWhether to visualize the features.False

Returns

TypeDescription
y (torch.Tensor)Concatenated predictions from all models.
train_out (None)Always None for ensemble inference.
Source code in ultralytics/nn/tasks.pyView on GitHub
def forward(self, x, augment=False, profile=False, visualize=False):
    """Generate the YOLO network's final layer.

    Args:
        x (torch.Tensor): Input tensor.
        augment (bool): Whether to augment the input.
        profile (bool): Whether to profile the model.
        visualize (bool): Whether to visualize the features.

    Returns:
        y (torch.Tensor): Concatenated predictions from all models.
        train_out (None): Always None for ensemble inference.
    """
    y = [module(x, augment, profile, visualize)[0] for module in self]
    # y = torch.stack(y).max(0)[0]  # max ensemble
    # y = torch.stack(y).mean(0)  # mean ensemble
    y = torch.cat(y, 2)  # nms ensemble, y shape(B, HW, C)
    return y, None  # inference, train output





class ultralytics.nn.tasks.SafeClass

SafeClass(self, *args, **kwargs)

A placeholder class to replace unknown classes during unpickling.

Args

NameTypeDescriptionDefault
*argsrequired
**kwargsrequired

Methods

NameDescription
__call__Run SafeClass instance, ignoring all arguments.
Source code in ultralytics/nn/tasks.pyView on GitHub
class SafeClass:
    """A placeholder class to replace unknown classes during unpickling."""

    def __init__(self, *args, **kwargs):
        """Initialize SafeClass instance, ignoring all arguments."""
        pass


method ultralytics.nn.tasks.SafeClass.__call__

def __call__(self, *args, **kwargs)

Run SafeClass instance, ignoring all arguments.

Args

NameTypeDescriptionDefault
*argsrequired
**kwargsrequired
Source code in ultralytics/nn/tasks.pyView on GitHub
def __call__(self, *args, **kwargs):
    """Run SafeClass instance, ignoring all arguments."""
    pass





class ultralytics.nn.tasks.SafeUnpickler

SafeUnpickler()

Bases: pickle.Unpickler

Custom Unpickler that replaces unknown classes with SafeClass.

Methods

NameDescription
find_classAttempt to find a class, returning SafeClass if not among safe modules.
Source code in ultralytics/nn/tasks.pyView on GitHub
class SafeUnpickler(pickle.Unpickler):


method ultralytics.nn.tasks.SafeUnpickler.find_class

def find_class(self, module, name)

Attempt to find a class, returning SafeClass if not among safe modules.

Args

NameTypeDescriptionDefault
modulestrModule name.required
namestrClass name.required

Returns

TypeDescription
typeFound class or SafeClass.
Source code in ultralytics/nn/tasks.pyView on GitHub
def find_class(self, module, name):
    """Attempt to find a class, returning SafeClass if not among safe modules.

    Args:
        module (str): Module name.
        name (str): Class name.

    Returns:
        (type): Found class or SafeClass.
    """
    safe_modules = (
        "torch",
        "collections",
        "collections.abc",
        "builtins",
        "math",
        "numpy",
        # Add other modules considered safe
    )
    if module in safe_modules:
        return super().find_class(module, name)
    else:
        return SafeClass





function ultralytics.nn.tasks.temporary_modules

def temporary_modules(modules = None, attributes = None)

Context manager for temporarily adding or modifying modules in Python's module cache (sys.modules).

This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've moved a module from one location to another, but you still want to support the old import paths for backwards compatibility.

Args

NameTypeDescriptionDefault
modulesdict, optionalA dictionary mapping old module paths to new module paths.None
attributesdict, optionalA dictionary mapping old module attributes to new module attributes.None

Examples

>>> with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
>>> import old.module  # this will now import new.module
>>> from old.module import attribute  # this will now import new.module.attribute

Notes

The changes are only in effect inside the context manager and are undone once the context manager exits. Be aware that directly manipulating sys.modules can lead to unpredictable results, especially in larger applications or libraries. Use this function with caution.

Source code in ultralytics/nn/tasks.pyView on GitHub
@contextlib.contextmanager
def temporary_modules(modules=None, attributes=None):
    """Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).

    This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
    moved a module from one location to another, but you still want to support the old import paths for backwards
    compatibility.

    Args:
        modules (dict, optional): A dictionary mapping old module paths to new module paths.
        attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.

    Examples:
        >>> with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
        >>> import old.module  # this will now import new.module
        >>> from old.module import attribute  # this will now import new.module.attribute

    Notes:
        The changes are only in effect inside the context manager and are undone once the context manager exits.
        Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
        applications or libraries. Use this function with caution.
    """
    if modules is None:
        modules = {}
    if attributes is None:
        attributes = {}
    import sys
    from importlib import import_module

    try:
        # Set attributes in sys.modules under their old name
        for old, new in attributes.items():
            old_module, old_attr = old.rsplit(".", 1)
            new_module, new_attr = new.rsplit(".", 1)
            setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))

        # Set modules in sys.modules under their old name
        for old, new in modules.items():
            sys.modules[old] = import_module(new)

        yield
    finally:
        # Remove the temporary module paths
        for old in modules:
            if old in sys.modules:
                del sys.modules[old]





function ultralytics.nn.tasks.torch_safe_load

def torch_safe_load(weight, safe_only = False)

Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches

the error, logs a warning message, and attempts to install the missing module via the check_requirements() function. After installation, the function again attempts to load the model using torch.load().

Args

NameTypeDescriptionDefault
weightstrThe file path of the PyTorch model.required
safe_onlyboolIf True, replace unknown classes with SafeClass during loading.False

Returns

TypeDescription
ckpt (dict)The loaded model checkpoint.
file (str)The loaded filename.

Examples

>>> from ultralytics.nn.tasks import torch_safe_load
>>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
Source code in ultralytics/nn/tasks.pyView on GitHub
def torch_safe_load(weight, safe_only=False):
    """Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
    the error, logs a warning message, and attempts to install the missing module via the check_requirements()
    function. After installation, the function again attempts to load the model using torch.load().

    Args:
        weight (str): The file path of the PyTorch model.
        safe_only (bool): If True, replace unknown classes with SafeClass during loading.

    Returns:
        ckpt (dict): The loaded model checkpoint.
        file (str): The loaded filename.

    Examples:
        >>> from ultralytics.nn.tasks import torch_safe_load
        >>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
    """
    from ultralytics.utils.downloads import attempt_download_asset

    check_suffix(file=weight, suffix=".pt")
    file = attempt_download_asset(weight)  # search online if missing locally
    try:
        with temporary_modules(
            modules={
                "ultralytics.yolo.utils": "ultralytics.utils",
                "ultralytics.yolo.v8": "ultralytics.models.yolo",
                "ultralytics.yolo.data": "ultralytics.data",
            },
            attributes={
                "ultralytics.nn.modules.block.Silence": "torch.nn.Identity",  # YOLOv9e
                "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel",  # YOLOv10
                "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss",  # YOLOv10
            },
        ):
            if safe_only:
                # Load via custom pickle module
                safe_pickle = types.ModuleType("safe_pickle")
                safe_pickle.Unpickler = SafeUnpickler
                safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
                with open(file, "rb") as f:
                    ckpt = torch_load(f, pickle_module=safe_pickle)
            else:
                ckpt = torch_load(file, map_location="cpu")

    except ModuleNotFoundError as e:  # e.name is missing module name
        if e.name == "models":
            raise TypeError(
                emojis(
                    f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
                    f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
                    f"YOLOv8 at https://github.com/ultralytics/ultralytics."
                    f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
                    f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
                )
            ) from e
        elif e.name == "numpy._core":
            raise ModuleNotFoundError(
                emojis(
                    f"ERROR ❌️ {weight} requires numpy>=1.26.1, however numpy=={__import__('numpy').__version__} is installed."
                )
            ) from e
        LOGGER.warning(
            f"{weight} appears to require '{e.name}', which is not in Ultralytics requirements."
            f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
            f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
            f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
        )
        check_requirements(e.name)  # install missing module
        ckpt = torch_load(file, map_location="cpu")

    if not isinstance(ckpt, dict):
        # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt")
        LOGGER.warning(
            f"The file '{weight}' appears to be improperly saved or formatted. "
            f"For optimal results, use model.save('filename.pt') to correctly save YOLO models."
        )
        ckpt = {"model": ckpt.model}

    return ckpt, file





function ultralytics.nn.tasks.load_checkpoint

def load_checkpoint(weight, device = None, inplace = True, fuse = False)

Load a single model weights.

Args

NameTypeDescriptionDefault
weightstr | PathModel weight path.required
devicetorch.device, optionalDevice to load model to.None
inplaceboolWhether to do inplace operations.True
fuseboolWhether to fuse model.False

Returns

TypeDescription
model (torch.nn.Module)Loaded model.
ckpt (dict)Model checkpoint dictionary.
Source code in ultralytics/nn/tasks.pyView on GitHub
def load_checkpoint(weight, device=None, inplace=True, fuse=False):
    """Load a single model weights.

    Args:
        weight (str | Path): Model weight path.
        device (torch.device, optional): Device to load model to.
        inplace (bool): Whether to do inplace operations.
        fuse (bool): Whether to fuse model.

    Returns:
        model (torch.nn.Module): Loaded model.
        ckpt (dict): Model checkpoint dictionary.
    """
    ckpt, weight = torch_safe_load(weight)  # load ckpt
    args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))}  # combine model and default args, preferring model args
    model = (ckpt.get("ema") or ckpt["model"]).float()  # FP32 model

    # Model compatibility updates
    model.args = args  # attach args to model
    model.pt_path = weight  # attach *.pt file path to model
    model.task = getattr(model, "task", guess_model_task(model))
    if not hasattr(model, "stride"):
        model.stride = torch.tensor([32.0])

    model = (model.fuse() if fuse and hasattr(model, "fuse") else model).eval().to(device)  # model in eval mode

    # Module updates
    for m in model.modules():
        if hasattr(m, "inplace"):
            m.inplace = inplace
        elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
            m.recompute_scale_factor = None  # torch 1.11.0 compatibility

    # Return model and ckpt
    return model, ckpt





function ultralytics.nn.tasks.parse_model

def parse_model(d, ch, verbose = True)

Parse a YOLO model.yaml dictionary into a PyTorch model.

Args

NameTypeDescriptionDefault
ddictModel dictionary.required
chintInput channels.required
verboseboolWhether to print model details.True

Returns

TypeDescription
model (torch.nn.Sequential)PyTorch model.
save (list)Sorted list of output layers.
Source code in ultralytics/nn/tasks.pyView on GitHub
def parse_model(d, ch, verbose=True):
    """Parse a YOLO model.yaml dictionary into a PyTorch model.

    Args:
        d (dict): Model dictionary.
        ch (int): Input channels.
        verbose (bool): Whether to print model details.

    Returns:
        model (torch.nn.Sequential): PyTorch model.
        save (list): Sorted list of output layers.
    """
    import ast

    # Args
    legacy = True  # backward compatibility for v3/v5/v8/v9 models
    max_channels = float("inf")
    nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
    depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
    scale = d.get("scale")
    if scales:
        if not scale:
            scale = next(iter(scales.keys()))
            LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
        depth, width, max_channels = scales[scale]

    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = torch.nn.SiLU()
        if verbose:
            LOGGER.info(f"{colorstr('activation:')} {act}")  # print

    if verbose:
        LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")
    ch = [ch]
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    base_modules = frozenset(
        {
            Classify,
            Conv,
            ConvTranspose,
            GhostConv,
            Bottleneck,
            GhostBottleneck,
            SPP,
            SPPF,
            C2fPSA,
            C2PSA,
            DWConv,
            Focus,
            BottleneckCSP,
            C1,
            C2,
            C2f,
            C3k2,
            RepNCSPELAN4,
            ELAN1,
            ADown,
            AConv,
            SPPELAN,
            C2fAttn,
            C3,
            C3TR,
            C3Ghost,
            torch.nn.ConvTranspose2d,
            DWConvTranspose2d,
            C3x,
            RepC3,
            PSA,
            SCDown,
            C2fCIB,
            A2C2f,
        }
    )
    repeat_modules = frozenset(  # modules with 'repeat' arguments
        {
            BottleneckCSP,
            C1,
            C2,
            C2f,
            C3k2,
            C2fAttn,
            C3,
            C3TR,
            C3Ghost,
            C3x,
            RepC3,
            C2fPSA,
            C2fCIB,
            C2PSA,
            A2C2f,
        }
    )
    for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # from, number, module, args
        m = (
            getattr(torch.nn, m[3:])
            if "nn." in m
            else getattr(__import__("torchvision").ops, m[16:])
            if "torchvision.ops." in m
            else globals()[m]
        )  # get module
        for j, a in enumerate(args):
            if isinstance(a, str):
                with contextlib.suppress(ValueError):
                    args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain
        if m in base_modules:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            if m is C2fAttn:  # set 1) embed channels and 2) num heads
                args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
                args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])

            args = [c1, c2, *args[1:]]
            if m in repeat_modules:
                args.insert(2, n)  # number of repeats
                n = 1
            if m is C3k2:  # for M/L/X sizes
                legacy = False
                if scale in "mlx":
                    args[3] = True
            if m is A2C2f:
                legacy = False
                if scale in "lx":  # for L/X sizes
                    args.extend((True, 1.2))
            if m is C2fCIB:
                legacy = False
        elif m is AIFI:
            args = [ch[f], *args]
        elif m in frozenset({HGStem, HGBlock}):
            c1, cm, c2 = ch[f], args[0], args[1]
            args = [c1, cm, c2, *args[2:]]
            if m is HGBlock:
                args.insert(4, n)  # number of repeats
                n = 1
        elif m is ResNetLayer:
            c2 = args[1] if args[3] else args[1] * 4
        elif m is torch.nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        elif m in frozenset(
            {Detect, WorldDetect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB, ImagePoolingAttn, v10Detect}
        ):
            args.append([ch[x] for x in f])
            if m is Segment or m is YOLOESegment:
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
            if m in {Detect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB}:
                m.legacy = legacy
        elif m is RTDETRDecoder:  # special case, channels arg must be passed in index 1
            args.insert(1, [ch[x] for x in f])
        elif m is CBLinear:
            c2 = args[0]
            c1 = ch[f]
            args = [c1, c2, *args[1:]]
        elif m is CBFuse:
            c2 = ch[f[-1]]
        elif m in frozenset({TorchVision, Index}):
            c2 = args[0]
            c1 = ch[f]
            args = [*args[1:]]
        else:
            c2 = ch[f]

        m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace("__main__.", "")  # module type
        m_.np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type = i, f, t  # attach index, 'from' index, type
        if verbose:
            LOGGER.info(f"{i:>3}{f!s:>20}{n_:>3}{m_.np:10.0f}  {t:<45}{args!s:<30}")  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return torch.nn.Sequential(*layers), sorted(save)





function ultralytics.nn.tasks.yaml_model_load

def yaml_model_load(path)

Load a YOLOv8 model from a YAML file.

Args

NameTypeDescriptionDefault
pathstr | PathPath to the YAML file.required

Returns

TypeDescription
dictModel dictionary.
Source code in ultralytics/nn/tasks.pyView on GitHub
def yaml_model_load(path):
    """Load a YOLOv8 model from a YAML file.

    Args:
        path (str | Path): Path to the YAML file.

    Returns:
        (dict): Model dictionary.
    """
    path = Path(path)
    if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
        new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
        LOGGER.warning(f"Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
        path = path.with_name(new_stem + path.suffix)

    unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path))  # i.e. yolov8x.yaml -> yolov8.yaml
    yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
    d = YAML.load(yaml_file)  # model dict
    d["scale"] = guess_model_scale(path)
    d["yaml_file"] = str(path)
    return d





function ultralytics.nn.tasks.guess_model_scale

def guess_model_scale(model_path)

Extract the size character n, s, m, l, or x of the model's scale from the model path.

Args

NameTypeDescriptionDefault
model_pathstr | PathThe path to the YOLO model's YAML file.required

Returns

TypeDescription
strThe size character of the model's scale (n, s, m, l, or x).
Source code in ultralytics/nn/tasks.pyView on GitHub
def guess_model_scale(model_path):
    """Extract the size character n, s, m, l, or x of the model's scale from the model path.

    Args:
        model_path (str | Path): The path to the YOLO model's YAML file.

    Returns:
        (str): The size character of the model's scale (n, s, m, l, or x).
    """
    try:
        return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
    except AttributeError:
        return ""





function ultralytics.nn.tasks.guess_model_task

def guess_model_task(model)

Guess the task of a PyTorch model from its architecture or configuration.

Args

NameTypeDescriptionDefault
modeltorch.nn.Module | dictPyTorch model or model configuration in YAML format.required

Returns

TypeDescription
strTask of the model ('detect', 'segment', 'classify', 'pose', 'obb').
Source code in ultralytics/nn/tasks.pyView on GitHub
def guess_model_task(model):
    """Guess the task of a PyTorch model from its architecture or configuration.

    Args:
        model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.

    Returns:
        (str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').
    """

    def cfg2task(cfg):
        """Guess from YAML dictionary."""
        m = cfg["head"][-1][-2].lower()  # output module name
        if m in {"classify", "classifier", "cls", "fc"}:
            return "classify"
        if "detect" in m:
            return "detect"
        if "segment" in m:
            return "segment"
        if m == "pose":
            return "pose"
        if m == "obb":
            return "obb"

    # Guess from model cfg
    if isinstance(model, dict):
        with contextlib.suppress(Exception):
            return cfg2task(model)
    # Guess from PyTorch model
    if isinstance(model, torch.nn.Module):  # PyTorch model
        for x in "model.args", "model.model.args", "model.model.model.args":
            with contextlib.suppress(Exception):
                return eval(x)["task"]
        for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
            with contextlib.suppress(Exception):
                return cfg2task(eval(x))
        for m in model.modules():
            if isinstance(m, (Segment, YOLOESegment)):
                return "segment"
            elif isinstance(m, Classify):
                return "classify"
            elif isinstance(m, Pose):
                return "pose"
            elif isinstance(m, OBB):
                return "obb"
            elif isinstance(m, (Detect, WorldDetect, YOLOEDetect, v10Detect)):
                return "detect"

    # Guess from model filename
    if isinstance(model, (str, Path)):
        model = Path(model)
        if "-seg" in model.stem or "segment" in model.parts:
            return "segment"
        elif "-cls" in model.stem or "classify" in model.parts:
            return "classify"
        elif "-pose" in model.stem or "pose" in model.parts:
            return "pose"
        elif "-obb" in model.stem or "obb" in model.parts:
            return "obb"
        elif "detect" in model.parts:
            return "detect"

    # Unable to determine task from model
    LOGGER.warning(
        "Unable to automatically guess model task, assuming 'task=detect'. "
        "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
    )
    return "detect"  # assume detect





📅 Created 2 years ago ✏️ Updated 2 days ago
glenn-jocherRizwanMunawarjk4eBurhan-Q