انتقل إلى المحتوى

مرجع ل ultralytics/models/yolo/pose/train.py

ملاحظه

هذا الملف متاح في https://github.com/ultralytics/ultralytics/ نقطة / الرئيسية /ultralytics/نماذج/yolo/ تشكل / قطار .py. إذا اكتشفت مشكلة ، فيرجى المساعدة في إصلاحها من خلال المساهمة في طلب 🛠️ سحب. شكرا لك 🙏!



ultralytics.models.yolo.pose.train.PoseTrainer

قواعد: DetectionTrainer

فئة تمدد فئة DetectionTrainer للتدريب بناء على نموذج الوضع.

مثل
from ultralytics.models.yolo.pose import PoseTrainer

args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
trainer = PoseTrainer(overrides=args)
trainer.train()
شفرة المصدر في ultralytics/models/yolo/pose/train.py
11 12 13 14 15 16 17 18 1920 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 3637383940 4142 4344 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 686970717273747576 7778 79
class PoseTrainer(yolo.detect.DetectionTrainer):
    """
    A class extending the DetectionTrainer class for training based on a pose model.

    Example:
        ```python
        from ultralytics.models.yolo.pose import PoseTrainer

        args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
        trainer = PoseTrainer(overrides=args)
        trainer.train()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a PoseTrainer object with specified configurations and overrides."""
        if overrides is None:
            overrides = {}
        overrides["task"] = "pose"
        super().__init__(cfg, overrides, _callbacks)

        if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
            LOGGER.warning(
                "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
                "See https://github.com/ultralytics/ultralytics/issues/4031."
            )

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Get pose estimation model with specified configuration and weights."""
        model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
        if weights:
            model.load(weights)

        return model

    def set_model_attributes(self):
        """Sets keypoints shape attribute of PoseModel."""
        super().set_model_attributes()
        self.model.kpt_shape = self.data["kpt_shape"]

    def get_validator(self):
        """Returns an instance of the PoseValidator class for validation."""
        self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
        return yolo.pose.PoseValidator(
            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
        )

    def plot_training_samples(self, batch, ni):
        """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
        images = batch["img"]
        kpts = batch["keypoints"]
        cls = batch["cls"].squeeze(-1)
        bboxes = batch["bboxes"]
        paths = batch["im_file"]
        batch_idx = batch["batch_idx"]
        plot_images(
            images,
            batch_idx,
            cls,
            bboxes,
            kpts=kpts,
            paths=paths,
            fname=self.save_dir / f"train_batch{ni}.jpg",
            on_plot=self.on_plot,
        )

    def plot_metrics(self):
        """Plots training/val metrics."""
        plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

تهيئة كائن PoseTrainer مع تكوينات وتجاوزات محددة.

شفرة المصدر في ultralytics/models/yolo/pose/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a PoseTrainer object with specified configurations and overrides."""
    if overrides is None:
        overrides = {}
    overrides["task"] = "pose"
    super().__init__(cfg, overrides, _callbacks)

    if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
        LOGGER.warning(
            "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
            "See https://github.com/ultralytics/ultralytics/issues/4031."
        )

get_model(cfg=None, weights=None, verbose=True)

احصل على نموذج تقدير الوضع مع التكوين والأوزان المحددة.

شفرة المصدر في ultralytics/models/yolo/pose/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Get pose estimation model with specified configuration and weights."""
    model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
    if weights:
        model.load(weights)

    return model

get_validator()

إرجاع مثيل من الفئة بوز فاليكاتور للتحقق من الصحة.

شفرة المصدر في ultralytics/models/yolo/pose/train.py
def get_validator(self):
    """Returns an instance of the PoseValidator class for validation."""
    self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
    return yolo.pose.PoseValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

plot_metrics()

مؤامرات التدريب / مقاييس فال.

شفرة المصدر في ultralytics/models/yolo/pose/train.py
def plot_metrics(self):
    """Plots training/val metrics."""
    plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

plot_training_samples(batch, ni)

ارسم مجموعة من عينات التدريب باستخدام تسميات الفصل المشروحة والمربعات المحيطة والنقاط الرئيسية.

شفرة المصدر في ultralytics/models/yolo/pose/train.py
58 59 60 61 62 63 64 65 66676869 707172737475
def plot_training_samples(self, batch, ni):
    """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
    images = batch["img"]
    kpts = batch["keypoints"]
    cls = batch["cls"].squeeze(-1)
    bboxes = batch["bboxes"]
    paths = batch["im_file"]
    batch_idx = batch["batch_idx"]
    plot_images(
        images,
        batch_idx,
        cls,
        bboxes,
        kpts=kpts,
        paths=paths,
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

set_model_attributes()

يضبط سمة شكل النقاط الرئيسية ل PoseModel.

شفرة المصدر في ultralytics/models/yolo/pose/train.py
def set_model_attributes(self):
    """Sets keypoints shape attribute of PoseModel."""
    super().set_model_attributes()
    self.model.kpt_shape = self.data["kpt_shape"]





تم إنشاء 2023-11-12, اخر تحديث 2023-11-25
المؤلفون: جلين جوشر (3)