Skip to content

Reference for ultralytics/models/fastsam/predict.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.fastsam.predict.FastSAMPredictor

FastSAMPredictor(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: SegmentationPredictor

FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics YOLO framework.

This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single- class segmentation.

Source code in ultralytics/models/fastsam/predict.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initializes a FastSAMPredictor for fast SAM segmentation tasks in Ultralytics YOLO framework."""
    super().__init__(cfg, overrides, _callbacks)
    self.prompts = {}

postprocess

postprocess(preds, img, orig_imgs)

Applies box postprocess for FastSAM predictions.

Source code in ultralytics/models/fastsam/predict.py
def postprocess(self, preds, img, orig_imgs):
    """Applies box postprocess for FastSAM predictions."""
    bboxes = self.prompts.pop("bboxes", None)
    points = self.prompts.pop("points", None)
    labels = self.prompts.pop("labels", None)
    texts = self.prompts.pop("texts", None)
    results = super().postprocess(preds, img, orig_imgs)
    for result in results:
        full_box = torch.tensor(
            [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
        )
        boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
        idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
        if idx.numel() != 0:
            result.boxes.xyxy[idx] = full_box

    return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)

prompt

prompt(results, bboxes=None, points=None, labels=None, texts=None)

Internal function for image segmentation inference based on cues like bounding boxes, points, and masks. Leverages SAM's specialized architecture for prompt-based, real-time segmentation.

Parameters:

Name Type Description Default
results Results | List[Results]

The original inference results from FastSAM models without any prompts.

required
bboxes ndarray | List

Bounding boxes with shape (N, 4), in XYXY format.

None
points ndarray | List

Points indicating object locations with shape (N, 2), in pixels.

None
labels ndarray | List

Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.

None
texts str | List[str]

Textual prompts, a list contains string objects.

None

Returns:

Type Description
List[Results]

The output results determined by prompts.

Source code in ultralytics/models/fastsam/predict.py
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
    """
    Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
    Leverages SAM's specialized architecture for prompt-based, real-time segmentation.

    Args:
        results (Results | List[Results]): The original inference results from FastSAM models without any prompts.
        bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
        points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
        labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
        texts (str | List[str], optional): Textual prompts, a list contains string objects.

    Returns:
        (List[Results]): The output results determined by prompts.
    """
    if bboxes is None and points is None and texts is None:
        return results
    prompt_results = []
    if not isinstance(results, list):
        results = [results]
    for result in results:
        if len(result) == 0:
            prompt_results.append(result)
            continue
        masks = result.masks.data
        if masks.shape[1:] != result.orig_shape:
            masks = scale_masks(masks[None], result.orig_shape)[0]
        # bboxes prompt
        idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
        if bboxes is not None:
            bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
            bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
            bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
            mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
            full_mask_areas = torch.sum(masks, dim=(1, 2))

            union = bbox_areas[:, None] + full_mask_areas - mask_areas
            idx[torch.argmax(mask_areas / union, dim=1)] = True
        if points is not None:
            points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
            points = points[None] if points.ndim == 1 else points
            if labels is None:
                labels = torch.ones(points.shape[0])
            labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
            assert len(labels) == len(points), (
                f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
            )
            point_idx = (
                torch.ones(len(result), dtype=torch.bool, device=self.device)
                if labels.sum() == 0  # all negative points
                else torch.zeros(len(result), dtype=torch.bool, device=self.device)
            )
            for point, label in zip(points, labels):
                point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
            idx |= point_idx
        if texts is not None:
            if isinstance(texts, str):
                texts = [texts]
            crop_ims, filter_idx = [], []
            for i, b in enumerate(result.boxes.xyxy.tolist()):
                x1, y1, x2, y2 = (int(x) for x in b)
                if masks[i].sum() <= 100:
                    filter_idx.append(i)
                    continue
                crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
            similarity = self._clip_inference(crop_ims, texts)
            text_idx = torch.argmax(similarity, dim=-1)  # (M, )
            if len(filter_idx):
                text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
            idx[text_idx] = True

        prompt_results.append(result[idx])

    return prompt_results

set_prompts

set_prompts(prompts)

Set prompts in advance.

Source code in ultralytics/models/fastsam/predict.py
def set_prompts(self, prompts):
    """Set prompts in advance."""
    self.prompts = prompts



📅 Created 1 year ago ✏️ Updated 4 months ago