Reference for ultralytics/models/yolo/classify/predict.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/classify/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ClassificationPredictor(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)
              Bases: BasePredictor
A class extending the BasePredictor class for prediction based on a classification model.
This predictor handles the specific requirements of classification models, including preprocessing images and postprocessing predictions to generate classification results.
Attributes:
| Name | Type | Description | 
|---|---|---|
| args | dict | Configuration arguments for the predictor. | 
Methods:
| Name | Description | 
|---|---|
| preprocess | Convert input images to model-compatible format. | 
| postprocess | Process model predictions into Results objects. | 
Notes
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
>>> predictor = ClassificationPredictor(overrides=args)
>>> predictor.predict_cli()
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification tasks. It ensures the task is set to 'classify' regardless of input configuration.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| cfg | dict | Default configuration dictionary containing prediction settings. | DEFAULT_CFG | 
| overrides | dict | Configuration overrides that take precedence over cfg. | None | 
| _callbacks | list | List of callback functions to be executed during prediction. | None | 
Source code in ultralytics/models/yolo/classify/predict.py
| 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |  | 
postprocess(preds, img, orig_imgs)
Process predictions to return Results objects with classification probabilities.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| preds | Tensor | Raw predictions from the model. | required | 
| img | Tensor | Input images after preprocessing. | required | 
| orig_imgs | list[ndarray] | Tensor | Original images before preprocessing. | required | 
Returns:
| Type | Description | 
|---|---|
| list[Results] | List of Results objects containing classification results for each image. | 
Source code in ultralytics/models/yolo/classify/predict.py
| 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |  | 
preprocess(img)
Convert input images to model-compatible tensor format with appropriate normalization.
Source code in ultralytics/models/yolo/classify/predict.py
| 65 66 67 68 69 70 71 72 |  | 
setup_source(source)
Set up source and inference mode and classify transforms.
Source code in ultralytics/models/yolo/classify/predict.py
| 53 54 55 56 57 58 59 60 61 62 63 |  |