Reference for ultralytics/models/sam/predict.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.sam.predict.Predictor
Bases: BasePredictor
Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
The class provides an interface for model inference tailored to image segmentation tasks. With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time mask generation. The class is capable of working with various types of prompts such as bounding boxes, points, and low-resolution masks.
Attributes:
Name | Type | Description |
---|---|---|
cfg |
dict
|
Configuration dictionary specifying model and task-related parameters. |
overrides |
dict
|
Dictionary containing values that override the default configuration. |
_callbacks |
dict
|
Dictionary of user-defined callback functions to augment behavior. |
args |
namespace
|
Namespace to hold command-line arguments or other operational variables. |
im |
Tensor
|
Preprocessed input image tensor. |
features |
Tensor
|
Extracted image features used for inference. |
prompts |
dict
|
Collection of various prompt types, such as bounding boxes and points. |
segment_all |
bool
|
Flag to control whether to segment all objects in the image or only specified ones. |
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg |
dict
|
Configuration dictionary. |
DEFAULT_CFG
|
overrides |
dict
|
Dictionary of values to override default configuration. |
None
|
_callbacks |
dict
|
Dictionary of callback functions to customize behavior. |
None
|
Source code in ultralytics/models/sam/predict.py
generate
generate(im, crop_n_layers=0, crop_overlap_ratio=512 / 1500, crop_downscale_factor=1, point_grids=None, points_stride=32, points_batch_size=64, conf_thres=0.88, stability_score_thresh=0.95, stability_score_offset=0.95, crop_nms_thresh=0.7)
Perform image segmentation using the Segment Anything Model (SAM).
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
im |
Tensor
|
Input tensor representing the preprocessed image with dimensions (N, C, H, W). |
required |
crop_n_layers |
int
|
Specifies the number of layers for additional mask predictions on image crops. Each layer produces 2**i_layer number of image crops. |
0
|
crop_overlap_ratio |
float
|
Determines the overlap between crops. Scaled down in subsequent layers. |
512 / 1500
|
crop_downscale_factor |
int
|
Scaling factor for the number of sampled points-per-side in each layer. |
1
|
point_grids |
list[ndarray]
|
Custom grids for point sampling normalized to [0,1]. Used in the nth crop layer. |
None
|
points_stride |
int
|
Number of points to sample along each side of the image. Exclusive with 'point_grids'. |
32
|
points_batch_size |
int
|
Batch size for the number of points processed simultaneously. |
64
|
conf_thres |
float
|
Confidence threshold [0,1] for filtering based on the model's mask quality prediction. |
0.88
|
stability_score_thresh |
float
|
Stability threshold [0,1] for mask filtering based on mask stability. |
0.95
|
stability_score_offset |
float
|
Offset value for calculating stability score. |
0.95
|
crop_nms_thresh |
float
|
IoU cutoff for NMS to remove duplicate masks between crops. |
0.7
|
Returns:
Type | Description |
---|---|
tuple
|
A tuple containing segmented masks, confidence scores, and bounding boxes. |
Source code in ultralytics/models/sam/predict.py
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
|
inference
inference(im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs)
Perform image segmentation inference based on the given input cues, using the currently loaded image. This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and mask decoder for real-time and promptable segmentation tasks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
im |
Tensor
|
The preprocessed input image in tensor format, with shape (N, C, H, W). |
required |
bboxes |
ndarray | List
|
Bounding boxes with shape (N, 4), in XYXY format. |
None
|
points |
ndarray | List
|
Points indicating object locations with shape (N, 2), in pixels. |
None
|
labels |
ndarray | List
|
Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. |
None
|
masks |
ndarray
|
Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. |
None
|
multimask_output |
bool
|
Flag to return multiple masks. Helpful for ambiguous prompts. |
False
|
Returns:
Type | Description |
---|---|
tuple
|
Contains the following three elements. - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. - np.ndarray: An array of length C containing quality scores predicted by the model for each mask. - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. |
Source code in ultralytics/models/sam/predict.py
postprocess
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions. The SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds |
tuple
|
The output from SAM model inference, containing masks, scores, and optional bounding boxes. |
required |
img |
Tensor
|
The processed input image tensor. |
required |
orig_imgs |
list | Tensor
|
The original, unprocessed images. |
required |
Returns:
Type | Description |
---|---|
list
|
List of Results objects containing detection masks, bounding boxes, and other metadata. |
Source code in ultralytics/models/sam/predict.py
pre_transform
Perform initial transformations on the input image for preprocessing.
The method applies transformations such as resizing to prepare the image for further preprocessing. Currently, batched inference is not supported; hence the list length should be 1.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
im |
List[ndarray]
|
List containing images in HWC numpy array format. |
required |
Returns:
Type | Description |
---|---|
List[ndarray]
|
List of transformed images. |
Source code in ultralytics/models/sam/predict.py
preprocess
Preprocess the input image for model inference.
The method prepares the input image by applying transformations and normalization. It supports both torch.Tensor and list of np.ndarray as input formats.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
im |
Tensor | List[ndarray]
|
BCHW tensor format or list of HWC numpy arrays. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The preprocessed image tensor. |
Source code in ultralytics/models/sam/predict.py
prompt_inference
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks. Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
im |
Tensor
|
The preprocessed input image in tensor format, with shape (N, C, H, W). |
required |
bboxes |
ndarray | List
|
Bounding boxes with shape (N, 4), in XYXY format. |
None
|
points |
ndarray | List
|
Points indicating object locations with shape (N, 2), in pixels. |
None
|
labels |
ndarray | List
|
Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. |
None
|
masks |
ndarray
|
Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. |
None
|
multimask_output |
bool
|
Flag to return multiple masks. Helpful for ambiguous prompts. |
False
|
Returns:
Type | Description |
---|---|
tuple
|
Contains the following three elements. - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. - np.ndarray: An array of length C containing quality scores predicted by the model for each mask. - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. |
Source code in ultralytics/models/sam/predict.py
remove_small_regions
staticmethod
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum Suppression (NMS) to eliminate any newly created duplicate boxes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
masks |
Tensor
|
A tensor containing the masks to be processed. Shape should be (N, H, W), where N is the number of masks, H is height, and W is width. |
required |
min_area |
int
|
The minimum area below which disconnected regions and holes will be removed. Defaults to 0. |
0
|
nms_thresh |
float
|
The IoU threshold for the NMS algorithm. Defaults to 0.7. |
0.7
|
Returns:
Type | Description |
---|---|
tuple([Tensor, List[int]])
|
|
Source code in ultralytics/models/sam/predict.py
reset_image
set_image
Preprocesses and sets a single image for inference.
This function sets up the model if not already initialized, configures the data source to the specified image, and preprocesses the image for feature extraction. Only one image can be set at a time.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
str | ndarray
|
Image file path as a string, or a np.ndarray image read by cv2. |
required |
Raises:
Type | Description |
---|---|
AssertionError
|
If more than one image is set. |
Source code in ultralytics/models/sam/predict.py
set_prompts
setup_model
Initializes the Segment Anything Model (SAM) for inference.
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary parameters for image normalization and other Ultralytics compatibility settings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module
|
A pre-trained SAM model. If None, a model will be built based on configuration. |
required |
verbose |
bool
|
If True, prints selected device information. |
True
|
Attributes:
Name | Type | Description |
---|---|---|
model |
Module
|
The SAM model allocated to the chosen device for inference. |
device |
device
|
The device to which the model and tensors are allocated. |
mean |
Tensor
|
The mean values for image normalization. |
std |
Tensor
|
The standard deviation values for image normalization. |
Source code in ultralytics/models/sam/predict.py
setup_source
Sets up the data source for inference.
This method configures the data source from which images will be fetched for inference. The source could be a directory, a video file, or other types of image data sources.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str | Path
|
The path to the image data source for inference. |
required |