Skip to content

Reference for ultralytics/models/yolo/yoloe/val.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/yoloe/val.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.yoloe.val.YOLOEDetectValidator

YOLOEDetectValidator(
    dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None
)

Bases: DetectionValidator

A mixin class for YOLOE model validation that handles both text and visual prompt embeddings.

This mixin provides functionality to validate YOLOE models using either text or visual prompt embeddings. It includes methods for extracting visual prompt embeddings from samples, preprocessing batches, and running validation with different prompt types.

Attributes:

Name Type Description
device device

The device on which validation is performed.

args namespace

Configuration arguments for validation.

dataloader DataLoader

DataLoader for validation data.

Source code in ultralytics/models/yolo/detect/val.py
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
    """
    Initialize detection validator with necessary variables and settings.

    Args:
        dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
        save_dir (Path, optional): Directory to save results.
        pbar (Any, optional): Progress bar for displaying progress.
        args (dict, optional): Arguments for the validator.
        _callbacks (list, optional): List of callback functions.
    """
    super().__init__(dataloader, save_dir, pbar, args, _callbacks)
    self.nt_per_class = None
    self.nt_per_image = None
    self.is_coco = False
    self.is_lvis = False
    self.class_map = None
    self.args.task = "detect"
    self.metrics = DetMetrics(save_dir=self.save_dir)
    self.iouv = torch.linspace(0.5, 0.95, 10)  # IoU vector for mAP@0.5:0.95
    self.niou = self.iouv.numel()

__call__

__call__(trainer=None, model=None, refer_data=None, load_vp=False)

Run validation on the model using either text or visual prompt embeddings.

This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It supports validation during training (using a trainer object) or standalone validation with a provided model.

Parameters:

Name Type Description Default
trainer object

Trainer object containing the model and device.

None
model YOLOEModel

Model to validate. Required if trainer is not provided.

None
refer_data str

Path to reference data for visual prompts.

None
load_vp bool

Whether to load visual prompts. If False, text prompts are used.

False

Returns:

Type Description
dict

Validation statistics containing metrics computed during validation.

Source code in ultralytics/models/yolo/yoloe/val.py
@smart_inference_mode()
def __call__(self, trainer=None, model=None, refer_data=None, load_vp=False):
    """
    Run validation on the model using either text or visual prompt embeddings.

    This method validates the model using either text prompts or visual prompts, depending
    on the `load_vp` flag. It supports validation during training (using a trainer object)
    or standalone validation with a provided model.

    Args:
        trainer (object, optional): Trainer object containing the model and device.
        model (YOLOEModel, optional): Model to validate. Required if `trainer` is not provided.
        refer_data (str, optional): Path to reference data for visual prompts.
        load_vp (bool): Whether to load visual prompts. If False, text prompts are used.

    Returns:
        (dict): Validation statistics containing metrics computed during validation.
    """
    if trainer is not None:
        self.device = trainer.device
        model = trainer.ema.ema
        names = [name.split("/")[0] for name in list(self.dataloader.dataset.data["names"].values())]

        if load_vp:
            LOGGER.info("Validate using the visual prompt.")
            self.args.half = False
            # Directly use the same dataloader for visual embeddings extracted during training
            vpe = self.get_visual_pe(self.dataloader, model)
            model.set_classes(names, vpe)
        else:
            LOGGER.info("Validate using the text prompt.")
            tpe = model.get_text_pe(names)
            model.set_classes(names, tpe)
        stats = super().__call__(trainer, model)
    else:
        if refer_data is not None:
            assert load_vp, "Refer data is only used for visual prompt validation."
        self.device = select_device(self.args.device)

        if isinstance(model, str):
            from ultralytics.nn.tasks import attempt_load_weights

            model = attempt_load_weights(model, device=self.device, inplace=True)
        model.eval().to(self.device)
        data = check_det_dataset(refer_data or self.args.data)
        names = [name.split("/")[0] for name in list(data["names"].values())]

        if load_vp:
            LOGGER.info("Validate using the visual prompt.")
            self.args.half = False
            # TODO: need to check if the names from refer data is consistent with the evaluated dataset
            # could use same dataset or refer to extract visual prompt embeddings
            dataloader = self.get_vpe_dataloader(data)
            vpe = self.get_visual_pe(dataloader, model)
            model.set_classes(names, vpe)
            stats = super().__call__(model=deepcopy(model))
        elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], "lrpc"):  # prompt-free
            return super().__call__(trainer, model)
        else:
            LOGGER.info("Validate using the text prompt.")
            tpe = model.get_text_pe(names)
            model.set_classes(names, tpe)
            stats = super().__call__(model=deepcopy(model))
    return stats

get_visual_pe

get_visual_pe(dataloader, model)

Extract visual prompt embeddings from training samples.

This function processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It normalizes the embeddings and handles cases where no samples exist for a class.

Parameters:

Name Type Description Default
dataloader DataLoader

The dataloader providing training samples.

required
model YOLOEModel

The YOLOE model from which to extract visual prompt embeddings.

required

Returns:

Type Description
Tensor

Visual prompt embeddings with shape (1, num_classes, embed_dim).

Source code in ultralytics/models/yolo/yoloe/val.py
@smart_inference_mode()
def get_visual_pe(self, dataloader, model):
    """
    Extract visual prompt embeddings from training samples.

    This function processes a dataloader to compute visual prompt embeddings for each class
    using a YOLOE model. It normalizes the embeddings and handles cases where no samples
    exist for a class.

    Args:
        dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
        model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.

    Returns:
        (torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
    """
    assert isinstance(model, YOLOEModel)
    names = [name.split("/")[0] for name in list(dataloader.dataset.data["names"].values())]
    visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
    cls_visual_num = torch.zeros(len(names))

    desc = "Get visual prompt embeddings from samples"

    for batch in dataloader:
        cls = batch["cls"].squeeze(-1).to(torch.int).unique()
        count = torch.bincount(cls, minlength=len(names))
        cls_visual_num += count

    cls_visual_num = cls_visual_num.to(self.device)

    pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
    for batch in pbar:
        batch = self.preprocess(batch)
        preds = model.get_visual_pe(batch["img"], visual=batch["visuals"])  # (B, max_n, embed_dim)

        batch_idx = batch["batch_idx"]
        for i in range(preds.shape[0]):
            cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
            pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
            pad_cls[: len(cls)] = cls
            for c in cls:
                visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]

    visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
    visual_pe[cls_visual_num == 0] = 0
    return visual_pe.unsqueeze(0)

get_vpe_dataloader

get_vpe_dataloader(data)

Create a dataloader for LVIS training visual prompt samples.

This function prepares a dataloader for visual prompt embeddings (VPE) using the LVIS dataset. It applies necessary transformations and configurations to the dataset and returns a dataloader for validation purposes.

Parameters:

Name Type Description Default
data dict

Dataset configuration dictionary containing paths and settings.

required

Returns:

Type Description
DataLoader

The dataLoader for visual prompt samples.

Source code in ultralytics/models/yolo/yoloe/val.py
def get_vpe_dataloader(self, data):
    """
    Create a dataloader for LVIS training visual prompt samples.

    This function prepares a dataloader for visual prompt embeddings (VPE) using the LVIS dataset.
    It applies necessary transformations and configurations to the dataset and returns a dataloader
    for validation purposes.

    Args:
        data (dict): Dataset configuration dictionary containing paths and settings.

    Returns:
        (torch.utils.data.DataLoader): The dataLoader for visual prompt samples.
    """
    dataset = build_yolo_dataset(
        self.args,
        data.get(self.args.split, data.get("val")),
        self.args.batch,
        data,
        mode="val",
        rect=False,
    )
    if isinstance(dataset, YOLOConcatDataset):
        for d in dataset.datasets:
            d.transforms.append(LoadVisualPrompt())
    else:
        dataset.transforms.append(LoadVisualPrompt())
    return build_dataloader(
        dataset,
        self.args.batch,
        self.args.workers,
        shuffle=False,
        rank=-1,
    )

preprocess

preprocess(batch)

Preprocess batch data, ensuring visuals are on the same device as images.

Source code in ultralytics/models/yolo/yoloe/val.py
def preprocess(self, batch):
    """Preprocess batch data, ensuring visuals are on the same device as images."""
    batch = super().preprocess(batch)
    if "visuals" in batch:
        batch["visuals"] = batch["visuals"].to(batch["img"].device)
    return batch





ultralytics.models.yolo.yoloe.val.YOLOESegValidator

YOLOESegValidator(
    dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None
)

Bases: YOLOEDetectValidator, SegmentationValidator

YOLOE segmentation validator that supports both text and visual prompt embeddings.

Source code in ultralytics/models/yolo/segment/val.py
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
    """
    Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.

    Args:
        dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
        save_dir (Path, optional): Directory to save results.
        pbar (Any, optional): Progress bar for displaying progress.
        args (namespace, optional): Arguments for the validator.
        _callbacks (list, optional): List of callback functions.
    """
    super().__init__(dataloader, save_dir, pbar, args, _callbacks)
    self.plot_masks = None
    self.process = None
    self.args.task = "segment"
    self.metrics = SegmentMetrics(save_dir=self.save_dir)



📅 Created 13 days ago ✏️ Updated 13 days ago