Skip to content

Reference for ultralytics/models/yolo/yoloe/predict.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/yoloe/predict.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.yolo.yoloe.predict.YOLOEVPDetectPredictor

YOLOEVPDetectPredictor()

Bases: DetectionPredictor

A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.

This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt handling, and preprocessing transformations.

Attributes

NameTypeDescription
modeltorch.nn.ModuleThe YOLO model for inference.
devicetorch.deviceDevice to run the model on (CPU or CUDA).
promptsdict | torch.TensorVisual prompts containing class indices and bounding boxes or masks.

Methods

NameDescription
_process_single_imageProcess a single image by resizing bounding boxes or masks and generating visuals.
get_vpeProcess the source to get the visual prompt embeddings (VPE).
inferenceRun inference with visual prompts.
pre_transformPreprocess images and prompts before inference.
set_promptsSet the visual prompts for the model.
setup_modelSet up the model for prediction.
Source code in ultralytics/models/yolo/yoloe/predict.pyView on GitHub
class YOLOEVPDetectPredictor(DetectionPredictor):


method ultralytics.models.yolo.yoloe.predict.YOLOEVPDetectPredictor._process_single_image

def _process_single_image(self, dst_shape, src_shape, category, bboxes = None, masks = None)

Process a single image by resizing bounding boxes or masks and generating visuals.

Args

NameTypeDescriptionDefault
dst_shapetupleThe target shape (height, width) of the image.required
src_shapetupleThe original shape (height, width) of the image.required
categorystrThe category of the image for visual prompts.required
bboxeslist | np.ndarray, optionalA list of bounding boxes in the format [x1, y1, x2, y2].None
masksnp.ndarray, optionalA list of masks corresponding to the image.None

Returns

TypeDescription
torch.TensorThe processed visuals for the image.

Raises

TypeDescription
ValueErrorIf neither bboxes nor masks are provided.
Source code in ultralytics/models/yolo/yoloe/predict.pyView on GitHub
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
    """Process a single image by resizing bounding boxes or masks and generating visuals.

    Args:
        dst_shape (tuple): The target shape (height, width) of the image.
        src_shape (tuple): The original shape (height, width) of the image.
        category (str): The category of the image for visual prompts.
        bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2].
        masks (np.ndarray, optional): A list of masks corresponding to the image.

    Returns:
        (torch.Tensor): The processed visuals for the image.

    Raises:
        ValueError: If neither `bboxes` nor `masks` are provided.
    """
    if bboxes is not None and len(bboxes):
        bboxes = np.array(bboxes, dtype=np.float32)
        if bboxes.ndim == 1:
            bboxes = bboxes[None, :]
        # Calculate scaling factor and adjust bounding boxes
        gain = min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])  # gain = old / new
        bboxes *= gain
        bboxes[..., 0::2] += round((dst_shape[1] - src_shape[1] * gain) / 2 - 0.1)
        bboxes[..., 1::2] += round((dst_shape[0] - src_shape[0] * gain) / 2 - 0.1)
    elif masks is not None:
        # Resize and process masks
        resized_masks = super().pre_transform(masks)
        masks = np.stack(resized_masks)  # (N, H, W)
        masks[masks == 114] = 0  # Reset padding values to 0
    else:
        raise ValueError("Please provide valid bboxes or masks")

    # Generate visuals using the visual prompt loader
    return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)


method ultralytics.models.yolo.yoloe.predict.YOLOEVPDetectPredictor.get_vpe

def get_vpe(self, source)

Process the source to get the visual prompt embeddings (VPE).

Args

NameTypeDescriptionDefault
sourcestr | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tupleThe source of the image to make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and torch tensors.required

Returns

TypeDescription
torch.TensorThe visual prompt embeddings (VPE) from the model.
Source code in ultralytics/models/yolo/yoloe/predict.pyView on GitHub
def get_vpe(self, source):
    """Process the source to get the visual prompt embeddings (VPE).

    Args:
        source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image to
            make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
            torch tensors.

    Returns:
        (torch.Tensor): The visual prompt embeddings (VPE) from the model.
    """
    self.setup_source(source)
    assert len(self.dataset) == 1, "get_vpe only supports one image!"
    for _, im0s, _ in self.dataset:
        im = self.preprocess(im0s)
        return self.model(im, vpe=self.prompts, return_vpe=True)


method ultralytics.models.yolo.yoloe.predict.YOLOEVPDetectPredictor.inference

def inference(self, im, *args, **kwargs)

Run inference with visual prompts.

Args

NameTypeDescriptionDefault
imtorch.TensorInput image tensor.required
*argsAnyVariable length argument list.required
**kwargsAnyArbitrary keyword arguments.required

Returns

TypeDescription
torch.TensorModel prediction results.
Source code in ultralytics/models/yolo/yoloe/predict.pyView on GitHub
def inference(self, im, *args, **kwargs):
    """Run inference with visual prompts.

    Args:
        im (torch.Tensor): Input image tensor.
        *args (Any): Variable length argument list.
        **kwargs (Any): Arbitrary keyword arguments.

    Returns:
        (torch.Tensor): Model prediction results.
    """
    return super().inference(im, vpe=self.prompts, *args, **kwargs)


method ultralytics.models.yolo.yoloe.predict.YOLOEVPDetectPredictor.pre_transform

def pre_transform(self, im)

Preprocess images and prompts before inference.

This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks) accordingly.

Args

NameTypeDescriptionDefault
imlistList containing a single input image.required

Returns

TypeDescription
listPreprocessed image ready for model inference.

Raises

TypeDescription
ValueErrorIf neither valid bounding boxes nor masks are provided in the prompts.
Source code in ultralytics/models/yolo/yoloe/predict.pyView on GitHub
def pre_transform(self, im):
    """Preprocess images and prompts before inference.

    This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks)
    accordingly.

    Args:
        im (list): List containing a single input image.

    Returns:
        (list): Preprocessed image ready for model inference.

    Raises:
        ValueError: If neither valid bounding boxes nor masks are provided in the prompts.
    """
    img = super().pre_transform(im)
    bboxes = self.prompts.pop("bboxes", None)
    masks = self.prompts.pop("masks", None)
    category = self.prompts["cls"]
    if len(img) == 1:
        visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
        prompts = visuals.unsqueeze(0).to(self.device)  # (1, N, H, W)
    else:
        # NOTE: only supports bboxes as prompts for now
        assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
        # NOTE: needs list[np.ndarray]
        assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (
            f"Expected list[np.ndarray], but got {bboxes}!"
        )
        assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (
            f"Expected list[np.ndarray], but got {category}!"
        )
        assert len(im) == len(category) == len(bboxes), (
            f"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!"
        )
        visuals = [
            self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
            for i in range(len(img))
        ]
        prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device)  # (B, N, H, W)
    self.prompts = prompts.half() if self.model.fp16 else prompts.float()
    return img


method ultralytics.models.yolo.yoloe.predict.YOLOEVPDetectPredictor.set_prompts

def set_prompts(self, prompts)

Set the visual prompts for the model.

Args

NameTypeDescriptionDefault
promptsdictDictionary containing class indices and bounding boxes or masks. Must include a 'cls' key with class indices.required
Source code in ultralytics/models/yolo/yoloe/predict.pyView on GitHub
def set_prompts(self, prompts):
    """Set the visual prompts for the model.

    Args:
        prompts (dict): Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key
            with class indices.
    """
    self.prompts = prompts


method ultralytics.models.yolo.yoloe.predict.YOLOEVPDetectPredictor.setup_model

def setup_model(self, model, verbose: bool = True)

Set up the model for prediction.

Args

NameTypeDescriptionDefault
modeltorch.nn.ModuleModel to load or use.required
verbosebool, optionalIf True, provides detailed logging.True
Source code in ultralytics/models/yolo/yoloe/predict.pyView on GitHub
def setup_model(self, model, verbose: bool = True):
    """Set up the model for prediction.

    Args:
        model (torch.nn.Module): Model to load or use.
        verbose (bool, optional): If True, provides detailed logging.
    """
    super().setup_model(model, verbose=verbose)
    self.done_warmup = True





class ultralytics.models.yolo.yoloe.predict.YOLOEVPSegPredictor

YOLOEVPSegPredictor()

Bases: YOLOEVPDetectPredictor, SegmentationPredictor

Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities.

Source code in ultralytics/models/yolo/yoloe/predict.pyView on GitHub
class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):





📅 Created 8 months ago ✏️ Updated 9 days ago
glenn-jocherRizwanMunawar