Skip to content

Reference for ultralytics/models/yolo/pose/train.py

Note

Full source code for this file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/pose/train.py. Help us fix any issues you see by submitting a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.pose.train.PoseTrainer

Bases: DetectionTrainer

A class extending the DetectionTrainer class for training based on a pose model.

Example
from ultralytics.models.yolo.pose import PoseTrainer

args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
trainer = PoseTrainer(overrides=args)
trainer.train()
Source code in ultralytics/models/yolo/pose/train.py
class PoseTrainer(yolo.detect.DetectionTrainer):
    """
    A class extending the DetectionTrainer class for training based on a pose model.

    Example:
        ```python
        from ultralytics.models.yolo.pose import PoseTrainer

        args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
        trainer = PoseTrainer(overrides=args)
        trainer.train()
        ```
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a PoseTrainer object with specified configurations and overrides."""
        if overrides is None:
            overrides = {}
        overrides['task'] = 'pose'
        super().__init__(cfg, overrides, _callbacks)

        if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
            LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
                           'See https://github.com/ultralytics/ultralytics/issues/4031.')

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Get pose estimation model with specified configuration and weights."""
        model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
        if weights:
            model.load(weights)

        return model

    def set_model_attributes(self):
        """Sets keypoints shape attribute of PoseModel."""
        super().set_model_attributes()
        self.model.kpt_shape = self.data['kpt_shape']

    def get_validator(self):
        """Returns an instance of the PoseValidator class for validation."""
        self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
        return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

    def plot_training_samples(self, batch, ni):
        """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
        images = batch['img']
        kpts = batch['keypoints']
        cls = batch['cls'].squeeze(-1)
        bboxes = batch['bboxes']
        paths = batch['im_file']
        batch_idx = batch['batch_idx']
        plot_images(images,
                    batch_idx,
                    cls,
                    bboxes,
                    kpts=kpts,
                    paths=paths,
                    fname=self.save_dir / f'train_batch{ni}.jpg',
                    on_plot=self.on_plot)

    def plot_metrics(self):
        """Plots training/val metrics."""
        plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Initialize a PoseTrainer object with specified configurations and overrides.

Source code in ultralytics/models/yolo/pose/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a PoseTrainer object with specified configurations and overrides."""
    if overrides is None:
        overrides = {}
    overrides['task'] = 'pose'
    super().__init__(cfg, overrides, _callbacks)

    if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
        LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
                       'See https://github.com/ultralytics/ultralytics/issues/4031.')

get_model(cfg=None, weights=None, verbose=True)

Get pose estimation model with specified configuration and weights.

Source code in ultralytics/models/yolo/pose/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Get pose estimation model with specified configuration and weights."""
    model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
    if weights:
        model.load(weights)

    return model

get_validator()

Returns an instance of the PoseValidator class for validation.

Source code in ultralytics/models/yolo/pose/train.py
def get_validator(self):
    """Returns an instance of the PoseValidator class for validation."""
    self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
    return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

plot_metrics()

Plots training/val metrics.

Source code in ultralytics/models/yolo/pose/train.py
def plot_metrics(self):
    """Plots training/val metrics."""
    plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

plot_training_samples(batch, ni)

Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.

Source code in ultralytics/models/yolo/pose/train.py
def plot_training_samples(self, batch, ni):
    """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
    images = batch['img']
    kpts = batch['keypoints']
    cls = batch['cls'].squeeze(-1)
    bboxes = batch['bboxes']
    paths = batch['im_file']
    batch_idx = batch['batch_idx']
    plot_images(images,
                batch_idx,
                cls,
                bboxes,
                kpts=kpts,
                paths=paths,
                fname=self.save_dir / f'train_batch{ni}.jpg',
                on_plot=self.on_plot)

set_model_attributes()

Sets keypoints shape attribute of PoseModel.

Source code in ultralytics/models/yolo/pose/train.py
def set_model_attributes(self):
    """Sets keypoints shape attribute of PoseModel."""
    super().set_model_attributes()
    self.model.kpt_shape = self.data['kpt_shape']




Created 2023-07-16, Updated 2023-08-20
Authors: glenn-jocher (6), Laughing-q (1)