Skip to content

Reference for ultralytics/models/yolo/pose/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/pose/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.pose.train.PoseTrainer

PoseTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: DetectionTrainer

A class extending the DetectionTrainer class for training YOLO pose estimation models.

This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization of pose keypoints alongside bounding boxes.

Attributes:

Name Type Description
args dict

Configuration arguments for training.

model PoseModel

The pose estimation model being trained.

data dict

Dataset configuration including keypoint shape information.

loss_names Tuple[str]

Names of the loss components used in training.

Methods:

Name Description
get_model

Retrieves a pose estimation model with specified configuration.

set_model_attributes

Sets keypoints shape attribute on the model.

get_validator

Creates a validator instance for model evaluation.

plot_training_samples

Visualizes training samples with keypoints.

plot_metrics

Generates and saves training/validation metric plots.

Examples:

>>> from ultralytics.models.yolo.pose import PoseTrainer
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
>>> trainer = PoseTrainer(overrides=args)
>>> trainer.train()

This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and handling specific configurations needed for keypoint detection models.

Parameters:

Name Type Description Default
cfg dict

Default configuration dictionary containing training parameters.

DEFAULT_CFG
overrides dict

Dictionary of parameter overrides for the default configuration.

None
_callbacks list

List of callback functions to be executed during training.

None
Notes

This trainer will automatically set the task to 'pose' regardless of what is provided in overrides. A warning is issued when using Apple MPS device due to known bugs with pose models.

Examples:

>>> from ultralytics.models.yolo.pose import PoseTrainer
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
>>> trainer = PoseTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/models/yolo/pose/train.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize a PoseTrainer object for training YOLO pose estimation models.

    This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
    handling specific configurations needed for keypoint detection models.

    Args:
        cfg (dict, optional): Default configuration dictionary containing training parameters.
        overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
        _callbacks (list, optional): List of callback functions to be executed during training.

    Notes:
        This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
        A warning is issued when using Apple MPS device due to known bugs with pose models.

    Examples:
        >>> from ultralytics.models.yolo.pose import PoseTrainer
        >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
        >>> trainer = PoseTrainer(overrides=args)
        >>> trainer.train()
    """
    if overrides is None:
        overrides = {}
    overrides["task"] = "pose"
    super().__init__(cfg, overrides, _callbacks)

    if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
        LOGGER.warning(
            "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
            "See https://github.com/ultralytics/ultralytics/issues/4031."
        )

get_dataset

get_dataset()

Retrieves the dataset and ensures it contains the required kpt_shape key.

Returns:

Type Description
dict

A dictionary containing the training/validation/test dataset and category names.

Raises:

Type Description
KeyError

If the kpt_shape key is not present in the dataset.

Source code in ultralytics/models/yolo/pose/train.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def get_dataset(self):
    """
    Retrieves the dataset and ensures it contains the required `kpt_shape` key.

    Returns:
        (dict): A dictionary containing the training/validation/test dataset and category names.

    Raises:
        KeyError: If the `kpt_shape` key is not present in the dataset.
    """
    data = super().get_dataset()
    if "kpt_shape" not in data:
        raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
    return data

get_model

get_model(cfg=None, weights=None, verbose=True)

Get pose estimation model with specified configuration and weights.

Parameters:

Name Type Description Default
cfg str | Path | dict | None

Model configuration file path or dictionary.

None
weights str | Path | None

Path to the model weights file.

None
verbose bool

Whether to display model information.

True

Returns:

Type Description
PoseModel

Initialized pose estimation model.

Source code in ultralytics/models/yolo/pose/train.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def get_model(self, cfg=None, weights=None, verbose=True):
    """
    Get pose estimation model with specified configuration and weights.

    Args:
        cfg (str | Path | dict | None): Model configuration file path or dictionary.
        weights (str | Path | None): Path to the model weights file.
        verbose (bool): Whether to display model information.

    Returns:
        (PoseModel): Initialized pose estimation model.
    """
    model = PoseModel(
        cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
    )
    if weights:
        model.load(weights)

    return model

get_validator

get_validator()

Returns an instance of the PoseValidator class for validation.

Source code in ultralytics/models/yolo/pose/train.py
 96
 97
 98
 99
100
101
def get_validator(self):
    """Returns an instance of the PoseValidator class for validation."""
    self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
    return yolo.pose.PoseValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

plot_metrics

plot_metrics()

Plots training/val metrics.

Source code in ultralytics/models/yolo/pose/train.py
137
138
139
def plot_metrics(self):
    """Plots training/val metrics."""
    plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

plot_training_samples

plot_training_samples(batch, ni)

Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.

Parameters:

Name Type Description Default
batch dict

Dictionary containing batch data with the following keys: - img (torch.Tensor): Batch of images - keypoints (torch.Tensor): Keypoints coordinates for pose estimation - cls (torch.Tensor): Class labels - bboxes (torch.Tensor): Bounding box coordinates - im_file (list): List of image file paths - batch_idx (torch.Tensor): Batch indices for each instance

required
ni int

Current training iteration number used for filename

required

The function saves the plotted batch as an image in the trainer's save directory with the filename 'train_batch{ni}.jpg', where ni is the iteration number.

Source code in ultralytics/models/yolo/pose/train.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def plot_training_samples(self, batch, ni):
    """
    Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.

    Args:
        batch (dict): Dictionary containing batch data with the following keys:
            - img (torch.Tensor): Batch of images
            - keypoints (torch.Tensor): Keypoints coordinates for pose estimation
            - cls (torch.Tensor): Class labels
            - bboxes (torch.Tensor): Bounding box coordinates
            - im_file (list): List of image file paths
            - batch_idx (torch.Tensor): Batch indices for each instance
        ni (int): Current training iteration number used for filename

    The function saves the plotted batch as an image in the trainer's save directory with the filename
    'train_batch{ni}.jpg', where ni is the iteration number.
    """
    images = batch["img"]
    kpts = batch["keypoints"]
    cls = batch["cls"].squeeze(-1)
    bboxes = batch["bboxes"]
    paths = batch["im_file"]
    batch_idx = batch["batch_idx"]
    plot_images(
        images,
        batch_idx,
        cls,
        bboxes,
        kpts=kpts,
        paths=paths,
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

set_model_attributes

set_model_attributes()

Sets keypoints shape attribute of PoseModel.

Source code in ultralytics/models/yolo/pose/train.py
91
92
93
94
def set_model_attributes(self):
    """Sets keypoints shape attribute of PoseModel."""
    super().set_model_attributes()
    self.model.kpt_shape = self.data["kpt_shape"]





📅 Created 1 year ago ✏️ Updated 8 months ago