Reference for ultralytics/models/fastsam/predict.py
Improvements
This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/predict.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏
Summary
class ultralytics.models.fastsam.predict.FastSAMPredictor
FastSAMPredictor(self, cfg = DEFAULT_CFG, overrides = None, _callbacks = None)
Bases: SegmentationPredictor
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for single-class segmentation.
This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression optimized for single-class segmentation.
Args
| Name | Type | Description | Default |
|---|---|---|---|
cfg | dict | Configuration for the predictor. | DEFAULT_CFG |
overrides | dict, optional | Configuration overrides. | None |
_callbacks | list, optional | List of callback functions. | None |
Attributes
| Name | Type | Description |
|---|---|---|
prompts | dict | Dictionary containing prompt information for segmentation (bboxes, points, labels, texts). |
device | torch.device | Device on which model and tensors are processed. |
clip_model | Any, optional | CLIP model for text-based prompting, loaded on demand. |
clip_preprocess | Any, optional | CLIP preprocessing function for images, loaded on demand. |
Methods
| Name | Description |
|---|---|
_clip_inference | Perform CLIP inference to calculate similarity between images and text prompts. |
postprocess | Apply postprocessing to FastSAM predictions and handle prompts. |
prompt | Perform image segmentation inference based on cues like bounding boxes, points, and text prompts. |
set_prompts | Set prompts to be used during inference. |
Source code in ultralytics/models/fastsam/predict.py
View on GitHubclass FastSAMPredictor(SegmentationPredictor):
"""FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
single-class segmentation.
Attributes:
prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
device (torch.device): Device on which model and tensors are processed.
clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.
clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
Methods:
postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
prompt: Perform image segmentation inference based on various prompt types.
set_prompts: Set prompts to be used during inference.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize the FastSAMPredictor with configuration and callbacks.
This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
optimized for single-class segmentation.
Args:
cfg (dict): Configuration for the predictor.
overrides (dict, optional): Configuration overrides.
_callbacks (list, optional): List of callback functions.
"""
super().__init__(cfg, overrides, _callbacks)
self.prompts = {}
method ultralytics.models.fastsam.predict.FastSAMPredictor._clip_inference
def _clip_inference(self, images, texts)
Perform CLIP inference to calculate similarity between images and text prompts.
Args
| Name | Type | Description | Default |
|---|---|---|---|
images | list[PIL.Image] | List of source images, each should be PIL.Image with RGB channel order. | required |
texts | list[str] | List of prompt texts, each should be a string object. | required |
Returns
| Type | Description |
|---|---|
torch.Tensor | Similarity matrix between given images and texts with shape (M, N). |
Source code in ultralytics/models/fastsam/predict.py
View on GitHubdef _clip_inference(self, images, texts):
"""Perform CLIP inference to calculate similarity between images and text prompts.
Args:
images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
texts (list[str]): List of prompt texts, each should be a string object.
Returns:
(torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
"""
from ultralytics.nn.text_model import CLIP
if not hasattr(self, "clip"):
self.clip = CLIP("ViT-B/32", device=self.device)
images = torch.stack([self.clip.image_preprocess(image).to(self.device) for image in images])
image_features = self.clip.encode_image(images)
text_features = self.clip.encode_text(self.clip.tokenize(texts))
return text_features @ image_features.T # (M, N)
method ultralytics.models.fastsam.predict.FastSAMPredictor.postprocess
def postprocess(self, preds, img, orig_imgs)
Apply postprocessing to FastSAM predictions and handle prompts.
Args
| Name | Type | Description | Default |
|---|---|---|---|
preds | list[torch.Tensor] | Raw predictions from the model. | required |
img | torch.Tensor | Input image tensor that was fed to the model. | required |
orig_imgs | list[np.ndarray] | Original images before preprocessing. | required |
Returns
| Type | Description |
|---|---|
list[Results] | Processed results with prompts applied. |
Source code in ultralytics/models/fastsam/predict.py
View on GitHubdef postprocess(self, preds, img, orig_imgs):
"""Apply postprocessing to FastSAM predictions and handle prompts.
Args:
preds (list[torch.Tensor]): Raw predictions from the model.
img (torch.Tensor): Input image tensor that was fed to the model.
orig_imgs (list[np.ndarray]): Original images before preprocessing.
Returns:
(list[Results]): Processed results with prompts applied.
"""
bboxes = self.prompts.pop("bboxes", None)
points = self.prompts.pop("points", None)
labels = self.prompts.pop("labels", None)
texts = self.prompts.pop("texts", None)
results = super().postprocess(preds, img, orig_imgs)
for result in results:
full_box = torch.tensor(
[0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
)
boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
if idx.numel() != 0:
result.boxes.xyxy[idx] = full_box
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
method ultralytics.models.fastsam.predict.FastSAMPredictor.prompt
def prompt(self, results, bboxes = None, points = None, labels = None, texts = None)
Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
Args
| Name | Type | Description | Default |
|---|---|---|---|
results | Results | list[Results] | Original inference results from FastSAM models without any prompts. | required |
bboxes | np.ndarray | list, optional | Bounding boxes with shape (N, 4), in XYXY format. | None |
points | np.ndarray | list, optional | Points indicating object locations with shape (N, 2), in pixels. | None |
labels | np.ndarray | list, optional | Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. | None |
texts | str | list[str], optional | Textual prompts, a list containing string objects. | None |
Returns
| Type | Description |
|---|---|
list[Results] | Output results filtered and determined by the provided prompts. |
Source code in ultralytics/models/fastsam/predict.py
View on GitHubdef prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
"""Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
Args:
results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
texts (str | list[str], optional): Textual prompts, a list containing string objects.
Returns:
(list[Results]): Output results filtered and determined by the provided prompts.
"""
if bboxes is None and points is None and texts is None:
return results
prompt_results = []
if not isinstance(results, list):
results = [results]
for result in results:
if len(result) == 0:
prompt_results.append(result)
continue
masks = result.masks.data
if masks.shape[1:] != result.orig_shape:
masks = (scale_masks(masks[None].float(), result.orig_shape)[0] > 0.5).byte()
# bboxes prompt
idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
full_mask_areas = torch.sum(masks, dim=(1, 2))
union = bbox_areas[:, None] + full_mask_areas - mask_areas
idx[torch.argmax(mask_areas / union, dim=1)] = True
if points is not None:
points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
points = points[None] if points.ndim == 1 else points
if labels is None:
labels = torch.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
assert len(labels) == len(points), (
f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
)
point_idx = (
torch.ones(len(result), dtype=torch.bool, device=self.device)
if labels.sum() == 0 # all negative points
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
)
for point, label in zip(points, labels):
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
idx |= point_idx
if texts is not None:
if isinstance(texts, str):
texts = [texts]
crop_ims, filter_idx = [], []
for i, b in enumerate(result.boxes.xyxy.tolist()):
x1, y1, x2, y2 = (int(x) for x in b)
if (masks[i].sum() if TORCH_1_10 else masks[i].sum(0).sum()) <= 100: # torch 1.9 bug workaround
filter_idx.append(i)
continue
crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
similarity = self._clip_inference(crop_ims, texts)
text_idx = torch.argmax(similarity, dim=-1) # (M, )
if len(filter_idx):
text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
idx[text_idx] = True
prompt_results.append(result[idx])
return prompt_results
method ultralytics.models.fastsam.predict.FastSAMPredictor.set_prompts
def set_prompts(self, prompts)
Set prompts to be used during inference.
Args
| Name | Type | Description | Default |
|---|---|---|---|
prompts | required |
Source code in ultralytics/models/fastsam/predict.py
View on GitHubdef set_prompts(self, prompts):
"""Set prompts to be used during inference."""
self.prompts = prompts