Skip to content

Reference for ultralytics/models/nas/predict.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/nas/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.nas.predict.NASPredictor

NASPredictor(
    cfg=DEFAULT_CFG,
    overrides: dict[str, Any] | None = None,
    _callbacks: dict[str, list[callable]] | None = None,
)

Bases: DetectionPredictor

Ultralytics YOLO NAS Predictor for object detection.

This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and scaling the bounding boxes to fit the original image dimensions.

Attributes:

Name Type Description
args Namespace

Namespace containing various configurations for post-processing including confidence threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.

model Module

The YOLO NAS model used for inference.

batch list

Batch of inputs for processing.

Examples:

>>> from ultralytics import NAS
>>> model = NAS("yolo_nas_s")
>>> predictor = model.predictor

Assume that raw_preds, img, orig_imgs are available

>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
Notes

Typically, this class is not instantiated directly. It is used internally within the NAS class.

Source code in ultralytics/engine/predictor.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def __init__(
    self,
    cfg=DEFAULT_CFG,
    overrides: dict[str, Any] | None = None,
    _callbacks: dict[str, list[callable]] | None = None,
):
    """
    Initialize the BasePredictor class.

    Args:
        cfg (str | dict): Path to a configuration file or a configuration dictionary.
        overrides (dict, optional): Configuration overrides.
        _callbacks (dict, optional): Dictionary of callback functions.
    """
    self.args = get_cfg(cfg, overrides)
    self.save_dir = get_save_dir(self.args)
    if self.args.conf is None:
        self.args.conf = 0.25  # default conf=0.25
    self.done_warmup = False
    if self.args.show:
        self.args.show = check_imshow(warn=True)

    # Usable if setup is done
    self.model = None
    self.data = self.args.data  # data_dict
    self.imgsz = None
    self.device = None
    self.dataset = None
    self.vid_writer = {}  # dict of {save_path: video_writer, ...}
    self.plotted_img = None
    self.source_type = None
    self.seen = 0
    self.windows = []
    self.batch = None
    self.results = None
    self.transforms = None
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    self.txt_path = None
    self._lock = threading.Lock()  # for automatic thread-safe inference
    callbacks.add_integration_callbacks(self)

postprocess

postprocess(preds_in, img, orig_imgs)

Postprocess NAS model predictions to generate final detection results.

This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies post-processing operations to generate the final detection results compatible with Ultralytics result visualization and analysis tools.

Parameters:

Name Type Description Default
preds_in list

Raw predictions from the NAS model, typically containing bounding boxes and class scores.

required
img Tensor

Input image tensor that was fed to the model, with shape (B, C, H, W).

required
orig_imgs list | Tensor | ndarray

Original images before preprocessing, used for scaling coordinates back to original dimensions.

required

Returns:

Type Description
list

List of Results objects containing the processed predictions for each image in the batch.

Examples:

>>> predictor = NAS("yolo_nas_s").predictor
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
Source code in ultralytics/models/nas/predict.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def postprocess(self, preds_in, img, orig_imgs):
    """
    Postprocess NAS model predictions to generate final detection results.

    This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies
    post-processing operations to generate the final detection results compatible with Ultralytics
    result visualization and analysis tools.

    Args:
        preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores.
        img (torch.Tensor): Input image tensor that was fed to the model, with shape (B, C, H, W).
        orig_imgs (list | torch.Tensor | np.ndarray): Original images before preprocessing, used for scaling
            coordinates back to original dimensions.

    Returns:
        (list): List of Results objects containing the processed predictions for each image in the batch.

    Examples:
        >>> predictor = NAS("yolo_nas_s").predictor
        >>> results = predictor.postprocess(raw_preds, img, orig_imgs)
    """
    boxes = ops.xyxy2xywh(preds_in[0][0])  # Convert bounding boxes from xyxy to xywh format
    preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)  # Concatenate boxes with class scores
    return super().postprocess(preds, img, orig_imgs)





📅 Created 1 year ago ✏️ Updated 1 year ago