Reference for ultralytics/models/yolo/pose/train.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/pose/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.yolo.pose.train.PoseTrainer
PoseTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)
Bases: DetectionTrainer
A class extending the DetectionTrainer class for training YOLO pose estimation models.
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization of pose keypoints alongside bounding boxes.
Attributes:
Name | Type | Description |
---|---|---|
args |
dict
|
Configuration arguments for training. |
model |
PoseModel
|
The pose estimation model being trained. |
data |
dict
|
Dataset configuration including keypoint shape information. |
loss_names |
Tuple[str]
|
Names of the loss components used in training. |
Methods:
Name | Description |
---|---|
get_model |
Retrieves a pose estimation model with specified configuration. |
set_model_attributes |
Sets keypoints shape attribute on the model. |
get_validator |
Creates a validator instance for model evaluation. |
plot_training_samples |
Visualizes training samples with keypoints. |
plot_metrics |
Generates and saves training/validation metric plots. |
Examples:
>>> from ultralytics.models.yolo.pose import PoseTrainer
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
>>> trainer = PoseTrainer(overrides=args)
>>> trainer.train()
This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and handling specific configurations needed for keypoint detection models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
dict
|
Default configuration dictionary containing training parameters. |
DEFAULT_CFG
|
overrides
|
dict
|
Dictionary of parameter overrides for the default configuration. |
None
|
_callbacks
|
list
|
List of callback functions to be executed during training. |
None
|
Notes
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides. A warning is issued when using Apple MPS device due to known bugs with pose models.
Examples:
>>> from ultralytics.models.yolo.pose import PoseTrainer
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
>>> trainer = PoseTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/models/yolo/pose/train.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
|
get_dataset
get_dataset()
Retrieves the dataset and ensures it contains the required kpt_shape
key.
Returns:
Type | Description |
---|---|
dict
|
A dictionary containing the training/validation/test dataset and category names. |
Raises:
Type | Description |
---|---|
KeyError
|
If the |
Source code in ultralytics/models/yolo/pose/train.py
141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
|
get_model
get_model(cfg=None, weights=None, verbose=True)
Get pose estimation model with specified configuration and weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
str | Path | dict | None
|
Model configuration file path or dictionary. |
None
|
weights
|
str | Path | None
|
Path to the model weights file. |
None
|
verbose
|
bool
|
Whether to display model information. |
True
|
Returns:
Type | Description |
---|---|
PoseModel
|
Initialized pose estimation model. |
Source code in ultralytics/models/yolo/pose/train.py
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
|
get_validator
get_validator()
Returns an instance of the PoseValidator class for validation.
Source code in ultralytics/models/yolo/pose/train.py
96 97 98 99 100 101 |
|
plot_metrics
plot_metrics()
Plots training/val metrics.
Source code in ultralytics/models/yolo/pose/train.py
137 138 139 |
|
plot_training_samples
plot_training_samples(batch, ni)
Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
dict
|
Dictionary containing batch data with the following keys: - img (torch.Tensor): Batch of images - keypoints (torch.Tensor): Keypoints coordinates for pose estimation - cls (torch.Tensor): Class labels - bboxes (torch.Tensor): Bounding box coordinates - im_file (list): List of image file paths - batch_idx (torch.Tensor): Batch indices for each instance |
required |
ni
|
int
|
Current training iteration number used for filename |
required |
The function saves the plotted batch as an image in the trainer's save directory with the filename 'train_batch{ni}.jpg', where ni is the iteration number.
Source code in ultralytics/models/yolo/pose/train.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
|
set_model_attributes
set_model_attributes()
Sets keypoints shape attribute of PoseModel.
Source code in ultralytics/models/yolo/pose/train.py
91 92 93 94 |
|