Bỏ để qua phần nội dung

Tài liệu tham khảo cho ultralytics/models/yolo/classify/train.py

Ghi

Tệp này có sẵn tại https://github.com/ultralytics/ultralytics/blob/main/ultralytics/Mô hình/yolo/phân loại/train.py. Nếu bạn phát hiện ra một vấn đề, vui lòng giúp khắc phục nó bằng cách đóng góp Yêu cầu 🛠️ kéo. Cảm ơn bạn 🙏 !



ultralytics.models.yolo.classify.train.ClassificationTrainer

Căn cứ: BaseTrainer

Một lớp mở rộng lớp BaseTrainer để đào tạo dựa trên mô hình phân loại.

Ghi chú
  • Các mô hình phân loại Torchvision cũng có thể được chuyển đến đối số 'model', tức là model = 'resnet18'.
Ví dụ
from ultralytics.models.yolo.classify import ClassificationTrainer

args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
trainer = ClassificationTrainer(overrides=args)
trainer.train()
Mã nguồn trong ultralytics/models/yolo/classify/train.py
class ClassificationTrainer(BaseTrainer):
    """
    A class extending the BaseTrainer class for training based on a classification model.

    Notes:
        - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.

    Example:
        ```python
        from ultralytics.models.yolo.classify import ClassificationTrainer

        args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
        trainer = ClassificationTrainer(overrides=args)
        trainer.train()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
        if overrides is None:
            overrides = {}
        overrides["task"] = "classify"
        if overrides.get("imgsz") is None:
            overrides["imgsz"] = 224
        super().__init__(cfg, overrides, _callbacks)

    def set_model_attributes(self):
        """Set the YOLO model's class names from the loaded dataset."""
        self.model.names = self.data["names"]

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Returns a modified PyTorch model configured for training YOLO."""
        model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)

        for m in model.modules():
            if not self.args.pretrained and hasattr(m, "reset_parameters"):
                m.reset_parameters()
            if isinstance(m, torch.nn.Dropout) and self.args.dropout:
                m.p = self.args.dropout  # set dropout
        for p in model.parameters():
            p.requires_grad = True  # for training
        return model

    def setup_model(self):
        """Load, create or download model for any task."""
        import torchvision  # scope for faster 'import ultralytics'

        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
            return

        model, ckpt = str(self.model), None
        # Load a YOLO model locally, from torchvision, or from Ultralytics assets
        if model.endswith(".pt"):
            self.model, ckpt = attempt_load_one_weight(model, device="cpu")
            for p in self.model.parameters():
                p.requires_grad = True  # for training
        elif model.split(".")[-1] in {"yaml", "yml"}:
            self.model = self.get_model(cfg=model)
        elif model in torchvision.models.__dict__:
            self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
        else:
            raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
        ClassificationModel.reshape_outputs(self.model, self.data["nc"])

        return ckpt

    def build_dataset(self, img_path, mode="train", batch=None):
        """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
        return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)

    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
        """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
        with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
            dataset = self.build_dataset(dataset_path, mode)

        loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
        # Attach inference transforms
        if mode != "train":
            if is_parallel(self.model):
                self.model.module.transforms = loader.dataset.torch_transforms
            else:
                self.model.transforms = loader.dataset.torch_transforms
        return loader

    def preprocess_batch(self, batch):
        """Preprocesses a batch of images and classes."""
        batch["img"] = batch["img"].to(self.device)
        batch["cls"] = batch["cls"].to(self.device)
        return batch

    def progress_string(self):
        """Returns a formatted string showing training progress."""
        return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
            "Epoch",
            "GPU_mem",
            *self.loss_names,
            "Instances",
            "Size",
        )

    def get_validator(self):
        """Returns an instance of ClassificationValidator for validation."""
        self.loss_names = ["loss"]
        return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)

    def label_loss_items(self, loss_items=None, prefix="train"):
        """
        Returns a loss dict with labelled training loss items tensor.

        Not needed for classification but necessary for segmentation & detection
        """
        keys = [f"{prefix}/{x}" for x in self.loss_names]
        if loss_items is None:
            return keys
        loss_items = [round(float(loss_items), 5)]
        return dict(zip(keys, loss_items))

    def plot_metrics(self):
        """Plots metrics from a CSV file."""
        plot_results(file=self.csv, classify=True, on_plot=self.on_plot)  # save results.png

    def final_eval(self):
        """Evaluate trained model and save validation results."""
        for f in self.last, self.best:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
                if f is self.best:
                    LOGGER.info(f"\nValidating {f}...")
                    self.validator.args.data = self.args.data
                    self.validator.args.plots = self.args.plots
                    self.metrics = self.validator(model=f)
                    self.metrics.pop("fitness", None)
                    self.run_callbacks("on_fit_epoch_end")
        LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")

    def plot_training_samples(self, batch, ni):
        """Plots training samples with their annotations."""
        plot_images(
            images=batch["img"],
            batch_idx=torch.arange(len(batch["img"])),
            cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models
            fname=self.save_dir / f"train_batch{ni}.jpg",
            on_plot=self.on_plot,
        )

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Khởi tạo một đối tượng ClassificationTrainer với các ghi đè cấu hình tùy chọn và gọi lại.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
    if overrides is None:
        overrides = {}
    overrides["task"] = "classify"
    if overrides.get("imgsz") is None:
        overrides["imgsz"] = 224
    super().__init__(cfg, overrides, _callbacks)

build_dataset(img_path, mode='train', batch=None)

Tạo một phiên bản ClassificationDataset được cung cấp một đường dẫn hình ảnh và chế độ (đào tạo / kiểm tra, v.v.).

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def build_dataset(self, img_path, mode="train", batch=None):
    """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
    return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)

final_eval()

Đánh giá mô hình được đào tạo và lưu kết quả xác nhận.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def final_eval(self):
    """Evaluate trained model and save validation results."""
    for f in self.last, self.best:
        if f.exists():
            strip_optimizer(f)  # strip optimizers
            if f is self.best:
                LOGGER.info(f"\nValidating {f}...")
                self.validator.args.data = self.args.data
                self.validator.args.plots = self.args.plots
                self.metrics = self.validator(model=f)
                self.metrics.pop("fitness", None)
                self.run_callbacks("on_fit_epoch_end")
    LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")

get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')

Trở lại PyTorch DataLoader với các biến đổi để xử lý trước hình ảnh để suy luận.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
    """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = self.build_dataset(dataset_path, mode)

    loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
    # Attach inference transforms
    if mode != "train":
        if is_parallel(self.model):
            self.model.module.transforms = loader.dataset.torch_transforms
        else:
            self.model.transforms = loader.dataset.torch_transforms
    return loader

get_model(cfg=None, weights=None, verbose=True)

Trả về đã sửa đổi PyTorch Mô hình được cấu hình để đào tạo YOLO.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Returns a modified PyTorch model configured for training YOLO."""
    model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)

    for m in model.modules():
        if not self.args.pretrained and hasattr(m, "reset_parameters"):
            m.reset_parameters()
        if isinstance(m, torch.nn.Dropout) and self.args.dropout:
            m.p = self.args.dropout  # set dropout
    for p in model.parameters():
        p.requires_grad = True  # for training
    return model

get_validator()

Trả về một phiên bản của ClassificationValidator để xác thực.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def get_validator(self):
    """Returns an instance of ClassificationValidator for validation."""
    self.loss_names = ["loss"]
    return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)

label_loss_items(loss_items=None, prefix='train')

Trả về lệnh tổn thất với các mục tổn thất huấn luyện được gắn nhãn tensor.

Không cần thiết để phân loại nhưng cần thiết để phân đoạn và phát hiện

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def label_loss_items(self, loss_items=None, prefix="train"):
    """
    Returns a loss dict with labelled training loss items tensor.

    Not needed for classification but necessary for segmentation & detection
    """
    keys = [f"{prefix}/{x}" for x in self.loss_names]
    if loss_items is None:
        return keys
    loss_items = [round(float(loss_items), 5)]
    return dict(zip(keys, loss_items))

plot_metrics()

Biểu thị số liệu từ tệp CSV.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def plot_metrics(self):
    """Plots metrics from a CSV file."""
    plot_results(file=self.csv, classify=True, on_plot=self.on_plot)  # save results.png

plot_training_samples(batch, ni)

Vẽ mẫu đào tạo với chú thích của họ.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def plot_training_samples(self, batch, ni):
    """Plots training samples with their annotations."""
    plot_images(
        images=batch["img"],
        batch_idx=torch.arange(len(batch["img"])),
        cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

preprocess_batch(batch)

Xử lý trước một loạt các hình ảnh và lớp.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def preprocess_batch(self, batch):
    """Preprocesses a batch of images and classes."""
    batch["img"] = batch["img"].to(self.device)
    batch["cls"] = batch["cls"].to(self.device)
    return batch

progress_string()

Trả về một chuỗi được định dạng hiển thị tiến trình đào tạo.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def progress_string(self):
    """Returns a formatted string showing training progress."""
    return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
        "Epoch",
        "GPU_mem",
        *self.loss_names,
        "Instances",
        "Size",
    )

set_model_attributes()

Đặt YOLO Tên lớp của model từ tập dữ liệu đã tải.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def set_model_attributes(self):
    """Set the YOLO model's class names from the loaded dataset."""
    self.model.names = self.data["names"]

setup_model()

Tải, tạo hoặc tải xuống mô hình cho bất kỳ tác vụ nào.

Mã nguồn trong ultralytics/models/yolo/classify/train.py
def setup_model(self):
    """Load, create or download model for any task."""
    import torchvision  # scope for faster 'import ultralytics'

    if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
        return

    model, ckpt = str(self.model), None
    # Load a YOLO model locally, from torchvision, or from Ultralytics assets
    if model.endswith(".pt"):
        self.model, ckpt = attempt_load_one_weight(model, device="cpu")
        for p in self.model.parameters():
            p.requires_grad = True  # for training
    elif model.split(".")[-1] in {"yaml", "yml"}:
        self.model = self.get_model(cfg=model)
    elif model in torchvision.models.__dict__:
        self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
    else:
        raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
    ClassificationModel.reshape_outputs(self.model, self.data["nc"])

    return ckpt





Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (5), Burhan-Q (1)