Skip to content

Reference for ultralytics/models/fastsam/prompt.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/prompt.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.fastsam.prompt.FastSAMPrompt

FastSAMPrompt(source, results, device='cuda')

Fast Segment Anything Model class for image annotation and visualization.

Attributes:

Name Type Description
device str

Computing device ('cuda' or 'cpu').

results

Object detection or segmentation results.

source

Source image or image path.

clip

CLIP model for linear assignment.

Source code in ultralytics/models/fastsam/prompt.py
def __init__(self, source, results, device="cuda") -> None:
    """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
    if isinstance(source, (str, Path)) and os.path.isdir(source):
        raise ValueError("FastSAM only accepts image paths and PIL Image sources, not directories.")
    self.device = device
    self.results = results
    self.source = source

    # Import and assign clip
    try:
        import clip
    except ImportError:
        checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
        import clip
    self.clip = clip

box_prompt

box_prompt(bbox)

Modifies the bounding box properties and calculates IoU between masks and bounding box.

Source code in ultralytics/models/fastsam/prompt.py
def box_prompt(self, bbox):
    """Modifies the bounding box properties and calculates IoU between masks and bounding box."""
    if self.results[0].masks is not None:
        assert bbox[2] != 0 and bbox[3] != 0, "Bounding box width and height should not be zero"
        masks = self.results[0].masks.data
        target_height, target_width = self.results[0].orig_shape
        h = masks.shape[1]
        w = masks.shape[2]
        if h != target_height or w != target_width:
            bbox = [
                int(bbox[0] * w / target_width),
                int(bbox[1] * h / target_height),
                int(bbox[2] * w / target_width),
                int(bbox[3] * h / target_height),
            ]
        bbox[0] = max(round(bbox[0]), 0)
        bbox[1] = max(round(bbox[1]), 0)
        bbox[2] = min(round(bbox[2]), w)
        bbox[3] = min(round(bbox[3]), h)

        # IoUs = torch.zeros(len(masks), dtype=torch.float32)
        bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])

        masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
        orig_masks_area = torch.sum(masks, dim=(1, 2))

        union = bbox_area + orig_masks_area - masks_area
        iou = masks_area / union
        max_iou_index = torch.argmax(iou)

        self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
    return self.results

everything_prompt

everything_prompt()

Returns the processed results from the previous methods in the class.

Source code in ultralytics/models/fastsam/prompt.py
def everything_prompt(self):
    """Returns the processed results from the previous methods in the class."""
    return self.results

fast_show_mask staticmethod

fast_show_mask(annotation, ax, random_color=False, bbox=None, points=None, pointlabel=None, retinamask=True, target_height=960, target_width=960)

Quickly shows the mask annotations on the given matplotlib axis.

Parameters:

Name Type Description Default
annotation array - like

Mask annotation.

required
ax Axes

Matplotlib axis.

required
random_color bool

Whether to use random color for masks. Defaults to False.

False
bbox list

Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.

None
points list

Points to be plotted. Defaults to None.

None
pointlabel list

Labels for the points. Defaults to None.

None
retinamask bool

Whether to use retina mask. Defaults to True.

True
target_height int

Target height for resizing. Defaults to 960.

960
target_width int

Target width for resizing. Defaults to 960.

960
Source code in ultralytics/models/fastsam/prompt.py
@staticmethod
def fast_show_mask(
    annotation,
    ax,
    random_color=False,
    bbox=None,
    points=None,
    pointlabel=None,
    retinamask=True,
    target_height=960,
    target_width=960,
):
    """
    Quickly shows the mask annotations on the given matplotlib axis.

    Args:
        annotation (array-like): Mask annotation.
        ax (matplotlib.axes.Axes): Matplotlib axis.
        random_color (bool, optional): Whether to use random color for masks. Defaults to False.
        bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
        points (list, optional): Points to be plotted. Defaults to None.
        pointlabel (list, optional): Labels for the points. Defaults to None.
        retinamask (bool, optional): Whether to use retina mask. Defaults to True.
        target_height (int, optional): Target height for resizing. Defaults to 960.
        target_width (int, optional): Target width for resizing. Defaults to 960.
    """
    import matplotlib.pyplot as plt

    n, h, w = annotation.shape  # batch, height, width

    areas = np.sum(annotation, axis=(1, 2))
    annotation = annotation[np.argsort(areas)]

    index = (annotation != 0).argmax(axis=0)
    if random_color:
        color = np.random.random((n, 1, 1, 3))
    else:
        color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
    transparency = np.ones((n, 1, 1, 1)) * 0.6
    visual = np.concatenate([color, transparency], axis=-1)
    mask_image = np.expand_dims(annotation, -1) * visual

    show = np.zeros((h, w, 4))
    h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
    indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))

    show[h_indices, w_indices, :] = mask_image[indices]
    if bbox is not None:
        x1, y1, x2, y2 = bbox
        ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
    # Draw point
    if points is not None:
        plt.scatter(
            [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
            [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
            s=20,
            c="y",
        )
        plt.scatter(
            [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
            [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
            s=20,
            c="m",
        )

    if not retinamask:
        show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
    ax.imshow(show)

plot

plot(annotations, output, bbox=None, points=None, point_label=None, mask_random_color=True, better_quality=True, retina=False, with_contours=True)

Plots annotations, bounding boxes, and points on images and saves the output.

Parameters:

Name Type Description Default
annotations list

Annotations to be plotted.

required
output str or Path

Output directory for saving the plots.

required
bbox list

Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.

None
points list

Points to be plotted. Defaults to None.

None
point_label list

Labels for the points. Defaults to None.

None
mask_random_color bool

Whether to use random color for masks. Defaults to True.

True
better_quality bool

Whether to apply morphological transformations for better mask quality. Defaults to True.

True
retina bool

Whether to use retina mask. Defaults to False.

False
with_contours bool

Whether to plot contours. Defaults to True.

True
Source code in ultralytics/models/fastsam/prompt.py
def plot(
    self,
    annotations,
    output,
    bbox=None,
    points=None,
    point_label=None,
    mask_random_color=True,
    better_quality=True,
    retina=False,
    with_contours=True,
):
    """
    Plots annotations, bounding boxes, and points on images and saves the output.

    Args:
        annotations (list): Annotations to be plotted.
        output (str or Path): Output directory for saving the plots.
        bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
        points (list, optional): Points to be plotted. Defaults to None.
        point_label (list, optional): Labels for the points. Defaults to None.
        mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
        better_quality (bool, optional): Whether to apply morphological transformations for better mask quality.
            Defaults to True.
        retina (bool, optional): Whether to use retina mask. Defaults to False.
        with_contours (bool, optional): Whether to plot contours. Defaults to True.
    """
    import matplotlib.pyplot as plt

    pbar = TQDM(annotations, total=len(annotations))
    for ann in pbar:
        result_name = os.path.basename(ann.path)
        image = ann.orig_img[..., ::-1]  # BGR to RGB
        original_h, original_w = ann.orig_shape
        # For macOS only
        # plt.switch_backend('TkAgg')
        plt.figure(figsize=(original_w / 100, original_h / 100))
        # Add subplot with no margin.
        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
        plt.margins(0, 0)
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.imshow(image)

        if ann.masks is not None:
            masks = ann.masks.data
            if better_quality:
                if isinstance(masks[0], torch.Tensor):
                    masks = np.array(masks.cpu())
                for i, mask in enumerate(masks):
                    mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
                    masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))

            self.fast_show_mask(
                masks,
                plt.gca(),
                random_color=mask_random_color,
                bbox=bbox,
                points=points,
                pointlabel=point_label,
                retinamask=retina,
                target_height=original_h,
                target_width=original_w,
            )

            if with_contours:
                contour_all = []
                temp = np.zeros((original_h, original_w, 1))
                for i, mask in enumerate(masks):
                    mask = mask.astype(np.uint8)
                    if not retina:
                        mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
                    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
                    contour_all.extend(iter(contours))
                cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
                color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
                contour_mask = temp / 255 * color.reshape(1, 1, -1)
                plt.imshow(contour_mask)

        # Save the figure
        save_path = Path(output) / result_name
        save_path.parent.mkdir(exist_ok=True, parents=True)
        plt.axis("off")
        plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
        plt.close()
        pbar.set_description(f"Saving {result_name} to {save_path}")

point_prompt

point_prompt(points, pointlabel)

Adjusts points on detected masks based on user input and returns the modified results.

Source code in ultralytics/models/fastsam/prompt.py
def point_prompt(self, points, pointlabel):  # numpy
    """Adjusts points on detected masks based on user input and returns the modified results."""
    if self.results[0].masks is not None:
        masks = self._format_results(self.results[0], 0)
        target_height, target_width = self.results[0].orig_shape
        h = masks[0]["segmentation"].shape[0]
        w = masks[0]["segmentation"].shape[1]
        if h != target_height or w != target_width:
            points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
        onemask = np.zeros((h, w))
        for annotation in masks:
            mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
            for i, point in enumerate(points):
                if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
                    onemask += mask
                if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
                    onemask -= mask
        onemask = onemask >= 1
        self.results[0].masks.data = torch.tensor(np.array([onemask]))
    return self.results

retrieve

retrieve(model, preprocess, elements, search_text: str, device) -> Tensor

Processes images and text with a model, calculates similarity, and returns softmax score.

Source code in ultralytics/models/fastsam/prompt.py
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> Tensor:
    """Processes images and text with a model, calculates similarity, and returns softmax score."""
    preprocessed_images = [preprocess(image).to(device) for image in elements]
    tokenized_text = self.clip.tokenize([search_text]).to(device)
    stacked_images = torch.stack(preprocessed_images)
    image_features = model.encode_image(stacked_images)
    text_features = model.encode_text(tokenized_text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    probs = 100.0 * image_features @ text_features.T
    return probs[:, 0].softmax(dim=0)

text_prompt

text_prompt(text, clip_download_root=None)

Processes a text prompt, applies it to existing results and returns the updated results.

Source code in ultralytics/models/fastsam/prompt.py
def text_prompt(self, text, clip_download_root=None):
    """Processes a text prompt, applies it to existing results and returns the updated results."""
    if self.results[0].masks is not None:
        format_results = self._format_results(self.results[0], 0)
        cropped_images, filter_id, annotations = self._crop_image(format_results)
        clip_model, preprocess = self.clip.load("ViT-B/32", download_root=clip_download_root, device=self.device)
        scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device)
        max_idx = torch.argmax(scores)
        max_idx += sum(np.array(filter_id) <= int(max_idx))
        self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
    return self.results





Created 2023-11-12, Updated 2024-07-21
Authors: glenn-jocher (6), Burhan-Q (1)