Reference for ultralytics/models/yolo/classify/train.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/classify/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.yolo.classify.train.ClassificationTrainer
Bases: BaseTrainer
A class extending the BaseTrainer class for training based on a classification model.
This trainer handles the training process for image classification tasks, supporting both YOLO classification models and torchvision models.
Attributes:
Name | Type | Description |
---|---|---|
model |
ClassificationModel
|
The classification model to be trained. |
data |
dict
|
Dictionary containing dataset information including class names and number of classes. |
loss_names |
List[str]
|
Names of the loss functions used during training. |
validator |
ClassificationValidator
|
Validator instance for model evaluation. |
Methods:
Name | Description |
---|---|
set_model_attributes |
Set the model's class names from the loaded dataset. |
get_model |
Return a modified PyTorch model configured for training. |
setup_model |
Load, create or download model for classification. |
build_dataset |
Create a ClassificationDataset instance. |
get_dataloader |
Return PyTorch DataLoader with transforms for image preprocessing. |
preprocess_batch |
Preprocess a batch of images and classes. |
progress_string |
Return a formatted string showing training progress. |
get_validator |
Return an instance of ClassificationValidator. |
label_loss_items |
Return a loss dict with labelled training loss items. |
plot_metrics |
Plot metrics from a CSV file. |
final_eval |
Evaluate trained model and save validation results. |
plot_training_samples |
Plot training samples with their annotations. |
Examples:
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
>>> trainer = ClassificationTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/models/yolo/classify/train.py
build_dataset
Create a ClassificationDataset instance given an image path and mode.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img_path
|
str
|
Path to the dataset images. |
required |
mode
|
str
|
Dataset mode ('train', 'val', or 'test'). |
'train'
|
batch
|
Any
|
Batch information (unused in this implementation). |
None
|
Returns:
Type | Description |
---|---|
ClassificationDataset
|
Dataset for the specified mode. |
Source code in ultralytics/models/yolo/classify/train.py
final_eval
Evaluate trained model and save validation results.
Source code in ultralytics/models/yolo/classify/train.py
get_dataloader
Return PyTorch DataLoader with transforms to preprocess images.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_path
|
str
|
Path to the dataset. |
required |
batch_size
|
int
|
Number of images per batch. |
16
|
rank
|
int
|
Process rank for distributed training. |
0
|
mode
|
str
|
'train', 'val', or 'test' mode. |
'train'
|
Returns:
Type | Description |
---|---|
DataLoader
|
DataLoader for the specified dataset and mode. |
Source code in ultralytics/models/yolo/classify/train.py
get_model
Return a modified PyTorch model configured for training YOLO.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
Any
|
Model configuration. |
None
|
weights
|
Any
|
Pre-trained model weights. |
None
|
verbose
|
bool
|
Whether to display model information. |
True
|
Returns:
Type | Description |
---|---|
ClassificationModel
|
Configured PyTorch model for classification. |
Source code in ultralytics/models/yolo/classify/train.py
get_validator
Returns an instance of ClassificationValidator for validation.
Source code in ultralytics/models/yolo/classify/train.py
label_loss_items
Return a loss dict with labelled training loss items tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_items
|
Tensor
|
Loss tensor items. |
None
|
prefix
|
str
|
Prefix to prepend to loss names. |
'train'
|
Returns:
Type | Description |
---|---|
Dict[str, float] | List[str]
|
Dictionary of loss items or list of loss keys if loss_items is None. |
Source code in ultralytics/models/yolo/classify/train.py
plot_metrics
plot_training_samples
Plot training samples with their annotations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Dict[str, Tensor]
|
Batch containing images and class labels. |
required |
ni
|
int
|
Number of iterations. |
required |
Source code in ultralytics/models/yolo/classify/train.py
preprocess_batch
Preprocesses a batch of images and classes.
progress_string
Returns a formatted string showing training progress.
set_model_attributes
setup_model
Load, create or download model for classification tasks.
Returns:
Type | Description |
---|---|
Any
|
Model checkpoint if applicable, otherwise None. |