सामग्री पर जाएं

के लिए संदर्भ ultralytics/models/yolo/pose/train.py

नोट

यह फ़ाइल यहाँ उपलब्ध है https://github.com/ultralytics/ultralytics/बूँद/मुख्य/ultralytics/मॉडल/yolo/मुद्रा/ट्रेन.py। यदि आप कोई समस्या देखते हैं तो कृपया पुल अनुरोध का योगदान करके इसे ठीक करने में मदद करें 🛠️। 🙏 धन्यवाद !



ultralytics.models.yolo.pose.train.PoseTrainer

का रूप: DetectionTrainer

एक मुद्रा मॉडल के आधार पर प्रशिक्षण के लिए डिटेक्शनट्रेनर वर्ग का विस्तार करने वाला एक वर्ग।

उदाहरण
from ultralytics.models.yolo.pose import PoseTrainer

args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
trainer = PoseTrainer(overrides=args)
trainer.train()
में स्रोत कोड ultralytics/models/yolo/pose/train.py
11बांग्लादेश 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36373839 404142434445 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 686970717273747576 7778 79
class PoseTrainer(yolo.detect.DetectionTrainer):
    """
    A class extending the DetectionTrainer class for training based on a pose model.

    Example:
        ```python
        from ultralytics.models.yolo.pose import PoseTrainer

        args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
        trainer = PoseTrainer(overrides=args)
        trainer.train()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a PoseTrainer object with specified configurations and overrides."""
        if overrides is None:
            overrides = {}
        overrides["task"] = "pose"
        super().__init__(cfg, overrides, _callbacks)

        if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
            LOGGER.warning(
                "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
                "See https://github.com/ultralytics/ultralytics/issues/4031."
            )

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Get pose estimation model with specified configuration and weights."""
        model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
        if weights:
            model.load(weights)

        return model

    def set_model_attributes(self):
        """Sets keypoints shape attribute of PoseModel."""
        super().set_model_attributes()
        self.model.kpt_shape = self.data["kpt_shape"]

    def get_validator(self):
        """Returns an instance of the PoseValidator class for validation."""
        self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
        return yolo.pose.PoseValidator(
            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
        )

    def plot_training_samples(self, batch, ni):
        """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
        images = batch["img"]
        kpts = batch["keypoints"]
        cls = batch["cls"].squeeze(-1)
        bboxes = batch["bboxes"]
        paths = batch["im_file"]
        batch_idx = batch["batch_idx"]
        plot_images(
            images,
            batch_idx,
            cls,
            bboxes,
            kpts=kpts,
            paths=paths,
            fname=self.save_dir / f"train_batch{ni}.jpg",
            on_plot=self.on_plot,
        )

    def plot_metrics(self):
        """Plots training/val metrics."""
        plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

निर्दिष्ट कॉन्फ़िगरेशन और ओवरराइड के साथ एक PoseTrainer ऑब्जेक्ट को इनिशियलाइज़ करें।

में स्रोत कोड ultralytics/models/yolo/pose/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a PoseTrainer object with specified configurations and overrides."""
    if overrides is None:
        overrides = {}
    overrides["task"] = "pose"
    super().__init__(cfg, overrides, _callbacks)

    if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
        LOGGER.warning(
            "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
            "See https://github.com/ultralytics/ultralytics/issues/4031."
        )

get_model(cfg=None, weights=None, verbose=True)

निर्दिष्ट कॉन्फ़िगरेशन और वजन के साथ मुद्रा अनुमान मॉडल प्राप्त करें।

में स्रोत कोड ultralytics/models/yolo/pose/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Get pose estimation model with specified configuration and weights."""
    model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
    if weights:
        model.load(weights)

    return model

get_validator()

सत्यापन के लिए PoseValidator वर्ग का एक उदाहरण देता है।

में स्रोत कोड ultralytics/models/yolo/pose/train.py
def get_validator(self):
    """Returns an instance of the PoseValidator class for validation."""
    self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
    return yolo.pose.PoseValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

plot_metrics()

भूखंड प्रशिक्षण / वैल मेट्रिक्स।

में स्रोत कोड ultralytics/models/yolo/pose/train.py
def plot_metrics(self):
    """Plots training/val metrics."""
    plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

plot_training_samples(batch, ni)

एनोटेट वर्ग लेबल, बाउंडिंग बॉक्स और कीपॉइंट के साथ प्रशिक्षण नमूनों के एक बैच को प्लॉट करें।

में स्रोत कोड ultralytics/models/yolo/pose/train.py
58 59 60 61 62 63 64 65 6667 6869 70 71 72 7374 75
def plot_training_samples(self, batch, ni):
    """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
    images = batch["img"]
    kpts = batch["keypoints"]
    cls = batch["cls"].squeeze(-1)
    bboxes = batch["bboxes"]
    paths = batch["im_file"]
    batch_idx = batch["batch_idx"]
    plot_images(
        images,
        batch_idx,
        cls,
        bboxes,
        kpts=kpts,
        paths=paths,
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

set_model_attributes()

PoseModel की कीपॉइंट आकृति विशेषता सेट करता है।

में स्रोत कोड ultralytics/models/yolo/pose/train.py
def set_model_attributes(self):
    """Sets keypoints shape attribute of PoseModel."""
    super().set_model_attributes()
    self.model.kpt_shape = self.data["kpt_shape"]





2023-11-12 बनाया गया, अपडेट किया गया 2023-11-25
लेखक: ग्लेन-जोचर (3)