Reference for ultralytics/models/yolo/classify/predict.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/classify/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.yolo.classify.predict.ClassificationPredictor
ClassificationPredictor(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)
Bases: BasePredictor
A class extending the BasePredictor class for prediction based on a classification model.
This predictor handles the specific requirements of classification models, including preprocessing images and postprocessing predictions to generate classification results.
Attributes:
Name | Type | Description |
---|---|---|
args |
dict
|
Configuration arguments for the predictor. |
_legacy_transform_name |
str
|
Name of the legacy transform class for backward compatibility. |
Methods:
Name | Description |
---|---|
preprocess |
Convert input images to model-compatible format. |
postprocess |
Process model predictions into Results objects. |
Notes
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
>>> predictor = ClassificationPredictor(overrides=args)
>>> predictor.predict_cli()
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification tasks. It ensures the task is set to 'classify' regardless of input configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
dict
|
Default configuration dictionary containing prediction settings. Defaults to DEFAULT_CFG. |
DEFAULT_CFG
|
overrides
|
dict
|
Configuration overrides that take precedence over cfg. |
None
|
_callbacks
|
list
|
List of callback functions to be executed during prediction. |
None
|
Source code in ultralytics/models/yolo/classify/predict.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
|
postprocess
postprocess(preds, img, orig_imgs)
Process predictions to return Results objects with classification probabilities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds
|
Tensor
|
Raw predictions from the model. |
required |
img
|
Tensor
|
Input images after preprocessing. |
required |
orig_imgs
|
List[ndarray] | Tensor
|
Original images before preprocessing. |
required |
Returns:
Type | Description |
---|---|
List[Results]
|
List of Results objects containing classification results for each image. |
Source code in ultralytics/models/yolo/classify/predict.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
|
preprocess
preprocess(img)
Convert input images to model-compatible tensor format with appropriate normalization.
Source code in ultralytics/models/yolo/classify/predict.py
54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|