Skip to content

Reference for ultralytics/models/yolo/yoloe/val.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/yoloe/val.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.yolo.yoloe.val.YOLOEDetectValidator

YOLOEDetectValidator()

Bases: DetectionValidator

A validator class for YOLOE detection models that handles both text and visual prompt embeddings.

This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible evaluation strategies for prompt-based object detection.

Attributes

NameTypeDescription
devicetorch.deviceThe device on which validation is performed.
argsnamespaceConfiguration arguments for validation.
dataloaderDataLoaderDataLoader for validation data.

Methods

NameDescription
__call__Run validation on the model using either text or visual prompt embeddings.
get_visual_peExtract visual prompt embeddings from training samples.
get_vpe_dataloaderCreate a dataloader for LVIS training visual prompt samples.

Examples

Validate with text prompts
>>> validator = YOLOEDetectValidator()
>>> stats = validator(model=model, load_vp=False)

Validate with visual prompts
>>> stats = validator(model=model, refer_data="path/to/data.yaml", load_vp=True)
Source code in ultralytics/models/yolo/yoloe/val.pyView on GitHub
class YOLOEDetectValidator(DetectionValidator):


method ultralytics.models.yolo.yoloe.val.YOLOEDetectValidator.__call__

def __call__(
    self,
    trainer: Any | None = None,
    model: YOLOEModel | str | None = None,
    refer_data: str | None = None,
    load_vp: bool = False,
) -> dict[str, Any]

Run validation on the model using either text or visual prompt embeddings.

This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It supports validation during training (using a trainer object) or standalone validation with a provided model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.

Args

NameTypeDescriptionDefault
trainerobject, optionalTrainer object containing the model and device.None
modelYOLOEModel | str, optionalModel to validate. Required if trainer is not provided.None
refer_datastr, optionalPath to reference data for visual prompts.None
load_vpboolWhether to load visual prompts. If False, text prompts are used.False

Returns

TypeDescription
dictValidation statistics containing metrics computed during validation.
Source code in ultralytics/models/yolo/yoloe/val.pyView on GitHub
@smart_inference_mode()
def __call__(
    self,
    trainer: Any | None = None,
    model: YOLOEModel | str | None = None,
    refer_data: str | None = None,
    load_vp: bool = False,
) -> dict[str, Any]:
    """Run validation on the model using either text or visual prompt embeddings.

    This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It
    supports validation during training (using a trainer object) or standalone validation with a provided model. For
    visual prompts, reference data can be specified to extract embeddings from a different dataset.

    Args:
        trainer (object, optional): Trainer object containing the model and device.
        model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.
        refer_data (str, optional): Path to reference data for visual prompts.
        load_vp (bool): Whether to load visual prompts. If False, text prompts are used.

    Returns:
        (dict): Validation statistics containing metrics computed during validation.
    """
    if trainer is not None:
        self.device = trainer.device
        model = trainer.ema.ema
        names = [name.split("/", 1)[0] for name in list(self.dataloader.dataset.data["names"].values())]

        if load_vp:
            LOGGER.info("Validate using the visual prompt.")
            self.args.half = False
            # Directly use the same dataloader for visual embeddings extracted during training
            vpe = self.get_visual_pe(self.dataloader, model)
            model.set_classes(names, vpe)
        else:
            LOGGER.info("Validate using the text prompt.")
            tpe = model.get_text_pe(names)
            model.set_classes(names, tpe)
        stats = super().__call__(trainer, model)
    else:
        if refer_data is not None:
            assert load_vp, "Refer data is only used for visual prompt validation."
        self.device = select_device(self.args.device, verbose=False)

        if isinstance(model, (str, Path)):
            from ultralytics.nn.tasks import load_checkpoint

            model, _ = load_checkpoint(model, device=self.device)  # model, ckpt
        model.eval().to(self.device)
        data = check_det_dataset(refer_data or self.args.data)
        names = [name.split("/", 1)[0] for name in list(data["names"].values())]

        if load_vp:
            LOGGER.info("Validate using the visual prompt.")
            self.args.half = False
            # TODO: need to check if the names from refer data is consistent with the evaluated dataset
            # could use same dataset or refer to extract visual prompt embeddings
            dataloader = self.get_vpe_dataloader(data)
            vpe = self.get_visual_pe(dataloader, model)
            model.set_classes(names, vpe)
            stats = super().__call__(model=deepcopy(model))
        elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], "lrpc"):  # prompt-free
            return super().__call__(trainer, model)
        else:
            LOGGER.info("Validate using the text prompt.")
            tpe = model.get_text_pe(names)
            model.set_classes(names, tpe)
            stats = super().__call__(model=deepcopy(model))
    return stats


method ultralytics.models.yolo.yoloe.val.YOLOEDetectValidator.get_visual_pe

def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor

Extract visual prompt embeddings from training samples.

This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to zero.

Args

NameTypeDescriptionDefault
dataloadertorch.utils.data.DataLoaderThe dataloader providing training samples.required
modelYOLOEModelThe YOLOE model from which to extract visual prompt embeddings.required

Returns

TypeDescription
torch.TensorVisual prompt embeddings with shape (1, num_classes, embed_dim).
Source code in ultralytics/models/yolo/yoloe/val.pyView on GitHub
@smart_inference_mode()
def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
    """Extract visual prompt embeddings from training samples.

    This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It
    normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to
    zero.

    Args:
        dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
        model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.

    Returns:
        (torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
    """
    assert isinstance(model, YOLOEModel)
    names = [name.split("/", 1)[0] for name in list(dataloader.dataset.data["names"].values())]
    visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
    cls_visual_num = torch.zeros(len(names))

    desc = "Get visual prompt embeddings from samples"

    # Count samples per class
    for batch in dataloader:
        cls = batch["cls"].squeeze(-1).to(torch.int).unique()
        count = torch.bincount(cls, minlength=len(names))
        cls_visual_num += count

    cls_visual_num = cls_visual_num.to(self.device)

    # Extract visual prompt embeddings
    pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
    for batch in pbar:
        batch = self.preprocess(batch)
        preds = model.get_visual_pe(batch["img"], visual=batch["visuals"])  # (B, max_n, embed_dim)

        batch_idx = batch["batch_idx"]
        for i in range(preds.shape[0]):
            cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
            pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
            pad_cls[: cls.shape[0]] = cls
            for c in cls:
                visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]

    # Normalize embeddings for classes with samples, set others to zero
    visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
    visual_pe[cls_visual_num == 0] = 0
    return visual_pe.unsqueeze(0)


method ultralytics.models.yolo.yoloe.val.YOLOEDetectValidator.get_vpe_dataloader

def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader

Create a dataloader for LVIS training visual prompt samples.

This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.

Args

NameTypeDescriptionDefault
datadictDataset configuration dictionary containing paths and settings.required

Returns

TypeDescription
torch.utils.data.DataLoaderThe dataloader for visual prompt samples.
Source code in ultralytics/models/yolo/yoloe/val.pyView on GitHub
def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
    """Create a dataloader for LVIS training visual prompt samples.

    This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies
    necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.

    Args:
        data (dict): Dataset configuration dictionary containing paths and settings.

    Returns:
        (torch.utils.data.DataLoader): The dataloader for visual prompt samples.
    """
    dataset = build_yolo_dataset(
        self.args,
        data.get(self.args.split, data.get("val")),
        self.args.batch,
        data,
        mode="val",
        rect=False,
    )
    if isinstance(dataset, YOLOConcatDataset):
        for d in dataset.datasets:
            d.transforms.append(LoadVisualPrompt())
    else:
        dataset.transforms.append(LoadVisualPrompt())
    return build_dataloader(
        dataset,
        self.args.batch,
        self.args.workers,
        shuffle=False,
        rank=-1,
    )





class ultralytics.models.yolo.yoloe.val.YOLOESegValidator

YOLOESegValidator()

Bases: YOLOEDetectValidator, SegmentationValidator

YOLOE segmentation validator that supports both text and visual prompt embeddings.

Source code in ultralytics/models/yolo/yoloe/val.pyView on GitHub
class YOLOESegValidator(YOLOEDetectValidator, SegmentationValidator):





📅 Created 8 months ago ✏️ Updated 2 days ago
glenn-jocherRizwanMunawar