Skip to content

Reference for ultralytics/engine/trainer.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/trainer.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.engine.trainer.BaseTrainer

BaseTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

BaseTrainer.

A base class for creating trainers.

Attributes:

Name Type Description
args SimpleNamespace

Configuration for the trainer.

validator BaseValidator

Validator instance.

model Module

Model instance.

callbacks defaultdict

Dictionary of callbacks.

save_dir Path

Directory to save results.

wdir Path

Directory to save weights.

last Path

Path to the last checkpoint.

best Path

Path to the best checkpoint.

save_period int

Save checkpoint every x epochs (disabled if < 1).

batch_size int

Batch size for training.

epochs int

Number of epochs to train for.

start_epoch int

Starting epoch for training.

device device

Device to use for training.

amp bool

Flag to enable AMP (Automatic Mixed Precision).

scaler GradScaler

Gradient scaler for AMP.

data str

Path to data.

trainset Dataset

Training dataset.

testset Dataset

Testing dataset.

ema Module

EMA (Exponential Moving Average) of the model.

resume bool

Resume training from a checkpoint.

lf Module

Loss function.

scheduler _LRScheduler

Learning rate scheduler.

best_fitness float

The best fitness value achieved.

fitness float

Current fitness value.

loss float

Current loss value.

tloss float

Total loss value.

loss_names list

List of loss names.

csv Path

Path to results CSV file.

Parameters:

Name Type Description Default
cfg str

Path to a configuration file. Defaults to DEFAULT_CFG.

DEFAULT_CFG
overrides dict

Configuration overrides. Defaults to None.

None
Source code in ultralytics/engine/trainer.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initializes the BaseTrainer class.

    Args:
        cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
        overrides (dict, optional): Configuration overrides. Defaults to None.
    """
    self.args = get_cfg(cfg, overrides)
    self.check_resume(overrides)
    self.device = select_device(self.args.device, self.args.batch)
    self.validator = None
    self.metrics = None
    self.plots = {}
    init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)

    # Dirs
    self.save_dir = get_save_dir(self.args)
    self.args.name = self.save_dir.name  # update name for loggers
    self.wdir = self.save_dir / "weights"  # weights dir
    if RANK in {-1, 0}:
        self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
        self.args.save_dir = str(self.save_dir)
        yaml_save(self.save_dir / "args.yaml", vars(self.args))  # save run args
    self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths
    self.save_period = self.args.save_period

    self.batch_size = self.args.batch
    self.epochs = self.args.epochs
    self.start_epoch = 0
    if RANK == -1:
        print_args(vars(self.args))

    # Device
    if self.device.type in {"cpu", "mps"}:
        self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading

    # Model and Dataset
    self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolov8n -> yolov8n.pt
    with torch_distributed_zero_first(RANK):  # avoid auto-downloading dataset multiple times
        self.trainset, self.testset = self.get_dataset()
    self.ema = None

    # Optimization utils init
    self.lf = None
    self.scheduler = None

    # Epoch level metrics
    self.best_fitness = None
    self.fitness = None
    self.loss = None
    self.tloss = None
    self.loss_names = ["Loss"]
    self.csv = self.save_dir / "results.csv"
    self.plot_idx = [0, 1, 2]

    # HUB
    self.hub_session = None

    # Callbacks
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    if RANK in {-1, 0}:
        callbacks.add_integration_callbacks(self)

add_callback

add_callback(event: str, callback)

Appends the given callback.

Source code in ultralytics/engine/trainer.py
def add_callback(self, event: str, callback):
    """Appends the given callback."""
    self.callbacks[event].append(callback)

build_dataset

build_dataset(img_path, mode='train', batch=None)

Build dataset.

Source code in ultralytics/engine/trainer.py
def build_dataset(self, img_path, mode="train", batch=None):
    """Build dataset."""
    raise NotImplementedError("build_dataset function not implemented in trainer")

build_optimizer

build_optimizer(model, name='auto', lr=0.001, momentum=0.9, decay=1e-05, iterations=100000.0)

Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum, weight decay, and number of iterations.

Parameters:

Name Type Description Default
model Module

The model for which to build an optimizer.

required
name str

The name of the optimizer to use. If 'auto', the optimizer is selected based on the number of iterations. Default: 'auto'.

'auto'
lr float

The learning rate for the optimizer. Default: 0.001.

0.001
momentum float

The momentum factor for the optimizer. Default: 0.9.

0.9
decay float

The weight decay for the optimizer. Default: 1e-5.

1e-05
iterations float

The number of iterations, which determines the optimizer if name is 'auto'. Default: 1e5.

100000.0

Returns:

Type Description
Optimizer

The constructed optimizer.

Source code in ultralytics/engine/trainer.py
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
    """
    Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
    weight decay, and number of iterations.

    Args:
        model (torch.nn.Module): The model for which to build an optimizer.
        name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
            based on the number of iterations. Default: 'auto'.
        lr (float, optional): The learning rate for the optimizer. Default: 0.001.
        momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
        decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
        iterations (float, optional): The number of iterations, which determines the optimizer if
            name is 'auto'. Default: 1e5.

    Returns:
        (torch.optim.Optimizer): The constructed optimizer.
    """

    g = [], [], []  # optimizer parameter groups
    bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()
    if name == "auto":
        LOGGER.info(
            f"{colorstr('optimizer:')} 'optimizer=auto' found, "
            f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
            f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
        )
        nc = getattr(model, "nc", 10)  # number of classes
        lr_fit = round(0.002 * 5 / (4 + nc), 6)  # lr0 fit equation to 6 decimal places
        name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
        self.args.warmup_bias_lr = 0.0  # no higher than 0.01 for Adam

    for module_name, module in model.named_modules():
        for param_name, param in module.named_parameters(recurse=False):
            fullname = f"{module_name}.{param_name}" if module_name else param_name
            if "bias" in fullname:  # bias (no decay)
                g[2].append(param)
            elif isinstance(module, bn):  # weight (no decay)
                g[1].append(param)
            else:  # weight (with decay)
                g[0].append(param)

    if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
        optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
    elif name == "RMSProp":
        optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
    elif name == "SGD":
        optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
    else:
        raise NotImplementedError(
            f"Optimizer '{name}' not found in list of available optimizers "
            f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
            "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
        )

    optimizer.add_param_group({"params": g[0], "weight_decay": decay})  # add g0 with weight_decay
    optimizer.add_param_group({"params": g[1], "weight_decay": 0.0})  # add g1 (BatchNorm2d weights)
    LOGGER.info(
        f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
        f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
    )
    return optimizer

build_targets

build_targets(preds, targets)

Builds target tensors for training YOLO model.

Source code in ultralytics/engine/trainer.py
def build_targets(self, preds, targets):
    """Builds target tensors for training YOLO model."""
    pass

check_resume

check_resume(overrides)

Check if resume checkpoint exists and update arguments accordingly.

Source code in ultralytics/engine/trainer.py
def check_resume(self, overrides):
    """Check if resume checkpoint exists and update arguments accordingly."""
    resume = self.args.resume
    if resume:
        try:
            exists = isinstance(resume, (str, Path)) and Path(resume).exists()
            last = Path(check_file(resume) if exists else get_latest_run())

            # Check that resume data YAML exists, otherwise strip to force re-download of dataset
            ckpt_args = attempt_load_weights(last).args
            if not Path(ckpt_args["data"]).exists():
                ckpt_args["data"] = self.args.data

            resume = True
            self.args = get_cfg(ckpt_args)
            self.args.model = self.args.resume = str(last)  # reinstate model
            for k in "imgsz", "batch", "device":  # allow arg updates to reduce memory or update device on resume
                if k in overrides:
                    setattr(self.args, k, overrides[k])

        except Exception as e:
            raise FileNotFoundError(
                "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
                "i.e. 'yolo train resume model=path/to/last.pt'"
            ) from e
    self.resume = resume

final_eval

final_eval()

Performs final evaluation and validation for object detection YOLO model.

Source code in ultralytics/engine/trainer.py
def final_eval(self):
    """Performs final evaluation and validation for object detection YOLO model."""
    for f in self.last, self.best:
        if f.exists():
            strip_optimizer(f)  # strip optimizers
            if f is self.best:
                LOGGER.info(f"\nValidating {f}...")
                self.validator.args.plots = self.args.plots
                self.metrics = self.validator(model=f)
                self.metrics.pop("fitness", None)
                self.run_callbacks("on_fit_epoch_end")

get_dataloader

get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')

Returns dataloader derived from torch.data.Dataloader.

Source code in ultralytics/engine/trainer.py
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
    """Returns dataloader derived from torch.data.Dataloader."""
    raise NotImplementedError("get_dataloader function not implemented in trainer")

get_dataset

get_dataset()

Get train, val path from data dict if it exists.

Returns None if data format is not recognized.

Source code in ultralytics/engine/trainer.py
def get_dataset(self):
    """
    Get train, val path from data dict if it exists.

    Returns None if data format is not recognized.
    """
    try:
        if self.args.task == "classify":
            data = check_cls_dataset(self.args.data)
        elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
            "detect",
            "segment",
            "pose",
            "obb",
        }:
            data = check_det_dataset(self.args.data)
            if "yaml_file" in data:
                self.args.data = data["yaml_file"]  # for validating 'yolo train data=url.zip' usage
    except Exception as e:
        raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
    self.data = data
    return data["train"], data.get("val") or data.get("test")

get_model

get_model(cfg=None, weights=None, verbose=True)

Get model and raise NotImplementedError for loading cfg files.

Source code in ultralytics/engine/trainer.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Get model and raise NotImplementedError for loading cfg files."""
    raise NotImplementedError("This task trainer doesn't support loading cfg files")

get_validator

get_validator()

Returns a NotImplementedError when the get_validator function is called.

Source code in ultralytics/engine/trainer.py
def get_validator(self):
    """Returns a NotImplementedError when the get_validator function is called."""
    raise NotImplementedError("get_validator function not implemented in trainer")

label_loss_items

label_loss_items(loss_items=None, prefix='train')

Returns a loss dict with labelled training loss items tensor.

Note

This is not needed for classification but necessary for segmentation & detection

Source code in ultralytics/engine/trainer.py
def label_loss_items(self, loss_items=None, prefix="train"):
    """
    Returns a loss dict with labelled training loss items tensor.

    Note:
        This is not needed for classification but necessary for segmentation & detection
    """
    return {"loss": loss_items} if loss_items is not None else ["loss"]

on_plot

on_plot(name, data=None)

Registers plots (e.g. to be consumed in callbacks)

Source code in ultralytics/engine/trainer.py
def on_plot(self, name, data=None):
    """Registers plots (e.g. to be consumed in callbacks)"""
    path = Path(name)
    self.plots[path] = {"data": data, "timestamp": time.time()}

optimizer_step

optimizer_step()

Perform a single step of the training optimizer with gradient clipping and EMA update.

Source code in ultralytics/engine/trainer.py
def optimizer_step(self):
    """Perform a single step of the training optimizer with gradient clipping and EMA update."""
    self.scaler.unscale_(self.optimizer)  # unscale gradients
    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)  # clip gradients
    self.scaler.step(self.optimizer)
    self.scaler.update()
    self.optimizer.zero_grad()
    if self.ema:
        self.ema.update(self.model)

plot_metrics

plot_metrics()

Plot and display metrics visually.

Source code in ultralytics/engine/trainer.py
def plot_metrics(self):
    """Plot and display metrics visually."""
    pass

plot_training_labels

plot_training_labels()

Plots training labels for YOLO model.

Source code in ultralytics/engine/trainer.py
def plot_training_labels(self):
    """Plots training labels for YOLO model."""
    pass

plot_training_samples

plot_training_samples(batch, ni)

Plots training samples during YOLO training.

Source code in ultralytics/engine/trainer.py
def plot_training_samples(self, batch, ni):
    """Plots training samples during YOLO training."""
    pass

preprocess_batch

preprocess_batch(batch)

Allows custom preprocessing model inputs and ground truths depending on task type.

Source code in ultralytics/engine/trainer.py
def preprocess_batch(self, batch):
    """Allows custom preprocessing model inputs and ground truths depending on task type."""
    return batch

progress_string

progress_string()

Returns a string describing training progress.

Source code in ultralytics/engine/trainer.py
def progress_string(self):
    """Returns a string describing training progress."""
    return ""

resume_training

resume_training(ckpt)

Resume YOLO training from given epoch and best fitness.

Source code in ultralytics/engine/trainer.py
def resume_training(self, ckpt):
    """Resume YOLO training from given epoch and best fitness."""
    if ckpt is None or not self.resume:
        return
    best_fitness = 0.0
    start_epoch = ckpt.get("epoch", -1) + 1
    if ckpt.get("optimizer", None) is not None:
        self.optimizer.load_state_dict(ckpt["optimizer"])  # optimizer
        best_fitness = ckpt["best_fitness"]
    if self.ema and ckpt.get("ema"):
        self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())  # EMA
        self.ema.updates = ckpt["updates"]
    assert start_epoch > 0, (
        f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
        f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
    )
    LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
    if self.epochs < start_epoch:
        LOGGER.info(
            f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
        )
        self.epochs += ckpt["epoch"]  # finetune additional epochs
    self.best_fitness = best_fitness
    self.start_epoch = start_epoch
    if start_epoch > (self.epochs - self.args.close_mosaic):
        self._close_dataloader_mosaic()

run_callbacks

run_callbacks(event: str)

Run all existing callbacks associated with a particular event.

Source code in ultralytics/engine/trainer.py
def run_callbacks(self, event: str):
    """Run all existing callbacks associated with a particular event."""
    for callback in self.callbacks.get(event, []):
        callback(self)

save_metrics

save_metrics(metrics)

Saves training metrics to a CSV file.

Source code in ultralytics/engine/trainer.py
def save_metrics(self, metrics):
    """Saves training metrics to a CSV file."""
    keys, vals = list(metrics.keys()), list(metrics.values())
    n = len(metrics) + 1  # number of cols
    s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n")  # header
    with open(self.csv, "a") as f:
        f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")

save_model

save_model()

Save model training checkpoints with additional metadata.

Source code in ultralytics/engine/trainer.py
def save_model(self):
    """Save model training checkpoints with additional metadata."""
    import io

    import pandas as pd  # scope for faster 'import ultralytics'

    # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
    buffer = io.BytesIO()
    torch.save(
        {
            "epoch": self.epoch,
            "best_fitness": self.best_fitness,
            "model": None,  # resume and final checkpoints derive from EMA
            "ema": deepcopy(self.ema.ema).half(),
            "updates": self.ema.updates,
            "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
            "train_args": vars(self.args),  # save as dict
            "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
            "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
            "date": datetime.now().isoformat(),
            "version": __version__,
            "license": "AGPL-3.0 (https://ultralytics.com/license)",
            "docs": "https://docs.ultralytics.com",
        },
        buffer,
    )
    serialized_ckpt = buffer.getvalue()  # get the serialized content to save

    # Save checkpoints
    self.last.write_bytes(serialized_ckpt)  # save last.pt
    if self.best_fitness == self.fitness:
        self.best.write_bytes(serialized_ckpt)  # save best.pt
    if (self.save_period > 0) and (self.epoch % self.save_period == 0):
        (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt)  # save epoch, i.e. 'epoch3.pt'

set_callback

set_callback(event: str, callback)

Overrides the existing callbacks with the given callback.

Source code in ultralytics/engine/trainer.py
def set_callback(self, event: str, callback):
    """Overrides the existing callbacks with the given callback."""
    self.callbacks[event] = [callback]

set_model_attributes

set_model_attributes()

To set or update model parameters before training.

Source code in ultralytics/engine/trainer.py
def set_model_attributes(self):
    """To set or update model parameters before training."""
    self.model.names = self.data["names"]

setup_model

setup_model()

Load/create/download model for any task.

Source code in ultralytics/engine/trainer.py
def setup_model(self):
    """Load/create/download model for any task."""
    if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
        return

    cfg, weights = self.model, None
    ckpt = None
    if str(self.model).endswith(".pt"):
        weights, ckpt = attempt_load_one_weight(self.model)
        cfg = weights.yaml
    elif isinstance(self.args.pretrained, (str, Path)):
        weights, _ = attempt_load_one_weight(self.args.pretrained)
    self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
    return ckpt

train

train()

Allow device='', device=None on Multi-GPU systems to default to device=0.

Source code in ultralytics/engine/trainer.py
def train(self):
    """Allow device='', device=None on Multi-GPU systems to default to device=0."""
    if isinstance(self.args.device, str) and len(self.args.device):  # i.e. device='0' or device='0,1,2,3'
        world_size = len(self.args.device.split(","))
    elif isinstance(self.args.device, (tuple, list)):  # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
        world_size = len(self.args.device)
    elif torch.cuda.is_available():  # i.e. device=None or device='' or device=number
        world_size = 1  # default to device 0
    else:  # i.e. device='cpu' or 'mps'
        world_size = 0

    # Run subprocess if DDP training, else train normally
    if world_size > 1 and "LOCAL_RANK" not in os.environ:
        # Argument checks
        if self.args.rect:
            LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
            self.args.rect = False
        if self.args.batch < 1.0:
            LOGGER.warning(
                "WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting "
                "default 'batch=16'"
            )
            self.args.batch = 16

        # Command
        cmd, file = generate_ddp_command(world_size, self)
        try:
            LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
            subprocess.run(cmd, check=True)
        except Exception as e:
            raise e
        finally:
            ddp_cleanup(self, str(file))

    else:
        self._do_train(world_size)

validate

validate()

Runs validation on test set using self.validator.

The returned dict is expected to contain "fitness" key.

Source code in ultralytics/engine/trainer.py
def validate(self):
    """
    Runs validation on test set using self.validator.

    The returned dict is expected to contain "fitness" key.
    """
    metrics = self.validator(self)
    fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
    if not self.best_fitness or self.best_fitness < fitness:
        self.best_fitness = fitness
    return metrics, fitness





Created 2023-11-12, Updated 2024-07-21
Authors: glenn-jocher (6), Burhan-Q (1)