Skip to content

Reference for ultralytics/engine/trainer.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/trainer.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.engine.trainer.BaseTrainer

BaseTrainer(self, cfg = DEFAULT_CFG, overrides = None, _callbacks = None)

A base class for creating trainers.

This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing, and various training utilities. It supports both single-GPU and multi-GPU distributed training.

Args

NameTypeDescriptionDefault
cfgstr, optionalPath to a configuration file.DEFAULT_CFG
overridesdict, optionalConfiguration overrides.None
_callbackslist, optionalList of callback functions.None

Attributes

NameTypeDescription
argsSimpleNamespaceConfiguration for the trainer.
validatorBaseValidatorValidator instance.
modelnn.ModuleModel instance.
callbacksdefaultdictDictionary of callbacks.
save_dirPathDirectory to save results.
wdirPathDirectory to save weights.
lastPathPath to the last checkpoint.
bestPathPath to the best checkpoint.
save_periodintSave checkpoint every x epochs (disabled if < 1).
batch_sizeintBatch size for training.
epochsintNumber of epochs to train for.
start_epochintStarting epoch for training.
devicetorch.deviceDevice to use for training.
ampboolFlag to enable AMP (Automatic Mixed Precision).
scaleramp.GradScalerGradient scaler for AMP.
datastrPath to data.
emann.ModuleEMA (Exponential Moving Average) of the model.
resumeboolResume training from a checkpoint.
lfnn.ModuleLoss function.
schedulertorch.optim.lr_scheduler._LRSchedulerLearning rate scheduler.
best_fitnessfloatThe best fitness value achieved.
fitnessfloatCurrent fitness value.
lossfloatCurrent loss value.
tlossfloatTotal loss value.
loss_nameslistList of loss names.
csvPathPath to results CSV file.
metricsdictDictionary of metrics.
plotsdictDictionary of plots.

Methods

NameDescription
_clear_memoryClear accelerator memory by calling garbage collector and emptying cache.
_close_dataloader_mosaicUpdate dataloaders to stop using mosaic augmentation.
_do_trainTrain the model with the specified world size.
_get_memoryGet accelerator memory utilization in GB or as a fraction of total memory.
_handle_nan_recoveryDetect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint.
_load_checkpoint_stateLoad optimizer, scaler, EMA, and best_fitness from checkpoint.
_model_trainSet model in training mode.
_setup_ddpInitialize and set the DistributedDataParallel parameters for training.
_setup_schedulerInitialize training learning rate scheduler.
_setup_trainBuild dataloaders and optimizer on correct rank process.
add_callbackAppend the given callback to the event's callback list.
auto_batchCalculate optimal batch size based on model and device memory constraints.
build_datasetBuild dataset.
build_optimizerConstruct an optimizer for the given model.
build_targetsBuild target tensors for training YOLO model.
check_resumeCheck if resume checkpoint exists and update arguments accordingly.
final_evalPerform final evaluation and validation for object detection YOLO model.
get_dataloaderReturn dataloader derived from torch.data.Dataloader.
get_datasetGet train and validation datasets from data dictionary.
get_modelGet model and raise NotImplementedError for loading cfg files.
get_validatorReturn a NotImplementedError when the get_validator function is called.
label_loss_itemsReturn a loss dict with labeled training loss items tensor.
on_plotRegister plots (e.g. to be consumed in callbacks).
optimizer_stepPerform a single step of the training optimizer with gradient clipping and EMA update.
plot_metricsPlot metrics from a CSV file.
plot_training_labelsPlot training labels for YOLO model.
plot_training_samplesPlot training samples during YOLO training.
preprocess_batchAllow custom preprocessing model inputs and ground truths depending on task type.
progress_stringReturn a string describing training progress.
read_results_csvRead results.csv into a dictionary using polars.
resume_trainingResume YOLO training from given epoch and best fitness.
run_callbacksRun all existing callbacks associated with a particular event.
save_metricsSave training metrics to a CSV file.
save_modelSave model training checkpoints with additional metadata.
set_callbackOverride the existing callbacks with the given callback for the specified event.
set_model_attributesSet or update model parameters before training.
setup_modelLoad, create, or download model for any task.
trainAllow device='', device=None on Multi-GPU systems to default to device=0.
validateRun validation on val set using self.validator.

Examples

Initialize a trainer and start training
>>> trainer = BaseTrainer(cfg="config.yaml")
>>> trainer.train()
Source code in ultralytics/engine/trainer.pyView on GitHub
class BaseTrainer:
    """A base class for creating trainers.

    This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
    and various training utilities. It supports both single-GPU and multi-GPU distributed training.

    Attributes:
        args (SimpleNamespace): Configuration for the trainer.
        validator (BaseValidator): Validator instance.
        model (nn.Module): Model instance.
        callbacks (defaultdict): Dictionary of callbacks.
        save_dir (Path): Directory to save results.
        wdir (Path): Directory to save weights.
        last (Path): Path to the last checkpoint.
        best (Path): Path to the best checkpoint.
        save_period (int): Save checkpoint every x epochs (disabled if < 1).
        batch_size (int): Batch size for training.
        epochs (int): Number of epochs to train for.
        start_epoch (int): Starting epoch for training.
        device (torch.device): Device to use for training.
        amp (bool): Flag to enable AMP (Automatic Mixed Precision).
        scaler (amp.GradScaler): Gradient scaler for AMP.
        data (str): Path to data.
        ema (nn.Module): EMA (Exponential Moving Average) of the model.
        resume (bool): Resume training from a checkpoint.
        lf (nn.Module): Loss function.
        scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
        best_fitness (float): The best fitness value achieved.
        fitness (float): Current fitness value.
        loss (float): Current loss value.
        tloss (float): Total loss value.
        loss_names (list): List of loss names.
        csv (Path): Path to results CSV file.
        metrics (dict): Dictionary of metrics.
        plots (dict): Dictionary of plots.

    Methods:
        train: Execute the training process.
        validate: Run validation on the test set.
        save_model: Save model training checkpoints.
        get_dataset: Get train and validation datasets.
        setup_model: Load, create, or download model.
        build_optimizer: Construct an optimizer for the model.

    Examples:
        Initialize a trainer and start training
        >>> trainer = BaseTrainer(cfg="config.yaml")
        >>> trainer.train()
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize the BaseTrainer class.

        Args:
            cfg (str, optional): Path to a configuration file.
            overrides (dict, optional): Configuration overrides.
            _callbacks (list, optional): List of callback functions.
        """
        self.hub_session = overrides.pop("session", None)  # HUB
        self.args = get_cfg(cfg, overrides)
        self.check_resume(overrides)
        self.device = select_device(self.args.device)
        # Update "-1" devices so post-training val does not repeat search
        self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
        self.validator = None
        self.metrics = None
        self.plots = {}
        init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)

        # Dirs
        self.save_dir = get_save_dir(self.args)
        self.args.name = self.save_dir.name  # update name for loggers
        self.wdir = self.save_dir / "weights"  # weights dir
        if RANK in {-1, 0}:
            self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
            self.args.save_dir = str(self.save_dir)
            # Save run args, serializing augmentations as reprs for resume compatibility
            args_dict = vars(self.args).copy()
            if args_dict.get("augmentations") is not None:
                # Serialize Albumentations transforms as their repr strings for checkpoint compatibility
                args_dict["augmentations"] = [repr(t) for t in args_dict["augmentations"]]
            YAML.save(self.save_dir / "args.yaml", args_dict)  # save run args
        self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths
        self.save_period = self.args.save_period

        self.batch_size = self.args.batch
        self.epochs = self.args.epochs or 100  # in case users accidentally pass epochs=None with timed training
        self.start_epoch = 0
        if RANK == -1:
            print_args(vars(self.args))

        # Device
        if self.device.type in {"cpu", "mps"}:
            self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading

        # Model and Dataset
        self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolo11n -> yolo11n.pt
        with torch_distributed_zero_first(LOCAL_RANK):  # avoid auto-downloading dataset multiple times
            self.data = self.get_dataset()

        self.ema = None

        # Optimization utils init
        self.lf = None
        self.scheduler = None

        # Epoch level metrics
        self.best_fitness = None
        self.fitness = None
        self.loss = None
        self.tloss = None
        self.loss_names = ["Loss"]
        self.csv = self.save_dir / "results.csv"
        if self.csv.exists() and not self.args.resume:
            self.csv.unlink()
        self.plot_idx = [0, 1, 2]
        self.nan_recovery_attempts = 0

        # Callbacks
        self.callbacks = _callbacks or callbacks.get_default_callbacks()

        if isinstance(self.args.device, str) and len(self.args.device):  # i.e. device='0' or device='0,1,2,3'
            world_size = len(self.args.device.split(","))
        elif isinstance(self.args.device, (tuple, list)):  # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
            world_size = len(self.args.device)
        elif self.args.device in {"cpu", "mps"}:  # i.e. device='cpu' or 'mps'
            world_size = 0
        elif torch.cuda.is_available():  # i.e. device=None or device='' or device=number
            world_size = 1  # default to device 0
        else:  # i.e. device=None or device=''
            world_size = 0

        self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
        self.world_size = world_size
        # Run subprocess if DDP training, else train normally
        if RANK in {-1, 0} and not self.ddp:
            callbacks.add_integration_callbacks(self)
            # Start console logging immediately at trainer initialization
            self.run_callbacks("on_pretrain_routine_start")


method ultralytics.engine.trainer.BaseTrainer._clear_memory

def _clear_memory(self, threshold: float | None = None)

Clear accelerator memory by calling garbage collector and emptying cache.

Args

NameTypeDescriptionDefault
thresholdfloat | NoneNone
Source code in ultralytics/engine/trainer.pyView on GitHub
def _clear_memory(self, threshold: float | None = None):
    """Clear accelerator memory by calling garbage collector and emptying cache."""
    if threshold:
        assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
        if self._get_memory(fraction=True) <= threshold:
            return
    gc.collect()
    if self.device.type == "mps":
        torch.mps.empty_cache()
    elif self.device.type == "cpu":
        return
    else:
        torch.cuda.empty_cache()


method ultralytics.engine.trainer.BaseTrainer._close_dataloader_mosaic

def _close_dataloader_mosaic(self)

Update dataloaders to stop using mosaic augmentation.

Source code in ultralytics/engine/trainer.pyView on GitHub
def _close_dataloader_mosaic(self):
    """Update dataloaders to stop using mosaic augmentation."""
    if hasattr(self.train_loader.dataset, "mosaic"):
        self.train_loader.dataset.mosaic = False
    if hasattr(self.train_loader.dataset, "close_mosaic"):
        LOGGER.info("Closing dataloader mosaic")
        self.train_loader.dataset.close_mosaic(hyp=copy(self.args))


method ultralytics.engine.trainer.BaseTrainer._do_train

def _do_train(self)

Train the model with the specified world size.

Source code in ultralytics/engine/trainer.pyView on GitHub
def _do_train(self):
    """Train the model with the specified world size."""
    if self.world_size > 1:
        self._setup_ddp()
    self._setup_train()

    nb = len(self.train_loader)  # number of batches
    nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1  # warmup iterations
    last_opt_step = -1
    self.epoch_time = None
    self.epoch_time_start = time.time()
    self.train_time_start = time.time()
    self.run_callbacks("on_train_start")
    LOGGER.info(
        f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
        f"Using {self.train_loader.num_workers * (self.world_size or 1)} dataloader workers\n"
        f"Logging results to {colorstr('bold', self.save_dir)}\n"
        f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
    )
    if self.args.close_mosaic:
        base_idx = (self.epochs - self.args.close_mosaic) * nb
        self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
    epoch = self.start_epoch
    self.optimizer.zero_grad()  # zero any resumed gradients to ensure stability on train start
    while True:
        self.epoch = epoch
        self.run_callbacks("on_train_epoch_start")
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")  # suppress 'Detected lr_scheduler.step() before optimizer.step()'
            self.scheduler.step()

        self._model_train()
        if RANK != -1:
            self.train_loader.sampler.set_epoch(epoch)
        pbar = enumerate(self.train_loader)
        # Update dataloader attributes (optional)
        if epoch == (self.epochs - self.args.close_mosaic):
            self._close_dataloader_mosaic()
            self.train_loader.reset()

        if RANK in {-1, 0}:
            LOGGER.info(self.progress_string())
            pbar = TQDM(enumerate(self.train_loader), total=nb)
        self.tloss = None
        for i, batch in pbar:
            self.run_callbacks("on_train_batch_start")
            # Warmup
            ni = i + nb * epoch
            if ni <= nw:
                xi = [0, nw]  # x interp
                self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
                for j, x in enumerate(self.optimizer.param_groups):
                    # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x["lr"] = np.interp(
                        ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
                    )
                    if "momentum" in x:
                        x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])

            # Forward
            with autocast(self.amp):
                batch = self.preprocess_batch(batch)
                if self.args.compile:
                    # Decouple inference and loss calculations for improved compile performance
                    preds = self.model(batch["img"])
                    loss, self.loss_items = unwrap_model(self.model).loss(batch, preds)
                else:
                    loss, self.loss_items = self.model(batch)
                self.loss = loss.sum()
                if RANK != -1:
                    self.loss *= self.world_size
                self.tloss = self.loss_items if self.tloss is None else (self.tloss * i + self.loss_items) / (i + 1)

            # Backward
            self.scaler.scale(self.loss).backward()
            if ni - last_opt_step >= self.accumulate:
                self.optimizer_step()
                last_opt_step = ni

                # Timed stopping
                if self.args.time:
                    self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
                    if RANK != -1:  # if DDP training
                        broadcast_list = [self.stop if RANK == 0 else None]
                        dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
                        self.stop = broadcast_list[0]
                    if self.stop:  # training time exceeded
                        break

            # Log
            if RANK in {-1, 0}:
                loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
                pbar.set_description(
                    ("%11s" * 2 + "%11.4g" * (2 + loss_length))
                    % (
                        f"{epoch + 1}/{self.epochs}",
                        f"{self._get_memory():.3g}G",  # (GB) GPU memory util
                        *(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)),  # losses
                        batch["cls"].shape[0],  # batch size, i.e. 8
                        batch["img"].shape[-1],  # imgsz, i.e 640
                    )
                )
                self.run_callbacks("on_batch_end")
                if self.args.plots and ni in self.plot_idx:
                    self.plot_training_samples(batch, ni)

            self.run_callbacks("on_train_batch_end")

        self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers

        self.run_callbacks("on_train_epoch_end")
        if RANK in {-1, 0}:
            self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])

        # Validation
        final_epoch = epoch + 1 >= self.epochs
        if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
            self._clear_memory(threshold=0.5)  # prevent VRAM spike
            self.metrics, self.fitness = self.validate()

        # NaN recovery
        if self._handle_nan_recovery(epoch):
            continue

        self.nan_recovery_attempts = 0
        if RANK in {-1, 0}:
            self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
            self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
            if self.args.time:
                self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)

            # Save model
            if self.args.save or final_epoch:
                self.save_model()
                self.run_callbacks("on_model_save")

        # Scheduler
        t = time.time()
        self.epoch_time = t - self.epoch_time_start
        self.epoch_time_start = t
        if self.args.time:
            mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
            self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
            self._setup_scheduler()
            self.scheduler.last_epoch = self.epoch  # do not move
            self.stop |= epoch >= self.epochs  # stop if exceeded epochs
        self.run_callbacks("on_fit_epoch_end")
        self._clear_memory(0.5)  # clear if memory utilization > 50%

        # Early Stopping
        if RANK != -1:  # if DDP training
            broadcast_list = [self.stop if RANK == 0 else None]
            dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
            self.stop = broadcast_list[0]
        if self.stop:
            break  # must break all DDP ranks
        epoch += 1

    seconds = time.time() - self.train_time_start
    LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
    # Do final val with best.pt
    self.final_eval()
    if RANK in {-1, 0}:
        if self.args.plots:
            self.plot_metrics()
        self.run_callbacks("on_train_end")
    self._clear_memory()
    unset_deterministic()
    self.run_callbacks("teardown")


method ultralytics.engine.trainer.BaseTrainer._get_memory

def _get_memory(self, fraction = False)

Get accelerator memory utilization in GB or as a fraction of total memory.

Args

NameTypeDescriptionDefault
fractionFalse
Source code in ultralytics/engine/trainer.pyView on GitHub
def _get_memory(self, fraction=False):
    """Get accelerator memory utilization in GB or as a fraction of total memory."""
    memory, total = 0, 0
    if self.device.type == "mps":
        memory = torch.mps.driver_allocated_memory()
        if fraction:
            return __import__("psutil").virtual_memory().percent / 100
    elif self.device.type != "cpu":
        memory = torch.cuda.memory_reserved()
        if fraction:
            total = torch.cuda.get_device_properties(self.device).total_memory
    return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)


method ultralytics.engine.trainer.BaseTrainer._handle_nan_recovery

def _handle_nan_recovery(self, epoch)

Detect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint.

Args

NameTypeDescriptionDefault
epochrequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def _handle_nan_recovery(self, epoch):
    """Detect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint."""
    loss_nan = self.loss is not None and not self.loss.isfinite()
    fitness_nan = self.fitness is not None and not np.isfinite(self.fitness)
    fitness_collapse = self.best_fitness and self.best_fitness > 0 and self.fitness == 0
    corrupted = RANK in {-1, 0} and loss_nan and (fitness_nan or fitness_collapse)
    reason = "Loss NaN/Inf" if loss_nan else "Fitness NaN/Inf" if fitness_nan else "Fitness collapse"
    if RANK != -1:  # DDP: broadcast to all ranks
        broadcast_list = [corrupted if RANK == 0 else None]
        dist.broadcast_object_list(broadcast_list, 0)
        corrupted = broadcast_list[0]
    if not corrupted:
        return False
    if epoch == self.start_epoch or not self.last.exists():
        LOGGER.warning(f"{reason} detected but can not recover from last.pt...")
        return False  # Cannot recover on first epoch, let training continue
    self.nan_recovery_attempts += 1
    if self.nan_recovery_attempts > 3:
        raise RuntimeError(f"Training failed: NaN persisted for {self.nan_recovery_attempts} epochs")
    LOGGER.warning(f"{reason} detected (attempt {self.nan_recovery_attempts}/3), recovering from last.pt...")
    self._model_train()  # set model to train mode before loading checkpoint to avoid inference tensor errors
    _, ckpt = load_checkpoint(self.last)
    ema_state = ckpt["ema"].float().state_dict()
    if not all(torch.isfinite(v).all() for v in ema_state.values() if isinstance(v, torch.Tensor)):
        raise RuntimeError(f"Checkpoint {self.last} is corrupted with NaN/Inf weights")
    unwrap_model(self.model).load_state_dict(ema_state)  # Load EMA weights into model
    self._load_checkpoint_state(ckpt)  # Load optimizer/scaler/EMA/best_fitness
    del ckpt, ema_state
    self.scheduler.last_epoch = epoch - 1
    return True


method ultralytics.engine.trainer.BaseTrainer._load_checkpoint_state

def _load_checkpoint_state(self, ckpt)

Load optimizer, scaler, EMA, and best_fitness from checkpoint.

Args

NameTypeDescriptionDefault
ckptrequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def _load_checkpoint_state(self, ckpt):
    """Load optimizer, scaler, EMA, and best_fitness from checkpoint."""
    if ckpt.get("optimizer") is not None:
        self.optimizer.load_state_dict(ckpt["optimizer"])
    if ckpt.get("scaler") is not None:
        self.scaler.load_state_dict(ckpt["scaler"])
    if self.ema and ckpt.get("ema"):
        self.ema = ModelEMA(self.model)  # validation with EMA creates inference tensors that can't be updated
        self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())
        self.ema.updates = ckpt["updates"]
    self.best_fitness = ckpt.get("best_fitness", 0.0)


method ultralytics.engine.trainer.BaseTrainer._model_train

def _model_train(self)

Set model in training mode.

Source code in ultralytics/engine/trainer.pyView on GitHub
def _model_train(self):
    """Set model in training mode."""
    self.model.train()
    # Freeze BN stat
    for n, m in self.model.named_modules():
        if any(filter(lambda f: f in n, self.freeze_layer_names)) and isinstance(m, nn.BatchNorm2d):
            m.eval()


method ultralytics.engine.trainer.BaseTrainer._setup_ddp

def _setup_ddp(self)

Initialize and set the DistributedDataParallel parameters for training.

Source code in ultralytics/engine/trainer.pyView on GitHub
def _setup_ddp(self):
    """Initialize and set the DistributedDataParallel parameters for training."""
    torch.cuda.set_device(RANK)
    self.device = torch.device("cuda", RANK)
    os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"  # set to enforce timeout
    dist.init_process_group(
        backend="nccl" if dist.is_nccl_available() else "gloo",
        timeout=timedelta(seconds=10800),  # 3 hours
        rank=RANK,
        world_size=self.world_size,
    )


method ultralytics.engine.trainer.BaseTrainer._setup_scheduler

def _setup_scheduler(self)

Initialize training learning rate scheduler.

Source code in ultralytics/engine/trainer.pyView on GitHub
def _setup_scheduler(self):
    """Initialize training learning rate scheduler."""
    if self.args.cos_lr:
        self.lf = one_cycle(1, self.args.lrf, self.epochs)  # cosine 1->hyp['lrf']
    else:
        self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf  # linear
    self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)


method ultralytics.engine.trainer.BaseTrainer._setup_train

def _setup_train(self)

Build dataloaders and optimizer on correct rank process.

Source code in ultralytics/engine/trainer.pyView on GitHub
def _setup_train(self):
    """Build dataloaders and optimizer on correct rank process."""
    ckpt = self.setup_model()
    self.model = self.model.to(self.device)
    self.set_model_attributes()

    # Compile model
    self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)

    # Freeze layers
    freeze_list = (
        self.args.freeze
        if isinstance(self.args.freeze, list)
        else range(self.args.freeze)
        if isinstance(self.args.freeze, int)
        else []
    )
    always_freeze_names = [".dfl"]  # always freeze these layers
    freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
    self.freeze_layer_names = freeze_layer_names
    for k, v in self.model.named_parameters():
        # v.register_hook(lambda x: torch.nan_to_num(x))  # NaN to 0 (commented for erratic training results)
        if any(x in k for x in freeze_layer_names):
            LOGGER.info(f"Freezing layer '{k}'")
            v.requires_grad = False
        elif not v.requires_grad and v.dtype.is_floating_point:  # only floating point Tensor can require gradients
            LOGGER.warning(
                f"setting 'requires_grad=True' for frozen layer '{k}'. "
                "See ultralytics.engine.trainer for customization of frozen layers."
            )
            v.requires_grad = True

    # Check AMP
    self.amp = torch.tensor(self.args.amp).to(self.device)  # True or False
    if self.amp and RANK in {-1, 0}:  # Single-GPU and DDP
        callbacks_backup = callbacks.default_callbacks.copy()  # backup callbacks as check_amp() resets them
        self.amp = torch.tensor(check_amp(self.model), device=self.device)
        callbacks.default_callbacks = callbacks_backup  # restore callbacks
    if RANK > -1 and self.world_size > 1:  # DDP
        dist.broadcast(self.amp.int(), src=0)  # broadcast from rank 0 to all other ranks; gloo errors with boolean
    self.amp = bool(self.amp)  # as boolean
    self.scaler = (
        torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
    )
    if self.world_size > 1:
        self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)

    # Check imgsz
    gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32)  # grid size (max stride)
    self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
    self.stride = gs  # for multiscale training

    # Batch size
    if self.batch_size < 1 and RANK == -1:  # single-GPU only, estimate best batch size
        self.args.batch = self.batch_size = self.auto_batch()

    # Dataloaders
    batch_size = self.batch_size // max(self.world_size, 1)
    self.train_loader = self.get_dataloader(
        self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
    )
    # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
    self.test_loader = self.get_dataloader(
        self.data.get("val") or self.data.get("test"),
        batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
        rank=LOCAL_RANK,
        mode="val",
    )
    self.validator = self.get_validator()
    self.ema = ModelEMA(self.model)
    if RANK in {-1, 0}:
        metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
        self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
        if self.args.plots:
            self.plot_training_labels()

    # Optimizer
    self.accumulate = max(round(self.args.nbs / self.batch_size), 1)  # accumulate loss before optimizing
    weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs  # scale weight_decay
    iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
    self.optimizer = self.build_optimizer(
        model=self.model,
        name=self.args.optimizer,
        lr=self.args.lr0,
        momentum=self.args.momentum,
        decay=weight_decay,
        iterations=iterations,
    )
    # Scheduler
    self._setup_scheduler()
    self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
    self.resume_training(ckpt)
    self.scheduler.last_epoch = self.start_epoch - 1  # do not move
    self.run_callbacks("on_pretrain_routine_end")


method ultralytics.engine.trainer.BaseTrainer.add_callback

def add_callback(self, event: str, callback)

Append the given callback to the event's callback list.

Args

NameTypeDescriptionDefault
eventstrrequired
callbackrequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def add_callback(self, event: str, callback):
    """Append the given callback to the event's callback list."""
    self.callbacks[event].append(callback)


method ultralytics.engine.trainer.BaseTrainer.auto_batch

def auto_batch(self, max_num_obj = 0)

Calculate optimal batch size based on model and device memory constraints.

Args

NameTypeDescriptionDefault
max_num_obj0
Source code in ultralytics/engine/trainer.pyView on GitHub
def auto_batch(self, max_num_obj=0):
    """Calculate optimal batch size based on model and device memory constraints."""
    return check_train_batch_size(
        model=self.model,
        imgsz=self.args.imgsz,
        amp=self.amp,
        batch=self.batch_size,
        max_num_obj=max_num_obj,
    )  # returns batch size


method ultralytics.engine.trainer.BaseTrainer.build_dataset

def build_dataset(self, img_path, mode = "train", batch = None)

Build dataset.

Args

NameTypeDescriptionDefault
img_pathrequired
mode"train"
batchNone
Source code in ultralytics/engine/trainer.pyView on GitHub
def build_dataset(self, img_path, mode="train", batch=None):
    """Build dataset."""
    raise NotImplementedError("build_dataset function not implemented in trainer")


method ultralytics.engine.trainer.BaseTrainer.build_optimizer

def build_optimizer(self, model, name = "auto", lr = 0.001, momentum = 0.9, decay = 1e-5, iterations = 1e5)

Construct an optimizer for the given model.

Args

NameTypeDescriptionDefault
modeltorch.nn.ModuleThe model for which to build an optimizer.required
namestr, optionalThe name of the optimizer to use. If 'auto', the optimizer is selected based on the number of iterations."auto"
lrfloat, optionalThe learning rate for the optimizer.0.001
momentumfloat, optionalThe momentum factor for the optimizer.0.9
decayfloat, optionalThe weight decay for the optimizer.1e-5
iterationsfloat, optionalThe number of iterations, which determines the optimizer if name is 'auto'.1e5

Returns

TypeDescription
torch.optim.OptimizerThe constructed optimizer.
Source code in ultralytics/engine/trainer.pyView on GitHub
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
    """Construct an optimizer for the given model.

    Args:
        model (torch.nn.Module): The model for which to build an optimizer.
        name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
            number of iterations.
        lr (float, optional): The learning rate for the optimizer.
        momentum (float, optional): The momentum factor for the optimizer.
        decay (float, optional): The weight decay for the optimizer.
        iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.

    Returns:
        (torch.optim.Optimizer): The constructed optimizer.
    """
    g = [], [], []  # optimizer parameter groups
    bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()
    if name == "auto":
        LOGGER.info(
            f"{colorstr('optimizer:')} 'optimizer=auto' found, "
            f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
            f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
        )
        nc = self.data.get("nc", 10)  # number of classes
        lr_fit = round(0.002 * 5 / (4 + nc), 6)  # lr0 fit equation to 6 decimal places
        name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
        self.args.warmup_bias_lr = 0.0  # no higher than 0.01 for Adam

    for module_name, module in model.named_modules():
        for param_name, param in module.named_parameters(recurse=False):
            fullname = f"{module_name}.{param_name}" if module_name else param_name
            if "bias" in fullname:  # bias (no decay)
                g[2].append(param)
            elif isinstance(module, bn) or "logit_scale" in fullname:  # weight (no decay)
                # ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
                g[1].append(param)
            else:  # weight (with decay)
                g[0].append(param)

    optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
    name = {x.lower(): x for x in optimizers}.get(name.lower())
    if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
        optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
    elif name == "RMSProp":
        optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
    elif name == "SGD":
        optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
    else:
        raise NotImplementedError(
            f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
            "Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
        )

    optimizer.add_param_group({"params": g[0], "weight_decay": decay})  # add g0 with weight_decay
    optimizer.add_param_group({"params": g[1], "weight_decay": 0.0})  # add g1 (BatchNorm2d weights)
    LOGGER.info(
        f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
        f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
    )
    return optimizer


method ultralytics.engine.trainer.BaseTrainer.build_targets

def build_targets(self, preds, targets)

Build target tensors for training YOLO model.

Args

NameTypeDescriptionDefault
predsrequired
targetsrequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def build_targets(self, preds, targets):
    """Build target tensors for training YOLO model."""
    pass


method ultralytics.engine.trainer.BaseTrainer.check_resume

def check_resume(self, overrides)

Check if resume checkpoint exists and update arguments accordingly.

Args

NameTypeDescriptionDefault
overridesrequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def check_resume(self, overrides):
    """Check if resume checkpoint exists and update arguments accordingly."""
    resume = self.args.resume
    if resume:
        try:
            exists = isinstance(resume, (str, Path)) and Path(resume).exists()
            last = Path(check_file(resume) if exists else get_latest_run())

            # Check that resume data YAML exists, otherwise strip to force re-download of dataset
            ckpt_args = load_checkpoint(last)[0].args
            if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
                ckpt_args["data"] = self.args.data

            resume = True
            self.args = get_cfg(ckpt_args)
            self.args.model = self.args.resume = str(last)  # reinstate model
            for k in (
                "imgsz",
                "batch",
                "device",
                "close_mosaic",
                "augmentations",
            ):  # allow arg updates to reduce memory or update device on resume
                if k in overrides:
                    setattr(self.args, k, overrides[k])

            # Handle augmentations parameter for resume: check if user provided custom augmentations
            if ckpt_args.get("augmentations") is not None:
                # Augmentations were saved in checkpoint as reprs but can't be restored automatically
                LOGGER.warning(
                    "Custom Albumentations transforms were used in the original training run but are not "
                    "being restored. To preserve custom augmentations when resuming, you need to pass the "
                    "'augmentations' parameter again to get expected results. Example: \n"
                    f"model.train(resume=True, augmentations={ckpt_args['augmentations']})"
                )

        except Exception as e:
            raise FileNotFoundError(
                "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
                "i.e. 'yolo train resume model=path/to/last.pt'"
            ) from e
    self.resume = resume


method ultralytics.engine.trainer.BaseTrainer.final_eval

def final_eval(self)

Perform final evaluation and validation for object detection YOLO model.

Source code in ultralytics/engine/trainer.pyView on GitHub
def final_eval(self):
    """Perform final evaluation and validation for object detection YOLO model."""
    model = self.best if self.best.exists() else None
    with torch_distributed_zero_first(LOCAL_RANK):  # strip only on GPU 0; other GPUs should wait
        if RANK in {-1, 0}:
            ckpt = strip_optimizer(self.last) if self.last.exists() else {}
            if model:
                # update best.pt train_metrics from last.pt
                strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
    if model:
        LOGGER.info(f"\nValidating {model}...")
        self.validator.args.plots = self.args.plots
        self.validator.args.compile = False  # disable final val compile as too slow
        self.metrics = self.validator(model=model)
        self.metrics.pop("fitness", None)
        self.run_callbacks("on_fit_epoch_end")


method ultralytics.engine.trainer.BaseTrainer.get_dataloader

def get_dataloader(self, dataset_path, batch_size = 16, rank = 0, mode = "train")

Return dataloader derived from torch.data.Dataloader.

Args

NameTypeDescriptionDefault
dataset_pathrequired
batch_size16
rank0
mode"train"
Source code in ultralytics/engine/trainer.pyView on GitHub
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
    """Return dataloader derived from torch.data.Dataloader."""
    raise NotImplementedError("get_dataloader function not implemented in trainer")


method ultralytics.engine.trainer.BaseTrainer.get_dataset

def get_dataset(self)

Get train and validation datasets from data dictionary.

Returns

TypeDescription
dictA dictionary containing the training/validation/test dataset and category names.
Source code in ultralytics/engine/trainer.pyView on GitHub
def get_dataset(self):
    """Get train and validation datasets from data dictionary.

    Returns:
        (dict): A dictionary containing the training/validation/test dataset and category names.
    """
    try:
        if self.args.task == "classify":
            data = check_cls_dataset(self.args.data)
        elif str(self.args.data).rsplit(".", 1)[-1] == "ndjson":
            # Convert NDJSON to YOLO format
            import asyncio

            from ultralytics.data.converter import convert_ndjson_to_yolo

            yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
            self.args.data = str(yaml_path)
            data = check_det_dataset(self.args.data)
        elif str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
            "detect",
            "segment",
            "pose",
            "obb",
        }:
            data = check_det_dataset(self.args.data)
            if "yaml_file" in data:
                self.args.data = data["yaml_file"]  # for validating 'yolo train data=url.zip' usage
    except Exception as e:
        raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
    if self.args.single_cls:
        LOGGER.info("Overriding class names with single class.")
        data["names"] = {0: "item"}
        data["nc"] = 1
    return data


method ultralytics.engine.trainer.BaseTrainer.get_model

def get_model(self, cfg = None, weights = None, verbose = True)

Get model and raise NotImplementedError for loading cfg files.

Args

NameTypeDescriptionDefault
cfgNone
weightsNone
verboseTrue
Source code in ultralytics/engine/trainer.pyView on GitHub
def get_model(self, cfg=None, weights=None, verbose=True):
    """Get model and raise NotImplementedError for loading cfg files."""
    raise NotImplementedError("This task trainer doesn't support loading cfg files")


method ultralytics.engine.trainer.BaseTrainer.get_validator

def get_validator(self)

Return a NotImplementedError when the get_validator function is called.

Source code in ultralytics/engine/trainer.pyView on GitHub
def get_validator(self):
    """Return a NotImplementedError when the get_validator function is called."""
    raise NotImplementedError("get_validator function not implemented in trainer")


method ultralytics.engine.trainer.BaseTrainer.label_loss_items

def label_loss_items(self, loss_items = None, prefix = "train")

Return a loss dict with labeled training loss items tensor.

Args

NameTypeDescriptionDefault
loss_itemsNone
prefix"train"

Notes

This is not needed for classification but necessary for segmentation & detection

Source code in ultralytics/engine/trainer.pyView on GitHub
def label_loss_items(self, loss_items=None, prefix="train"):
    """Return a loss dict with labeled training loss items tensor.

    Notes:
        This is not needed for classification but necessary for segmentation & detection
    """
    return {"loss": loss_items} if loss_items is not None else ["loss"]


method ultralytics.engine.trainer.BaseTrainer.on_plot

def on_plot(self, name, data = None)

Register plots (e.g. to be consumed in callbacks).

Args

NameTypeDescriptionDefault
namerequired
dataNone
Source code in ultralytics/engine/trainer.pyView on GitHub
def on_plot(self, name, data=None):
    """Register plots (e.g. to be consumed in callbacks)."""
    path = Path(name)
    self.plots[path] = {"data": data, "timestamp": time.time()}


method ultralytics.engine.trainer.BaseTrainer.optimizer_step

def optimizer_step(self)

Perform a single step of the training optimizer with gradient clipping and EMA update.

Source code in ultralytics/engine/trainer.pyView on GitHub
def optimizer_step(self):
    """Perform a single step of the training optimizer with gradient clipping and EMA update."""
    self.scaler.unscale_(self.optimizer)  # unscale gradients
    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
    self.scaler.step(self.optimizer)
    self.scaler.update()
    self.optimizer.zero_grad()
    if self.ema:
        self.ema.update(self.model)


method ultralytics.engine.trainer.BaseTrainer.plot_metrics

def plot_metrics(self)

Plot metrics from a CSV file.

Source code in ultralytics/engine/trainer.pyView on GitHub
def plot_metrics(self):
    """Plot metrics from a CSV file."""
    plot_results(file=self.csv, on_plot=self.on_plot)  # save results.png


method ultralytics.engine.trainer.BaseTrainer.plot_training_labels

def plot_training_labels(self)

Plot training labels for YOLO model.

Source code in ultralytics/engine/trainer.pyView on GitHub
def plot_training_labels(self):
    """Plot training labels for YOLO model."""
    pass


method ultralytics.engine.trainer.BaseTrainer.plot_training_samples

def plot_training_samples(self, batch, ni)

Plot training samples during YOLO training.

Args

NameTypeDescriptionDefault
batchrequired
nirequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def plot_training_samples(self, batch, ni):
    """Plot training samples during YOLO training."""
    pass


method ultralytics.engine.trainer.BaseTrainer.preprocess_batch

def preprocess_batch(self, batch)

Allow custom preprocessing model inputs and ground truths depending on task type.

Args

NameTypeDescriptionDefault
batchrequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def preprocess_batch(self, batch):
    """Allow custom preprocessing model inputs and ground truths depending on task type."""
    return batch


method ultralytics.engine.trainer.BaseTrainer.progress_string

def progress_string(self)

Return a string describing training progress.

Source code in ultralytics/engine/trainer.pyView on GitHub
def progress_string(self):
    """Return a string describing training progress."""
    return ""


method ultralytics.engine.trainer.BaseTrainer.read_results_csv

def read_results_csv(self)

Read results.csv into a dictionary using polars.

Source code in ultralytics/engine/trainer.pyView on GitHub
def read_results_csv(self):
    """Read results.csv into a dictionary using polars."""
    import polars as pl  # scope for faster 'import ultralytics'

    try:
        return pl.read_csv(self.csv, infer_schema_length=None).to_dict(as_series=False)
    except Exception:
        return {}


method ultralytics.engine.trainer.BaseTrainer.resume_training

def resume_training(self, ckpt)

Resume YOLO training from given epoch and best fitness.

Args

NameTypeDescriptionDefault
ckptrequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def resume_training(self, ckpt):
    """Resume YOLO training from given epoch and best fitness."""
    if ckpt is None or not self.resume:
        return
    start_epoch = ckpt.get("epoch", -1) + 1
    assert start_epoch > 0, (
        f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
        f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
    )
    LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
    if self.epochs < start_epoch:
        LOGGER.info(
            f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
        )
        self.epochs += ckpt["epoch"]  # finetune additional epochs
    self._load_checkpoint_state(ckpt)
    self.start_epoch = start_epoch
    if start_epoch > (self.epochs - self.args.close_mosaic):
        self._close_dataloader_mosaic()


method ultralytics.engine.trainer.BaseTrainer.run_callbacks

def run_callbacks(self, event: str)

Run all existing callbacks associated with a particular event.

Args

NameTypeDescriptionDefault
eventstrrequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def run_callbacks(self, event: str):
    """Run all existing callbacks associated with a particular event."""
    for callback in self.callbacks.get(event, []):
        callback(self)


method ultralytics.engine.trainer.BaseTrainer.save_metrics

def save_metrics(self, metrics)

Save training metrics to a CSV file.

Args

NameTypeDescriptionDefault
metricsrequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def save_metrics(self, metrics):
    """Save training metrics to a CSV file."""
    keys, vals = list(metrics.keys()), list(metrics.values())
    n = len(metrics) + 2  # number of cols
    t = time.time() - self.train_time_start
    self.csv.parent.mkdir(parents=True, exist_ok=True)  # ensure parent directory exists
    s = "" if self.csv.exists() else ("%s," * n % ("epoch", "time", *keys)).rstrip(",") + "\n"
    with open(self.csv, "a", encoding="utf-8") as f:
        f.write(s + ("%.6g," * n % (self.epoch + 1, t, *vals)).rstrip(",") + "\n")


method ultralytics.engine.trainer.BaseTrainer.save_model

def save_model(self)

Save model training checkpoints with additional metadata.

Source code in ultralytics/engine/trainer.pyView on GitHub
def save_model(self):
    """Save model training checkpoints with additional metadata."""
    import io

    # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
    buffer = io.BytesIO()
    torch.save(
        {
            "epoch": self.epoch,
            "best_fitness": self.best_fitness,
            "model": None,  # resume and final checkpoints derive from EMA
            "ema": deepcopy(unwrap_model(self.ema.ema)).half(),
            "updates": self.ema.updates,
            "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
            "scaler": self.scaler.state_dict(),
            "train_args": vars(self.args),  # save as dict
            "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
            "train_results": self.read_results_csv(),
            "date": datetime.now().isoformat(),
            "version": __version__,
            "git": {
                "root": str(GIT.root),
                "branch": GIT.branch,
                "commit": GIT.commit,
                "origin": GIT.origin,
            },
            "license": "AGPL-3.0 (https://ultralytics.com/license)",
            "docs": "https://docs.ultralytics.com",
        },
        buffer,
    )
    serialized_ckpt = buffer.getvalue()  # get the serialized content to save

    # Save checkpoints
    self.wdir.mkdir(parents=True, exist_ok=True)  # ensure weights directory exists
    self.last.write_bytes(serialized_ckpt)  # save last.pt
    if self.best_fitness == self.fitness:
        self.best.write_bytes(serialized_ckpt)  # save best.pt
    if (self.save_period > 0) and (self.epoch % self.save_period == 0):
        (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt)  # save epoch, i.e. 'epoch3.pt'


method ultralytics.engine.trainer.BaseTrainer.set_callback

def set_callback(self, event: str, callback)

Override the existing callbacks with the given callback for the specified event.

Args

NameTypeDescriptionDefault
eventstrrequired
callbackrequired
Source code in ultralytics/engine/trainer.pyView on GitHub
def set_callback(self, event: str, callback):
    """Override the existing callbacks with the given callback for the specified event."""
    self.callbacks[event] = [callback]


method ultralytics.engine.trainer.BaseTrainer.set_model_attributes

def set_model_attributes(self)

Set or update model parameters before training.

Source code in ultralytics/engine/trainer.pyView on GitHub
def set_model_attributes(self):
    """Set or update model parameters before training."""
    self.model.names = self.data["names"]


method ultralytics.engine.trainer.BaseTrainer.setup_model

def setup_model(self)

Load, create, or download model for any task.

Returns

TypeDescription
dictOptional checkpoint to resume training from.
Source code in ultralytics/engine/trainer.pyView on GitHub
def setup_model(self):
    """Load, create, or download model for any task.

    Returns:
        (dict): Optional checkpoint to resume training from.
    """
    if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
        return

    cfg, weights = self.model, None
    ckpt = None
    if str(self.model).endswith(".pt"):
        weights, ckpt = load_checkpoint(self.model)
        cfg = weights.yaml
    elif isinstance(self.args.pretrained, (str, Path)):
        weights, _ = load_checkpoint(self.args.pretrained)
    self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
    return ckpt


method ultralytics.engine.trainer.BaseTrainer.train

def train(self)

Allow device='', device=None on Multi-GPU systems to default to device=0.

Source code in ultralytics/engine/trainer.pyView on GitHub
def train(self):
    """Allow device='', device=None on Multi-GPU systems to default to device=0."""
    # Run subprocess if DDP training, else train normally
    if self.ddp:
        # Argument checks
        if self.args.rect:
            LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
            self.args.rect = False
        if self.args.batch < 1.0:
            raise ValueError(
                "AutoBatch with batch<1 not supported for Multi-GPU training, "
                f"please specify a valid batch size multiple of GPU count {self.world_size}, i.e. batch={self.world_size * 8}."
            )

        # Command
        cmd, file = generate_ddp_command(self)
        try:
            LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
            subprocess.run(cmd, check=True)
        except Exception as e:
            raise e
        finally:
            ddp_cleanup(self, str(file))

    else:
        self._do_train()


method ultralytics.engine.trainer.BaseTrainer.validate

def validate(self)

Run validation on val set using self.validator.

Returns

TypeDescription
metrics (dict)Dictionary of validation metrics.
fitness (float)Fitness score for the validation.
Source code in ultralytics/engine/trainer.pyView on GitHub
def validate(self):
    """Run validation on val set using self.validator.

    Returns:
        metrics (dict): Dictionary of validation metrics.
        fitness (float): Fitness score for the validation.
    """
    if self.ema and self.world_size > 1:
        # Sync EMA buffers from rank 0 to all ranks
        for buffer in self.ema.ema.buffers():
            dist.broadcast(buffer, src=0)
    metrics = self.validator(self)
    if metrics is None:
        return None, None
    fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
    if not self.best_fitness or self.best_fitness < fitness:
        self.best_fitness = fitness
    return metrics, fitness





📅 Created 2 years ago ✏️ Updated 1 day ago
glenn-jocherjk4eBurhan-Q