Reference for ultralytics/models/yolo/classify/train.py
Note
Full source code for this file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/classify/train.py. Help us fix any issues you see by submitting a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.yolo.classify.train.ClassificationTrainer
Bases: BaseTrainer
A class extending the BaseTrainer class for training based on a classification model.
Notes
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Example
Source code in ultralytics/models/yolo/classify/train.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
|
__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)
Initialize a ClassificationTrainer object with optional configuration overrides and callbacks.
Source code in ultralytics/models/yolo/classify/train.py
final_eval()
Evaluate trained model and save validation results.
Source code in ultralytics/models/yolo/classify/train.py
get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')
Returns PyTorch DataLoader with transforms to preprocess images for inference.
Source code in ultralytics/models/yolo/classify/train.py
get_model(cfg=None, weights=None, verbose=True)
Returns a modified PyTorch model configured for training YOLO.
Source code in ultralytics/models/yolo/classify/train.py
get_validator()
Returns an instance of ClassificationValidator for validation.
label_loss_items(loss_items=None, prefix='train')
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for segmentation & detection
Source code in ultralytics/models/yolo/classify/train.py
plot_metrics()
plot_training_samples(batch, ni)
Plots training samples with their annotations.
Source code in ultralytics/models/yolo/classify/train.py
preprocess_batch(batch)
Preprocesses a batch of images and classes.
progress_string()
Returns a formatted string showing training progress.
set_model_attributes()
setup_model()
Load, create or download model for any task.