Skip to content

Reference for ultralytics/models/fastsam/predict.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.fastsam.predict.FastSAMPredictor

FastSAMPredictor(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: SegmentationPredictor

FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.

This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for single-class segmentation.

Attributes:

Name Type Description
prompts dict

Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).

device device

Device on which model and tensors are processed.

clip_model Any

CLIP model for text-based prompting, loaded on demand.

clip_preprocess Any

CLIP preprocessing function for images, loaded on demand.

Methods:

Name Description
postprocess

Applies box postprocessing for FastSAM predictions.

prompt

Performs image segmentation inference based on various prompt types.

_clip_inference

Performs CLIP inference to calculate similarity between images and text prompts.

set_prompts

Sets prompts to be used during inference.

This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression optimized for single-class segmentation.

Parameters:

Name Type Description Default
cfg dict

Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.

DEFAULT_CFG
overrides dict

Configuration overrides.

None
_callbacks list

List of callback functions.

None
Source code in ultralytics/models/fastsam/predict.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize the FastSAMPredictor with configuration and callbacks.

    This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
    extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
    optimized for single-class segmentation.

    Args:
        cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
        overrides (dict, optional): Configuration overrides.
        _callbacks (list, optional): List of callback functions.
    """
    super().__init__(cfg, overrides, _callbacks)
    self.prompts = {}

postprocess

postprocess(preds, img, orig_imgs)

Apply postprocessing to FastSAM predictions and handle prompts.

Parameters:

Name Type Description Default
preds List[Tensor]

Raw predictions from the model.

required
img Tensor

Input image tensor that was fed to the model.

required
orig_imgs List[ndarray]

Original images before preprocessing.

required

Returns:

Type Description
List[Results]

Processed results with prompts applied.

Source code in ultralytics/models/fastsam/predict.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def postprocess(self, preds, img, orig_imgs):
    """
    Apply postprocessing to FastSAM predictions and handle prompts.

    Args:
        preds (List[torch.Tensor]): Raw predictions from the model.
        img (torch.Tensor): Input image tensor that was fed to the model.
        orig_imgs (List[numpy.ndarray]): Original images before preprocessing.

    Returns:
        (List[Results]): Processed results with prompts applied.
    """
    bboxes = self.prompts.pop("bboxes", None)
    points = self.prompts.pop("points", None)
    labels = self.prompts.pop("labels", None)
    texts = self.prompts.pop("texts", None)
    results = super().postprocess(preds, img, orig_imgs)
    for result in results:
        full_box = torch.tensor(
            [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
        )
        boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
        idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
        if idx.numel() != 0:
            result.boxes.xyxy[idx] = full_box

    return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)

prompt

prompt(results, bboxes=None, points=None, labels=None, texts=None)

Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.

Parameters:

Name Type Description Default
results Results | List[Results]

Original inference results from FastSAM models without any prompts.

required
bboxes ndarray | List

Bounding boxes with shape (N, 4), in XYXY format.

None
points ndarray | List

Points indicating object locations with shape (N, 2), in pixels.

None
labels ndarray | List

Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.

None
texts str | List[str]

Textual prompts, a list containing string objects.

None

Returns:

Type Description
List[Results]

Output results filtered and determined by the provided prompts.

Source code in ultralytics/models/fastsam/predict.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
    """
    Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.

    Args:
        results (Results | List[Results]): Original inference results from FastSAM models without any prompts.
        bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
        points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
        labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
        texts (str | List[str], optional): Textual prompts, a list containing string objects.

    Returns:
        (List[Results]): Output results filtered and determined by the provided prompts.
    """
    if bboxes is None and points is None and texts is None:
        return results
    prompt_results = []
    if not isinstance(results, list):
        results = [results]
    for result in results:
        if len(result) == 0:
            prompt_results.append(result)
            continue
        masks = result.masks.data
        if masks.shape[1:] != result.orig_shape:
            masks = scale_masks(masks[None], result.orig_shape)[0]
        # bboxes prompt
        idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
        if bboxes is not None:
            bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
            bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
            bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
            mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
            full_mask_areas = torch.sum(masks, dim=(1, 2))

            union = bbox_areas[:, None] + full_mask_areas - mask_areas
            idx[torch.argmax(mask_areas / union, dim=1)] = True
        if points is not None:
            points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
            points = points[None] if points.ndim == 1 else points
            if labels is None:
                labels = torch.ones(points.shape[0])
            labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
            assert len(labels) == len(points), (
                f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
            )
            point_idx = (
                torch.ones(len(result), dtype=torch.bool, device=self.device)
                if labels.sum() == 0  # all negative points
                else torch.zeros(len(result), dtype=torch.bool, device=self.device)
            )
            for point, label in zip(points, labels):
                point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
            idx |= point_idx
        if texts is not None:
            if isinstance(texts, str):
                texts = [texts]
            crop_ims, filter_idx = [], []
            for i, b in enumerate(result.boxes.xyxy.tolist()):
                x1, y1, x2, y2 = (int(x) for x in b)
                if masks[i].sum() <= 100:
                    filter_idx.append(i)
                    continue
                crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
            similarity = self._clip_inference(crop_ims, texts)
            text_idx = torch.argmax(similarity, dim=-1)  # (M, )
            if len(filter_idx):
                text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
            idx[text_idx] = True

        prompt_results.append(result[idx])

    return prompt_results

set_prompts

set_prompts(prompts)

Set prompts to be used during inference.

Source code in ultralytics/models/fastsam/predict.py
179
180
181
def set_prompts(self, prompts):
    """Set prompts to be used during inference."""
    self.prompts = prompts





📅 Created 1 year ago ✏️ Updated 8 months ago