Skip to content

Reference for ultralytics/models/yolo/pose/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/pose/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.pose.train.PoseTrainer

PoseTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: DetectionTrainer

A class extending the DetectionTrainer class for training YOLO pose estimation models.

This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization of pose keypoints alongside bounding boxes.

Attributes:

Name Type Description
args dict

Configuration arguments for training.

model PoseModel

The pose estimation model being trained.

data dict

Dataset configuration including keypoint shape information.

loss_names Tuple[str]

Names of the loss components used in training.

Methods:

Name Description
get_model

Retrieves a pose estimation model with specified configuration.

set_model_attributes

Sets keypoints shape attribute on the model.

get_validator

Creates a validator instance for model evaluation.

plot_training_samples

Visualizes training samples with keypoints.

plot_metrics

Generates and saves training/validation metric plots.

Examples:

>>> from ultralytics.models.yolo.pose import PoseTrainer
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
>>> trainer = PoseTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/models/yolo/pose/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a PoseTrainer object with specified configurations and overrides."""
    if overrides is None:
        overrides = {}
    overrides["task"] = "pose"
    super().__init__(cfg, overrides, _callbacks)

    if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
        LOGGER.warning(
            "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
            "See https://github.com/ultralytics/ultralytics/issues/4031."
        )

get_model

get_model(cfg=None, weights=None, verbose=True)

Get pose estimation model with specified configuration and weights.

Source code in ultralytics/models/yolo/pose/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Get pose estimation model with specified configuration and weights."""
    model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
    if weights:
        model.load(weights)

    return model

get_validator

get_validator()

Returns an instance of the PoseValidator class for validation.

Source code in ultralytics/models/yolo/pose/train.py
def get_validator(self):
    """Returns an instance of the PoseValidator class for validation."""
    self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
    return yolo.pose.PoseValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

plot_metrics

plot_metrics()

Plots training/val metrics.

Source code in ultralytics/models/yolo/pose/train.py
def plot_metrics(self):
    """Plots training/val metrics."""
    plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

plot_training_samples

plot_training_samples(batch, ni)

Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.

Source code in ultralytics/models/yolo/pose/train.py
def plot_training_samples(self, batch, ni):
    """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
    images = batch["img"]
    kpts = batch["keypoints"]
    cls = batch["cls"].squeeze(-1)
    bboxes = batch["bboxes"]
    paths = batch["im_file"]
    batch_idx = batch["batch_idx"]
    plot_images(
        images,
        batch_idx,
        cls,
        bboxes,
        kpts=kpts,
        paths=paths,
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

set_model_attributes

set_model_attributes()

Sets keypoints shape attribute of PoseModel.

Source code in ultralytics/models/yolo/pose/train.py
def set_model_attributes(self):
    """Sets keypoints shape attribute of PoseModel."""
    super().set_model_attributes()
    self.model.kpt_shape = self.data["kpt_shape"]



📅 Created 1 year ago ✏️ Updated 6 months ago