Reference for ultralytics/engine/trainer.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/trainer.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.engine.trainer.BaseTrainer
BaseTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)
A base class for creating trainers.
Attributes:
Name | Type | Description |
---|---|---|
args |
SimpleNamespace
|
Configuration for the trainer. |
validator |
BaseValidator
|
Validator instance. |
model |
Module
|
Model instance. |
callbacks |
defaultdict
|
Dictionary of callbacks. |
save_dir |
Path
|
Directory to save results. |
wdir |
Path
|
Directory to save weights. |
last |
Path
|
Path to the last checkpoint. |
best |
Path
|
Path to the best checkpoint. |
save_period |
int
|
Save checkpoint every x epochs (disabled if < 1). |
batch_size |
int
|
Batch size for training. |
epochs |
int
|
Number of epochs to train for. |
start_epoch |
int
|
Starting epoch for training. |
device |
device
|
Device to use for training. |
amp |
bool
|
Flag to enable AMP (Automatic Mixed Precision). |
scaler |
GradScaler
|
Gradient scaler for AMP. |
data |
str
|
Path to data. |
ema |
Module
|
EMA (Exponential Moving Average) of the model. |
resume |
bool
|
Resume training from a checkpoint. |
lf |
Module
|
Loss function. |
scheduler |
_LRScheduler
|
Learning rate scheduler. |
best_fitness |
float
|
The best fitness value achieved. |
fitness |
float
|
Current fitness value. |
loss |
float
|
Current loss value. |
tloss |
float
|
Total loss value. |
loss_names |
list
|
List of loss names. |
csv |
Path
|
Path to results CSV file. |
metrics |
dict
|
Dictionary of metrics. |
plots |
dict
|
Dictionary of plots. |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
str
|
Path to a configuration file. Defaults to DEFAULT_CFG. |
DEFAULT_CFG
|
overrides
|
dict
|
Configuration overrides. Defaults to None. |
None
|
_callbacks
|
list
|
List of callback functions. Defaults to None. |
None
|
Source code in ultralytics/engine/trainer.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
|
add_callback
add_callback(event: str, callback)
Append the given callback to the event's callback list.
Source code in ultralytics/engine/trainer.py
162 163 164 |
|
auto_batch
auto_batch(max_num_obj=0)
Calculate optimal batch size based on model and device memory constraints.
Source code in ultralytics/engine/trainer.py
489 490 491 492 493 494 495 496 497 |
|
build_dataset
build_dataset(img_path, mode='train', batch=None)
Build dataset.
Source code in ultralytics/engine/trainer.py
657 658 659 |
|
build_optimizer
build_optimizer(
model, name="auto", lr=0.001, momentum=0.9, decay=1e-05, iterations=100000.0
)
Construct an optimizer for the given model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model for which to build an optimizer. |
required |
name
|
str
|
The name of the optimizer to use. If 'auto', the optimizer is selected based on the number of iterations. Default: 'auto'. |
'auto'
|
lr
|
float
|
The learning rate for the optimizer. Default: 0.001. |
0.001
|
momentum
|
float
|
The momentum factor for the optimizer. Default: 0.9. |
0.9
|
decay
|
float
|
The weight decay for the optimizer. Default: 1e-5. |
1e-05
|
iterations
|
float
|
The number of iterations, which determines the optimizer if name is 'auto'. Default: 1e5. |
100000.0
|
Returns:
Type | Description |
---|---|
Optimizer
|
The constructed optimizer. |
Source code in ultralytics/engine/trainer.py
792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 |
|
build_targets
build_targets(preds, targets)
Builds target tensors for training YOLO model.
Source code in ultralytics/engine/trainer.py
674 675 676 |
|
check_resume
check_resume(overrides)
Check if resume checkpoint exists and update arguments accordingly.
Source code in ultralytics/engine/trainer.py
725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 |
|
final_eval
final_eval()
Perform final evaluation and validation for object detection YOLO model.
Source code in ultralytics/engine/trainer.py
709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 |
|
get_dataloader
get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')
Returns dataloader derived from torch.data.Dataloader.
Source code in ultralytics/engine/trainer.py
653 654 655 |
|
get_dataset
get_dataset()
Get train and validation datasets from data dictionary.
Returns:
Type | Description |
---|---|
dict
|
A dictionary containing the training/validation/test dataset and category names. |
Source code in ultralytics/engine/trainer.py
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 |
|
get_model
get_model(cfg=None, weights=None, verbose=True)
Get model and raise NotImplementedError for loading cfg files.
Source code in ultralytics/engine/trainer.py
645 646 647 |
|
get_validator
get_validator()
Returns a NotImplementedError when the get_validator function is called.
Source code in ultralytics/engine/trainer.py
649 650 651 |
|
label_loss_items
label_loss_items(loss_items=None, prefix='train')
Returns a loss dict with labelled training loss items tensor.
Note
This is not needed for classification but necessary for segmentation & detection
Source code in ultralytics/engine/trainer.py
661 662 663 664 665 666 667 668 |
|
on_plot
on_plot(name, data=None)
Registers plots (e.g. to be consumed in callbacks).
Source code in ultralytics/engine/trainer.py
704 705 706 707 |
|
optimizer_step
optimizer_step()
Perform a single step of the training optimizer with gradient clipping and EMA update.
Source code in ultralytics/engine/trainer.py
618 619 620 621 622 623 624 625 626 |
|
plot_metrics
plot_metrics()
Plot and display metrics visually.
Source code in ultralytics/engine/trainer.py
700 701 702 |
|
plot_training_labels
plot_training_labels()
Plots training labels for YOLO model.
Source code in ultralytics/engine/trainer.py
687 688 689 |
|
plot_training_samples
plot_training_samples(batch, ni)
Plots training samples during YOLO training.
Source code in ultralytics/engine/trainer.py
683 684 685 |
|
preprocess_batch
preprocess_batch(batch)
Allows custom preprocessing model inputs and ground truths depending on task type.
Source code in ultralytics/engine/trainer.py
628 629 630 |
|
progress_string
progress_string()
Returns a string describing training progress.
Source code in ultralytics/engine/trainer.py
678 679 680 |
|
read_results_csv
read_results_csv()
Read results.csv into a dictionary using pandas.
Source code in ultralytics/engine/trainer.py
522 523 524 525 526 |
|
resume_training
resume_training(ckpt)
Resume YOLO training from given epoch and best fitness.
Source code in ultralytics/engine/trainer.py
757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 |
|
run_callbacks
run_callbacks(event: str)
Run all existing callbacks associated with a particular event.
Source code in ultralytics/engine/trainer.py
170 171 172 173 |
|
save_metrics
save_metrics(metrics)
Save training metrics to a CSV file.
Source code in ultralytics/engine/trainer.py
691 692 693 694 695 696 697 698 |
|
save_model
save_model()
Save model training checkpoints with additional metadata.
Source code in ultralytics/engine/trainer.py
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 |
|
set_callback
set_callback(event: str, callback)
Override the existing callbacks with the given callback for the specified event.
Source code in ultralytics/engine/trainer.py
166 167 168 |
|
set_model_attributes
set_model_attributes()
Set or update model parameters before training.
Source code in ultralytics/engine/trainer.py
670 671 672 |
|
setup_model
setup_model()
Load, create, or download model for any task.
Returns:
Type | Description |
---|---|
dict
|
Optional checkpoint to resume training from. |
Source code in ultralytics/engine/trainer.py
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 |
|
train
train()
Allow device='', device=None on Multi-GPU systems to default to device=0.
Source code in ultralytics/engine/trainer.py
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
|
validate
validate()
Run validation on test set using self.validator.
Returns:
Type | Description |
---|---|
tuple
|
A tuple containing metrics dictionary and fitness score. |
Source code in ultralytics/engine/trainer.py
632 633 634 635 636 637 638 639 640 641 642 643 |
|