Skip to content

Reference for ultralytics/models/fastsam/predict.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/predict.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.fastsam.predict.FastSAMPredictor

FastSAMPredictor(self, cfg = DEFAULT_CFG, overrides = None, _callbacks = None)

Bases: SegmentationPredictor

FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.

This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for single-class segmentation.

This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression optimized for single-class segmentation.

Args

NameTypeDescriptionDefault
cfgdictConfiguration for the predictor.DEFAULT_CFG
overridesdict, optionalConfiguration overrides.None
_callbackslist, optionalList of callback functions.None

Attributes

NameTypeDescription
promptsdictDictionary containing prompt information for segmentation (bboxes, points, labels, texts).
devicetorch.deviceDevice on which model and tensors are processed.
clip_modelAny, optionalCLIP model for text-based prompting, loaded on demand.
clip_preprocessAny, optionalCLIP preprocessing function for images, loaded on demand.

Methods

NameDescription
_clip_inferencePerform CLIP inference to calculate similarity between images and text prompts.
postprocessApply postprocessing to FastSAM predictions and handle prompts.
promptPerform image segmentation inference based on cues like bounding boxes, points, and text prompts.
set_promptsSet prompts to be used during inference.
Source code in ultralytics/models/fastsam/predict.pyView on GitHub
class FastSAMPredictor(SegmentationPredictor):
    """FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.

    This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
    adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
    single-class segmentation.

    Attributes:
        prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
        device (torch.device): Device on which model and tensors are processed.
        clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.
        clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.

    Methods:
        postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
        prompt: Perform image segmentation inference based on various prompt types.
        set_prompts: Set prompts to be used during inference.
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize the FastSAMPredictor with configuration and callbacks.

        This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
        extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
        optimized for single-class segmentation.

        Args:
            cfg (dict): Configuration for the predictor.
            overrides (dict, optional): Configuration overrides.
            _callbacks (list, optional): List of callback functions.
        """
        super().__init__(cfg, overrides, _callbacks)
        self.prompts = {}


method ultralytics.models.fastsam.predict.FastSAMPredictor._clip_inference

def _clip_inference(self, images, texts)

Perform CLIP inference to calculate similarity between images and text prompts.

Args

NameTypeDescriptionDefault
imageslist[PIL.Image]List of source images, each should be PIL.Image with RGB channel order.required
textslist[str]List of prompt texts, each should be a string object.required

Returns

TypeDescription
torch.TensorSimilarity matrix between given images and texts with shape (M, N).
Source code in ultralytics/models/fastsam/predict.pyView on GitHub
def _clip_inference(self, images, texts):
    """Perform CLIP inference to calculate similarity between images and text prompts.

    Args:
        images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
        texts (list[str]): List of prompt texts, each should be a string object.

    Returns:
        (torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
    """
    from ultralytics.nn.text_model import CLIP

    if not hasattr(self, "clip"):
        self.clip = CLIP("ViT-B/32", device=self.device)
    images = torch.stack([self.clip.image_preprocess(image).to(self.device) for image in images])
    image_features = self.clip.encode_image(images)
    text_features = self.clip.encode_text(self.clip.tokenize(texts))
    return text_features @ image_features.T  # (M, N)


method ultralytics.models.fastsam.predict.FastSAMPredictor.postprocess

def postprocess(self, preds, img, orig_imgs)

Apply postprocessing to FastSAM predictions and handle prompts.

Args

NameTypeDescriptionDefault
predslist[torch.Tensor]Raw predictions from the model.required
imgtorch.TensorInput image tensor that was fed to the model.required
orig_imgslist[np.ndarray]Original images before preprocessing.required

Returns

TypeDescription
list[Results]Processed results with prompts applied.
Source code in ultralytics/models/fastsam/predict.pyView on GitHub
def postprocess(self, preds, img, orig_imgs):
    """Apply postprocessing to FastSAM predictions and handle prompts.

    Args:
        preds (list[torch.Tensor]): Raw predictions from the model.
        img (torch.Tensor): Input image tensor that was fed to the model.
        orig_imgs (list[np.ndarray]): Original images before preprocessing.

    Returns:
        (list[Results]): Processed results with prompts applied.
    """
    bboxes = self.prompts.pop("bboxes", None)
    points = self.prompts.pop("points", None)
    labels = self.prompts.pop("labels", None)
    texts = self.prompts.pop("texts", None)
    results = super().postprocess(preds, img, orig_imgs)
    for result in results:
        full_box = torch.tensor(
            [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
        )
        boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
        idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
        if idx.numel() != 0:
            result.boxes.xyxy[idx] = full_box

    return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)


method ultralytics.models.fastsam.predict.FastSAMPredictor.prompt

def prompt(self, results, bboxes = None, points = None, labels = None, texts = None)

Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.

Args

NameTypeDescriptionDefault
resultsResults | list[Results]Original inference results from FastSAM models without any prompts.required
bboxesnp.ndarray | list, optionalBounding boxes with shape (N, 4), in XYXY format.None
pointsnp.ndarray | list, optionalPoints indicating object locations with shape (N, 2), in pixels.None
labelsnp.ndarray | list, optionalLabels for point prompts, shape (N, ). 1 = foreground, 0 = background.None
textsstr | list[str], optionalTextual prompts, a list containing string objects.None

Returns

TypeDescription
list[Results]Output results filtered and determined by the provided prompts.
Source code in ultralytics/models/fastsam/predict.pyView on GitHub
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
    """Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.

    Args:
        results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
        bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
        points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
        labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
        texts (str | list[str], optional): Textual prompts, a list containing string objects.

    Returns:
        (list[Results]): Output results filtered and determined by the provided prompts.
    """
    if bboxes is None and points is None and texts is None:
        return results
    prompt_results = []
    if not isinstance(results, list):
        results = [results]
    for result in results:
        if len(result) == 0:
            prompt_results.append(result)
            continue
        masks = result.masks.data
        if masks.shape[1:] != result.orig_shape:
            masks = (scale_masks(masks[None].float(), result.orig_shape)[0] > 0.5).byte()
        # bboxes prompt
        idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
        if bboxes is not None:
            bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
            bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
            bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
            mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
            full_mask_areas = torch.sum(masks, dim=(1, 2))

            union = bbox_areas[:, None] + full_mask_areas - mask_areas
            idx[torch.argmax(mask_areas / union, dim=1)] = True
        if points is not None:
            points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
            points = points[None] if points.ndim == 1 else points
            if labels is None:
                labels = torch.ones(points.shape[0])
            labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
            assert len(labels) == len(points), (
                f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
            )
            point_idx = (
                torch.ones(len(result), dtype=torch.bool, device=self.device)
                if labels.sum() == 0  # all negative points
                else torch.zeros(len(result), dtype=torch.bool, device=self.device)
            )
            for point, label in zip(points, labels):
                point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
            idx |= point_idx
        if texts is not None:
            if isinstance(texts, str):
                texts = [texts]
            crop_ims, filter_idx = [], []
            for i, b in enumerate(result.boxes.xyxy.tolist()):
                x1, y1, x2, y2 = (int(x) for x in b)
                if (masks[i].sum() if TORCH_1_10 else masks[i].sum(0).sum()) <= 100:  # torch 1.9 bug workaround
                    filter_idx.append(i)
                    continue
                crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
            similarity = self._clip_inference(crop_ims, texts)
            text_idx = torch.argmax(similarity, dim=-1)  # (M, )
            if len(filter_idx):
                text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
            idx[text_idx] = True

        prompt_results.append(result[idx])

    return prompt_results


method ultralytics.models.fastsam.predict.FastSAMPredictor.set_prompts

def set_prompts(self, prompts)

Set prompts to be used during inference.

Args

NameTypeDescriptionDefault
promptsrequired
Source code in ultralytics/models/fastsam/predict.pyView on GitHub
def set_prompts(self, prompts):
    """Set prompts to be used during inference."""
    self.prompts = prompts





📅 Created 2 years ago ✏️ Updated 2 days ago
glenn-jocherjk4eBurhan-Q