Skip to content

Reference for ultralytics/models/sam/amg.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/amg.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.sam.amg.is_box_near_crop_edge

is_box_near_crop_edge(
    boxes: Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
) -> torch.Tensor

Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.

Parameters:

Name Type Description Default
boxes Tensor

Bounding boxes in XYXY format.

required
crop_box list[int]

Crop box coordinates in [x0, y0, x1, y1] format.

required
orig_box list[int]

Original image box coordinates in [x0, y0, x1, y1] format.

required
atol float

Absolute tolerance for edge proximity detection.

20.0

Returns:

Type Description
Tensor

Boolean tensor indicating which boxes are near crop edges.

Examples:

>>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])
>>> crop_box = [0, 0, 200, 200]
>>> orig_box = [0, 0, 300, 300]
>>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)
Source code in ultralytics/models/sam/amg.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def is_box_near_crop_edge(
    boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
) -> torch.Tensor:
    """
    Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.

    Args:
        boxes (torch.Tensor): Bounding boxes in XYXY format.
        crop_box (list[int]): Crop box coordinates in [x0, y0, x1, y1] format.
        orig_box (list[int]): Original image box coordinates in [x0, y0, x1, y1] format.
        atol (float, optional): Absolute tolerance for edge proximity detection.

    Returns:
        (torch.Tensor): Boolean tensor indicating which boxes are near crop edges.

    Examples:
        >>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])
        >>> crop_box = [0, 0, 200, 200]
        >>> orig_box = [0, 0, 300, 300]
        >>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)
    """
    crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
    orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
    boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
    near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
    near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
    near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
    return torch.any(near_crop_edge, dim=1)





ultralytics.models.sam.amg.batch_iterator

batch_iterator(batch_size: int, *args) -> Generator[list[Any]]

Yield batches of data from input arguments with specified batch size for efficient processing.

This function takes a batch size and any number of iterables, then yields batches of elements from those iterables. All input iterables must have the same length.

Parameters:

Name Type Description Default
batch_size int

Size of each batch to yield.

required
*args Any

Variable length input iterables to batch. All iterables must have the same length.

()

Yields:

Type Description
list[Any]

A list of batched elements from each input iterable.

Examples:

>>> data = [1, 2, 3, 4, 5]
>>> labels = ["a", "b", "c", "d", "e"]
>>> for batch in batch_iterator(2, data, labels):
...     print(batch)
[[1, 2], ['a', 'b']]
[[3, 4], ['c', 'd']]
[[5], ['e']]
Source code in ultralytics/models/sam/amg.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
    """
    Yield batches of data from input arguments with specified batch size for efficient processing.

    This function takes a batch size and any number of iterables, then yields batches of elements from those
    iterables. All input iterables must have the same length.

    Args:
        batch_size (int): Size of each batch to yield.
        *args (Any): Variable length input iterables to batch. All iterables must have the same length.

    Yields:
        (list[Any]): A list of batched elements from each input iterable.

    Examples:
        >>> data = [1, 2, 3, 4, 5]
        >>> labels = ["a", "b", "c", "d", "e"]
        >>> for batch in batch_iterator(2, data, labels):
        ...     print(batch)
        [[1, 2], ['a', 'b']]
        [[3, 4], ['c', 'd']]
        [[5], ['e']]
    """
    assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
    n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
    for b in range(n_batches):
        yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]





ultralytics.models.sam.amg.calculate_stability_score

calculate_stability_score(
    masks: Tensor, mask_threshold: float, threshold_offset: float
) -> torch.Tensor

Compute the stability score for a batch of masks.

The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and low values.

Parameters:

Name Type Description Default
masks Tensor

Batch of predicted mask logits.

required
mask_threshold float

Threshold value for creating binary masks.

required
threshold_offset float

Offset applied to the threshold for creating high and low binary masks.

required

Returns:

Type Description
Tensor

Stability scores for each mask in the batch.

Notes
  • One mask is always contained inside the other.
  • Memory is saved by preventing unnecessary cast to torch.int64.

Examples:

>>> masks = torch.rand(10, 256, 256)  # Batch of 10 masks
>>> mask_threshold = 0.5
>>> threshold_offset = 0.1
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
Source code in ultralytics/models/sam/amg.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
    """
    Compute the stability score for a batch of masks.

    The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
    high and low values.

    Args:
        masks (torch.Tensor): Batch of predicted mask logits.
        mask_threshold (float): Threshold value for creating binary masks.
        threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.

    Returns:
        (torch.Tensor): Stability scores for each mask in the batch.

    Notes:
        - One mask is always contained inside the other.
        - Memory is saved by preventing unnecessary cast to torch.int64.

    Examples:
        >>> masks = torch.rand(10, 256, 256)  # Batch of 10 masks
        >>> mask_threshold = 0.5
        >>> threshold_offset = 0.1
        >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
    """
    intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
    unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
    return intersections / unions





ultralytics.models.sam.amg.build_point_grid

build_point_grid(n_per_side: int) -> np.ndarray

Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks.

Source code in ultralytics/models/sam/amg.py
103
104
105
106
107
108
109
def build_point_grid(n_per_side: int) -> np.ndarray:
    """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
    offset = 1 / (2 * n_per_side)
    points_one_side = np.linspace(offset, 1 - offset, n_per_side)
    points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
    points_y = np.tile(points_one_side[:, None], (1, n_per_side))
    return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)





ultralytics.models.sam.amg.build_all_layer_point_grids

build_all_layer_point_grids(
    n_per_side: int, n_layers: int, scale_per_layer: int
) -> list[np.ndarray]

Generate point grids for multiple crop layers with varying scales and densities.

Source code in ultralytics/models/sam/amg.py
112
113
114
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> list[np.ndarray]:
    """Generate point grids for multiple crop layers with varying scales and densities."""
    return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]





ultralytics.models.sam.amg.generate_crop_boxes

generate_crop_boxes(
    im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
) -> tuple[list[list[int]], list[int]]

Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.

Parameters:

Name Type Description Default
im_size tuple[int, ...]

Height and width of the input image.

required
n_layers int

Number of layers to generate crop boxes for.

required
overlap_ratio float

Ratio of overlap between adjacent crop boxes.

required

Returns:

Name Type Description
crop_boxes list[list[int]]

List of crop boxes in [x0, y0, x1, y1] format.

layer_idxs list[int]

List of layer indices corresponding to each crop box.

Examples:

>>> im_size = (800, 1200)  # Height, width
>>> n_layers = 3
>>> overlap_ratio = 0.25
>>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)
Source code in ultralytics/models/sam/amg.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def generate_crop_boxes(
    im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
) -> tuple[list[list[int]], list[int]]:
    """
    Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.

    Args:
        im_size (tuple[int, ...]): Height and width of the input image.
        n_layers (int): Number of layers to generate crop boxes for.
        overlap_ratio (float): Ratio of overlap between adjacent crop boxes.

    Returns:
        crop_boxes (list[list[int]]): List of crop boxes in [x0, y0, x1, y1] format.
        layer_idxs (list[int]): List of layer indices corresponding to each crop box.

    Examples:
        >>> im_size = (800, 1200)  # Height, width
        >>> n_layers = 3
        >>> overlap_ratio = 0.25
        >>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)
    """
    crop_boxes, layer_idxs = [], []
    im_h, im_w = im_size
    short_side = min(im_h, im_w)

    # Original image
    crop_boxes.append([0, 0, im_w, im_h])
    layer_idxs.append(0)

    def crop_len(orig_len, n_crops, overlap):
        """Calculate the length of each crop given the original length, number of crops, and overlap."""
        return math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)

    for i_layer in range(n_layers):
        n_crops_per_side = 2 ** (i_layer + 1)
        overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))

        crop_w = crop_len(im_w, n_crops_per_side, overlap)
        crop_h = crop_len(im_h, n_crops_per_side, overlap)

        crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
        crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]

        # Crops in XYWH format
        for x0, y0 in product(crop_box_x0, crop_box_y0):
            box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
            crop_boxes.append(box)
            layer_idxs.append(i_layer + 1)

    return crop_boxes, layer_idxs





ultralytics.models.sam.amg.uncrop_boxes_xyxy

uncrop_boxes_xyxy(boxes: Tensor, crop_box: list[int]) -> torch.Tensor

Uncrop bounding boxes by adding the crop box offset to their coordinates.

Source code in ultralytics/models/sam/amg.py
169
170
171
172
173
174
175
176
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
    """Uncrop bounding boxes by adding the crop box offset to their coordinates."""
    x0, y0, _, _ = crop_box
    offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
    # Check if boxes has a channel dimension
    if len(boxes.shape) == 3:
        offset = offset.unsqueeze(1)
    return boxes + offset





ultralytics.models.sam.amg.uncrop_points

uncrop_points(points: Tensor, crop_box: list[int]) -> torch.Tensor

Uncrop points by adding the crop box offset to their coordinates.

Source code in ultralytics/models/sam/amg.py
179
180
181
182
183
184
185
186
def uncrop_points(points: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
    """Uncrop points by adding the crop box offset to their coordinates."""
    x0, y0, _, _ = crop_box
    offset = torch.tensor([[x0, y0]], device=points.device)
    # Check if points has a channel dimension
    if len(points.shape) == 3:
        offset = offset.unsqueeze(1)
    return points + offset





ultralytics.models.sam.amg.uncrop_masks

uncrop_masks(
    masks: Tensor, crop_box: list[int], orig_h: int, orig_w: int
) -> torch.Tensor

Uncrop masks by padding them to the original image size, handling coordinate transformations.

Source code in ultralytics/models/sam/amg.py
189
190
191
192
193
194
195
196
197
def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w: int) -> torch.Tensor:
    """Uncrop masks by padding them to the original image size, handling coordinate transformations."""
    x0, y0, x1, y1 = crop_box
    if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
        return masks
    # Coordinate transform masks
    pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
    pad = (x0, pad_x - x0, y0, pad_y - y0)
    return torch.nn.functional.pad(masks, pad, value=0)





ultralytics.models.sam.amg.remove_small_regions

remove_small_regions(
    mask: ndarray, area_thresh: float, mode: str
) -> tuple[np.ndarray, bool]

Remove small disconnected regions or holes in a mask based on area threshold and mode.

Parameters:

Name Type Description Default
mask ndarray

Binary mask to process.

required
area_thresh float

Area threshold below which regions will be removed.

required
mode str

Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected regions.

required

Returns:

Name Type Description
processed_mask ndarray

Processed binary mask with small regions removed.

modified bool

Whether any regions were modified.

Examples:

>>> mask = np.zeros((100, 100), dtype=np.bool_)
>>> mask[40:60, 40:60] = True  # Create a square
>>> mask[45:55, 45:55] = False  # Create a hole
>>> processed_mask, modified = remove_small_regions(mask, 50, "holes")
Source code in ultralytics/models/sam/amg.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
    """
    Remove small disconnected regions or holes in a mask based on area threshold and mode.

    Args:
        mask (np.ndarray): Binary mask to process.
        area_thresh (float): Area threshold below which regions will be removed.
        mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
            regions.

    Returns:
        processed_mask (np.ndarray): Processed binary mask with small regions removed.
        modified (bool): Whether any regions were modified.

    Examples:
        >>> mask = np.zeros((100, 100), dtype=np.bool_)
        >>> mask[40:60, 40:60] = True  # Create a square
        >>> mask[45:55, 45:55] = False  # Create a hole
        >>> processed_mask, modified = remove_small_regions(mask, 50, "holes")
    """
    import cv2  # type: ignore

    assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
    correct_holes = mode == "holes"
    working_mask = (correct_holes ^ mask).astype(np.uint8)
    n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
    sizes = stats[:, -1][1:]  # Row 0 is background label
    small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
    if not small_regions:
        return mask, False
    fill_labels = [0, *small_regions]
    if not correct_holes:
        # If every region is below threshold, keep largest
        fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
    mask = np.isin(regions, fill_labels)
    return mask, True





ultralytics.models.sam.amg.batched_mask_to_box

batched_mask_to_box(masks: Tensor) -> torch.Tensor

Calculate bounding boxes in XYXY format around binary masks.

Parameters:

Name Type Description Default
masks Tensor

Binary masks with shape (B, H, W) or (B, C, H, W).

required

Returns:

Type Description
Tensor

Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).

Notes
  • Handles empty masks by returning zero boxes.
  • Preserves input tensor dimensions in the output.
Source code in ultralytics/models/sam/amg.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
    """
    Calculate bounding boxes in XYXY format around binary masks.

    Args:
        masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).

    Returns:
        (torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).

    Notes:
        - Handles empty masks by returning zero boxes.
        - Preserves input tensor dimensions in the output.
    """
    # torch.max below raises an error on empty inputs, just skip in this case
    if torch.numel(masks) == 0:
        return torch.zeros(*masks.shape[:-2], 4, device=masks.device)

    # Normalize shape to CxHxW
    shape = masks.shape
    h, w = shape[-2:]
    masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
    # Get top and bottom edges
    in_height, _ = torch.max(masks, dim=-1)
    in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
    bottom_edges, _ = torch.max(in_height_coords, dim=-1)
    in_height_coords = in_height_coords + h * (~in_height)
    top_edges, _ = torch.min(in_height_coords, dim=-1)

    # Get left and right edges
    in_width, _ = torch.max(masks, dim=-2)
    in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
    right_edges, _ = torch.max(in_width_coords, dim=-1)
    in_width_coords = in_width_coords + w * (~in_width)
    left_edges, _ = torch.min(in_width_coords, dim=-1)

    # If the mask is empty the right edge will be to the left of the left edge.
    # Replace these boxes with [0, 0, 0, 0]
    empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
    out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
    out = out * (~empty_filter).unsqueeze(-1)

    # Return to original shape
    return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]





📅 Created 1 year ago ✏️ Updated 1 year ago