Reference for ultralytics/models/yolo/classify/val.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/classify/val.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.yolo.classify.val.ClassificationValidator
Bases: BaseValidator
A class extending the BaseValidator class for validation based on a classification model.
This validator handles the validation process for classification models, including metrics calculation, confusion matrix generation, and visualization of results.
Attributes:
Name | Type | Description |
---|---|---|
targets |
List[Tensor]
|
Ground truth class labels. |
pred |
List[Tensor]
|
Model predictions. |
metrics |
ClassifyMetrics
|
Object to calculate and store classification metrics. |
names |
dict
|
Mapping of class indices to class names. |
nc |
int
|
Number of classes. |
confusion_matrix |
ConfusionMatrix
|
Matrix to evaluate model performance across classes. |
Methods:
Name | Description |
---|---|
get_desc |
Return a formatted string summarizing classification metrics. |
init_metrics |
Initialize confusion matrix, class names, and tracking containers. |
preprocess |
Preprocess input batch by moving data to device. |
update_metrics |
Update running metrics with model predictions and batch targets. |
finalize_metrics |
Finalize metrics including confusion matrix and processing speed. |
postprocess |
Extract the primary prediction from model output. |
get_stats |
Calculate and return a dictionary of metrics. |
build_dataset |
Create a ClassificationDataset instance for validation. |
get_dataloader |
Build and return a data loader for classification validation. |
print_results |
Print evaluation metrics for the classification model. |
plot_val_samples |
Plot validation image samples with their ground truth labels. |
plot_predictions |
Plot images with their predicted class labels. |
Examples:
>>> from ultralytics.models.yolo.classify import ClassificationValidator
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
>>> validator = ClassificationValidator(args=args)
>>> validator()
Notes
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Source code in ultralytics/models/yolo/classify/val.py
build_dataset
Create a ClassificationDataset instance for validation.
finalize_metrics
Finalize metrics including confusion matrix and processing speed.
Source code in ultralytics/models/yolo/classify/val.py
get_dataloader
Build and return a data loader for classification validation.
Source code in ultralytics/models/yolo/classify/val.py
get_desc
get_stats
Calculate and return a dictionary of metrics by processing targets and predictions.
init_metrics
Initialize confusion matrix, class names, and tracking containers for predictions and targets.
Source code in ultralytics/models/yolo/classify/val.py
plot_predictions
Plot images with their predicted class labels and save the visualization.
Source code in ultralytics/models/yolo/classify/val.py
plot_val_samples
Plot validation image samples with their ground truth labels.
Source code in ultralytics/models/yolo/classify/val.py
postprocess
Extract the primary prediction from model output if it's in a list or tuple format.
preprocess
Preprocess input batch by moving data to device and converting to appropriate dtype.
Source code in ultralytics/models/yolo/classify/val.py
print_results
Print evaluation metrics for the classification model.
update_metrics
Update running metrics with model predictions and batch targets.