Skip to content

Reference for ultralytics/utils/ops.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py. If you spot a problem please help fix it by contributing a Pull Request đŸ› ïž. Thank you 🙏!


ultralytics.utils.ops.Profile

Profile(t: float = 0.0, device: device | None = None)

Bases: ContextDecorator

Ultralytics Profile class for timing code execution.

Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing measurements with CUDA synchronization support for GPU operations.

Attributes:

Name Type Description
t float

Accumulated time in seconds.

device device

Device used for model inference.

cuda bool

Whether CUDA is being used for timing synchronization.

Examples:

Use as a context manager to time code execution

>>> with Profile(device=device) as dt:
...     pass  # slow operation here
>>> print(dt)  # prints "Elapsed time is 9.5367431640625e-07 s"

Use as a decorator to time function execution

>>> @Profile()
... def slow_function():
...     time.sleep(0.1)

Parameters:

Name Type Description Default
t float

Initial accumulated time in seconds.

0.0
device device

Device used for model inference to enable CUDA synchronization.

None
Source code in ultralytics/utils/ops.py
42
43
44
45
46
47
48
49
50
51
52
def __init__(self, t: float = 0.0, device: torch.device | None = None):
    """
    Initialize the Profile class.

    Args:
        t (float): Initial accumulated time in seconds.
        device (torch.device, optional): Device used for model inference to enable CUDA synchronization.
    """
    self.t = t
    self.device = device
    self.cuda = bool(device and str(device).startswith("cuda"))

__enter__

__enter__()

Start timing.

Source code in ultralytics/utils/ops.py
54
55
56
57
def __enter__(self):
    """Start timing."""
    self.start = self.time()
    return self

__exit__

__exit__(type, value, traceback)

Stop timing.

Source code in ultralytics/utils/ops.py
59
60
61
62
def __exit__(self, type, value, traceback):
    """Stop timing."""
    self.dt = self.time() - self.start  # delta-time
    self.t += self.dt  # accumulate dt

__str__

__str__()

Return a human-readable string representing the accumulated elapsed time.

Source code in ultralytics/utils/ops.py
64
65
66
def __str__(self):
    """Return a human-readable string representing the accumulated elapsed time."""
    return f"Elapsed time is {self.t} s"

time

time()

Get current time with CUDA synchronization if applicable.

Source code in ultralytics/utils/ops.py
68
69
70
71
72
def time(self):
    """Get current time with CUDA synchronization if applicable."""
    if self.cuda:
        torch.cuda.synchronize(self.device)
    return time.perf_counter()





ultralytics.utils.ops.segment2box

segment2box(segment, width: int = 640, height: int = 640)

Convert segment coordinates to bounding box coordinates.

Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates. Applies inside-image constraint and clips coordinates when necessary.

Parameters:

Name Type Description Default
segment Tensor

Segment coordinates in format (N, 2) where N is number of points.

required
width int

Width of the image in pixels.

640
height int

Height of the image in pixels.

640

Returns:

Type Description
ndarray

Bounding box coordinates in xyxy format [x1, y1, x2, y2].

Source code in ultralytics/utils/ops.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def segment2box(segment, width: int = 640, height: int = 640):
    """
    Convert segment coordinates to bounding box coordinates.

    Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates.
    Applies inside-image constraint and clips coordinates when necessary.

    Args:
        segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.
        width (int): Width of the image in pixels.
        height (int): Height of the image in pixels.

    Returns:
        (np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].
    """
    x, y = segment.T  # segment xy
    # Clip coordinates if 3 out of 4 sides are outside the image
    if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:
        x = x.clip(0, width)
        y = y.clip(0, height)
    inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
    x = x[inside]
    y = y[inside]
    return (
        np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
        if any(x)
        else np.zeros(4, dtype=segment.dtype)
    )  # xyxy





ultralytics.utils.ops.scale_boxes

scale_boxes(
    img1_shape,
    boxes,
    img0_shape,
    ratio_pad=None,
    padding: bool = True,
    xywh: bool = False,
)

Rescale bounding boxes from one image shape to another.

Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes. Supports both xyxy and xywh box formats.

Parameters:

Name Type Description Default
img1_shape tuple

Shape of the source image (height, width).

required
boxes Tensor

Bounding boxes to rescale in format (N, 4).

required
img0_shape tuple

Shape of the target image (height, width).

required
ratio_pad tuple

Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.

None
padding bool

Whether boxes are based on YOLO-style augmented images with padding.

True
xywh bool

Whether box format is xywh (True) or xyxy (False).

False

Returns:

Type Description
Tensor

Rescaled bounding boxes in the same format as input.

Source code in ultralytics/utils/ops.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):
    """
    Rescale bounding boxes from one image shape to another.

    Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.
    Supports both xyxy and xywh box formats.

    Args:
        img1_shape (tuple): Shape of the source image (height, width).
        boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).
        img0_shape (tuple): Shape of the target image (height, width).
        ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.
        padding (bool): Whether boxes are based on YOLO-style augmented images with padding.
        xywh (bool): Whether box format is xywh (True) or xyxy (False).

    Returns:
        (torch.Tensor): Rescaled bounding boxes in the same format as input.
    """
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1)
        pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)
    else:
        gain = ratio_pad[0][0]
        pad_x, pad_y = ratio_pad[1]

    if padding:
        boxes[..., 0] -= pad_x  # x padding
        boxes[..., 1] -= pad_y  # y padding
        if not xywh:
            boxes[..., 2] -= pad_x  # x padding
            boxes[..., 3] -= pad_y  # y padding
    boxes[..., :4] /= gain
    return boxes if xywh else clip_boxes(boxes, img0_shape)





ultralytics.utils.ops.make_divisible

make_divisible(x: int, divisor)

Return the nearest number that is divisible by the given divisor.

Parameters:

Name Type Description Default
x int

The number to make divisible.

required
divisor int | Tensor

The divisor.

required

Returns:

Type Description
int

The nearest number divisible by the divisor.

Source code in ultralytics/utils/ops.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def make_divisible(x: int, divisor):
    """
    Return the nearest number that is divisible by the given divisor.

    Args:
        x (int): The number to make divisible.
        divisor (int | torch.Tensor): The divisor.

    Returns:
        (int): The nearest number divisible by the divisor.
    """
    if isinstance(divisor, torch.Tensor):
        divisor = int(divisor.max())  # to int
    return math.ceil(x / divisor) * divisor





ultralytics.utils.ops.clip_boxes

clip_boxes(boxes, shape)

Clip bounding boxes to image boundaries.

Parameters:

Name Type Description Default
boxes Tensor | ndarray

Bounding boxes to clip.

required
shape tuple

Image shape as HWC or HW (supports both).

required

Returns:

Type Description
Tensor | ndarray

Clipped bounding boxes.

Source code in ultralytics/utils/ops.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def clip_boxes(boxes, shape):
    """
    Clip bounding boxes to image boundaries.

    Args:
        boxes (torch.Tensor | np.ndarray): Bounding boxes to clip.
        shape (tuple): Image shape as HWC or HW (supports both).

    Returns:
        (torch.Tensor | np.ndarray): Clipped bounding boxes.
    """
    h, w = shape[:2]  # supports both HWC or HW shapes
    if isinstance(boxes, torch.Tensor):  # faster individually
        if NOT_MACOS14:
            boxes[..., 0].clamp_(0, w)  # x1
            boxes[..., 1].clamp_(0, h)  # y1
            boxes[..., 2].clamp_(0, w)  # x2
            boxes[..., 3].clamp_(0, h)  # y2
        else:  # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
            boxes[..., 0] = boxes[..., 0].clamp(0, w)
            boxes[..., 1] = boxes[..., 1].clamp(0, h)
            boxes[..., 2] = boxes[..., 2].clamp(0, w)
            boxes[..., 3] = boxes[..., 3].clamp(0, h)
    else:  # np.array (faster grouped)
        boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, w)  # x1, x2
        boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, h)  # y1, y2
    return boxes





ultralytics.utils.ops.clip_coords

clip_coords(coords, shape)

Clip line coordinates to image boundaries.

Parameters:

Name Type Description Default
coords Tensor | ndarray

Line coordinates to clip.

required
shape tuple

Image shape as HWC or HW (supports both).

required

Returns:

Type Description
Tensor | ndarray

Clipped coordinates.

Source code in ultralytics/utils/ops.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def clip_coords(coords, shape):
    """
    Clip line coordinates to image boundaries.

    Args:
        coords (torch.Tensor | np.ndarray): Line coordinates to clip.
        shape (tuple): Image shape as HWC or HW (supports both).

    Returns:
        (torch.Tensor | np.ndarray): Clipped coordinates.
    """
    h, w = shape[:2]  # supports both HWC or HW shapes
    if isinstance(coords, torch.Tensor):
        if NOT_MACOS14:
            coords[..., 0].clamp_(0, w)  # x
            coords[..., 1].clamp_(0, h)  # y
        else:  # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
            coords[..., 0] = coords[..., 0].clamp(0, w)
            coords[..., 1] = coords[..., 1].clamp(0, h)
    else:  # np.array
        coords[..., 0] = coords[..., 0].clip(0, w)  # x
        coords[..., 1] = coords[..., 1].clip(0, h)  # y
    return coords





ultralytics.utils.ops.scale_image

scale_image(masks, im0_shape, ratio_pad=None)

Rescale masks to original image size.

Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding that was applied during preprocessing.

Parameters:

Name Type Description Default
masks ndarray

Resized and padded masks with shape [H, W, N] or [H, W, 3].

required
im0_shape tuple

Original image shape as HWC or HW (supports both).

required
ratio_pad tuple

Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).

None

Returns:

Type Description
ndarray

Rescaled masks with shape [H, W, N] matching original image dimensions.

Source code in ultralytics/utils/ops.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def scale_image(masks, im0_shape, ratio_pad=None):
    """
    Rescale masks to original image size.

    Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding
    that was applied during preprocessing.

    Args:
        masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].
        im0_shape (tuple): Original image shape as HWC or HW (supports both).
        ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).

    Returns:
        (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.
    """
    # Rescale coordinates (xyxy) from im1_shape to im0_shape
    im0_h, im0_w = im0_shape[:2]  # supports both HWC or HW shapes
    im1_h, im1_w, _ = masks.shape
    if im1_h == im0_h and im1_w == im0_w:
        return masks

    if ratio_pad is None:  # calculate from im0_shape
        gain = min(im1_h / im0_h, im1_w / im0_w)  # gain  = old / new
        pad = (im1_w - im0_w * gain) / 2, (im1_h - im0_h * gain) / 2  # wh padding
    else:
        pad = ratio_pad[1]

    pad_w, pad_h = pad
    top = round(pad_h - 0.1)
    left = round(pad_w - 0.1)
    bottom = im1_h - round(pad_h + 0.1)
    right = im1_w - round(pad_w + 0.1)

    if len(masks.shape) < 2:
        raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
    masks = masks[top:bottom, left:right]
    # handle the cv2.resize 512 channels limitation: https://github.com/ultralytics/ultralytics/pull/21947
    masks = [cv2.resize(array, (im0_w, im0_h)) for array in np.array_split(masks, masks.shape[-1] // 512 + 1, axis=-1)]
    masks = np.concatenate(masks, axis=-1) if len(masks) > 1 else masks[0]
    if len(masks.shape) == 2:
        masks = masks[:, :, None]

    return masks





ultralytics.utils.ops.xyxy2xywh

xyxy2xywh(x)

Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner.

Parameters:

Name Type Description Default
x ndarray | Tensor

Input bounding box coordinates in (x1, y1, x2, y2) format.

required

Returns:

Type Description
ndarray | Tensor

Bounding box coordinates in (x, y, width, height) format.

Source code in ultralytics/utils/ops.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def xyxy2xywh(x):
    """
    Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
    top-left corner and (x2, y2) is the bottom-right corner.

    Args:
        x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.

    Returns:
        (np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.
    """
    assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
    y = empty_like(x)  # faster than clone/copy
    x1, y1, x2, y2 = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
    y[..., 0] = (x1 + x2) / 2  # x center
    y[..., 1] = (y1 + y2) / 2  # y center
    y[..., 2] = x2 - x1  # width
    y[..., 3] = y2 - y1  # height
    return y





ultralytics.utils.ops.xywh2xyxy

xywh2xyxy(x)

Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.

Parameters:

Name Type Description Default
x ndarray | Tensor

Input bounding box coordinates in (x, y, width, height) format.

required

Returns:

Type Description
ndarray | Tensor

Bounding box coordinates in (x1, y1, x2, y2) format.

Source code in ultralytics/utils/ops.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def xywh2xyxy(x):
    """
    Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
    top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.

    Args:
        x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.

    Returns:
        (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.
    """
    assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
    y = empty_like(x)  # faster than clone/copy
    xy = x[..., :2]  # centers
    wh = x[..., 2:] / 2  # half width-height
    y[..., :2] = xy - wh  # top left xy
    y[..., 2:] = xy + wh  # bottom right xy
    return y





ultralytics.utils.ops.xywhn2xyxy

xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0)

Convert normalized bounding box coordinates to pixel coordinates.

Parameters:

Name Type Description Default
x ndarray | Tensor

Normalized bounding box coordinates in (x, y, w, h) format.

required
w int

Image width in pixels.

640
h int

Image height in pixels.

640
padw int

Padding width in pixels.

0
padh int

Padding height in pixels.

0

Returns:

Name Type Description
y ndarray | Tensor

The coordinates of the bounding box in the format [x1, y1, x2, y2] where x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.

Source code in ultralytics/utils/ops.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
    """
    Convert normalized bounding box coordinates to pixel coordinates.

    Args:
        x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.
        w (int): Image width in pixels.
        h (int): Image height in pixels.
        padw (int): Padding width in pixels.
        padh (int): Padding height in pixels.

    Returns:
        y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
            x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
    """
    assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
    y = empty_like(x)  # faster than clone/copy
    xc, yc, xw, xh = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
    half_w, half_h = xw / 2, xh / 2
    y[..., 0] = w * (xc - half_w) + padw  # top left x
    y[..., 1] = h * (yc - half_h) + padh  # top left y
    y[..., 2] = w * (xc + half_w) + padw  # bottom right x
    y[..., 3] = h * (yc + half_h) + padh  # bottom right y
    return y





ultralytics.utils.ops.xyxy2xywhn

xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0)

Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, width and height are normalized to image dimensions.

Parameters:

Name Type Description Default
x ndarray | Tensor

Input bounding box coordinates in (x1, y1, x2, y2) format.

required
w int

Image width in pixels.

640
h int

Image height in pixels.

640
clip bool

Whether to clip boxes to image boundaries.

False
eps float

Minimum value for box width and height.

0.0

Returns:

Type Description
ndarray | Tensor

Normalized bounding box coordinates in (x, y, width, height) format.

Source code in ultralytics/utils/ops.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):
    """
    Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
    width and height are normalized to image dimensions.

    Args:
        x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
        w (int): Image width in pixels.
        h (int): Image height in pixels.
        clip (bool): Whether to clip boxes to image boundaries.
        eps (float): Minimum value for box width and height.

    Returns:
        (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.
    """
    if clip:
        x = clip_boxes(x, (h - eps, w - eps))
    assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
    y = empty_like(x)  # faster than clone/copy
    x1, y1, x2, y2 = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
    y[..., 0] = ((x1 + x2) / 2) / w  # x center
    y[..., 1] = ((y1 + y2) / 2) / h  # y center
    y[..., 2] = (x2 - x1) / w  # width
    y[..., 3] = (y2 - y1) / h  # height
    return y





ultralytics.utils.ops.xywh2ltwh

xywh2ltwh(x)

Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.

Parameters:

Name Type Description Default
x ndarray | Tensor

Input bounding box coordinates in xywh format.

required

Returns:

Type Description
ndarray | Tensor

Bounding box coordinates in xyltwh format.

Source code in ultralytics/utils/ops.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def xywh2ltwh(x):
    """
    Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.

    Args:
        x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.

    Returns:
        (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
    y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
    return y





ultralytics.utils.ops.xyxy2ltwh

xyxy2ltwh(x)

Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.

Parameters:

Name Type Description Default
x ndarray | Tensor

Input bounding box coordinates in xyxy format.

required

Returns:

Type Description
ndarray | Tensor

Bounding box coordinates in xyltwh format.

Source code in ultralytics/utils/ops.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def xyxy2ltwh(x):
    """
    Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.

    Args:
        x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.

    Returns:
        (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 2] = x[..., 2] - x[..., 0]  # width
    y[..., 3] = x[..., 3] - x[..., 1]  # height
    return y





ultralytics.utils.ops.ltwh2xywh

ltwh2xywh(x)

Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.

Parameters:

Name Type Description Default
x Tensor

Input bounding box coordinates.

required

Returns:

Type Description
ndarray | Tensor

Bounding box coordinates in xywh format.

Source code in ultralytics/utils/ops.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
def ltwh2xywh(x):
    """
    Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.

    Args:
        x (torch.Tensor): Input bounding box coordinates.

    Returns:
        (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = x[..., 0] + x[..., 2] / 2  # center x
    y[..., 1] = x[..., 1] + x[..., 3] / 2  # center y
    return y





ultralytics.utils.ops.xyxyxyxy2xywhr

xyxyxyxy2xywhr(x)

Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.

Parameters:

Name Type Description Default
x ndarray | Tensor

Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.

required

Returns:

Type Description
ndarray | Tensor

Converted data in [cx, cy, w, h, rotation] format with shape (N, 5). Rotation values are in radians from 0 to pi/2.

Source code in ultralytics/utils/ops.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def xyxyxyxy2xywhr(x):
    """
    Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.

    Args:
        x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.

    Returns:
        (np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).
            Rotation values are in radians from 0 to pi/2.
    """
    is_torch = isinstance(x, torch.Tensor)
    points = x.cpu().numpy() if is_torch else x
    points = points.reshape(len(x), -1, 2)
    rboxes = []
    for pts in points:
        # NOTE: Use cv2.minAreaRect to get accurate xywhr,
        # especially some objects are cut off by augmentations in dataloader.
        (cx, cy), (w, h), angle = cv2.minAreaRect(pts)
        rboxes.append([cx, cy, w, h, angle / 180 * np.pi])
    return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)





ultralytics.utils.ops.xywhr2xyxyxyxy

xywhr2xyxyxyxy(x)

Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.

Parameters:

Name Type Description Default
x ndarray | Tensor

Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5). Rotation values should be in radians from 0 to pi/2.

required

Returns:

Type Description
ndarray | Tensor

Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).

Source code in ultralytics/utils/ops.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def xywhr2xyxyxyxy(x):
    """
    Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.

    Args:
        x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).
            Rotation values should be in radians from 0 to pi/2.

    Returns:
        (np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).
    """
    cos, sin, cat, stack = (
        (torch.cos, torch.sin, torch.cat, torch.stack)
        if isinstance(x, torch.Tensor)
        else (np.cos, np.sin, np.concatenate, np.stack)
    )

    ctr = x[..., :2]
    w, h, angle = (x[..., i : i + 1] for i in range(2, 5))
    cos_value, sin_value = cos(angle), sin(angle)
    vec1 = [w / 2 * cos_value, w / 2 * sin_value]
    vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
    vec1 = cat(vec1, -1)
    vec2 = cat(vec2, -1)
    pt1 = ctr + vec1 + vec2
    pt2 = ctr + vec1 - vec2
    pt3 = ctr - vec1 - vec2
    pt4 = ctr - vec1 + vec2
    return stack([pt1, pt2, pt3, pt4], -2)





ultralytics.utils.ops.ltwh2xyxy

ltwh2xyxy(x)

Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.

Parameters:

Name Type Description Default
x ndarray | Tensor

Input bounding box coordinates.

required

Returns:

Type Description
ndarray | Tensor

Bounding box coordinates in xyxy format.

Source code in ultralytics/utils/ops.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
def ltwh2xyxy(x):
    """
    Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.

    Args:
        x (np.ndarray | torch.Tensor): Input bounding box coordinates.

    Returns:
        (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 2] = x[..., 2] + x[..., 0]  # width
    y[..., 3] = x[..., 3] + x[..., 1]  # height
    return y





ultralytics.utils.ops.segments2boxes

segments2boxes(segments)

Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).

Parameters:

Name Type Description Default
segments list

List of segments where each segment is a list of points, each point is [x, y] coordinates.

required

Returns:

Type Description
ndarray

Bounding box coordinates in xywh format.

Source code in ultralytics/utils/ops.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def segments2boxes(segments):
    """
    Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).

    Args:
        segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.

    Returns:
        (np.ndarray): Bounding box coordinates in xywh format.
    """
    boxes = []
    for s in segments:
        x, y = s.T  # segment xy
        boxes.append([x.min(), y.min(), x.max(), y.max()])  # cls, xyxy
    return xyxy2xywh(np.array(boxes))  # cls, xywh





ultralytics.utils.ops.resample_segments

resample_segments(segments, n: int = 1000)

Resample segments to n points each using linear interpolation.

Parameters:

Name Type Description Default
segments list

List of (N, 2) arrays where N is the number of points in each segment.

required
n int

Number of points to resample each segment to.

1000

Returns:

Type Description
list

Resampled segments with n points each.

Source code in ultralytics/utils/ops.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
def resample_segments(segments, n: int = 1000):
    """
    Resample segments to n points each using linear interpolation.

    Args:
        segments (list): List of (N, 2) arrays where N is the number of points in each segment.
        n (int): Number of points to resample each segment to.

    Returns:
        (list): Resampled segments with n points each.
    """
    for i, s in enumerate(segments):
        if len(s) == n:
            continue
        s = np.concatenate((s, s[0:1, :]), axis=0)
        x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n)
        xp = np.arange(len(s))
        x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x
        segments[i] = (
            np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
        )  # segment xy
    return segments





ultralytics.utils.ops.crop_mask

crop_mask(masks, boxes)

Crop masks to bounding box regions.

Parameters:

Name Type Description Default
masks Tensor

Masks with shape (N, H, W).

required
boxes Tensor

Bounding box coordinates with shape (N, 4) in relative point form.

required

Returns:

Type Description
Tensor

Cropped masks.

Source code in ultralytics/utils/ops.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
def crop_mask(masks, boxes):
    """
    Crop masks to bounding box regions.

    Args:
        masks (torch.Tensor): Masks with shape (N, H, W).
        boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.

    Returns:
        (torch.Tensor): Cropped masks.
    """
    n, h, w = masks.shape
    if n < 50:  # faster for fewer masks (predict)
        for i, (x1, y1, x2, y2) in enumerate(boxes.round().int()):
            masks[i, :y1] = 0
            masks[i, y2:] = 0
            masks[i, :, :x1] = 0
            masks[i, :, x2:] = 0
        return masks
    else:  # faster for more masks (val)
        x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(n,1,1)
        r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,1,w)
        c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(1,h,1)
        return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))





ultralytics.utils.ops.process_mask

process_mask(protos, masks_in, bboxes, shape, upsample: bool = False)

Apply masks to bounding boxes using mask head output.

Parameters:

Name Type Description Default
protos Tensor

Mask prototypes with shape (mask_dim, mask_h, mask_w).

required
masks_in Tensor

Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.

required
bboxes Tensor

Bounding boxes with shape (N, 4) where N is number of masks after NMS.

required
shape tuple

Input image size as (height, width).

required
upsample bool

Whether to upsample masks to original image size.

False

Returns:

Type Description
Tensor

A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w are the height and width of the input image. The mask is applied to the bounding boxes.

Source code in ultralytics/utils/ops.py
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):
    """
    Apply masks to bounding boxes using mask head output.

    Args:
        protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
        masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
        bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
        shape (tuple): Input image size as (height, width).
        upsample (bool): Whether to upsample masks to original image size.

    Returns:
        (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
            are the height and width of the input image. The mask is applied to the bounding boxes.
    """
    c, mh, mw = protos.shape  # CHW
    masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)  # CHW

    width_ratio = mw / shape[1]
    height_ratio = mh / shape[0]
    ratios = torch.tensor([[width_ratio, height_ratio, width_ratio, height_ratio]], device=bboxes.device)

    masks = crop_mask(masks, boxes=bboxes * ratios)  # CHW
    if upsample:
        masks = F.interpolate(masks[None], shape, mode="bilinear")[0]  # CHW
    return masks.gt_(0.0).byte()





ultralytics.utils.ops.process_mask_native

process_mask_native(protos, masks_in, bboxes, shape)

Apply masks to bounding boxes using mask head output with native upsampling.

Parameters:

Name Type Description Default
protos Tensor

Mask prototypes with shape (mask_dim, mask_h, mask_w).

required
masks_in Tensor

Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.

required
bboxes Tensor

Bounding boxes with shape (N, 4) where N is number of masks after NMS.

required
shape tuple

Input image size as (height, width).

required

Returns:

Type Description
Tensor

Binary mask tensor with shape (H, W, N).

Source code in ultralytics/utils/ops.py
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
def process_mask_native(protos, masks_in, bboxes, shape):
    """
    Apply masks to bounding boxes using mask head output with native upsampling.

    Args:
        protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
        masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
        bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
        shape (tuple): Input image size as (height, width).

    Returns:
        (torch.Tensor): Binary mask tensor with shape (H, W, N).
    """
    c, mh, mw = protos.shape  # CHW
    masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
    masks = scale_masks(masks[None], shape)[0]  # CHW
    masks = crop_mask(masks, bboxes)  # CHW
    return masks.gt_(0.0).byte()





ultralytics.utils.ops.scale_masks

scale_masks(masks, shape, padding: bool = True)

Rescale segment masks to target shape.

Parameters:

Name Type Description Default
masks Tensor

Masks with shape (N, C, H, W).

required
shape tuple

Target height and width as (height, width).

required
padding bool

Whether masks are based on YOLO-style augmented images with padding.

True

Returns:

Type Description
Tensor

Rescaled masks.

Source code in ultralytics/utils/ops.py
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
def scale_masks(masks, shape, padding: bool = True):
    """
    Rescale segment masks to target shape.

    Args:
        masks (torch.Tensor): Masks with shape (N, C, H, W).
        shape (tuple): Target height and width as (height, width).
        padding (bool): Whether masks are based on YOLO-style augmented images with padding.

    Returns:
        (torch.Tensor): Rescaled masks.
    """
    mh, mw = masks.shape[2:]
    gain = min(mh / shape[0], mw / shape[1])  # gain  = old / new
    pad_w = mw - shape[1] * gain
    pad_h = mh - shape[0] * gain
    if padding:
        pad_w /= 2
        pad_h /= 2
    top, left = (round(pad_h - 0.1), round(pad_w - 0.1)) if padding else (0, 0)
    bottom = mh - round(pad_h + 0.1)
    right = mw - round(pad_w + 0.1)
    return F.interpolate(masks[..., top:bottom, left:right], shape, mode="bilinear")  # NCHW masks





ultralytics.utils.ops.scale_coords

scale_coords(
    img1_shape,
    coords,
    img0_shape,
    ratio_pad=None,
    normalize: bool = False,
    padding: bool = True,
)

Rescale segment coordinates from img1_shape to img0_shape.

Parameters:

Name Type Description Default
img1_shape tuple

Source image shape as HWC or HW (supports both).

required
coords Tensor

Coordinates to scale with shape (N, 2).

required
img0_shape tuple

Image 0 shape as HWC or HW (supports both).

required
ratio_pad tuple

Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).

None
normalize bool

Whether to normalize coordinates to range [0, 1].

False
padding bool

Whether coordinates are based on YOLO-style augmented images with padding.

True

Returns:

Type Description
Tensor

Scaled coordinates.

Source code in ultralytics/utils/ops.py
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):
    """
    Rescale segment coordinates from img1_shape to img0_shape.

    Args:
        img1_shape (tuple): Source image shape as HWC or HW (supports both).
        coords (torch.Tensor): Coordinates to scale with shape (N, 2).
        img0_shape (tuple): Image 0 shape as HWC or HW (supports both).
        ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
        normalize (bool): Whether to normalize coordinates to range [0, 1].
        padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.

    Returns:
        (torch.Tensor): Scaled coordinates.
    """
    img0_h, img0_w = img0_shape[:2]  # supports both HWC or HW shapes
    if ratio_pad is None:  # calculate from img0_shape
        img1_h, img1_w = img1_shape[:2]  # supports both HWC or HW shapes
        gain = min(img1_h / img0_h, img1_w / img0_w)  # gain  = old / new
        pad = (img1_w - img0_w * gain) / 2, (img1_h - img0_h * gain) / 2  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    if padding:
        coords[..., 0] -= pad[0]  # x padding
        coords[..., 1] -= pad[1]  # y padding
    coords[..., 0] /= gain
    coords[..., 1] /= gain
    coords = clip_coords(coords, img0_shape)
    if normalize:
        coords[..., 0] /= img0_w  # width
        coords[..., 1] /= img0_h  # height
    return coords





ultralytics.utils.ops.regularize_rboxes

regularize_rboxes(rboxes)

Regularize rotated bounding boxes to range [0, pi/2].

Parameters:

Name Type Description Default
rboxes Tensor

Input rotated boxes with shape (N, 5) in xywhr format.

required

Returns:

Type Description
Tensor

Regularized rotated boxes.

Source code in ultralytics/utils/ops.py
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
def regularize_rboxes(rboxes):
    """
    Regularize rotated bounding boxes to range [0, pi/2].

    Args:
        rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.

    Returns:
        (torch.Tensor): Regularized rotated boxes.
    """
    x, y, w, h, t = rboxes.unbind(dim=-1)
    # Swap edge if t >= pi/2 while not being symmetrically opposite
    swap = t % math.pi >= math.pi / 2
    w_ = torch.where(swap, h, w)
    h_ = torch.where(swap, w, h)
    t = t % (math.pi / 2)
    return torch.stack([x, y, w_, h_, t], dim=-1)  # regularized boxes





ultralytics.utils.ops.masks2segments

masks2segments(masks, strategy: str = 'all')

Convert masks to segments using contour detection.

Parameters:

Name Type Description Default
masks Tensor

Binary masks with shape (batch_size, 160, 160).

required
strategy str

Segmentation strategy, either 'all' or 'largest'.

'all'

Returns:

Type Description
list

List of segment masks as float32 arrays.

Source code in ultralytics/utils/ops.py
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
def masks2segments(masks, strategy: str = "all"):
    """
    Convert masks to segments using contour detection.

    Args:
        masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).
        strategy (str): Segmentation strategy, either 'all' or 'largest'.

    Returns:
        (list): List of segment masks as float32 arrays.
    """
    from ultralytics.data.converter import merge_multi_segment

    segments = []
    for x in masks.byte().cpu().numpy():
        c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
        if c:
            if strategy == "all":  # merge and concatenate all segments
                c = (
                    np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c]))
                    if len(c) > 1
                    else c[0].reshape(-1, 2)
                )
            elif strategy == "largest":  # select largest segment
                c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
        else:
            c = np.zeros((0, 2))  # no segments found
        segments.append(c.astype("float32"))
    return segments





ultralytics.utils.ops.convert_torch2numpy_batch

convert_torch2numpy_batch(batch: Tensor) -> np.ndarray

Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.

Parameters:

Name Type Description Default
batch Tensor

Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.

required

Returns:

Type Description
ndarray

Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.

Source code in ultralytics/utils/ops.py
694
695
696
697
698
699
700
701
702
703
704
def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
    """
    Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.

    Args:
        batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.

    Returns:
        (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.
    """
    return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).byte().cpu().numpy()





ultralytics.utils.ops.clean_str

clean_str(s)

Clean a string by replacing special characters with '_' character.

Parameters:

Name Type Description Default
s str

A string needing special characters replaced.

required

Returns:

Type Description
str

A string with special characters replaced by an underscore _.

Source code in ultralytics/utils/ops.py
707
708
709
710
711
712
713
714
715
716
717
def clean_str(s):
    """
    Clean a string by replacing special characters with '_' character.

    Args:
        s (str): A string needing special characters replaced.

    Returns:
        (str): A string with special characters replaced by an underscore _.
    """
    return re.sub(pattern="[|@#!¥·$€%&()=?Âż^*;:,šŽ><+]", repl="_", string=s)





ultralytics.utils.ops.empty_like

empty_like(x)

Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.

Source code in ultralytics/utils/ops.py
720
721
722
723
724
def empty_like(x):
    """Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
    return (
        torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
    )





📅 Created 1 year ago ✏ Updated 1 month ago