Reference for ultralytics/engine/trainer.py
Improvements
This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/trainer.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏
Summary
BaseTrainer.add_callbackBaseTrainer.set_callbackBaseTrainer.run_callbacksBaseTrainer.trainBaseTrainer._setup_schedulerBaseTrainer._setup_ddpBaseTrainer._setup_trainBaseTrainer._do_trainBaseTrainer.auto_batchBaseTrainer._get_memoryBaseTrainer._clear_memoryBaseTrainer.read_results_csvBaseTrainer._model_trainBaseTrainer.save_modelBaseTrainer.get_datasetBaseTrainer.setup_modelBaseTrainer.optimizer_stepBaseTrainer.preprocess_batchBaseTrainer.validateBaseTrainer.get_modelBaseTrainer.get_validatorBaseTrainer.get_dataloaderBaseTrainer.build_datasetBaseTrainer.label_loss_itemsBaseTrainer.set_model_attributesBaseTrainer.build_targetsBaseTrainer.progress_stringBaseTrainer.plot_training_samplesBaseTrainer.plot_training_labelsBaseTrainer.save_metricsBaseTrainer.plot_metricsBaseTrainer.on_plotBaseTrainer.final_evalBaseTrainer.check_resumeBaseTrainer._load_checkpoint_stateBaseTrainer._handle_nan_recoveryBaseTrainer.resume_trainingBaseTrainer._close_dataloader_mosaicBaseTrainer.build_optimizer
class ultralytics.engine.trainer.BaseTrainer
BaseTrainer(self, cfg = DEFAULT_CFG, overrides = None, _callbacks = None)
A base class for creating trainers.
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing, and various training utilities. It supports both single-GPU and multi-GPU distributed training.
Args
| Name | Type | Description | Default |
|---|---|---|---|
cfg | str, optional | Path to a configuration file. | DEFAULT_CFG |
overrides | dict, optional | Configuration overrides. | None |
_callbacks | list, optional | List of callback functions. | None |
Attributes
| Name | Type | Description |
|---|---|---|
args | SimpleNamespace | Configuration for the trainer. |
validator | BaseValidator | Validator instance. |
model | nn.Module | Model instance. |
callbacks | defaultdict | Dictionary of callbacks. |
save_dir | Path | Directory to save results. |
wdir | Path | Directory to save weights. |
last | Path | Path to the last checkpoint. |
best | Path | Path to the best checkpoint. |
save_period | int | Save checkpoint every x epochs (disabled if < 1). |
batch_size | int | Batch size for training. |
epochs | int | Number of epochs to train for. |
start_epoch | int | Starting epoch for training. |
device | torch.device | Device to use for training. |
amp | bool | Flag to enable AMP (Automatic Mixed Precision). |
scaler | amp.GradScaler | Gradient scaler for AMP. |
data | str | Path to data. |
ema | nn.Module | EMA (Exponential Moving Average) of the model. |
resume | bool | Resume training from a checkpoint. |
lf | nn.Module | Loss function. |
scheduler | torch.optim.lr_scheduler._LRScheduler | Learning rate scheduler. |
best_fitness | float | The best fitness value achieved. |
fitness | float | Current fitness value. |
loss | float | Current loss value. |
tloss | float | Total loss value. |
loss_names | list | List of loss names. |
csv | Path | Path to results CSV file. |
metrics | dict | Dictionary of metrics. |
plots | dict | Dictionary of plots. |
Methods
| Name | Description |
|---|---|
_clear_memory | Clear accelerator memory by calling garbage collector and emptying cache. |
_close_dataloader_mosaic | Update dataloaders to stop using mosaic augmentation. |
_do_train | Train the model with the specified world size. |
_get_memory | Get accelerator memory utilization in GB or as a fraction of total memory. |
_handle_nan_recovery | Detect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint. |
_load_checkpoint_state | Load optimizer, scaler, EMA, and best_fitness from checkpoint. |
_model_train | Set model in training mode. |
_setup_ddp | Initialize and set the DistributedDataParallel parameters for training. |
_setup_scheduler | Initialize training learning rate scheduler. |
_setup_train | Build dataloaders and optimizer on correct rank process. |
add_callback | Append the given callback to the event's callback list. |
auto_batch | Calculate optimal batch size based on model and device memory constraints. |
build_dataset | Build dataset. |
build_optimizer | Construct an optimizer for the given model. |
build_targets | Build target tensors for training YOLO model. |
check_resume | Check if resume checkpoint exists and update arguments accordingly. |
final_eval | Perform final evaluation and validation for object detection YOLO model. |
get_dataloader | Return dataloader derived from torch.data.Dataloader. |
get_dataset | Get train and validation datasets from data dictionary. |
get_model | Get model and raise NotImplementedError for loading cfg files. |
get_validator | Return a NotImplementedError when the get_validator function is called. |
label_loss_items | Return a loss dict with labeled training loss items tensor. |
on_plot | Register plots (e.g. to be consumed in callbacks). |
optimizer_step | Perform a single step of the training optimizer with gradient clipping and EMA update. |
plot_metrics | Plot metrics from a CSV file. |
plot_training_labels | Plot training labels for YOLO model. |
plot_training_samples | Plot training samples during YOLO training. |
preprocess_batch | Allow custom preprocessing model inputs and ground truths depending on task type. |
progress_string | Return a string describing training progress. |
read_results_csv | Read results.csv into a dictionary using polars. |
resume_training | Resume YOLO training from given epoch and best fitness. |
run_callbacks | Run all existing callbacks associated with a particular event. |
save_metrics | Save training metrics to a CSV file. |
save_model | Save model training checkpoints with additional metadata. |
set_callback | Override the existing callbacks with the given callback for the specified event. |
set_model_attributes | Set or update model parameters before training. |
setup_model | Load, create, or download model for any task. |
train | Allow device='', device=None on Multi-GPU systems to default to device=0. |
validate | Run validation on val set using self.validator. |
Examples
Initialize a trainer and start training
>>> trainer = BaseTrainer(cfg="config.yaml")
>>> trainer.train()
Source code in ultralytics/engine/trainer.py
View on GitHubclass BaseTrainer:
"""A base class for creating trainers.
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
and various training utilities. It supports both single-GPU and multi-GPU distributed training.
Attributes:
args (SimpleNamespace): Configuration for the trainer.
validator (BaseValidator): Validator instance.
model (nn.Module): Model instance.
callbacks (defaultdict): Dictionary of callbacks.
save_dir (Path): Directory to save results.
wdir (Path): Directory to save weights.
last (Path): Path to the last checkpoint.
best (Path): Path to the best checkpoint.
save_period (int): Save checkpoint every x epochs (disabled if < 1).
batch_size (int): Batch size for training.
epochs (int): Number of epochs to train for.
start_epoch (int): Starting epoch for training.
device (torch.device): Device to use for training.
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
scaler (amp.GradScaler): Gradient scaler for AMP.
data (str): Path to data.
ema (nn.Module): EMA (Exponential Moving Average) of the model.
resume (bool): Resume training from a checkpoint.
lf (nn.Module): Loss function.
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
best_fitness (float): The best fitness value achieved.
fitness (float): Current fitness value.
loss (float): Current loss value.
tloss (float): Total loss value.
loss_names (list): List of loss names.
csv (Path): Path to results CSV file.
metrics (dict): Dictionary of metrics.
plots (dict): Dictionary of plots.
Methods:
train: Execute the training process.
validate: Run validation on the test set.
save_model: Save model training checkpoints.
get_dataset: Get train and validation datasets.
setup_model: Load, create, or download model.
build_optimizer: Construct an optimizer for the model.
Examples:
Initialize a trainer and start training
>>> trainer = BaseTrainer(cfg="config.yaml")
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize the BaseTrainer class.
Args:
cfg (str, optional): Path to a configuration file.
overrides (dict, optional): Configuration overrides.
_callbacks (list, optional): List of callback functions.
"""
self.hub_session = overrides.pop("session", None) # HUB
self.args = get_cfg(cfg, overrides)
self.check_resume(overrides)
self.device = select_device(self.args.device)
# Update "-1" devices so post-training val does not repeat search
self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
self.validator = None
self.metrics = None
self.plots = {}
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
self.save_dir = get_save_dir(self.args)
self.args.name = self.save_dir.name # update name for loggers
self.wdir = self.save_dir / "weights" # weights dir
if RANK in {-1, 0}:
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.args.save_dir = str(self.save_dir)
# Save run args, serializing augmentations as reprs for resume compatibility
args_dict = vars(self.args).copy()
if args_dict.get("augmentations") is not None:
# Serialize Albumentations transforms as their repr strings for checkpoint compatibility
args_dict["augmentations"] = [repr(t) for t in args_dict["augmentations"]]
YAML.save(self.save_dir / "args.yaml", args_dict) # save run args
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
self.save_period = self.args.save_period
self.batch_size = self.args.batch
self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training
self.start_epoch = 0
if RANK == -1:
print_args(vars(self.args))
# Device
if self.device.type in {"cpu", "mps"}:
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataset
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolo11n -> yolo11n.pt
with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
self.data = self.get_dataset()
self.ema = None
# Optimization utils init
self.lf = None
self.scheduler = None
# Epoch level metrics
self.best_fitness = None
self.fitness = None
self.loss = None
self.tloss = None
self.loss_names = ["Loss"]
self.csv = self.save_dir / "results.csv"
if self.csv.exists() and not self.args.resume:
self.csv.unlink()
self.plot_idx = [0, 1, 2]
self.nan_recovery_attempts = 0
# Callbacks
self.callbacks = _callbacks or callbacks.get_default_callbacks()
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
world_size = len(self.args.device.split(","))
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
world_size = len(self.args.device)
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
world_size = 0
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
world_size = 1 # default to device 0
else: # i.e. device=None or device=''
world_size = 0
self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
self.world_size = world_size
# Run subprocess if DDP training, else train normally
if RANK in {-1, 0} and not self.ddp:
callbacks.add_integration_callbacks(self)
# Start console logging immediately at trainer initialization
self.run_callbacks("on_pretrain_routine_start")
method ultralytics.engine.trainer.BaseTrainer._clear_memory
def _clear_memory(self, threshold: float | None = None)
Clear accelerator memory by calling garbage collector and emptying cache.
Args
| Name | Type | Description | Default |
|---|---|---|---|
threshold | float | None | None |
Source code in ultralytics/engine/trainer.py
View on GitHubdef _clear_memory(self, threshold: float | None = None):
"""Clear accelerator memory by calling garbage collector and emptying cache."""
if threshold:
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
if self._get_memory(fraction=True) <= threshold:
return
gc.collect()
if self.device.type == "mps":
torch.mps.empty_cache()
elif self.device.type == "cpu":
return
else:
torch.cuda.empty_cache()
method ultralytics.engine.trainer.BaseTrainer._close_dataloader_mosaic
def _close_dataloader_mosaic(self)
Update dataloaders to stop using mosaic augmentation.
Source code in ultralytics/engine/trainer.py
View on GitHubdef _close_dataloader_mosaic(self):
"""Update dataloaders to stop using mosaic augmentation."""
if hasattr(self.train_loader.dataset, "mosaic"):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, "close_mosaic"):
LOGGER.info("Closing dataloader mosaic")
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
method ultralytics.engine.trainer.BaseTrainer._do_train
def _do_train(self)
Train the model with the specified world size.
Source code in ultralytics/engine/trainer.py
View on GitHubdef _do_train(self):
"""Train the model with the specified world size."""
if self.world_size > 1:
self._setup_ddp()
self._setup_train()
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
last_opt_step = -1
self.epoch_time = None
self.epoch_time_start = time.time()
self.train_time_start = time.time()
self.run_callbacks("on_train_start")
LOGGER.info(
f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
f"Using {self.train_loader.num_workers * (self.world_size or 1)} dataloader workers\n"
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
)
if self.args.close_mosaic:
base_idx = (self.epochs - self.args.close_mosaic) * nb
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
epoch = self.start_epoch
self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
while True:
self.epoch = epoch
self.run_callbacks("on_train_epoch_start")
with warnings.catch_warnings():
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
self.scheduler.step()
self._model_train()
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic):
self._close_dataloader_mosaic()
self.train_loader.reset()
if RANK in {-1, 0}:
LOGGER.info(self.progress_string())
pbar = TQDM(enumerate(self.train_loader), total=nb)
self.tloss = None
for i, batch in pbar:
self.run_callbacks("on_train_batch_start")
# Warmup
ni = i + nb * epoch
if ni <= nw:
xi = [0, nw] # x interp
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
for j, x in enumerate(self.optimizer.param_groups):
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x["lr"] = np.interp(
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
)
if "momentum" in x:
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward
with autocast(self.amp):
batch = self.preprocess_batch(batch)
if self.args.compile:
# Decouple inference and loss calculations for improved compile performance
preds = self.model(batch["img"])
loss, self.loss_items = unwrap_model(self.model).loss(batch, preds)
else:
loss, self.loss_items = self.model(batch)
self.loss = loss.sum()
if RANK != -1:
self.loss *= self.world_size
self.tloss = self.loss_items if self.tloss is None else (self.tloss * i + self.loss_items) / (i + 1)
# Backward
self.scaler.scale(self.loss).backward()
if ni - last_opt_step >= self.accumulate:
self.optimizer_step()
last_opt_step = ni
# Timed stopping
if self.args.time:
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
if RANK != -1: # if DDP training
broadcast_list = [self.stop if RANK == 0 else None]
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
self.stop = broadcast_list[0]
if self.stop: # training time exceeded
break
# Log
if RANK in {-1, 0}:
loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
pbar.set_description(
("%11s" * 2 + "%11.4g" * (2 + loss_length))
% (
f"{epoch + 1}/{self.epochs}",
f"{self._get_memory():.3g}G", # (GB) GPU memory util
*(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses
batch["cls"].shape[0], # batch size, i.e. 8
batch["img"].shape[-1], # imgsz, i.e 640
)
)
self.run_callbacks("on_batch_end")
if self.args.plots and ni in self.plot_idx:
self.plot_training_samples(batch, ni)
self.run_callbacks("on_train_batch_end")
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.run_callbacks("on_train_epoch_end")
if RANK in {-1, 0}:
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
# Validation
final_epoch = epoch + 1 >= self.epochs
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
self._clear_memory(threshold=0.5) # prevent VRAM spike
self.metrics, self.fitness = self.validate()
# NaN recovery
if self._handle_nan_recovery(epoch):
continue
self.nan_recovery_attempts = 0
if RANK in {-1, 0}:
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
if self.args.time:
self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
# Save model
if self.args.save or final_epoch:
self.save_model()
self.run_callbacks("on_model_save")
# Scheduler
t = time.time()
self.epoch_time = t - self.epoch_time_start
self.epoch_time_start = t
if self.args.time:
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
self._setup_scheduler()
self.scheduler.last_epoch = self.epoch # do not move
self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.run_callbacks("on_fit_epoch_end")
self._clear_memory(0.5) # clear if memory utilization > 50%
# Early Stopping
if RANK != -1: # if DDP training
broadcast_list = [self.stop if RANK == 0 else None]
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
self.stop = broadcast_list[0]
if self.stop:
break # must break all DDP ranks
epoch += 1
seconds = time.time() - self.train_time_start
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
# Do final val with best.pt
self.final_eval()
if RANK in {-1, 0}:
if self.args.plots:
self.plot_metrics()
self.run_callbacks("on_train_end")
self._clear_memory()
unset_deterministic()
self.run_callbacks("teardown")
method ultralytics.engine.trainer.BaseTrainer._get_memory
def _get_memory(self, fraction = False)
Get accelerator memory utilization in GB or as a fraction of total memory.
Args
| Name | Type | Description | Default |
|---|---|---|---|
fraction | False |
Source code in ultralytics/engine/trainer.py
View on GitHubdef _get_memory(self, fraction=False):
"""Get accelerator memory utilization in GB or as a fraction of total memory."""
memory, total = 0, 0
if self.device.type == "mps":
memory = torch.mps.driver_allocated_memory()
if fraction:
return __import__("psutil").virtual_memory().percent / 100
elif self.device.type != "cpu":
memory = torch.cuda.memory_reserved()
if fraction:
total = torch.cuda.get_device_properties(self.device).total_memory
return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
method ultralytics.engine.trainer.BaseTrainer._handle_nan_recovery
def _handle_nan_recovery(self, epoch)
Detect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint.
Args
| Name | Type | Description | Default |
|---|---|---|---|
epoch | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef _handle_nan_recovery(self, epoch):
"""Detect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint."""
loss_nan = self.loss is not None and not self.loss.isfinite()
fitness_nan = self.fitness is not None and not np.isfinite(self.fitness)
fitness_collapse = self.best_fitness and self.best_fitness > 0 and self.fitness == 0
corrupted = RANK in {-1, 0} and loss_nan and (fitness_nan or fitness_collapse)
reason = "Loss NaN/Inf" if loss_nan else "Fitness NaN/Inf" if fitness_nan else "Fitness collapse"
if RANK != -1: # DDP: broadcast to all ranks
broadcast_list = [corrupted if RANK == 0 else None]
dist.broadcast_object_list(broadcast_list, 0)
corrupted = broadcast_list[0]
if not corrupted:
return False
if epoch == self.start_epoch or not self.last.exists():
LOGGER.warning(f"{reason} detected but can not recover from last.pt...")
return False # Cannot recover on first epoch, let training continue
self.nan_recovery_attempts += 1
if self.nan_recovery_attempts > 3:
raise RuntimeError(f"Training failed: NaN persisted for {self.nan_recovery_attempts} epochs")
LOGGER.warning(f"{reason} detected (attempt {self.nan_recovery_attempts}/3), recovering from last.pt...")
self._model_train() # set model to train mode before loading checkpoint to avoid inference tensor errors
_, ckpt = load_checkpoint(self.last)
ema_state = ckpt["ema"].float().state_dict()
if not all(torch.isfinite(v).all() for v in ema_state.values() if isinstance(v, torch.Tensor)):
raise RuntimeError(f"Checkpoint {self.last} is corrupted with NaN/Inf weights")
unwrap_model(self.model).load_state_dict(ema_state) # Load EMA weights into model
self._load_checkpoint_state(ckpt) # Load optimizer/scaler/EMA/best_fitness
del ckpt, ema_state
self.scheduler.last_epoch = epoch - 1
return True
method ultralytics.engine.trainer.BaseTrainer._load_checkpoint_state
def _load_checkpoint_state(self, ckpt)
Load optimizer, scaler, EMA, and best_fitness from checkpoint.
Args
| Name | Type | Description | Default |
|---|---|---|---|
ckpt | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef _load_checkpoint_state(self, ckpt):
"""Load optimizer, scaler, EMA, and best_fitness from checkpoint."""
if ckpt.get("optimizer") is not None:
self.optimizer.load_state_dict(ckpt["optimizer"])
if ckpt.get("scaler") is not None:
self.scaler.load_state_dict(ckpt["scaler"])
if self.ema and ckpt.get("ema"):
self.ema = ModelEMA(self.model) # validation with EMA creates inference tensors that can't be updated
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())
self.ema.updates = ckpt["updates"]
self.best_fitness = ckpt.get("best_fitness", 0.0)
method ultralytics.engine.trainer.BaseTrainer._model_train
def _model_train(self)
Set model in training mode.
Source code in ultralytics/engine/trainer.py
View on GitHubdef _model_train(self):
"""Set model in training mode."""
self.model.train()
# Freeze BN stat
for n, m in self.model.named_modules():
if any(filter(lambda f: f in n, self.freeze_layer_names)) and isinstance(m, nn.BatchNorm2d):
m.eval()
method ultralytics.engine.trainer.BaseTrainer._setup_ddp
def _setup_ddp(self)
Initialize and set the DistributedDataParallel parameters for training.
Source code in ultralytics/engine/trainer.py
View on GitHubdef _setup_ddp(self):
"""Initialize and set the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK)
self.device = torch.device("cuda", RANK)
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
dist.init_process_group(
backend="nccl" if dist.is_nccl_available() else "gloo",
timeout=timedelta(seconds=10800), # 3 hours
rank=RANK,
world_size=self.world_size,
)
method ultralytics.engine.trainer.BaseTrainer._setup_scheduler
def _setup_scheduler(self)
Initialize training learning rate scheduler.
Source code in ultralytics/engine/trainer.py
View on GitHubdef _setup_scheduler(self):
"""Initialize training learning rate scheduler."""
if self.args.cos_lr:
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
else:
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
method ultralytics.engine.trainer.BaseTrainer._setup_train
def _setup_train(self)
Build dataloaders and optimizer on correct rank process.
Source code in ultralytics/engine/trainer.py
View on GitHubdef _setup_train(self):
"""Build dataloaders and optimizer on correct rank process."""
ckpt = self.setup_model()
self.model = self.model.to(self.device)
self.set_model_attributes()
# Compile model
self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
# Freeze layers
freeze_list = (
self.args.freeze
if isinstance(self.args.freeze, list)
else range(self.args.freeze)
if isinstance(self.args.freeze, int)
else []
)
always_freeze_names = [".dfl"] # always freeze these layers
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
self.freeze_layer_names = freeze_layer_names
for k, v in self.model.named_parameters():
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
if any(x in k for x in freeze_layer_names):
LOGGER.info(f"Freezing layer '{k}'")
v.requires_grad = False
elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients
LOGGER.warning(
f"setting 'requires_grad=True' for frozen layer '{k}'. "
"See ultralytics.engine.trainer for customization of frozen layers."
)
v.requires_grad = True
# Check AMP
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
self.amp = torch.tensor(check_amp(self.model), device=self.device)
callbacks.default_callbacks = callbacks_backup # restore callbacks
if RANK > -1 and self.world_size > 1: # DDP
dist.broadcast(self.amp.int(), src=0) # broadcast from rank 0 to all other ranks; gloo errors with boolean
self.amp = bool(self.amp) # as boolean
self.scaler = (
torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
)
if self.world_size > 1:
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
self.stride = gs # for multiscale training
# Batch size
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
self.args.batch = self.batch_size = self.auto_batch()
# Dataloaders
batch_size = self.batch_size // max(self.world_size, 1)
self.train_loader = self.get_dataloader(
self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
)
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
self.test_loader = self.get_dataloader(
self.data.get("val") or self.data.get("test"),
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
rank=LOCAL_RANK,
mode="val",
)
self.validator = self.get_validator()
self.ema = ModelEMA(self.model)
if RANK in {-1, 0}:
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
if self.args.plots:
self.plot_training_labels()
# Optimizer
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
self.optimizer = self.build_optimizer(
model=self.model,
name=self.args.optimizer,
lr=self.args.lr0,
momentum=self.args.momentum,
decay=weight_decay,
iterations=iterations,
)
# Scheduler
self._setup_scheduler()
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks("on_pretrain_routine_end")
method ultralytics.engine.trainer.BaseTrainer.add_callback
def add_callback(self, event: str, callback)
Append the given callback to the event's callback list.
Args
| Name | Type | Description | Default |
|---|---|---|---|
event | str | required | |
callback | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef add_callback(self, event: str, callback):
"""Append the given callback to the event's callback list."""
self.callbacks[event].append(callback)
method ultralytics.engine.trainer.BaseTrainer.auto_batch
def auto_batch(self, max_num_obj = 0)
Calculate optimal batch size based on model and device memory constraints.
Args
| Name | Type | Description | Default |
|---|---|---|---|
max_num_obj | 0 |
Source code in ultralytics/engine/trainer.py
View on GitHubdef auto_batch(self, max_num_obj=0):
"""Calculate optimal batch size based on model and device memory constraints."""
return check_train_batch_size(
model=self.model,
imgsz=self.args.imgsz,
amp=self.amp,
batch=self.batch_size,
max_num_obj=max_num_obj,
) # returns batch size
method ultralytics.engine.trainer.BaseTrainer.build_dataset
def build_dataset(self, img_path, mode = "train", batch = None)
Build dataset.
Args
| Name | Type | Description | Default |
|---|---|---|---|
img_path | required | ||
mode | "train" | ||
batch | None |
Source code in ultralytics/engine/trainer.py
View on GitHubdef build_dataset(self, img_path, mode="train", batch=None):
"""Build dataset."""
raise NotImplementedError("build_dataset function not implemented in trainer")
method ultralytics.engine.trainer.BaseTrainer.build_optimizer
def build_optimizer(self, model, name = "auto", lr = 0.001, momentum = 0.9, decay = 1e-5, iterations = 1e5)
Construct an optimizer for the given model.
Args
| Name | Type | Description | Default |
|---|---|---|---|
model | torch.nn.Module | The model for which to build an optimizer. | required |
name | str, optional | The name of the optimizer to use. If 'auto', the optimizer is selected based on the number of iterations. | "auto" |
lr | float, optional | The learning rate for the optimizer. | 0.001 |
momentum | float, optional | The momentum factor for the optimizer. | 0.9 |
decay | float, optional | The weight decay for the optimizer. | 1e-5 |
iterations | float, optional | The number of iterations, which determines the optimizer if name is 'auto'. | 1e5 |
Returns
| Type | Description |
|---|---|
torch.optim.Optimizer | The constructed optimizer. |
Source code in ultralytics/engine/trainer.py
View on GitHubdef build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
"""Construct an optimizer for the given model.
Args:
model (torch.nn.Module): The model for which to build an optimizer.
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
number of iterations.
lr (float, optional): The learning rate for the optimizer.
momentum (float, optional): The momentum factor for the optimizer.
decay (float, optional): The weight decay for the optimizer.
iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
Returns:
(torch.optim.Optimizer): The constructed optimizer.
"""
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
if name == "auto":
LOGGER.info(
f"{colorstr('optimizer:')} 'optimizer=auto' found, "
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
)
nc = self.data.get("nc", 10) # number of classes
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
fullname = f"{module_name}.{param_name}" if module_name else param_name
if "bias" in fullname: # bias (no decay)
g[2].append(param)
elif isinstance(module, bn) or "logit_scale" in fullname: # weight (no decay)
# ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
g[1].append(param)
else: # weight (with decay)
g[0].append(param)
optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
name = {x.lower(): x for x in optimizers}.get(name.lower())
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
elif name == "RMSProp":
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
elif name == "SGD":
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
else:
raise NotImplementedError(
f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
"Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
)
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
)
return optimizer
method ultralytics.engine.trainer.BaseTrainer.build_targets
def build_targets(self, preds, targets)
Build target tensors for training YOLO model.
Args
| Name | Type | Description | Default |
|---|---|---|---|
preds | required | ||
targets | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef build_targets(self, preds, targets):
"""Build target tensors for training YOLO model."""
pass
method ultralytics.engine.trainer.BaseTrainer.check_resume
def check_resume(self, overrides)
Check if resume checkpoint exists and update arguments accordingly.
Args
| Name | Type | Description | Default |
|---|---|---|---|
overrides | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef check_resume(self, overrides):
"""Check if resume checkpoint exists and update arguments accordingly."""
resume = self.args.resume
if resume:
try:
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
last = Path(check_file(resume) if exists else get_latest_run())
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
ckpt_args = load_checkpoint(last)[0].args
if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
ckpt_args["data"] = self.args.data
resume = True
self.args = get_cfg(ckpt_args)
self.args.model = self.args.resume = str(last) # reinstate model
for k in (
"imgsz",
"batch",
"device",
"close_mosaic",
"augmentations",
): # allow arg updates to reduce memory or update device on resume
if k in overrides:
setattr(self.args, k, overrides[k])
# Handle augmentations parameter for resume: check if user provided custom augmentations
if ckpt_args.get("augmentations") is not None:
# Augmentations were saved in checkpoint as reprs but can't be restored automatically
LOGGER.warning(
"Custom Albumentations transforms were used in the original training run but are not "
"being restored. To preserve custom augmentations when resuming, you need to pass the "
"'augmentations' parameter again to get expected results. Example: \n"
f"model.train(resume=True, augmentations={ckpt_args['augmentations']})"
)
except Exception as e:
raise FileNotFoundError(
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
"i.e. 'yolo train resume model=path/to/last.pt'"
) from e
self.resume = resume
method ultralytics.engine.trainer.BaseTrainer.final_eval
def final_eval(self)
Perform final evaluation and validation for object detection YOLO model.
Source code in ultralytics/engine/trainer.py
View on GitHubdef final_eval(self):
"""Perform final evaluation and validation for object detection YOLO model."""
model = self.best if self.best.exists() else None
with torch_distributed_zero_first(LOCAL_RANK): # strip only on GPU 0; other GPUs should wait
if RANK in {-1, 0}:
ckpt = strip_optimizer(self.last) if self.last.exists() else {}
if model:
# update best.pt train_metrics from last.pt
strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
if model:
LOGGER.info(f"\nValidating {model}...")
self.validator.args.plots = self.args.plots
self.validator.args.compile = False # disable final val compile as too slow
self.metrics = self.validator(model=model)
self.metrics.pop("fitness", None)
self.run_callbacks("on_fit_epoch_end")
method ultralytics.engine.trainer.BaseTrainer.get_dataloader
def get_dataloader(self, dataset_path, batch_size = 16, rank = 0, mode = "train")
Return dataloader derived from torch.data.Dataloader.
Args
| Name | Type | Description | Default |
|---|---|---|---|
dataset_path | required | ||
batch_size | 16 | ||
rank | 0 | ||
mode | "train" |
Source code in ultralytics/engine/trainer.py
View on GitHubdef get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
"""Return dataloader derived from torch.data.Dataloader."""
raise NotImplementedError("get_dataloader function not implemented in trainer")
method ultralytics.engine.trainer.BaseTrainer.get_dataset
def get_dataset(self)
Get train and validation datasets from data dictionary.
Returns
| Type | Description |
|---|---|
dict | A dictionary containing the training/validation/test dataset and category names. |
Source code in ultralytics/engine/trainer.py
View on GitHubdef get_dataset(self):
"""Get train and validation datasets from data dictionary.
Returns:
(dict): A dictionary containing the training/validation/test dataset and category names.
"""
try:
if self.args.task == "classify":
data = check_cls_dataset(self.args.data)
elif str(self.args.data).rsplit(".", 1)[-1] == "ndjson":
# Convert NDJSON to YOLO format
import asyncio
from ultralytics.data.converter import convert_ndjson_to_yolo
yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
self.args.data = str(yaml_path)
data = check_det_dataset(self.args.data)
elif str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
"detect",
"segment",
"pose",
"obb",
}:
data = check_det_dataset(self.args.data)
if "yaml_file" in data:
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
if self.args.single_cls:
LOGGER.info("Overriding class names with single class.")
data["names"] = {0: "item"}
data["nc"] = 1
return data
method ultralytics.engine.trainer.BaseTrainer.get_model
def get_model(self, cfg = None, weights = None, verbose = True)
Get model and raise NotImplementedError for loading cfg files.
Args
| Name | Type | Description | Default |
|---|---|---|---|
cfg | None | ||
weights | None | ||
verbose | True |
Source code in ultralytics/engine/trainer.py
View on GitHubdef get_model(self, cfg=None, weights=None, verbose=True):
"""Get model and raise NotImplementedError for loading cfg files."""
raise NotImplementedError("This task trainer doesn't support loading cfg files")
method ultralytics.engine.trainer.BaseTrainer.get_validator
def get_validator(self)
Return a NotImplementedError when the get_validator function is called.
Source code in ultralytics/engine/trainer.py
View on GitHubdef get_validator(self):
"""Return a NotImplementedError when the get_validator function is called."""
raise NotImplementedError("get_validator function not implemented in trainer")
method ultralytics.engine.trainer.BaseTrainer.label_loss_items
def label_loss_items(self, loss_items = None, prefix = "train")
Return a loss dict with labeled training loss items tensor.
Args
| Name | Type | Description | Default |
|---|---|---|---|
loss_items | None | ||
prefix | "train" |
Notes
This is not needed for classification but necessary for segmentation & detection
Source code in ultralytics/engine/trainer.py
View on GitHubdef label_loss_items(self, loss_items=None, prefix="train"):
"""Return a loss dict with labeled training loss items tensor.
Notes:
This is not needed for classification but necessary for segmentation & detection
"""
return {"loss": loss_items} if loss_items is not None else ["loss"]
method ultralytics.engine.trainer.BaseTrainer.on_plot
def on_plot(self, name, data = None)
Register plots (e.g. to be consumed in callbacks).
Args
| Name | Type | Description | Default |
|---|---|---|---|
name | required | ||
data | None |
Source code in ultralytics/engine/trainer.py
View on GitHubdef on_plot(self, name, data=None):
"""Register plots (e.g. to be consumed in callbacks)."""
path = Path(name)
self.plots[path] = {"data": data, "timestamp": time.time()}
method ultralytics.engine.trainer.BaseTrainer.optimizer_step
def optimizer_step(self)
Perform a single step of the training optimizer with gradient clipping and EMA update.
Source code in ultralytics/engine/trainer.py
View on GitHubdef optimizer_step(self):
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
self.scaler.unscale_(self.optimizer) # unscale gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.ema:
self.ema.update(self.model)
method ultralytics.engine.trainer.BaseTrainer.plot_metrics
def plot_metrics(self)
Plot metrics from a CSV file.
Source code in ultralytics/engine/trainer.py
View on GitHubdef plot_metrics(self):
"""Plot metrics from a CSV file."""
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
method ultralytics.engine.trainer.BaseTrainer.plot_training_labels
def plot_training_labels(self)
Plot training labels for YOLO model.
Source code in ultralytics/engine/trainer.py
View on GitHubdef plot_training_labels(self):
"""Plot training labels for YOLO model."""
pass
method ultralytics.engine.trainer.BaseTrainer.plot_training_samples
def plot_training_samples(self, batch, ni)
Plot training samples during YOLO training.
Args
| Name | Type | Description | Default |
|---|---|---|---|
batch | required | ||
ni | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef plot_training_samples(self, batch, ni):
"""Plot training samples during YOLO training."""
pass
method ultralytics.engine.trainer.BaseTrainer.preprocess_batch
def preprocess_batch(self, batch)
Allow custom preprocessing model inputs and ground truths depending on task type.
Args
| Name | Type | Description | Default |
|---|---|---|---|
batch | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef preprocess_batch(self, batch):
"""Allow custom preprocessing model inputs and ground truths depending on task type."""
return batch
method ultralytics.engine.trainer.BaseTrainer.progress_string
def progress_string(self)
Return a string describing training progress.
Source code in ultralytics/engine/trainer.py
View on GitHubdef progress_string(self):
"""Return a string describing training progress."""
return ""
method ultralytics.engine.trainer.BaseTrainer.read_results_csv
def read_results_csv(self)
Read results.csv into a dictionary using polars.
Source code in ultralytics/engine/trainer.py
View on GitHubdef read_results_csv(self):
"""Read results.csv into a dictionary using polars."""
import polars as pl # scope for faster 'import ultralytics'
try:
return pl.read_csv(self.csv, infer_schema_length=None).to_dict(as_series=False)
except Exception:
return {}
method ultralytics.engine.trainer.BaseTrainer.resume_training
def resume_training(self, ckpt)
Resume YOLO training from given epoch and best fitness.
Args
| Name | Type | Description | Default |
|---|---|---|---|
ckpt | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef resume_training(self, ckpt):
"""Resume YOLO training from given epoch and best fitness."""
if ckpt is None or not self.resume:
return
start_epoch = ckpt.get("epoch", -1) + 1
assert start_epoch > 0, (
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
)
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
if self.epochs < start_epoch:
LOGGER.info(
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
)
self.epochs += ckpt["epoch"] # finetune additional epochs
self._load_checkpoint_state(ckpt)
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
self._close_dataloader_mosaic()
method ultralytics.engine.trainer.BaseTrainer.run_callbacks
def run_callbacks(self, event: str)
Run all existing callbacks associated with a particular event.
Args
| Name | Type | Description | Default |
|---|---|---|---|
event | str | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef run_callbacks(self, event: str):
"""Run all existing callbacks associated with a particular event."""
for callback in self.callbacks.get(event, []):
callback(self)
method ultralytics.engine.trainer.BaseTrainer.save_metrics
def save_metrics(self, metrics)
Save training metrics to a CSV file.
Args
| Name | Type | Description | Default |
|---|---|---|---|
metrics | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef save_metrics(self, metrics):
"""Save training metrics to a CSV file."""
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 2 # number of cols
t = time.time() - self.train_time_start
self.csv.parent.mkdir(parents=True, exist_ok=True) # ensure parent directory exists
s = "" if self.csv.exists() else ("%s," * n % ("epoch", "time", *keys)).rstrip(",") + "\n"
with open(self.csv, "a", encoding="utf-8") as f:
f.write(s + ("%.6g," * n % (self.epoch + 1, t, *vals)).rstrip(",") + "\n")
method ultralytics.engine.trainer.BaseTrainer.save_model
def save_model(self)
Save model training checkpoints with additional metadata.
Source code in ultralytics/engine/trainer.py
View on GitHubdef save_model(self):
"""Save model training checkpoints with additional metadata."""
import io
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
buffer = io.BytesIO()
torch.save(
{
"epoch": self.epoch,
"best_fitness": self.best_fitness,
"model": None, # resume and final checkpoints derive from EMA
"ema": deepcopy(unwrap_model(self.ema.ema)).half(),
"updates": self.ema.updates,
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
"scaler": self.scaler.state_dict(),
"train_args": vars(self.args), # save as dict
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
"train_results": self.read_results_csv(),
"date": datetime.now().isoformat(),
"version": __version__,
"git": {
"root": str(GIT.root),
"branch": GIT.branch,
"commit": GIT.commit,
"origin": GIT.origin,
},
"license": "AGPL-3.0 (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
},
buffer,
)
serialized_ckpt = buffer.getvalue() # get the serialized content to save
# Save checkpoints
self.wdir.mkdir(parents=True, exist_ok=True) # ensure weights directory exists
self.last.write_bytes(serialized_ckpt) # save last.pt
if self.best_fitness == self.fitness:
self.best.write_bytes(serialized_ckpt) # save best.pt
if (self.save_period > 0) and (self.epoch % self.save_period == 0):
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
method ultralytics.engine.trainer.BaseTrainer.set_callback
def set_callback(self, event: str, callback)
Override the existing callbacks with the given callback for the specified event.
Args
| Name | Type | Description | Default |
|---|---|---|---|
event | str | required | |
callback | required |
Source code in ultralytics/engine/trainer.py
View on GitHubdef set_callback(self, event: str, callback):
"""Override the existing callbacks with the given callback for the specified event."""
self.callbacks[event] = [callback]
method ultralytics.engine.trainer.BaseTrainer.set_model_attributes
def set_model_attributes(self)
Set or update model parameters before training.
Source code in ultralytics/engine/trainer.py
View on GitHubdef set_model_attributes(self):
"""Set or update model parameters before training."""
self.model.names = self.data["names"]
method ultralytics.engine.trainer.BaseTrainer.setup_model
def setup_model(self)
Load, create, or download model for any task.
Returns
| Type | Description |
|---|---|
dict | Optional checkpoint to resume training from. |
Source code in ultralytics/engine/trainer.py
View on GitHubdef setup_model(self):
"""Load, create, or download model for any task.
Returns:
(dict): Optional checkpoint to resume training from.
"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
cfg, weights = self.model, None
ckpt = None
if str(self.model).endswith(".pt"):
weights, ckpt = load_checkpoint(self.model)
cfg = weights.yaml
elif isinstance(self.args.pretrained, (str, Path)):
weights, _ = load_checkpoint(self.args.pretrained)
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
return ckpt
method ultralytics.engine.trainer.BaseTrainer.train
def train(self)
Allow device='', device=None on Multi-GPU systems to default to device=0.
Source code in ultralytics/engine/trainer.py
View on GitHubdef train(self):
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
# Run subprocess if DDP training, else train normally
if self.ddp:
# Argument checks
if self.args.rect:
LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
self.args.rect = False
if self.args.batch < 1.0:
raise ValueError(
"AutoBatch with batch<1 not supported for Multi-GPU training, "
f"please specify a valid batch size multiple of GPU count {self.world_size}, i.e. batch={self.world_size * 8}."
)
# Command
cmd, file = generate_ddp_command(self)
try:
LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
subprocess.run(cmd, check=True)
except Exception as e:
raise e
finally:
ddp_cleanup(self, str(file))
else:
self._do_train()
method ultralytics.engine.trainer.BaseTrainer.validate
def validate(self)
Run validation on val set using self.validator.
Returns
| Type | Description |
|---|---|
metrics (dict) | Dictionary of validation metrics. |
fitness (float) | Fitness score for the validation. |
Source code in ultralytics/engine/trainer.py
View on GitHubdef validate(self):
"""Run validation on val set using self.validator.
Returns:
metrics (dict): Dictionary of validation metrics.
fitness (float): Fitness score for the validation.
"""
if self.ema and self.world_size > 1:
# Sync EMA buffers from rank 0 to all ranks
for buffer in self.ema.ema.buffers():
dist.broadcast(buffer, src=0)
metrics = self.validator(self)
if metrics is None:
return None, None
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
return metrics, fitness