Reference for ultralytics/engine/trainer.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/trainer.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.engine.trainer.BaseTrainer
A base class for creating trainers.
Attributes:
Name | Type | Description |
---|---|---|
args | SimpleNamespace | Configuration for the trainer. |
validator | BaseValidator | Validator instance. |
model | Module | Model instance. |
callbacks | defaultdict | Dictionary of callbacks. |
save_dir | Path | Directory to save results. |
wdir | Path | Directory to save weights. |
last | Path | Path to the last checkpoint. |
best | Path | Path to the best checkpoint. |
save_period | int | Save checkpoint every x epochs (disabled if < 1). |
batch_size | int | Batch size for training. |
epochs | int | Number of epochs to train for. |
start_epoch | int | Starting epoch for training. |
device | device | Device to use for training. |
amp | bool | Flag to enable AMP (Automatic Mixed Precision). |
scaler | GradScaler | Gradient scaler for AMP. |
data | str | Path to data. |
trainset | Dataset | Training dataset. |
testset | Dataset | Testing dataset. |
ema | Module | EMA (Exponential Moving Average) of the model. |
resume | bool | Resume training from a checkpoint. |
lf | Module | Loss function. |
scheduler | _LRScheduler | Learning rate scheduler. |
best_fitness | float | The best fitness value achieved. |
fitness | float | Current fitness value. |
loss | float | Current loss value. |
tloss | float | Total loss value. |
loss_names | list | List of loss names. |
csv | Path | Path to results CSV file. |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg | str | Path to a configuration file. Defaults to DEFAULT_CFG. | DEFAULT_CFG |
overrides | dict | Configuration overrides. Defaults to None. | None |
Source code in ultralytics/engine/trainer.py
add_callback
auto_batch
Get batch size by calculating memory occupation of model.
Source code in ultralytics/engine/trainer.py
build_dataset
build_optimizer
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum, weight decay, and number of iterations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model | Module | The model for which to build an optimizer. | required |
name | str | The name of the optimizer to use. If 'auto', the optimizer is selected based on the number of iterations. Default: 'auto'. | 'auto' |
lr | float | The learning rate for the optimizer. Default: 0.001. | 0.001 |
momentum | float | The momentum factor for the optimizer. Default: 0.9. | 0.9 |
decay | float | The weight decay for the optimizer. Default: 1e-5. | 1e-05 |
iterations | float | The number of iterations, which determines the optimizer if name is 'auto'. Default: 1e5. | 100000.0 |
Returns:
Type | Description |
---|---|
Optimizer | The constructed optimizer. |
Source code in ultralytics/engine/trainer.py
build_targets
check_resume
Check if resume checkpoint exists and update arguments accordingly.
Source code in ultralytics/engine/trainer.py
final_eval
Performs final evaluation and validation for object detection YOLO model.
Source code in ultralytics/engine/trainer.py
get_dataloader
Returns dataloader derived from torch.data.Dataloader.
get_dataset
Get train, val path from data dict if it exists.
Returns None if data format is not recognized.
Source code in ultralytics/engine/trainer.py
get_model
Get model and raise NotImplementedError for loading cfg files.
get_validator
Returns a NotImplementedError when the get_validator function is called.
label_loss_items
Returns a loss dict with labelled training loss items tensor.
Note
This is not needed for classification but necessary for segmentation & detection
Source code in ultralytics/engine/trainer.py
on_plot
optimizer_step
Perform a single step of the training optimizer with gradient clipping and EMA update.
Source code in ultralytics/engine/trainer.py
plot_metrics
plot_training_labels
plot_training_samples
preprocess_batch
progress_string
read_results_csv
resume_training
Resume YOLO training from given epoch and best fitness.
Source code in ultralytics/engine/trainer.py
run_callbacks
save_metrics
Saves training metrics to a CSV file.
Source code in ultralytics/engine/trainer.py
save_model
Save model training checkpoints with additional metadata.
Source code in ultralytics/engine/trainer.py
set_callback
set_model_attributes
setup_model
Load/create/download model for any task.
Source code in ultralytics/engine/trainer.py
train
Allow device='', device=None on Multi-GPU systems to default to device=0.
Source code in ultralytics/engine/trainer.py
validate
Runs validation on test set using self.validator.
The returned dict is expected to contain "fitness" key.