Skip to content

Reference for ultralytics/utils/torch_utils.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/torch_utils.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.utils.torch_utils.ModelEMA

ModelEMA(model, decay=0.9999, tau=2000, updates=0)

Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models.

Keeps a moving average of everything in the model state_dict (parameters and buffers). For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage

To disable EMA set the enabled attribute to False.

Attributes:

Name Type Description
ema Module

Copy of the model in evaluation mode.

updates int

Number of EMA updates.

decay function

Decay function that determines the EMA weight.

enabled bool

Whether EMA is enabled.

Parameters:

Name Type Description Default
model Module

Model to create EMA for.

required
decay float

Maximum EMA decay rate. Defaults to 0.9999.

0.9999
tau int

EMA decay time constant. Defaults to 2000.

2000
updates int

Initial number of updates. Defaults to 0.

0
Source code in ultralytics/utils/torch_utils.py
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
    """
    Initialize EMA for 'model' with given arguments.

    Args:
        model (nn.Module): Model to create EMA for.
        decay (float, optional): Maximum EMA decay rate. Defaults to 0.9999.
        tau (int, optional): EMA decay time constant. Defaults to 2000.
        updates (int, optional): Initial number of updates. Defaults to 0.
    """
    self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
    self.updates = updates  # number of EMA updates
    self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
    for p in self.ema.parameters():
        p.requires_grad_(False)
    self.enabled = True

update

update(model)

Update EMA parameters.

Parameters:

Name Type Description Default
model Module

Model to update EMA from.

required
Source code in ultralytics/utils/torch_utils.py
def update(self, model):
    """
    Update EMA parameters.

    Args:
        model (nn.Module): Model to update EMA from.
    """
    if self.enabled:
        self.updates += 1
        d = self.decay(self.updates)

        msd = de_parallel(model).state_dict()  # model state_dict
        for k, v in self.ema.state_dict().items():
            if v.dtype.is_floating_point:  # true for FP16 and FP32
                v *= d
                v += (1 - d) * msd[k].detach()

update_attr

update_attr(model, include=(), exclude=('process_group', 'reducer'))

Updates attributes and saves stripped model with optimizer removed.

Parameters:

Name Type Description Default
model Module

Model to update attributes from.

required
include tuple

Attributes to include. Defaults to ().

()
exclude tuple

Attributes to exclude. Defaults to ("process_group", "reducer").

('process_group', 'reducer')
Source code in ultralytics/utils/torch_utils.py
def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
    """
    Updates attributes and saves stripped model with optimizer removed.

    Args:
        model (nn.Module): Model to update attributes from.
        include (tuple, optional): Attributes to include. Defaults to ().
        exclude (tuple, optional): Attributes to exclude. Defaults to ("process_group", "reducer").
    """
    if self.enabled:
        copy_attr(self.ema, model, include, exclude)





ultralytics.utils.torch_utils.EarlyStopping

EarlyStopping(patience=50)

Early stopping class that stops training when a specified number of epochs have passed without improvement.

Attributes:

Name Type Description
best_fitness float

Best fitness value observed.

best_epoch int

Epoch where best fitness was observed.

patience int

Number of epochs to wait after fitness stops improving before stopping.

possible_stop bool

Flag indicating if stopping may occur next epoch.

Parameters:

Name Type Description Default
patience int

Number of epochs to wait after fitness stops improving before stopping.

50
Source code in ultralytics/utils/torch_utils.py
def __init__(self, patience=50):
    """
    Initialize early stopping object.

    Args:
        patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
    """
    self.best_fitness = 0.0  # i.e. mAP
    self.best_epoch = 0
    self.patience = patience or float("inf")  # epochs to wait after fitness stops improving to stop
    self.possible_stop = False  # possible stop may occur next epoch

__call__

__call__(epoch, fitness)

Check whether to stop training.

Parameters:

Name Type Description Default
epoch int

Current epoch of training

required
fitness float

Fitness value of current epoch

required

Returns:

Type Description
bool

True if training should stop, False otherwise

Source code in ultralytics/utils/torch_utils.py
def __call__(self, epoch, fitness):
    """
    Check whether to stop training.

    Args:
        epoch (int): Current epoch of training
        fitness (float): Fitness value of current epoch

    Returns:
        (bool): True if training should stop, False otherwise
    """
    if fitness is None:  # check if fitness=None (happens when val=False)
        return False

    if fitness > self.best_fitness or self.best_fitness == 0:  # allow for early zero-fitness stage of training
        self.best_epoch = epoch
        self.best_fitness = fitness
    delta = epoch - self.best_epoch  # epochs without improvement
    self.possible_stop = delta >= (self.patience - 1)  # possible stop may occur next epoch
    stop = delta >= self.patience  # stop training if patience exceeded
    if stop:
        prefix = colorstr("EarlyStopping: ")
        LOGGER.info(
            f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
            f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
            f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
            f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
        )
    return stop





ultralytics.utils.torch_utils.FXModel

FXModel(model)

Bases: Module

A custom model class for torch.fx compatibility.

This class extends torch.nn.Module and is designed to ensure compatibility with torch.fx for tracing and graph manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying.

Attributes:

Name Type Description
model Module

The original model's layers.

Parameters:

Name Type Description Default
model Module

The original model to wrap for torch.fx compatibility.

required
Source code in ultralytics/utils/torch_utils.py
def __init__(self, model):
    """
    Initialize the FXModel.

    Args:
        model (nn.Module): The original model to wrap for torch.fx compatibility.
    """
    super().__init__()
    copy_attr(self, model)
    # Explicitly set `model` since `copy_attr` somehow does not copy it.
    self.model = model.model

forward

forward(x)

Forward pass through the model.

This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs.

Parameters:

Name Type Description Default
x Tensor

The input tensor to the model.

required

Returns:

Type Description
Tensor

The output tensor from the model.

Source code in ultralytics/utils/torch_utils.py
def forward(self, x):
    """
    Forward pass through the model.

    This method performs the forward pass through the model, handling the dependencies between layers and saving
    intermediate outputs.

    Args:
        x (torch.Tensor): The input tensor to the model.

    Returns:
        (torch.Tensor): The output tensor from the model.
    """
    y = []  # outputs
    for m in self.model:
        if m.f != -1:  # if not from previous layer
            # from earlier layers
            x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
        x = m(x)  # run
        y.append(x)  # save output
    return x





ultralytics.utils.torch_utils.torch_distributed_zero_first

torch_distributed_zero_first(local_rank: int)

Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first.

Source code in ultralytics/utils/torch_utils.py
@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
    initialized = dist.is_available() and dist.is_initialized()

    if initialized and local_rank not in {-1, 0}:
        dist.barrier(device_ids=[local_rank])
    yield
    if initialized and local_rank == 0:
        dist.barrier(device_ids=[local_rank])





ultralytics.utils.torch_utils.smart_inference_mode

smart_inference_mode()

Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator.

Source code in ultralytics/utils/torch_utils.py
def smart_inference_mode():
    """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""

    def decorate(fn):
        """Applies appropriate torch decorator for inference mode based on torch version."""
        if TORCH_1_9 and torch.is_inference_mode_enabled():
            return fn  # already in inference_mode, act as a pass-through
        else:
            return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)

    return decorate





ultralytics.utils.torch_utils.autocast

autocast(enabled: bool, device: str = 'cuda')

Get the appropriate autocast context manager based on PyTorch version and AMP setting.

This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.

Parameters:

Name Type Description Default
enabled bool

Whether to enable automatic mixed precision.

required
device str

The device to use for autocast. Defaults to 'cuda'.

'cuda'

Returns:

Type Description
autocast

The appropriate autocast context manager.

Notes
  • For PyTorch versions 1.13 and newer, it uses torch.amp.autocast.
  • For older versions, it uses torch.cuda.autocast.

Examples:

>>> with autocast(enabled=True):
...     # Your mixed precision operations here
...     pass
Source code in ultralytics/utils/torch_utils.py
def autocast(enabled: bool, device: str = "cuda"):
    """
    Get the appropriate autocast context manager based on PyTorch version and AMP setting.

    This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
    older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.

    Args:
        enabled (bool): Whether to enable automatic mixed precision.
        device (str, optional): The device to use for autocast. Defaults to 'cuda'.

    Returns:
        (torch.amp.autocast): The appropriate autocast context manager.

    Notes:
        - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
        - For older versions, it uses `torch.cuda.autocast`.

    Examples:
        >>> with autocast(enabled=True):
        ...     # Your mixed precision operations here
        ...     pass
    """
    if TORCH_1_13:
        return torch.amp.autocast(device, enabled=enabled)
    else:
        return torch.cuda.amp.autocast(enabled)





ultralytics.utils.torch_utils.get_cpu_info

get_cpu_info()

Return a string with system CPU information, i.e. 'Apple M2'.

Source code in ultralytics/utils/torch_utils.py
def get_cpu_info():
    """Return a string with system CPU information, i.e. 'Apple M2'."""
    from ultralytics.utils import PERSISTENT_CACHE  # avoid circular import error

    if "cpu_info" not in PERSISTENT_CACHE:
        try:
            import cpuinfo  # pip install py-cpuinfo

            k = "brand_raw", "hardware_raw", "arch_string_raw"  # keys sorted by preference
            info = cpuinfo.get_cpu_info()  # info dict
            string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
            PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
        except Exception:
            pass
    return PERSISTENT_CACHE.get("cpu_info", "unknown")





ultralytics.utils.torch_utils.get_gpu_info

get_gpu_info(index)

Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'.

Source code in ultralytics/utils/torch_utils.py
def get_gpu_info(index):
    """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
    properties = torch.cuda.get_device_properties(index)
    return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"





ultralytics.utils.torch_utils.select_device

select_device(device='', batch=0, newline=False, verbose=True)

Select the appropriate PyTorch device based on the provided arguments.

The function takes a string specifying the device or a torch.device object and returns a torch.device object representing the selected device. The function also validates the number of available devices and raises an exception if the requested device(s) are not available.

Parameters:

Name Type Description Default
device str | device

Device string or torch.device object. Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects the first available GPU, or CPU if no GPU is available.

''
batch int

Batch size being used in your model. Defaults to 0.

0
newline bool

If True, adds a newline at the end of the log string. Defaults to False.

False
verbose bool

If True, logs the device information. Defaults to True.

True

Returns:

Type Description
device

Selected device.

Raises:

Type Description
ValueError

If the specified device is not available or if the batch size is not a multiple of the number of devices when using multiple GPUs.

Examples:

>>> select_device("cuda:0")
device(type='cuda', index=0)
>>> select_device("cpu")
device(type='cpu')
Note

Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.

Source code in ultralytics/utils/torch_utils.py
def select_device(device="", batch=0, newline=False, verbose=True):
    """
    Select the appropriate PyTorch device based on the provided arguments.

    The function takes a string specifying the device or a torch.device object and returns a torch.device object
    representing the selected device. The function also validates the number of available devices and raises an
    exception if the requested device(s) are not available.

    Args:
        device (str | torch.device, optional): Device string or torch.device object.
            Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
            the first available GPU, or CPU if no GPU is available.
        batch (int, optional): Batch size being used in your model. Defaults to 0.
        newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
        verbose (bool, optional): If True, logs the device information. Defaults to True.

    Returns:
        (torch.device): Selected device.

    Raises:
        ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
            devices when using multiple GPUs.

    Examples:
        >>> select_device("cuda:0")
        device(type='cuda', index=0)

        >>> select_device("cpu")
        device(type='cpu')

    Note:
        Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
    """
    if isinstance(device, torch.device) or str(device).startswith("tpu"):
        return device

    s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
    device = str(device).lower()
    for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
        device = device.replace(remove, "")  # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
    cpu = device == "cpu"
    mps = device in {"mps", "mps:0"}  # Apple Metal Performance Shaders (MPS)
    if cpu or mps:
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # force torch.cuda.is_available() = False
    elif device:  # non-cpu device requested
        if device == "cuda":
            device = "0"
        if "," in device:
            device = ",".join([x for x in device.split(",") if x])  # remove sequential commas, i.e. "0,,1" -> "0,1"
        visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
        os.environ["CUDA_VISIBLE_DEVICES"] = device  # set environment variable - must be before assert is_available()
        if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
            LOGGER.info(s)
            install = (
                "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
                "CUDA devices are seen by torch.\n"
                if torch.cuda.device_count() == 0
                else ""
            )
            raise ValueError(
                f"Invalid CUDA 'device={device}' requested."
                f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
                f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
                f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
                f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
                f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
                f"{install}"
            )

    if not cpu and not mps and torch.cuda.is_available():  # prefer GPU if available
        devices = device.split(",") if device else "0"  # i.e. "0,1" -> ["0", "1"]
        n = len(devices)  # device count
        if n > 1:  # multi-GPU
            if batch < 1:
                raise ValueError(
                    "AutoBatch with batch<1 not supported for Multi-GPU training, "
                    "please specify a valid batch size, i.e. batch=16."
                )
            if batch >= 0 and batch % n != 0:  # check batch_size is divisible by device_count
                raise ValueError(
                    f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
                    f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
                )
        space = " " * (len(s) + 1)
        for i, d in enumerate(devices):
            s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n"  # bytes to MB
        arg = "cuda:0"
    elif mps and TORCH_2_0 and torch.backends.mps.is_available():
        # Prefer MPS if available
        s += f"MPS ({get_cpu_info()})\n"
        arg = "mps"
    else:  # revert to CPU
        s += f"CPU ({get_cpu_info()})\n"
        arg = "cpu"

    if arg in {"cpu", "mps"}:
        torch.set_num_threads(NUM_THREADS)  # reset OMP_NUM_THREADS for cpu training
    if verbose:
        LOGGER.info(s if newline else s.rstrip())
    return torch.device(arg)





ultralytics.utils.torch_utils.time_sync

time_sync()

PyTorch-accurate time.

Source code in ultralytics/utils/torch_utils.py
def time_sync():
    """PyTorch-accurate time."""
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    return time.time()





ultralytics.utils.torch_utils.fuse_conv_and_bn

fuse_conv_and_bn(conv, bn)

Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/.

Source code in ultralytics/utils/torch_utils.py
def fuse_conv_and_bn(conv, bn):
    """Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
    fusedconv = (
        nn.Conv2d(
            conv.in_channels,
            conv.out_channels,
            kernel_size=conv.kernel_size,
            stride=conv.stride,
            padding=conv.padding,
            dilation=conv.dilation,
            groups=conv.groups,
            bias=True,
        )
        .requires_grad_(False)
        .to(conv.weight.device)
    )

    # Prepare filters
    w_conv = conv.weight.view(conv.out_channels, -1)
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))

    # Prepare spatial bias
    b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

    return fusedconv





ultralytics.utils.torch_utils.fuse_deconv_and_bn

fuse_deconv_and_bn(deconv, bn)

Fuse ConvTranspose2d() and BatchNorm2d() layers.

Source code in ultralytics/utils/torch_utils.py
def fuse_deconv_and_bn(deconv, bn):
    """Fuse ConvTranspose2d() and BatchNorm2d() layers."""
    fuseddconv = (
        nn.ConvTranspose2d(
            deconv.in_channels,
            deconv.out_channels,
            kernel_size=deconv.kernel_size,
            stride=deconv.stride,
            padding=deconv.padding,
            output_padding=deconv.output_padding,
            dilation=deconv.dilation,
            groups=deconv.groups,
            bias=True,
        )
        .requires_grad_(False)
        .to(deconv.weight.device)
    )

    # Prepare filters
    w_deconv = deconv.weight.view(deconv.out_channels, -1)
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
    fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))

    # Prepare spatial bias
    b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

    return fuseddconv





ultralytics.utils.torch_utils.model_info

model_info(model, detailed=False, verbose=True, imgsz=640)

Print and return detailed model information layer by layer.

Parameters:

Name Type Description Default
model Module

Model to analyze.

required
detailed bool

Whether to print detailed layer information. Defaults to False.

False
verbose bool

Whether to print model information. Defaults to True.

True
imgsz int | List

Input image size. Defaults to 640.

640

Returns:

Type Description
Tuple[int, int, int, float]

Number of layers, parameters, gradients, and GFLOPs.

Source code in ultralytics/utils/torch_utils.py
def model_info(model, detailed=False, verbose=True, imgsz=640):
    """
    Print and return detailed model information layer by layer.

    Args:
        model (nn.Module): Model to analyze.
        detailed (bool, optional): Whether to print detailed layer information. Defaults to False.
        verbose (bool, optional): Whether to print model information. Defaults to True.
        imgsz (int | List, optional): Input image size. Defaults to 640.

    Returns:
        (Tuple[int, int, int, float]): Number of layers, parameters, gradients, and GFLOPs.
    """
    if not verbose:
        return
    n_p = get_num_params(model)  # number of parameters
    n_g = get_num_gradients(model)  # number of gradients
    layers = __import__("collections").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0)
    n_l = len(layers)  # number of layers
    if detailed:
        h = f"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}"
        LOGGER.info(h)
        for i, (mn, m) in enumerate(layers.items()):
            mn = mn.replace("module_list.", "")
            mt = m.__class__.__name__
            if len(m._parameters):
                for pn, p in m.named_parameters():
                    LOGGER.info(
                        f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
                    )
            else:  # layers with no learnable params
                LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}")

    flops = get_flops(model, imgsz)  # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
    fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
    fs = f", {flops:.1f} GFLOPs" if flops else ""
    yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
    model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
    LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
    return n_l, n_p, n_g, flops





ultralytics.utils.torch_utils.get_num_params

get_num_params(model)

Return the total number of parameters in a YOLO model.

Source code in ultralytics/utils/torch_utils.py
def get_num_params(model):
    """Return the total number of parameters in a YOLO model."""
    return sum(x.numel() for x in model.parameters())





ultralytics.utils.torch_utils.get_num_gradients

get_num_gradients(model)

Return the total number of parameters with gradients in a YOLO model.

Source code in ultralytics/utils/torch_utils.py
def get_num_gradients(model):
    """Return the total number of parameters with gradients in a YOLO model."""
    return sum(x.numel() for x in model.parameters() if x.requires_grad)





ultralytics.utils.torch_utils.model_info_for_loggers

model_info_for_loggers(trainer)

Return model info dict with useful model information.

Parameters:

Name Type Description Default
trainer BaseTrainer

The trainer object containing model and validation data.

required

Returns:

Type Description
dict

Dictionary containing model parameters, GFLOPs, and inference speeds.

Examples:

YOLOv8n info for loggers

>>> results = {
...    "model/parameters": 3151904,
...    "model/GFLOPs": 8.746,
...    "model/speed_ONNX(ms)": 41.244,
...    "model/speed_TensorRT(ms)": 3.211,
...    "model/speed_PyTorch(ms)": 18.755,
...}
Source code in ultralytics/utils/torch_utils.py
def model_info_for_loggers(trainer):
    """
    Return model info dict with useful model information.

    Args:
        trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.

    Returns:
        (dict): Dictionary containing model parameters, GFLOPs, and inference speeds.

    Examples:
        YOLOv8n info for loggers
        >>> results = {
        ...    "model/parameters": 3151904,
        ...    "model/GFLOPs": 8.746,
        ...    "model/speed_ONNX(ms)": 41.244,
        ...    "model/speed_TensorRT(ms)": 3.211,
        ...    "model/speed_PyTorch(ms)": 18.755,
        ...}
    """
    if trainer.args.profile:  # profile ONNX and TensorRT times
        from ultralytics.utils.benchmarks import ProfileModels

        results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
        results.pop("model/name")
    else:  # only return PyTorch times from most recent validation
        results = {
            "model/parameters": get_num_params(trainer.model),
            "model/GFLOPs": round(get_flops(trainer.model), 3),
        }
    results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
    return results





ultralytics.utils.torch_utils.get_flops

get_flops(model, imgsz=640)

Return a YOLO model's FLOPs.

Parameters:

Name Type Description Default
model Module

The model to calculate FLOPs for.

required
imgsz int | List[int]

Input image size. Defaults to 640.

640

Returns:

Type Description
float

The model's FLOPs in billions.

Source code in ultralytics/utils/torch_utils.py
def get_flops(model, imgsz=640):
    """
    Return a YOLO model's FLOPs.

    Args:
        model (nn.Module): The model to calculate FLOPs for.
        imgsz (int | List[int], optional): Input image size. Defaults to 640.

    Returns:
        (float): The model's FLOPs in billions.
    """
    if not thop:
        return 0.0  # if not installed return 0.0 GFLOPs

    try:
        model = de_parallel(model)
        p = next(model.parameters())
        if not isinstance(imgsz, list):
            imgsz = [imgsz, imgsz]  # expand if int/float
        try:
            # Use stride size for input tensor
            stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32  # max stride
            im = torch.empty((1, p.shape[1], stride, stride), device=p.device)  # input image in BCHW format
            flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2  # stride GFLOPs
            return flops * imgsz[0] / stride * imgsz[1] / stride  # imgsz GFLOPs
        except Exception:
            # Use actual image size for input tensor (i.e. required for RTDETR models)
            im = torch.empty((1, p.shape[1], *imgsz), device=p.device)  # input image in BCHW format
            return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2  # imgsz GFLOPs
    except Exception:
        return 0.0





ultralytics.utils.torch_utils.get_flops_with_torch_profiler

get_flops_with_torch_profiler(model, imgsz=640)

Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).

Parameters:

Name Type Description Default
model Module

The model to calculate FLOPs for.

required
imgsz int | List[int]

Input image size. Defaults to 640.

640

Returns:

Type Description
float

The model's FLOPs in billions.

Source code in ultralytics/utils/torch_utils.py
def get_flops_with_torch_profiler(model, imgsz=640):
    """
    Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).

    Args:
        model (nn.Module): The model to calculate FLOPs for.
        imgsz (int | List[int], optional): Input image size. Defaults to 640.

    Returns:
        (float): The model's FLOPs in billions.
    """
    if not TORCH_2_0:  # torch profiler implemented in torch>=2.0
        return 0.0
    model = de_parallel(model)
    p = next(model.parameters())
    if not isinstance(imgsz, list):
        imgsz = [imgsz, imgsz]  # expand if int/float
    try:
        # Use stride size for input tensor
        stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2  # max stride
        im = torch.empty((1, p.shape[1], stride, stride), device=p.device)  # input image in BCHW format
        with torch.profiler.profile(with_flops=True) as prof:
            model(im)
        flops = sum(x.flops for x in prof.key_averages()) / 1e9
        flops = flops * imgsz[0] / stride * imgsz[1] / stride  # 640x640 GFLOPs
    except Exception:
        # Use actual image size for input tensor (i.e. required for RTDETR models)
        im = torch.empty((1, p.shape[1], *imgsz), device=p.device)  # input image in BCHW format
        with torch.profiler.profile(with_flops=True) as prof:
            model(im)
        flops = sum(x.flops for x in prof.key_averages()) / 1e9
    return flops





ultralytics.utils.torch_utils.initialize_weights

initialize_weights(model)

Initialize model weights to random values.

Source code in ultralytics/utils/torch_utils.py
def initialize_weights(model):
    """Initialize model weights to random values."""
    for m in model.modules():
        t = type(m)
        if t is nn.Conv2d:
            pass  # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif t is nn.BatchNorm2d:
            m.eps = 1e-3
            m.momentum = 0.03
        elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
            m.inplace = True





ultralytics.utils.torch_utils.scale_img

scale_img(img, ratio=1.0, same_shape=False, gs=32)

Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple.

Parameters:

Name Type Description Default
img Tensor

Input image tensor.

required
ratio float

Scaling ratio. Defaults to 1.0.

1.0
same_shape bool

Whether to maintain the same shape. Defaults to False.

False
gs int

Grid size for padding. Defaults to 32.

32

Returns:

Type Description
Tensor

Scaled and padded image tensor.

Source code in ultralytics/utils/torch_utils.py
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
    """
    Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple.

    Args:
        img (torch.Tensor): Input image tensor.
        ratio (float, optional): Scaling ratio. Defaults to 1.0.
        same_shape (bool, optional): Whether to maintain the same shape. Defaults to False.
        gs (int, optional): Grid size for padding. Defaults to 32.

    Returns:
        (torch.Tensor): Scaled and padded image tensor.
    """
    if ratio == 1.0:
        return img
    h, w = img.shape[2:]
    s = (int(h * ratio), int(w * ratio))  # new size
    img = F.interpolate(img, size=s, mode="bilinear", align_corners=False)  # resize
    if not same_shape:  # pad/crop img
        h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
    return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447)  # value = imagenet mean





ultralytics.utils.torch_utils.copy_attr

copy_attr(a, b, include=(), exclude=())

Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.

Parameters:

Name Type Description Default
a object

Destination object to copy attributes to.

required
b object

Source object to copy attributes from.

required
include tuple

Attributes to include. If empty, all attributes are included. Defaults to ().

()
exclude tuple

Attributes to exclude. Defaults to ().

()
Source code in ultralytics/utils/torch_utils.py
def copy_attr(a, b, include=(), exclude=()):
    """
    Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.

    Args:
        a (object): Destination object to copy attributes to.
        b (object): Source object to copy attributes from.
        include (tuple, optional): Attributes to include. If empty, all attributes are included. Defaults to ().
        exclude (tuple, optional): Attributes to exclude. Defaults to ().
    """
    for k, v in b.__dict__.items():
        if (len(include) and k not in include) or k.startswith("_") or k in exclude:
            continue
        else:
            setattr(a, k, v)





ultralytics.utils.torch_utils.get_latest_opset

get_latest_opset()

Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.

Returns:

Type Description
int

The ONNX opset version.

Source code in ultralytics/utils/torch_utils.py
def get_latest_opset():
    """
    Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.

    Returns:
        (int): The ONNX opset version.
    """
    if TORCH_1_13:
        # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
        return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
    # Otherwise for PyTorch<=1.12 return the corresponding predefined opset
    version = torch.onnx.producer_version.rsplit(".", 1)[0]  # i.e. '2.3'
    return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)





ultralytics.utils.torch_utils.intersect_dicts

intersect_dicts(da, db, exclude=())

Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.

Parameters:

Name Type Description Default
da dict

First dictionary.

required
db dict

Second dictionary.

required
exclude tuple

Keys to exclude. Defaults to ().

()

Returns:

Type Description
dict

Dictionary of intersecting keys with matching shapes.

Source code in ultralytics/utils/torch_utils.py
def intersect_dicts(da, db, exclude=()):
    """
    Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.

    Args:
        da (dict): First dictionary.
        db (dict): Second dictionary.
        exclude (tuple, optional): Keys to exclude. Defaults to ().

    Returns:
        (dict): Dictionary of intersecting keys with matching shapes.
    """
    return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}





ultralytics.utils.torch_utils.is_parallel

is_parallel(model)

Returns True if model is of type DP or DDP.

Parameters:

Name Type Description Default
model Module

Model to check.

required

Returns:

Type Description
bool

True if model is DataParallel or DistributedDataParallel.

Source code in ultralytics/utils/torch_utils.py
def is_parallel(model):
    """
    Returns True if model is of type DP or DDP.

    Args:
        model (nn.Module): Model to check.

    Returns:
        (bool): True if model is DataParallel or DistributedDataParallel.
    """
    return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))





ultralytics.utils.torch_utils.de_parallel

de_parallel(model)

De-parallelize a model: returns single-GPU model if model is of type DP or DDP.

Parameters:

Name Type Description Default
model Module

Model to de-parallelize.

required

Returns:

Type Description
Module

De-parallelized model.

Source code in ultralytics/utils/torch_utils.py
def de_parallel(model):
    """
    De-parallelize a model: returns single-GPU model if model is of type DP or DDP.

    Args:
        model (nn.Module): Model to de-parallelize.

    Returns:
        (nn.Module): De-parallelized model.
    """
    return model.module if is_parallel(model) else model





ultralytics.utils.torch_utils.one_cycle

one_cycle(y1=0.0, y2=1.0, steps=100)

Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.

Parameters:

Name Type Description Default
y1 float

Initial value. Defaults to 0.0.

0.0
y2 float

Final value. Defaults to 1.0.

1.0
steps int

Number of steps. Defaults to 100.

100

Returns:

Type Description
function

Lambda function for computing the sinusoidal ramp.

Source code in ultralytics/utils/torch_utils.py
def one_cycle(y1=0.0, y2=1.0, steps=100):
    """
    Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.

    Args:
        y1 (float, optional): Initial value. Defaults to 0.0.
        y2 (float, optional): Final value. Defaults to 1.0.
        steps (int, optional): Number of steps. Defaults to 100.

    Returns:
        (function): Lambda function for computing the sinusoidal ramp.
    """
    return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1





ultralytics.utils.torch_utils.init_seeds

init_seeds(seed=0, deterministic=False)

Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.

Parameters:

Name Type Description Default
seed int

Random seed. Defaults to 0.

0
deterministic bool

Whether to set deterministic algorithms. Defaults to False.

False
Source code in ultralytics/utils/torch_utils.py
def init_seeds(seed=0, deterministic=False):
    """
    Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.

    Args:
        seed (int, optional): Random seed. Defaults to 0.
        deterministic (bool, optional): Whether to set deterministic algorithms. Defaults to False.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for Multi-GPU, exception safe
    # torch.backends.cudnn.benchmark = True  # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
    if deterministic:
        if TORCH_2_0:
            torch.use_deterministic_algorithms(True, warn_only=True)  # warn if deterministic is not possible
            torch.backends.cudnn.deterministic = True
            os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
            os.environ["PYTHONHASHSEED"] = str(seed)
        else:
            LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
    else:
        unset_deterministic()





ultralytics.utils.torch_utils.unset_deterministic

unset_deterministic()

Unsets all the configurations applied for deterministic training.

Source code in ultralytics/utils/torch_utils.py
def unset_deterministic():
    """Unsets all the configurations applied for deterministic training."""
    torch.use_deterministic_algorithms(False)
    torch.backends.cudnn.deterministic = False
    os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
    os.environ.pop("PYTHONHASHSEED", None)





ultralytics.utils.torch_utils.strip_optimizer

strip_optimizer(
    f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None
) -> dict

Strip optimizer from 'f' to finalize training, optionally save as 's'.

Parameters:

Name Type Description Default
f str | Path

File path to model to strip the optimizer from. Defaults to 'best.pt'.

'best.pt'
s str

File path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.

''
updates dict

A dictionary of updates to overlay onto the checkpoint before saving.

None

Returns:

Type Description
dict

The combined checkpoint dictionary.

Examples:

>>> from pathlib import Path
>>> from ultralytics.utils.torch_utils import strip_optimizer
>>> for f in Path("path/to/model/checkpoints").rglob("*.pt"):
>>>    strip_optimizer(f)
Source code in ultralytics/utils/torch_utils.py
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
    """
    Strip optimizer from 'f' to finalize training, optionally save as 's'.

    Args:
        f (str | Path): File path to model to strip the optimizer from. Defaults to 'best.pt'.
        s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
        updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.

    Returns:
        (dict): The combined checkpoint dictionary.

    Examples:
        >>> from pathlib import Path
        >>> from ultralytics.utils.torch_utils import strip_optimizer
        >>> for f in Path("path/to/model/checkpoints").rglob("*.pt"):
        >>>    strip_optimizer(f)
    """
    try:
        x = torch.load(f, map_location=torch.device("cpu"))
        assert isinstance(x, dict), "checkpoint is not a Python dictionary"
        assert "model" in x, "'model' missing from checkpoint"
    except Exception as e:
        LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
        return {}

    metadata = {
        "date": datetime.now().isoformat(),
        "version": __version__,
        "license": "AGPL-3.0 License (https://ultralytics.com/license)",
        "docs": "https://docs.ultralytics.com",
    }

    # Update model
    if x.get("ema"):
        x["model"] = x["ema"]  # replace model with EMA
    if hasattr(x["model"], "args"):
        x["model"].args = dict(x["model"].args)  # convert from IterableSimpleNamespace to dict
    if hasattr(x["model"], "criterion"):
        x["model"].criterion = None  # strip loss criterion
    x["model"].half()  # to FP16
    for p in x["model"].parameters():
        p.requires_grad = False

    # Update other keys
    args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})}  # combine args
    for k in "optimizer", "best_fitness", "ema", "updates":  # keys
        x[k] = None
    x["epoch"] = -1
    x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # strip non-default keys
    # x['model'].args = x['train_args']

    # Save
    combined = {**metadata, **x, **(updates or {})}
    torch.save(combined, s or f)  # combine dicts (prefer to the right)
    mb = os.path.getsize(s or f) / 1e6  # file size
    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
    return combined





ultralytics.utils.torch_utils.convert_optimizer_state_dict_to_fp16

convert_optimizer_state_dict_to_fp16(state_dict)

Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.

Parameters:

Name Type Description Default
state_dict dict

Optimizer state dictionary.

required

Returns:

Type Description
dict

Converted optimizer state dictionary with FP16 tensors.

Source code in ultralytics/utils/torch_utils.py
def convert_optimizer_state_dict_to_fp16(state_dict):
    """
    Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.

    Args:
        state_dict (dict): Optimizer state dictionary.

    Returns:
        (dict): Converted optimizer state dictionary with FP16 tensors.
    """
    for state in state_dict["state"].values():
        for k, v in state.items():
            if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
                state[k] = v.half()

    return state_dict





ultralytics.utils.torch_utils.cuda_memory_usage

cuda_memory_usage(device=None)

Monitor and manage CUDA memory usage.

This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory. It then yields a dictionary containing memory usage information, which can be updated by the caller. Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.

Parameters:

Name Type Description Default
device device

The CUDA device to query memory usage for. Defaults to None.

None

Yields:

Type Description
dict

A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.

Source code in ultralytics/utils/torch_utils.py
@contextmanager
def cuda_memory_usage(device=None):
    """
    Monitor and manage CUDA memory usage.

    This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
    It then yields a dictionary containing memory usage information, which can be updated by the caller.
    Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.

    Args:
        device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.

    Yields:
        (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
    """
    cuda_info = dict(memory=0)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        try:
            yield cuda_info
        finally:
            cuda_info["memory"] = torch.cuda.memory_reserved(device)
    else:
        yield cuda_info





ultralytics.utils.torch_utils.profile

profile(input, ops, n=10, device=None, max_num_obj=0)

Ultralytics speed, memory and FLOPs profiler.

Parameters:

Name Type Description Default
input Tensor | List[Tensor]

Input tensor(s) to profile.

required
ops Module | List[Module]

Model or list of operations to profile.

required
n int

Number of iterations to average. Defaults to 10.

10
device str | device

Device to profile on. Defaults to None.

None
max_num_obj int

Maximum number of objects for simulation. Defaults to 0.

0

Returns:

Type Description
list

Profile results for each operation.

Examples:

>>> from ultralytics.utils.torch_utils import profile
>>> input = torch.randn(16, 3, 640, 640)
>>> m1 = lambda x: x * torch.sigmoid(x)
>>> m2 = nn.SiLU()
>>> profile(input, [m1, m2], n=100)  # profile over 100 iterations
Source code in ultralytics/utils/torch_utils.py
def profile(input, ops, n=10, device=None, max_num_obj=0):
    """
    Ultralytics speed, memory and FLOPs profiler.

    Args:
        input (torch.Tensor | List[torch.Tensor]): Input tensor(s) to profile.
        ops (nn.Module | List[nn.Module]): Model or list of operations to profile.
        n (int, optional): Number of iterations to average. Defaults to 10.
        device (str | torch.device, optional): Device to profile on. Defaults to None.
        max_num_obj (int, optional): Maximum number of objects for simulation. Defaults to 0.

    Returns:
        (list): Profile results for each operation.

    Examples:
        >>> from ultralytics.utils.torch_utils import profile
        >>> input = torch.randn(16, 3, 640, 640)
        >>> m1 = lambda x: x * torch.sigmoid(x)
        >>> m2 = nn.SiLU()
        >>> profile(input, [m1, m2], n=100)  # profile over 100 iterations
    """
    results = []
    if not isinstance(device, torch.device):
        device = select_device(device)
    LOGGER.info(
        f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
        f"{'input':>24s}{'output':>24s}"
    )
    gc.collect()  # attempt to free unused memory
    torch.cuda.empty_cache()
    for x in input if isinstance(input, list) else [input]:
        x = x.to(device)
        x.requires_grad = True
        for m in ops if isinstance(ops, list) else [ops]:
            m = m.to(device) if hasattr(m, "to") else m  # device
            m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
            tf, tb, t = 0, 0, [0, 0, 0]  # dt forward, backward
            try:
                flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0  # GFLOPs
            except Exception:
                flops = 0

            try:
                mem = 0
                for _ in range(n):
                    with cuda_memory_usage(device) as cuda_info:
                        t[0] = time_sync()
                        y = m(x)
                        t[1] = time_sync()
                        try:
                            (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
                            t[2] = time_sync()
                        except Exception:  # no backward method
                            # print(e)  # for debug
                            t[2] = float("nan")
                    mem += cuda_info["memory"] / 1e9  # (GB)
                    tf += (t[1] - t[0]) * 1000 / n  # ms per op forward
                    tb += (t[2] - t[1]) * 1000 / n  # ms per op backward
                    if max_num_obj:  # simulate training with predictions per image grid (for AutoBatch)
                        with cuda_memory_usage(device) as cuda_info:
                            torch.randn(
                                x.shape[0],
                                max_num_obj,
                                int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
                                device=device,
                                dtype=torch.float32,
                            )
                        mem += cuda_info["memory"] / 1e9  # (GB)
                s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y))  # shapes
                p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0  # parameters
                LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
                results.append([p, flops, mem, tf, tb, s_in, s_out])
            except Exception as e:
                LOGGER.info(e)
                results.append(None)
            finally:
                gc.collect()  # attempt to free unused memory
                torch.cuda.empty_cache()
    return results



📅 Created 1 year ago ✏️ Updated 1 month ago