Skip to content

Reference for ultralytics/models/yolo/obb/val.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/obb/val.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.obb.val.OBBValidator

OBBValidator(
    dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None
)

Bases: DetectionValidator

A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.

This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and satellite imagery where objects can appear at various orientations.

Attributes:

Name Type Description
args dict

Configuration arguments for the validator.

metrics OBBMetrics

Metrics object for evaluating OBB model performance.

is_dota bool

Flag indicating whether the validation dataset is in DOTA format.

Methods:

Name Description
init_metrics

Initialize evaluation metrics for YOLO.

_process_batch

Process batch of detections and ground truth boxes to compute IoU matrix.

_prepare_batch

Prepare batch data for OBB validation.

_prepare_pred

Prepare predictions with scaled and padded bounding boxes.

plot_predictions

Plot predicted bounding boxes on input images.

pred_to_json

Serialize YOLO predictions to COCO json format.

save_one_txt

Save YOLO detections to a txt file in normalized coordinates.

eval_json

Evaluate YOLO output in JSON format and return performance statistics.

Examples:

>>> from ultralytics.models.yolo.obb import OBBValidator
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
>>> validator = OBBValidator(args=args)
>>> validator(model=args["model"])
Source code in ultralytics/models/yolo/obb/val.py
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
    """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
    super().__init__(dataloader, save_dir, pbar, args, _callbacks)
    self.args.task = "obb"
    self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)

eval_json

eval_json(stats)

Evaluate YOLO output in JSON format and save predictions in DOTA format.

Source code in ultralytics/models/yolo/obb/val.py
def eval_json(self, stats):
    """Evaluate YOLO output in JSON format and save predictions in DOTA format."""
    if self.args.save_json and self.is_dota and len(self.jdict):
        import json
        import re
        from collections import defaultdict

        pred_json = self.save_dir / "predictions.json"  # predictions
        pred_txt = self.save_dir / "predictions_txt"  # predictions
        pred_txt.mkdir(parents=True, exist_ok=True)
        data = json.load(open(pred_json))
        # Save split results
        LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
        for d in data:
            image_id = d["image_id"]
            score = d["score"]
            classname = self.names[d["category_id"] - 1].replace(" ", "-")
            p = d["poly"]

            with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
                f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
        # Save merged results, this could result slightly lower map than using official merging script,
        # because of the probiou calculation.
        pred_merged_txt = self.save_dir / "predictions_merged_txt"  # predictions
        pred_merged_txt.mkdir(parents=True, exist_ok=True)
        merged_results = defaultdict(list)
        LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
        for d in data:
            image_id = d["image_id"].split("__")[0]
            pattern = re.compile(r"\d+___\d+")
            x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
            bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
            bbox[0] += x
            bbox[1] += y
            bbox.extend([score, cls])
            merged_results[image_id].append(bbox)
        for image_id, bbox in merged_results.items():
            bbox = torch.tensor(bbox)
            max_wh = torch.max(bbox[:, :2]).item() * 2
            c = bbox[:, 6:7] * max_wh  # classes
            scores = bbox[:, 5]  # scores
            b = bbox[:, :5].clone()
            b[:, :2] += c
            # 0.3 could get results close to the ones from official merging script, even slightly better.
            i = ops.nms_rotated(b, scores, 0.3)
            bbox = bbox[i]

            b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
            for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
                classname = self.names[int(x[-1])].replace(" ", "-")
                p = [round(i, 3) for i in x[:-2]]  # poly
                score = round(x[-2], 3)

                with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
                    f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")

    return stats

init_metrics

init_metrics(model)

Initialize evaluation metrics for YOLO.

Source code in ultralytics/models/yolo/obb/val.py
def init_metrics(self, model):
    """Initialize evaluation metrics for YOLO."""
    super().init_metrics(model)
    val = self.data.get(self.args.split, "")  # validation path
    self.is_dota = isinstance(val, str) and "DOTA" in val  # check if dataset is DOTA format

plot_predictions

plot_predictions(batch, preds, ni)

Plot predicted bounding boxes on input images and save the result.

Source code in ultralytics/models/yolo/obb/val.py
def plot_predictions(self, batch, preds, ni):
    """Plot predicted bounding boxes on input images and save the result."""
    plot_images(
        batch["img"],
        *output_to_rotated_target(preds, max_det=self.args.max_det),
        paths=batch["im_file"],
        fname=self.save_dir / f"val_batch{ni}_pred.jpg",
        names=self.names,
        on_plot=self.on_plot,
    )  # pred

pred_to_json

pred_to_json(predn, filename)

Convert YOLO predictions to COCO JSON format with rotated bounding box information.

Source code in ultralytics/models/yolo/obb/val.py
def pred_to_json(self, predn, filename):
    """Convert YOLO predictions to COCO JSON format with rotated bounding box information."""
    stem = Path(filename).stem
    image_id = int(stem) if stem.isnumeric() else stem
    rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
    poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
    for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
        self.jdict.append(
            {
                "image_id": image_id,
                "category_id": self.class_map[int(predn[i, 5].item())],
                "score": round(predn[i, 4].item(), 5),
                "rbox": [round(x, 3) for x in r],
                "poly": [round(x, 3) for x in b],
            }
        )

save_one_txt

save_one_txt(predn, save_conf, shape, file)

Save YOLO detections to a txt file in normalized coordinates using the Results class.

Source code in ultralytics/models/yolo/obb/val.py
def save_one_txt(self, predn, save_conf, shape, file):
    """Save YOLO detections to a txt file in normalized coordinates using the Results class."""
    import numpy as np

    from ultralytics.engine.results import Results

    rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
    # xywh, r, conf, cls
    obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
    Results(
        np.zeros((shape[0], shape[1]), dtype=np.uint8),
        path=None,
        names=self.names,
        obb=obb,
    ).save_txt(file, save_conf=save_conf)



📅 Created 1 year ago ✏️ Updated 6 months ago