Skip to content

Reference for ultralytics/models/yolo/pose/val.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/pose/val.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.pose.val.PoseValidator

PoseValidator(
    dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None
)

Bases: DetectionValidator

A class extending the DetectionValidator class for validation based on a pose model.

This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized metrics for pose evaluation.

Attributes:

Name Type Description
sigma ndarray

Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.

kpt_shape List[int]

Shape of the keypoints, typically [17, 3] for COCO format.

args dict

Arguments for the validator including task set to "pose".

metrics PoseMetrics

Metrics object for pose evaluation.

Methods:

Name Description
preprocess

Preprocesses batch data for pose validation.

get_desc

Returns description of evaluation metrics.

init_metrics

Initializes pose metrics for the model.

_prepare_batch

Prepares a batch for processing.

_prepare_pred

Prepares and scales predictions for evaluation.

update_metrics

Updates metrics with new predictions.

_process_batch

Processes batch to compute IoU between detections and ground truth.

plot_val_samples

Plots validation samples with ground truth annotations.

plot_predictions

Plots model predictions.

save_one_txt

Saves detections to a text file.

pred_to_json

Converts predictions to COCO JSON format.

eval_json

Evaluates model using COCO JSON format.

Examples:

>>> from ultralytics.models.yolo.pose import PoseValidator
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
>>> validator = PoseValidator(args=args)
>>> validator()

This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized metrics for pose evaluation.

Parameters:

Name Type Description Default
dataloader DataLoader

Dataloader to be used for validation.

None
save_dir Path | str

Directory to save results.

None
pbar Any

Progress bar for displaying progress.

None
args dict

Arguments for the validator including task set to "pose".

None
_callbacks list

List of callback functions to be executed during validation.

None

Examples:

>>> from ultralytics.models.yolo.pose import PoseValidator
>>> args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
>>> validator = PoseValidator(args=args)
>>> validator()
Notes

This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS due to a known bug with pose models.

Source code in ultralytics/models/yolo/pose/val.py
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
    """
    Initialize a PoseValidator object for pose estimation validation.

    This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
    specialized metrics for pose evaluation.

    Args:
        dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
        save_dir (Path | str, optional): Directory to save results.
        pbar (Any, optional): Progress bar for displaying progress.
        args (dict, optional): Arguments for the validator including task set to "pose".
        _callbacks (list, optional): List of callback functions to be executed during validation.

    Examples:
        >>> from ultralytics.models.yolo.pose import PoseValidator
        >>> args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
        >>> validator = PoseValidator(args=args)
        >>> validator()

    Notes:
        This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
        for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
        due to a known bug with pose models.
    """
    super().__init__(dataloader, save_dir, pbar, args, _callbacks)
    self.sigma = None
    self.kpt_shape = None
    self.args.task = "pose"
    self.metrics = PoseMetrics(save_dir=self.save_dir)
    if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
        LOGGER.warning(
            "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
            "See https://github.com/ultralytics/ultralytics/issues/4031."
        )

eval_json

eval_json(stats)

Evaluate object detection model using COCO JSON format.

Source code in ultralytics/models/yolo/pose/val.py
def eval_json(self, stats):
    """Evaluate object detection model using COCO JSON format."""
    if self.args.save_json and self.is_coco and len(self.jdict):
        anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json"  # annotations
        pred_json = self.save_dir / "predictions.json"  # predictions
        LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
        try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
            check_requirements("pycocotools>=2.0.6")
            from pycocotools.coco import COCO  # noqa
            from pycocotools.cocoeval import COCOeval  # noqa

            for x in anno_json, pred_json:
                assert x.is_file(), f"{x} file not found"
            anno = COCO(str(anno_json))  # init annotations api
            pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path)
            for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
                if self.is_coco:
                    eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]  # im to eval
                eval.evaluate()
                eval.accumulate()
                eval.summarize()
                idx = i * 4 + 2
                stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
                    :2
                ]  # update mAP50-95 and mAP50
        except Exception as e:
            LOGGER.warning(f"pycocotools unable to run: {e}")
    return stats

get_desc

get_desc()

Return description of evaluation metrics in string format.

Source code in ultralytics/models/yolo/pose/val.py
def get_desc(self):
    """Return description of evaluation metrics in string format."""
    return ("%22s" + "%11s" * 10) % (
        "Class",
        "Images",
        "Instances",
        "Box(P",
        "R",
        "mAP50",
        "mAP50-95)",
        "Pose(P",
        "R",
        "mAP50",
        "mAP50-95)",
    )

init_metrics

init_metrics(model)

Initialize pose estimation metrics for YOLO model.

Source code in ultralytics/models/yolo/pose/val.py
def init_metrics(self, model):
    """Initialize pose estimation metrics for YOLO model."""
    super().init_metrics(model)
    self.kpt_shape = self.data["kpt_shape"]
    is_pose = self.kpt_shape == [17, 3]
    nkpt = self.kpt_shape[0]
    self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
    self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])

plot_predictions

plot_predictions(batch, preds, ni)

Plot and save model predictions with bounding boxes and keypoints.

Parameters:

Name Type Description Default
batch dict

Dictionary containing batch data including images, file paths, and other metadata.

required
preds List[Tensor]

List of prediction tensors from the model, each containing bounding boxes, confidence scores, class predictions, and keypoints.

required
ni int

Batch index used for naming the output file.

required

The function extracts keypoints from predictions, converts predictions to target format, and plots them on the input images. The resulting visualization is saved to the specified save directory.

Source code in ultralytics/models/yolo/pose/val.py
def plot_predictions(self, batch, preds, ni):
    """
    Plot and save model predictions with bounding boxes and keypoints.

    Args:
        batch (dict): Dictionary containing batch data including images, file paths, and other metadata.
        preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
            confidence scores, class predictions, and keypoints.
        ni (int): Batch index used for naming the output file.

    The function extracts keypoints from predictions, converts predictions to target format, and plots them
    on the input images. The resulting visualization is saved to the specified save directory.
    """
    pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
    plot_images(
        batch["img"],
        *output_to_target(preds, max_det=self.args.max_det),
        kpts=pred_kpts,
        paths=batch["im_file"],
        fname=self.save_dir / f"val_batch{ni}_pred.jpg",
        names=self.names,
        on_plot=self.on_plot,
    )  # pred

plot_val_samples

plot_val_samples(batch, ni)

Plot and save validation set samples with ground truth bounding boxes and keypoints.

Parameters:

Name Type Description Default
batch dict

Dictionary containing batch data with keys: - img (torch.Tensor): Batch of images - batch_idx (torch.Tensor): Batch indices for each image - cls (torch.Tensor): Class labels - bboxes (torch.Tensor): Bounding box coordinates - keypoints (torch.Tensor): Keypoint coordinates - im_file (list): List of image file paths

required
ni int

Batch index used for naming the output file

required
Source code in ultralytics/models/yolo/pose/val.py
def plot_val_samples(self, batch, ni):
    """
    Plot and save validation set samples with ground truth bounding boxes and keypoints.

    Args:
        batch (dict): Dictionary containing batch data with keys:
            - img (torch.Tensor): Batch of images
            - batch_idx (torch.Tensor): Batch indices for each image
            - cls (torch.Tensor): Class labels
            - bboxes (torch.Tensor): Bounding box coordinates
            - keypoints (torch.Tensor): Keypoint coordinates
            - im_file (list): List of image file paths
        ni (int): Batch index used for naming the output file
    """
    plot_images(
        batch["img"],
        batch["batch_idx"],
        batch["cls"].squeeze(-1),
        batch["bboxes"],
        kpts=batch["keypoints"],
        paths=batch["im_file"],
        fname=self.save_dir / f"val_batch{ni}_labels.jpg",
        names=self.names,
        on_plot=self.on_plot,
    )

pred_to_json

pred_to_json(predn, filename)

Convert YOLO predictions to COCO JSON format.

This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO format, and appends the results to the internal JSON dictionary (self.jdict).

Parameters:

Name Type Description Default
predn Tensor

Prediction tensor containing bounding boxes, confidence scores, class IDs, and keypoints, with shape (N, 6+K) where N is the number of predictions and K is the flattened keypoints dimension.

required
filename str | Path

Path to the image file for which predictions are being processed.

required
Notes

The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string), converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner before saving to the JSON dictionary.

Source code in ultralytics/models/yolo/pose/val.py
def pred_to_json(self, predn, filename):
    """
    Convert YOLO predictions to COCO JSON format.

    This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
    to COCO format, and appends the results to the internal JSON dictionary (self.jdict).

    Args:
        predn (torch.Tensor): Prediction tensor containing bounding boxes, confidence scores, class IDs,
            and keypoints, with shape (N, 6+K) where N is the number of predictions and K is the flattened
            keypoints dimension.
        filename (str | Path): Path to the image file for which predictions are being processed.

    Notes:
        The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
        converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
        before saving to the JSON dictionary.
    """
    stem = Path(filename).stem
    image_id = int(stem) if stem.isnumeric() else stem
    box = ops.xyxy2xywh(predn[:, :4])  # xywh
    box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
    for p, b in zip(predn.tolist(), box.tolist()):
        self.jdict.append(
            {
                "image_id": image_id,
                "category_id": self.class_map[int(p[5])],
                "bbox": [round(x, 3) for x in b],
                "keypoints": p[6:],
                "score": round(p[4], 5),
            }
        )

preprocess

preprocess(batch)

Preprocess batch by converting keypoints data to float and moving it to the device.

Source code in ultralytics/models/yolo/pose/val.py
def preprocess(self, batch):
    """Preprocess batch by converting keypoints data to float and moving it to the device."""
    batch = super().preprocess(batch)
    batch["keypoints"] = batch["keypoints"].to(self.device).float()
    return batch

save_one_txt

save_one_txt(predn, pred_kpts, save_conf, shape, file)

Save YOLO pose detections to a text file in normalized coordinates.

Parameters:

Name Type Description Default
predn Tensor

Prediction boxes and scores with shape (N, 6) for (x1, y1, x2, y2, conf, cls).

required
pred_kpts Tensor

Predicted keypoints with shape (N, K, D) where K is the number of keypoints and D is the dimension (typically 3 for x, y, visibility).

required
save_conf bool

Whether to save confidence scores.

required
shape tuple

Original image shape (height, width).

required
file Path

Output file path to save detections.

required
Notes

The output format is: class_id x_center y_center width height confidence keypoints where keypoints are normalized (x, y, visibility) values for each point.

Source code in ultralytics/models/yolo/pose/val.py
def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
    """
    Save YOLO pose detections to a text file in normalized coordinates.

    Args:
        predn (torch.Tensor): Prediction boxes and scores with shape (N, 6) for (x1, y1, x2, y2, conf, cls).
        pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
            and D is the dimension (typically 3 for x, y, visibility).
        save_conf (bool): Whether to save confidence scores.
        shape (tuple): Original image shape (height, width).
        file (Path): Output file path to save detections.

    Notes:
        The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
        normalized (x, y, visibility) values for each point.
    """
    from ultralytics.engine.results import Results

    Results(
        np.zeros((shape[0], shape[1]), dtype=np.uint8),
        path=None,
        names=self.names,
        boxes=predn[:, :6],
        keypoints=pred_kpts,
    ).save_txt(file, save_conf=save_conf)

update_metrics

update_metrics(preds, batch)

Update metrics with new predictions and ground truth data.

This method processes each prediction, compares it with ground truth, and updates various statistics for performance evaluation.

Parameters:

Name Type Description Default
preds List[Tensor]

List of prediction tensors from the model.

required
batch dict

Batch data containing images and ground truth annotations.

required
Source code in ultralytics/models/yolo/pose/val.py
def update_metrics(self, preds, batch):
    """
    Update metrics with new predictions and ground truth data.

    This method processes each prediction, compares it with ground truth, and updates various statistics
    for performance evaluation.

    Args:
        preds (List[torch.Tensor]): List of prediction tensors from the model.
        batch (dict): Batch data containing images and ground truth annotations.
    """
    for si, pred in enumerate(preds):
        self.seen += 1
        npr = len(pred)
        stat = dict(
            conf=torch.zeros(0, device=self.device),
            pred_cls=torch.zeros(0, device=self.device),
            tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
            tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
        )
        pbatch = self._prepare_batch(si, batch)
        cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
        nl = len(cls)
        stat["target_cls"] = cls
        stat["target_img"] = cls.unique()
        if npr == 0:
            if nl:
                for k in self.stats.keys():
                    self.stats[k].append(stat[k])
                if self.args.plots:
                    self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
            continue

        # Predictions
        if self.args.single_cls:
            pred[:, 5] = 0
        predn, pred_kpts = self._prepare_pred(pred, pbatch)
        stat["conf"] = predn[:, 4]
        stat["pred_cls"] = predn[:, 5]

        # Evaluate
        if nl:
            stat["tp"] = self._process_batch(predn, bbox, cls)
            stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
        if self.args.plots:
            self.confusion_matrix.process_batch(predn, bbox, cls)

        for k in self.stats.keys():
            self.stats[k].append(stat[k])

        # Save
        if self.args.save_json:
            self.pred_to_json(predn, batch["im_file"][si])
        if self.args.save_txt:
            self.save_one_txt(
                predn,
                pred_kpts,
                self.args.save_conf,
                pbatch["ori_shape"],
                self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
            )



📅 Created 1 year ago ✏️ Updated 7 months ago