Reference for ultralytics/models/yolo/classify/train.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/classify/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.yolo.classify.train.ClassificationTrainer
Bases: BaseTrainer
A class extending the BaseTrainer class for training based on a classification model.
Notes
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Example
Source code in ultralytics/models/yolo/classify/train.py
build_dataset
Creates a ClassificationDataset instance given an image path, and mode (train/test etc.).
Source code in ultralytics/models/yolo/classify/train.py
final_eval
Evaluate trained model and save validation results.
Source code in ultralytics/models/yolo/classify/train.py
get_dataloader
Returns PyTorch DataLoader with transforms to preprocess images for inference.
Source code in ultralytics/models/yolo/classify/train.py
get_model
Returns a modified PyTorch model configured for training YOLO.
Source code in ultralytics/models/yolo/classify/train.py
get_validator
Returns an instance of ClassificationValidator for validation.
Source code in ultralytics/models/yolo/classify/train.py
label_loss_items
Returns a loss dict with labelled training loss items tensor.
Not needed for classification but necessary for segmentation & detection
Source code in ultralytics/models/yolo/classify/train.py
plot_metrics
plot_training_samples
Plots training samples with their annotations.
Source code in ultralytics/models/yolo/classify/train.py
preprocess_batch
progress_string
Returns a formatted string showing training progress.
set_model_attributes
setup_model
Load, create or download model for any task.