Skip to content

Reference for ultralytics/models/yolo/pose/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/pose/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.pose.train.PoseTrainer

PoseTrainer(
    cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None
)

Bases: DetectionTrainer

A class extending the DetectionTrainer class for training YOLO pose estimation models.

This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization of pose keypoints alongside bounding boxes.

Attributes:

Name Type Description
args dict

Configuration arguments for training.

model PoseModel

The pose estimation model being trained.

data dict

Dataset configuration including keypoint shape information.

loss_names tuple

Names of the loss components used in training.

Methods:

Name Description
get_model

Retrieve a pose estimation model with specified configuration.

set_model_attributes

Set keypoints shape attribute on the model.

get_validator

Create a validator instance for model evaluation.

plot_training_samples

Visualize training samples with keypoints.

plot_metrics

Generate and save training/validation metric plots.

get_dataset

Retrieve the dataset and ensure it contains required kpt_shape key.

Examples:

>>> from ultralytics.models.yolo.pose import PoseTrainer
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
>>> trainer = PoseTrainer(overrides=args)
>>> trainer.train()

This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and handling specific configurations needed for keypoint detection models.

Parameters:

Name Type Description Default
cfg dict

Default configuration dictionary containing training parameters.

DEFAULT_CFG
overrides dict

Dictionary of parameter overrides for the default configuration.

None
_callbacks list

List of callback functions to be executed during training.

None
Notes

This trainer will automatically set the task to 'pose' regardless of what is provided in overrides. A warning is issued when using Apple MPS device due to known bugs with pose models.

Examples:

>>> from ultralytics.models.yolo.pose import PoseTrainer
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
>>> trainer = PoseTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/models/yolo/pose/train.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):
    """
    Initialize a PoseTrainer object for training YOLO pose estimation models.

    This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
    handling specific configurations needed for keypoint detection models.

    Args:
        cfg (dict, optional): Default configuration dictionary containing training parameters.
        overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
        _callbacks (list, optional): List of callback functions to be executed during training.

    Notes:
        This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
        A warning is issued when using Apple MPS device due to known bugs with pose models.

    Examples:
        >>> from ultralytics.models.yolo.pose import PoseTrainer
        >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
        >>> trainer = PoseTrainer(overrides=args)
        >>> trainer.train()
    """
    if overrides is None:
        overrides = {}
    overrides["task"] = "pose"
    super().__init__(cfg, overrides, _callbacks)

    if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
        LOGGER.warning(
            "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
            "See https://github.com/ultralytics/ultralytics/issues/4031."
        )

get_dataset

get_dataset() -> Dict[str, Any]

Retrieve the dataset and ensure it contains the required kpt_shape key.

Returns:

Type Description
dict

A dictionary containing the training/validation/test dataset and category names.

Raises:

Type Description
KeyError

If the kpt_shape key is not present in the dataset.

Source code in ultralytics/models/yolo/pose/train.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def get_dataset(self) -> Dict[str, Any]:
    """
    Retrieve the dataset and ensure it contains the required `kpt_shape` key.

    Returns:
        (dict): A dictionary containing the training/validation/test dataset and category names.

    Raises:
        KeyError: If the `kpt_shape` key is not present in the dataset.
    """
    data = super().get_dataset()
    if "kpt_shape" not in data:
        raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
    return data

get_model

get_model(
    cfg: Optional[Union[str, Path, Dict[str, Any]]] = None,
    weights: Optional[Union[str, Path]] = None,
    verbose: bool = True,
) -> PoseModel

Get pose estimation model with specified configuration and weights.

Parameters:

Name Type Description Default
cfg str | Path | dict

Model configuration file path or dictionary.

None
weights str | Path

Path to the model weights file.

None
verbose bool

Whether to display model information.

True

Returns:

Type Description
PoseModel

Initialized pose estimation model.

Source code in ultralytics/models/yolo/pose/train.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def get_model(
    self,
    cfg: Optional[Union[str, Path, Dict[str, Any]]] = None,
    weights: Optional[Union[str, Path]] = None,
    verbose: bool = True,
) -> PoseModel:
    """
    Get pose estimation model with specified configuration and weights.

    Args:
        cfg (str | Path | dict, optional): Model configuration file path or dictionary.
        weights (str | Path, optional): Path to the model weights file.
        verbose (bool): Whether to display model information.

    Returns:
        (PoseModel): Initialized pose estimation model.
    """
    model = PoseModel(
        cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
    )
    if weights:
        model.load(weights)

    return model

get_validator

get_validator()

Return an instance of the PoseValidator class for validation.

Source code in ultralytics/models/yolo/pose/train.py
104
105
106
107
108
109
def get_validator(self):
    """Return an instance of the PoseValidator class for validation."""
    self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
    return yolo.pose.PoseValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

plot_metrics

plot_metrics()

Plot training/validation metrics.

Source code in ultralytics/models/yolo/pose/train.py
111
112
113
def plot_metrics(self):
    """Plot training/validation metrics."""
    plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

set_model_attributes

set_model_attributes()

Set keypoints shape attribute of PoseModel.

Source code in ultralytics/models/yolo/pose/train.py
 99
100
101
102
def set_model_attributes(self):
    """Set keypoints shape attribute of PoseModel."""
    super().set_model_attributes()
    self.model.kpt_shape = self.data["kpt_shape"]





📅 Created 1 year ago ✏️ Updated 11 months ago