Skip to content

Reference for ultralytics/models/yolo/pose/val.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/pose/val.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.pose.val.PoseValidator

PoseValidator(dataloader=None, save_dir=None, args=None, _callbacks=None)

Bases: DetectionValidator

A class extending the DetectionValidator class for validation based on a pose model.

This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized metrics for pose evaluation.

Attributes:

Name Type Description
sigma ndarray

Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.

kpt_shape List[int]

Shape of the keypoints, typically [17, 3] for COCO format.

args dict

Arguments for the validator including task set to "pose".

metrics PoseMetrics

Metrics object for pose evaluation.

Methods:

Name Description
preprocess

Preprocess batch by converting keypoints data to float and moving it to the device.

get_desc

Return description of evaluation metrics in string format.

init_metrics

Initialize pose estimation metrics for YOLO model.

_prepare_batch

Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.

_prepare_pred

Prepare and scale keypoints in predictions for pose processing.

_process_batch

Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.

plot_val_samples

Plot and save validation set samples with ground truth bounding boxes and keypoints.

plot_predictions

Plot and save model predictions with bounding boxes and keypoints.

save_one_txt

Save YOLO pose detections to a text file in normalized coordinates.

pred_to_json

Convert YOLO predictions to COCO JSON format.

eval_json

Evaluate object detection model using COCO JSON format.

Examples:

>>> from ultralytics.models.yolo.pose import PoseValidator
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
>>> validator = PoseValidator(args=args)
>>> validator()

This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized metrics for pose evaluation.

Parameters:

Name Type Description Default
dataloader DataLoader

Dataloader to be used for validation.

None
save_dir Path | str

Directory to save results.

None
args dict

Arguments for the validator including task set to "pose".

None
_callbacks list

List of callback functions to be executed during validation.

None

Examples:

>>> from ultralytics.models.yolo.pose import PoseValidator
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
>>> validator = PoseValidator(args=args)
>>> validator()
Notes

This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS due to a known bug with pose models.

Source code in ultralytics/models/yolo/pose/val.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
    """
    Initialize a PoseValidator object for pose estimation validation.

    This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
    specialized metrics for pose evaluation.

    Args:
        dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
        save_dir (Path | str, optional): Directory to save results.
        args (dict, optional): Arguments for the validator including task set to "pose".
        _callbacks (list, optional): List of callback functions to be executed during validation.

    Examples:
        >>> from ultralytics.models.yolo.pose import PoseValidator
        >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
        >>> validator = PoseValidator(args=args)
        >>> validator()

    Notes:
        This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
        for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
        due to a known bug with pose models.
    """
    super().__init__(dataloader, save_dir, args, _callbacks)
    self.sigma = None
    self.kpt_shape = None
    self.args.task = "pose"
    self.metrics = PoseMetrics()
    if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
        LOGGER.warning(
            "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
            "See https://github.com/ultralytics/ultralytics/issues/4031."
        )

eval_json

eval_json(stats: Dict[str, Any]) -> Dict[str, Any]

Evaluate object detection model using COCO JSON format.

Source code in ultralytics/models/yolo/pose/val.py
289
290
291
292
293
def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
    """Evaluate object detection model using COCO JSON format."""
    anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json"  # annotations
    pred_json = self.save_dir / "predictions.json"  # predictions
    return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])

get_desc

get_desc() -> str

Return description of evaluation metrics in string format.

Source code in ultralytics/models/yolo/pose/val.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def get_desc(self) -> str:
    """Return description of evaluation metrics in string format."""
    return ("%22s" + "%11s" * 10) % (
        "Class",
        "Images",
        "Instances",
        "Box(P",
        "R",
        "mAP50",
        "mAP50-95)",
        "Pose(P",
        "R",
        "mAP50",
        "mAP50-95)",
    )

init_metrics

init_metrics(model: Module) -> None

Initialize evaluation metrics for YOLO pose validation.

Parameters:

Name Type Description Default
model Module

Model to validate.

required
Source code in ultralytics/models/yolo/pose/val.py
106
107
108
109
110
111
112
113
114
115
116
117
def init_metrics(self, model: torch.nn.Module) -> None:
    """
    Initialize evaluation metrics for YOLO pose validation.

    Args:
        model (torch.nn.Module): Model to validate.
    """
    super().init_metrics(model)
    self.kpt_shape = self.data["kpt_shape"]
    is_pose = self.kpt_shape == [17, 3]
    nkpt = self.kpt_shape[0]
    self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt

postprocess

postprocess(preds: Tensor) -> Dict[str, torch.Tensor]

Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.

This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).

Parameters:

Name Type Description Default
preds Tensor

Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence scores, class predictions, and keypoint data.

required

Returns:

Type Description
Dict[Tensor]

Dict of processed prediction dictionaries, each containing: - 'bboxes': Bounding box coordinates - 'conf': Confidence scores - 'cls': Class predictions - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)

Note

If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues to the next one. The keypoints are extracted from the 'extra' field which contains additional task-specific data beyond basic detection.

Source code in ultralytics/models/yolo/pose/val.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def postprocess(self, preds: torch.Tensor) -> Dict[str, torch.Tensor]:
    """
    Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.

    This method extends the parent class postprocessing by extracting keypoints from the 'extra'
    field of predictions and reshaping them according to the keypoint shape configuration.
    The keypoints are reshaped from a flattened format to the proper dimensional structure
    (typically [N, 17, 3] for COCO pose format).

    Args:
        preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
            bounding boxes, confidence scores, class predictions, and keypoint data.

    Returns:
        (Dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
            - 'bboxes': Bounding box coordinates
            - 'conf': Confidence scores
            - 'cls': Class predictions
            - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)

    Note:
        If no keypoints are present in a prediction (empty keypoints), that prediction
        is skipped and continues to the next one. The keypoints are extracted from the
        'extra' field which contains additional task-specific data beyond basic detection.
    """
    preds = super().postprocess(preds)
    for pred in preds:
        pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape)  # remove extra if exists
    return preds

pred_to_json

pred_to_json(predn: Dict[str, Tensor], filename: str) -> None

Convert YOLO predictions to COCO JSON format.

This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO format, and appends the results to the internal JSON dictionary (self.jdict).

Parameters:

Name Type Description Default
predn Dict[str, Tensor]

Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints' tensors.

required
filename str

Path to the image file for which predictions are being processed.

required
Notes

The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string), converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner before saving to the JSON dictionary.

Source code in ultralytics/models/yolo/pose/val.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:
    """
    Convert YOLO predictions to COCO JSON format.

    This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
    to COCO format, and appends the results to the internal JSON dictionary (self.jdict).

    Args:
        predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
            and 'keypoints' tensors.
        filename (str): Path to the image file for which predictions are being processed.

    Notes:
        The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
        converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
        before saving to the JSON dictionary.
    """
    stem = Path(filename).stem
    image_id = int(stem) if stem.isnumeric() else stem
    box = ops.xyxy2xywh(predn["bboxes"])  # xywh
    box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
    for b, s, c, k in zip(
        box.tolist(),
        predn["conf"].tolist(),
        predn["cls"].tolist(),
        predn["keypoints"].flatten(1, 2).tolist(),
    ):
        self.jdict.append(
            {
                "image_id": image_id,
                "category_id": self.class_map[int(c)],
                "bbox": [round(x, 3) for x in b],
                "keypoints": k,
                "score": round(s, 5),
            }
        )

preprocess

preprocess(batch: Dict[str, Any]) -> Dict[str, Any]

Preprocess batch by converting keypoints data to float and moving it to the device.

Source code in ultralytics/models/yolo/pose/val.py
84
85
86
87
88
def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
    """Preprocess batch by converting keypoints data to float and moving it to the device."""
    batch = super().preprocess(batch)
    batch["keypoints"] = batch["keypoints"].to(self.device).float()
    return batch

save_one_txt

save_one_txt(
    predn: Dict[str, Tensor],
    save_conf: bool,
    shape: Tuple[int, int],
    file: Path,
) -> None

Save YOLO pose detections to a text file in normalized coordinates.

Parameters:

Name Type Description Default
predn Dict[str, Tensor]

Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.

required
save_conf bool

Whether to save confidence scores.

required
shape Tuple[int, int]

Shape of the original image (height, width).

required
file Path

Output file path to save detections.

required
Notes

The output format is: class_id x_center y_center width height confidence keypoints where keypoints are normalized (x, y, visibility) values for each point.

Source code in ultralytics/models/yolo/pose/val.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
    """
    Save YOLO pose detections to a text file in normalized coordinates.

    Args:
        predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
        save_conf (bool): Whether to save confidence scores.
        shape (Tuple[int, int]): Shape of the original image (height, width).
        file (Path): Output file path to save detections.

    Notes:
        The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
        normalized (x, y, visibility) values for each point.
    """
    from ultralytics.engine.results import Results

    Results(
        np.zeros((shape[0], shape[1]), dtype=np.uint8),
        path=None,
        names=self.names,
        boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
        keypoints=predn["keypoints"],
    ).save_txt(file, save_conf=save_conf)





📅 Created 1 year ago ✏️ Updated 10 months ago