跳转至内容

图像分类

图像分类示例

图像分类是三个任务中最简单的,它涉及将整个图像分类到一组预定义的类别中。

图像分类器的输出是单个类别标签和一个置信度分数。当您只需要知道图像属于哪个类别,而不需要知道该类别的对象位于何处或其确切形状时,图像分类非常有用。



观看: 探索 Ultralytics YOLO 任务:使用 Ultralytics HUB 进行图像分类

提示

YOLO11 Classify 模型使用 -cls 后缀,例如 yolo11n-cls.pt 并且在以下数据集上进行了预训练 ImageNet.

模型

此处显示了YOLO11预训练的分类模型。检测、分割和姿势估计模型在COCO数据集上进行预训练,而分类模型在ImageNet数据集上进行预训练。

模型 首次使用时会自动从最新的 Ultralytics 版本 下载。

模型 尺寸
(像素)
acc
top1
acc
top5
速度
CPU ONNX
(毫秒)
速度
T4 TensorRT10
(毫秒)
参数
(M)
FLOPs
(B) at 224
YOLO11n-cls 224 70.0 89.4 5.0 ± 0.3 1.1 ± 0.0 1.6 0.5
YOLO11s-cls 224 75.4 92.7 7.9 ± 0.2 1.3 ± 0.0 5.5 1.6
YOLO11m-cls 224 77.3 93.9 17.2 ± 0.4 2.0 ± 0.0 10.4 5.0
YOLO11l-cls 224 78.3 94.3 23.2 ± 0.3 2.8 ± 0.0 12.9 6.2
YOLO11x-cls 224 79.5 94.9 41.4 ± 0.9 3.8 ± 0.0 28.4 13.7
  • acc 值是模型在 ImageNet 数据集验证集上的模型准确性。
    如需重现结果,请通过 yolo val classify data=path/to/ImageNet device=0
  • 速度 在 ImageNet val 图像上取平均值,使用 Amazon EC2 P4d 实例进行平均计算得出的。
    如需重现结果,请通过 yolo val classify data=path/to/ImageNet batch=1 device=0|cpu

训练

在 MNIST160 数据集上训练 YOLO11n-cls 100 个 epochs,图像大小为 64。有关可用参数的完整列表,请参见配置页面。

示例

from ultralytics import YOLO

# Load a model
model = YOLO("yolo11n-cls.yaml")  # build a new model from YAML
model = YOLO("yolo11n-cls.pt")  # load a pretrained model (recommended for training)
model = YOLO("yolo11n-cls.yaml").load("yolo11n-cls.pt")  # build from YAML and transfer weights

# Train the model
results = model.train(data="mnist160", epochs=100, imgsz=64)
# Build a new model from YAML and start training from scratch
yolo classify train data=mnist160 model=yolo11n-cls.yaml epochs=100 imgsz=64

# Start training from a pretrained *.pt model
yolo classify train data=mnist160 model=yolo11n-cls.pt epochs=100 imgsz=64

# Build a new model from YAML, transfer pretrained weights to it and start training
yolo classify train data=mnist160 model=yolo11n-cls.yaml pretrained=yolo11n-cls.pt epochs=100 imgsz=64

提示

Ultralytics YOLO 分类使用 torchvision.transforms.RandomResizedCrop 用于训练和 torchvision.transforms.CenterCrop 用于验证和推理。 这些基于裁剪的转换假定输入为正方形,并且可能会无意中裁剪掉具有极端纵横比的图像中的重要区域,从而可能导致训练期间关键视觉信息的丢失。 为了在保持图像比例的同时保留完整图像,请考虑使用 torchvision.transforms.Resize 而不是裁剪变换。

您可以通过自定义的增强流水线来实现这一点 ClassificationDatasetClassificationTrainer.

import torch
import torchvision.transforms as T

from ultralytics import YOLO
from ultralytics.data.dataset import ClassificationDataset
from ultralytics.models.yolo.classify import ClassificationTrainer, ClassificationValidator


class CustomizedDataset(ClassificationDataset):
    """A customized dataset class for image classification with enhanced data augmentation transforms."""

    def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
        """Initialize a customized classification dataset with enhanced data augmentation transforms."""
        super().__init__(root, args, augment, prefix)

        # Add your custom training transforms here
        train_transforms = T.Compose(
            [
                T.Resize((args.imgsz, args.imgsz)),
                T.RandomHorizontalFlip(p=args.fliplr),
                T.RandomVerticalFlip(p=args.flipud),
                T.RandAugment(interpolation=T.InterpolationMode.BILINEAR),
                T.ColorJitter(brightness=args.hsv_v, contrast=args.hsv_v, saturation=args.hsv_s, hue=args.hsv_h),
                T.ToTensor(),
                T.Normalize(mean=torch.tensor(0), std=torch.tensor(1)),
                T.RandomErasing(p=args.erasing, inplace=True),
            ]
        )

        # Add your custom validation transforms here
        val_transforms = T.Compose(
            [
                T.Resize((args.imgsz, args.imgsz)),
                T.ToTensor(),
                T.Normalize(mean=torch.tensor(0), std=torch.tensor(1)),
            ]
        )
        self.torch_transforms = train_transforms if augment else val_transforms


class CustomizedTrainer(ClassificationTrainer):
    """A customized trainer class for YOLO classification models with enhanced dataset handling."""

    def build_dataset(self, img_path: str, mode: str = "train", batch=None):
        """Build a customized dataset for classification training and the validation during training."""
        return CustomizedDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)


class CustomizedValidator(ClassificationValidator):
    """A customized validator class for YOLO classification models with enhanced dataset handling."""

    def build_dataset(self, img_path: str, mode: str = "train"):
        """Build a customized dataset for classification standalone validation."""
        return CustomizedDataset(root=img_path, args=self.args, augment=mode == "train", prefix=self.args.split)


model = YOLO("yolo11n-cls.pt")
model.train(data="imagenet1000", trainer=CustomizedTrainer, epochs=10, imgsz=224, batch=64)
model.val(data="imagenet1000", validator=CustomizedValidator, imgsz=224, batch=64)

数据集格式

YOLO 分类数据集格式的详细信息可以在数据集指南中找到。

验证

验证训练后的 YOLO11n-cls 模型 准确性 在 MNIST160 数据集上。由于 model 保留其训练 data 和参数作为模型属性,因此无需任何参数。

示例

from ultralytics import YOLO

# Load a model
model = YOLO("yolo11n-cls.pt")  # load an official model
model = YOLO("path/to/best.pt")  # load a custom model

# Validate the model
metrics = model.val()  # no arguments needed, dataset and settings remembered
metrics.top1  # top1 accuracy
metrics.top5  # top5 accuracy
yolo classify val model=yolo11n-cls.pt  # val official model
yolo classify val model=path/to/best.pt # val custom model

提示

正如在以下内容中提到的 训练部分,您可以通过使用自定义的 ClassificationTrainer。您需要应用相同的方法,通过实施自定义的 ClassificationValidator 在调用 val() 方法。请参阅中的完整代码示例 训练部分 有关实施细节。

预测

使用训练好的 YOLO11n-cls 模型来运行图像预测。

示例

from ultralytics import YOLO

# Load a model
model = YOLO("yolo11n-cls.pt")  # load an official model
model = YOLO("path/to/best.pt")  # load a custom model

# Predict with the model
results = model("https://ultralytics.com/images/bus.jpg")  # predict on an image
yolo classify predict model=yolo11n-cls.pt source='https://ultralytics.com/images/bus.jpg'  # predict with official model
yolo classify predict model=path/to/best.pt source='https://ultralytics.com/images/bus.jpg' # predict with custom model

查看完整 predict 模式的详细信息,请参阅 预测 页面。

导出

将 YOLO11n-cls 模型导出为其他格式,如 ONNX、CoreML 等。

示例

from ultralytics import YOLO

# Load a model
model = YOLO("yolo11n-cls.pt")  # load an official model
model = YOLO("path/to/best.pt")  # load a custom trained model

# Export the model
model.export(format="onnx")
yolo export model=yolo11n-cls.pt format=onnx  # export official model
yolo export model=path/to/best.pt format=onnx # export custom trained model

下表列出了可用的 YOLO11-cls 导出格式。您可以使用以下命令导出为任何格式 format 参数导出为任何格式,例如 format='onnx'format='engine'。您可以直接在导出的模型上进行预测或验证,例如 yolo predict model=yolo11n-cls.onnx。导出完成后,将显示您的模型的使用示例。

格式 format 参数 模型 元数据 参数
PyTorch - yolo11n-cls.pt -
TorchScript torchscript yolo11n-cls.torchscript imgsz, half, dynamic, optimize, nms, batch, device
ONNX onnx yolo11n-cls.onnx imgsz, half, dynamic, simplify, opset, nms, batch, device
OpenVINO openvino yolo11n-cls_openvino_model/ imgsz, half, dynamic, int8, nms, batch, data, fraction, device
TensorRT engine yolo11n-cls.engine imgsz, half, dynamic, simplify, workspace, int8, nms, batch, data, fraction, device
CoreML coreml yolo11n-cls.mlpackage imgsz, half, int8, nms, batch, device
TF SavedModel saved_model yolo11n-cls_saved_model/ imgsz, keras, int8, nms, batch, device
TF GraphDef pb yolo11n-cls.pb imgsz, batch, device
TF Lite tflite yolo11n-cls.tflite imgsz, half, int8, nms, batch, data, fraction, device
TF Edge TPU edgetpu yolo11n-cls_edgetpu.tflite imgsz, device
TF.js tfjs yolo11n-cls_web_model/ imgsz, half, int8, nms, batch, device
PaddlePaddle paddle yolo11n-cls_paddle_model/ imgsz, batch, device
MNN mnn yolo11n-cls.mnn imgsz, batch, int8, half, device
NCNN ncnn yolo11n-cls_ncnn_model/ imgsz, half, batch, device
IMX500 imx yolo11n-cls_imx_model/ imgsz, int8, data, fraction, device
RKNN rknn yolo11n-cls_rknn_model/ imgsz, batch, name, device

查看完整 export 详情请参见 导出 页面。

常见问题

YOLO11 在图像分类中的目的是什么?

YOLO11 模型,例如 yolo11n-cls.pt,专为高效图像分类而设计。它们为整个图像分配单个类别标签以及置信度分数。这对于知道图像的特定类别就足够的应用特别有用,而不是识别图像中物体的具体位置或形状。

如何训练用于图像分类的 YOLO11 模型?

要训练 YOLO11 模型,您可以使用 python 或 CLI 命令。例如,要训练一个 yolo11n-cls 模型在 MNIST160 数据集上进行 100 个 epochs,图像大小为 64:

示例

from ultralytics import YOLO

# Load a model
model = YOLO("yolo11n-cls.pt")  # load a pretrained model (recommended for training)

# Train the model
results = model.train(data="mnist160", epochs=100, imgsz=64)
yolo classify train data=mnist160 model=yolo11n-cls.pt epochs=100 imgsz=64

有关更多配置选项,请访问配置页面。

在哪里可以找到预训练的 YOLO11 分类模型?

预训练的 YOLO11 分类模型可以在 模型 部分找到。例如 yolo11n-cls.pt, yolo11s-cls.pt, yolo11m-cls.pt等模型,均在 ImageNet 数据集上进行了预训练,可以轻松下载并用于各种图像分类任务。

如何将训练好的 YOLO11 模型导出为不同的格式?

您可以使用 python 或 CLI 命令将训练好的 YOLO11 模型导出为各种格式。例如,要将模型导出为 ONNX 格式:

示例

from ultralytics import YOLO

# Load a model
model = YOLO("yolo11n-cls.pt")  # load the trained model

# Export the model to ONNX
model.export(format="onnx")
yolo export model=yolo11n-cls.pt format=onnx # export the trained model to ONNX format

有关详细的导出选项,请参阅导出页面。

如何验证训练好的 YOLO11 分类模型?

要在 MNIST160 等数据集上验证已训练模型的准确性,可以使用以下 python 或 CLI 命令:

示例

from ultralytics import YOLO

# Load a model
model = YOLO("yolo11n-cls.pt")  # load the trained model

# Validate the model
metrics = model.val()  # no arguments needed, uses the dataset and settings from training
metrics.top1  # top1 accuracy
metrics.top5  # top5 accuracy
yolo classify val model=yolo11n-cls.pt # validate the trained model

有关更多信息,请访问验证部分。



📅 创建于 1 年前 ✏️ 更新于 6 天前

评论