Skip to content

Reference for ultralytics/models/fastsam/predict.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.fastsam.predict.FastSAMPredictor

FastSAMPredictor(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: SegmentationPredictor


              flowchart TD
              ultralytics.models.fastsam.predict.FastSAMPredictor[FastSAMPredictor]
              ultralytics.models.yolo.segment.predict.SegmentationPredictor[SegmentationPredictor]
              ultralytics.models.yolo.detect.predict.DetectionPredictor[DetectionPredictor]
              ultralytics.engine.predictor.BasePredictor[BasePredictor]

                              ultralytics.models.yolo.segment.predict.SegmentationPredictor --> ultralytics.models.fastsam.predict.FastSAMPredictor
                                ultralytics.models.yolo.detect.predict.DetectionPredictor --> ultralytics.models.yolo.segment.predict.SegmentationPredictor
                                ultralytics.engine.predictor.BasePredictor --> ultralytics.models.yolo.detect.predict.DetectionPredictor
                




              click ultralytics.models.fastsam.predict.FastSAMPredictor href "" "ultralytics.models.fastsam.predict.FastSAMPredictor"
              click ultralytics.models.yolo.segment.predict.SegmentationPredictor href "" "ultralytics.models.yolo.segment.predict.SegmentationPredictor"
              click ultralytics.models.yolo.detect.predict.DetectionPredictor href "" "ultralytics.models.yolo.detect.predict.DetectionPredictor"
              click ultralytics.engine.predictor.BasePredictor href "" "ultralytics.engine.predictor.BasePredictor"
            

FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.

This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for single-class segmentation.

Attributes:

NameTypeDescription
prompts dict

Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).

device device

Device on which model and tensors are processed.

clip_model Any

CLIP model for text-based prompting, loaded on demand.

clip_preprocess Any

CLIP preprocessing function for images, loaded on demand.

Methods:

NameDescription
postprocess

Apply postprocessing to FastSAM predictions and handle prompts.

prompt

Perform image segmentation inference based on various prompt types.

set_prompts

Set prompts to be used during inference.

This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression optimized for single-class segmentation.

Parameters:

NameTypeDescriptionDefault
cfg dict

Configuration for the predictor.

DEFAULT_CFG
overrides dict

Configuration overrides.

None
_callbacks list

List of callback functions.

None
Source code in ultralytics/models/fastsam/predict.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize the FastSAMPredictor with configuration and callbacks.

    This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
    extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
    optimized for single-class segmentation.

    Args:
        cfg (dict): Configuration for the predictor.
        overrides (dict, optional): Configuration overrides.
        _callbacks (list, optional): List of callback functions.
    """
    super().__init__(cfg, overrides, _callbacks)
    self.prompts = {}

postprocess

postprocess(preds, img, orig_imgs)

Apply postprocessing to FastSAM predictions and handle prompts.

Parameters:

NameTypeDescriptionDefault
preds list[Tensor]

Raw predictions from the model.

required
img Tensor

Input image tensor that was fed to the model.

required
orig_imgs list[ndarray]

Original images before preprocessing.

required

Returns:

TypeDescription
list[Results]

Processed results with prompts applied.

Source code in ultralytics/models/fastsam/predict.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def postprocess(self, preds, img, orig_imgs):
    """Apply postprocessing to FastSAM predictions and handle prompts.

    Args:
        preds (list[torch.Tensor]): Raw predictions from the model.
        img (torch.Tensor): Input image tensor that was fed to the model.
        orig_imgs (list[np.ndarray]): Original images before preprocessing.

    Returns:
        (list[Results]): Processed results with prompts applied.
    """
    bboxes = self.prompts.pop("bboxes", None)
    points = self.prompts.pop("points", None)
    labels = self.prompts.pop("labels", None)
    texts = self.prompts.pop("texts", None)
    results = super().postprocess(preds, img, orig_imgs)
    for result in results:
        full_box = torch.tensor(
            [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
        )
        boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
        idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
        if idx.numel() != 0:
            result.boxes.xyxy[idx] = full_box

    return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)

prompt

prompt(results, bboxes=None, points=None, labels=None, texts=None)

Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.

Parameters:

NameTypeDescriptionDefault
results Results | list[Results]

Original inference results from FastSAM models without any prompts.

required
bboxes ndarray | list

Bounding boxes with shape (N, 4), in XYXY format.

None
points ndarray | list

Points indicating object locations with shape (N, 2), in pixels.

None
labels ndarray | list

Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.

None
texts str | list[str]

Textual prompts, a list containing string objects.

None

Returns:

TypeDescription
list[Results]

Output results filtered and determined by the provided prompts.

Source code in ultralytics/models/fastsam/predict.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
    """Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.

    Args:
        results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
        bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
        points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
        labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
        texts (str | list[str], optional): Textual prompts, a list containing string objects.

    Returns:
        (list[Results]): Output results filtered and determined by the provided prompts.
    """
    if bboxes is None and points is None and texts is None:
        return results
    prompt_results = []
    if not isinstance(results, list):
        results = [results]
    for result in results:
        if len(result) == 0:
            prompt_results.append(result)
            continue
        masks = result.masks.data
        if masks.shape[1:] != result.orig_shape:
            masks = (scale_masks(masks[None].float(), result.orig_shape)[0] > 0.5).byte()
        # bboxes prompt
        idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
        if bboxes is not None:
            bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
            bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
            bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
            mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
            full_mask_areas = torch.sum(masks, dim=(1, 2))

            union = bbox_areas[:, None] + full_mask_areas - mask_areas
            idx[torch.argmax(mask_areas / union, dim=1)] = True
        if points is not None:
            points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
            points = points[None] if points.ndim == 1 else points
            if labels is None:
                labels = torch.ones(points.shape[0])
            labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
            assert len(labels) == len(points), (
                f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
            )
            point_idx = (
                torch.ones(len(result), dtype=torch.bool, device=self.device)
                if labels.sum() == 0  # all negative points
                else torch.zeros(len(result), dtype=torch.bool, device=self.device)
            )
            for point, label in zip(points, labels):
                point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
            idx |= point_idx
        if texts is not None:
            if isinstance(texts, str):
                texts = [texts]
            crop_ims, filter_idx = [], []
            for i, b in enumerate(result.boxes.xyxy.tolist()):
                x1, y1, x2, y2 = (int(x) for x in b)
                if (masks[i].sum() if TORCH_1_10 else masks[i].sum(0).sum()) <= 100:  # torch 1.9 bug workaround
                    filter_idx.append(i)
                    continue
                crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
            similarity = self._clip_inference(crop_ims, texts)
            text_idx = torch.argmax(similarity, dim=-1)  # (M, )
            if len(filter_idx):
                text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
            idx[text_idx] = True

        prompt_results.append(result[idx])

    return prompt_results

set_prompts

set_prompts(prompts)

Set prompts to be used during inference.

Source code in ultralytics/models/fastsam/predict.py
168
169
170
def set_prompts(self, prompts):
    """Set prompts to be used during inference."""
    self.prompts = prompts





📅 Created 2 years ago ✏️ Updated 1 year ago
glenn-jocherjk4eBurhan-Q