सामग्री पर जाएं

के लिए संदर्भ ultralytics/models/yolo/classify/train.py

नोट

यह फ़ाइल यहाँ उपलब्ध है https://github.com/ultralytics/ultralytics/बूँद/मुख्य/ultralytics/मॉडल/yolo/classify/train.py का उपयोग करें। यदि आप कोई समस्या देखते हैं तो कृपया पुल अनुरोध का योगदान करके इसे ठीक करने में मदद करें 🛠️। 🙏 धन्यवाद !



ultralytics.models.yolo.classify.train.ClassificationTrainer

का रूप: BaseTrainer

वर्गीकरण मॉडल के आधार पर प्रशिक्षण के लिए बेसट्रेनर वर्ग का विस्तार करने वाला वर्ग।

नोट्स
  • टॉर्चविजन वर्गीकरण मॉडल को 'मॉडल' तर्क में भी पारित किया जा सकता है, यानी मॉडल = 'रेसनेट18'।
उदाहरण
from ultralytics.models.yolo.classify import ClassificationTrainer

args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
trainer = ClassificationTrainer(overrides=args)
trainer.train()
में स्रोत कोड ultralytics/models/yolo/classify/train.py
class ClassificationTrainer(BaseTrainer):
    """
    A class extending the BaseTrainer class for training based on a classification model.

    Notes:
        - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.

    Example:
        ```python
        from ultralytics.models.yolo.classify import ClassificationTrainer

        args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
        trainer = ClassificationTrainer(overrides=args)
        trainer.train()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
        if overrides is None:
            overrides = {}
        overrides["task"] = "classify"
        if overrides.get("imgsz") is None:
            overrides["imgsz"] = 224
        super().__init__(cfg, overrides, _callbacks)

    def set_model_attributes(self):
        """Set the YOLO model's class names from the loaded dataset."""
        self.model.names = self.data["names"]

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Returns a modified PyTorch model configured for training YOLO."""
        model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)

        for m in model.modules():
            if not self.args.pretrained and hasattr(m, "reset_parameters"):
                m.reset_parameters()
            if isinstance(m, torch.nn.Dropout) and self.args.dropout:
                m.p = self.args.dropout  # set dropout
        for p in model.parameters():
            p.requires_grad = True  # for training
        return model

    def setup_model(self):
        """Load, create or download model for any task."""
        import torchvision  # scope for faster 'import ultralytics'

        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
            return

        model, ckpt = str(self.model), None
        # Load a YOLO model locally, from torchvision, or from Ultralytics assets
        if model.endswith(".pt"):
            self.model, ckpt = attempt_load_one_weight(model, device="cpu")
            for p in self.model.parameters():
                p.requires_grad = True  # for training
        elif model.split(".")[-1] in {"yaml", "yml"}:
            self.model = self.get_model(cfg=model)
        elif model in torchvision.models.__dict__:
            self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
        else:
            raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
        ClassificationModel.reshape_outputs(self.model, self.data["nc"])

        return ckpt

    def build_dataset(self, img_path, mode="train", batch=None):
        """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
        return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)

    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
        """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
        with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
            dataset = self.build_dataset(dataset_path, mode)

        loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
        # Attach inference transforms
        if mode != "train":
            if is_parallel(self.model):
                self.model.module.transforms = loader.dataset.torch_transforms
            else:
                self.model.transforms = loader.dataset.torch_transforms
        return loader

    def preprocess_batch(self, batch):
        """Preprocesses a batch of images and classes."""
        batch["img"] = batch["img"].to(self.device)
        batch["cls"] = batch["cls"].to(self.device)
        return batch

    def progress_string(self):
        """Returns a formatted string showing training progress."""
        return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
            "Epoch",
            "GPU_mem",
            *self.loss_names,
            "Instances",
            "Size",
        )

    def get_validator(self):
        """Returns an instance of ClassificationValidator for validation."""
        self.loss_names = ["loss"]
        return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)

    def label_loss_items(self, loss_items=None, prefix="train"):
        """
        Returns a loss dict with labelled training loss items tensor.

        Not needed for classification but necessary for segmentation & detection
        """
        keys = [f"{prefix}/{x}" for x in self.loss_names]
        if loss_items is None:
            return keys
        loss_items = [round(float(loss_items), 5)]
        return dict(zip(keys, loss_items))

    def plot_metrics(self):
        """Plots metrics from a CSV file."""
        plot_results(file=self.csv, classify=True, on_plot=self.on_plot)  # save results.png

    def final_eval(self):
        """Evaluate trained model and save validation results."""
        for f in self.last, self.best:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
                if f is self.best:
                    LOGGER.info(f"\nValidating {f}...")
                    self.validator.args.data = self.args.data
                    self.validator.args.plots = self.args.plots
                    self.metrics = self.validator(model=f)
                    self.metrics.pop("fitness", None)
                    self.run_callbacks("on_fit_epoch_end")
        LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")

    def plot_training_samples(self, batch, ni):
        """Plots training samples with their annotations."""
        plot_images(
            images=batch["img"],
            batch_idx=torch.arange(len(batch["img"])),
            cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models
            fname=self.save_dir / f"train_batch{ni}.jpg",
            on_plot=self.on_plot,
        )

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

वैकल्पिक कॉन्फ़िगरेशन ओवरराइड और कॉलबैक के साथ एक ClassificationTrainer ऑब्जेक्ट प्रारंभ करें।

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
    if overrides is None:
        overrides = {}
    overrides["task"] = "classify"
    if overrides.get("imgsz") is None:
        overrides["imgsz"] = 224
    super().__init__(cfg, overrides, _callbacks)

build_dataset(img_path, mode='train', batch=None)

एक छवि पथ, और मोड (ट्रेन /

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def build_dataset(self, img_path, mode="train", batch=None):
    """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
    return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)

final_eval()

प्रशिक्षित मॉडल का मूल्यांकन करें और सत्यापन परिणामों को सहेजें।

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def final_eval(self):
    """Evaluate trained model and save validation results."""
    for f in self.last, self.best:
        if f.exists():
            strip_optimizer(f)  # strip optimizers
            if f is self.best:
                LOGGER.info(f"\nValidating {f}...")
                self.validator.args.data = self.args.data
                self.validator.args.plots = self.args.plots
                self.metrics = self.validator(model=f)
                self.metrics.pop("fitness", None)
                self.run_callbacks("on_fit_epoch_end")
    LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")

get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')

देता PyTorch DataLoader अनुमान के लिए छवियों को प्रीप्रोसेस करने के लिए बदल जाता है।

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
    """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = self.build_dataset(dataset_path, mode)

    loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
    # Attach inference transforms
    if mode != "train":
        if is_parallel(self.model):
            self.model.module.transforms = loader.dataset.torch_transforms
        else:
            self.model.transforms = loader.dataset.torch_transforms
    return loader

get_model(cfg=None, weights=None, verbose=True)

संशोधित लौटाता है PyTorch प्रशिक्षण के लिए कॉन्फ़िगर किया गया मॉडल YOLO.

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Returns a modified PyTorch model configured for training YOLO."""
    model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)

    for m in model.modules():
        if not self.args.pretrained and hasattr(m, "reset_parameters"):
            m.reset_parameters()
        if isinstance(m, torch.nn.Dropout) and self.args.dropout:
            m.p = self.args.dropout  # set dropout
    for p in model.parameters():
        p.requires_grad = True  # for training
    return model

get_validator()

सत्यापन के लिए ClassificationValidator का एक उदाहरण देता है.

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def get_validator(self):
    """Returns an instance of ClassificationValidator for validation."""
    self.loss_names = ["loss"]
    return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)

label_loss_items(loss_items=None, prefix='train')

लेबल किए गए प्रशिक्षण हानि आइटम्स के साथ हानि डिक्ट लौटाता है tensor.

वर्गीकरण के लिए आवश्यक नहीं है लेकिन विभाजन और पता लगाने के लिए आवश्यक है

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def label_loss_items(self, loss_items=None, prefix="train"):
    """
    Returns a loss dict with labelled training loss items tensor.

    Not needed for classification but necessary for segmentation & detection
    """
    keys = [f"{prefix}/{x}" for x in self.loss_names]
    if loss_items is None:
        return keys
    loss_items = [round(float(loss_items), 5)]
    return dict(zip(keys, loss_items))

plot_metrics()

CSV फ़ाइल से मीट्रिक प्लॉट करें.

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def plot_metrics(self):
    """Plots metrics from a CSV file."""
    plot_results(file=self.csv, classify=True, on_plot=self.on_plot)  # save results.png

plot_training_samples(batch, ni)

उनके एनोटेशन के साथ प्लॉट प्रशिक्षण के नमूने।

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def plot_training_samples(self, batch, ni):
    """Plots training samples with their annotations."""
    plot_images(
        images=batch["img"],
        batch_idx=torch.arange(len(batch["img"])),
        cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

preprocess_batch(batch)

छवियों और कक्षाओं के एक बैच को प्रीप्रोसेस करता है।

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def preprocess_batch(self, batch):
    """Preprocesses a batch of images and classes."""
    batch["img"] = batch["img"].to(self.device)
    batch["cls"] = batch["cls"].to(self.device)
    return batch

progress_string()

प्रशिक्षण प्रगति दिखाने वाली एक स्वरूपित स्ट्रिंग लौटाता है.

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def progress_string(self):
    """Returns a formatted string showing training progress."""
    return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
        "Epoch",
        "GPU_mem",
        *self.loss_names,
        "Instances",
        "Size",
    )

set_model_attributes()

सेट करें YOLO लोड किए गए डेटासेट से मॉडल के वर्ग के नाम।

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def set_model_attributes(self):
    """Set the YOLO model's class names from the loaded dataset."""
    self.model.names = self.data["names"]

setup_model()

किसी भी कार्य के लिए मॉडल लोड, बनाएं या डाउनलोड करें।

में स्रोत कोड ultralytics/models/yolo/classify/train.py
def setup_model(self):
    """Load, create or download model for any task."""
    import torchvision  # scope for faster 'import ultralytics'

    if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
        return

    model, ckpt = str(self.model), None
    # Load a YOLO model locally, from torchvision, or from Ultralytics assets
    if model.endswith(".pt"):
        self.model, ckpt = attempt_load_one_weight(model, device="cpu")
        for p in self.model.parameters():
            p.requires_grad = True  # for training
    elif model.split(".")[-1] in {"yaml", "yml"}:
        self.model = self.get_model(cfg=model)
    elif model in torchvision.models.__dict__:
        self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
    else:
        raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
    ClassificationModel.reshape_outputs(self.model, self.data["nc"])

    return ckpt





Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (5), Burhan-Q (1)