Skip to content

Reference for ultralytics/models/nas/predict.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/nas/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.nas.predict.NASPredictor

NASPredictor(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: DetectionPredictor

Ultralytics YOLO NAS Predictor for object detection.

This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and scaling the bounding boxes to fit the original image dimensions.

Attributes:

Name Type Description
args Namespace

Namespace containing various configurations for post-processing including confidence threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.

model Module

The YOLO NAS model used for inference.

batch list

Batch of inputs for processing.

Examples:

>>> from ultralytics import NAS
>>> model = NAS("yolo_nas_s")
>>> predictor = model.predictor

Assume that raw_preds, img, orig_imgs are available

>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
Notes

Typically, this class is not instantiated directly. It is used internally within the NAS class.

Source code in ultralytics/engine/predictor.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize the BasePredictor class.

    Args:
        cfg (str | dict): Path to a configuration file or a configuration dictionary.
        overrides (dict | None): Configuration overrides.
        _callbacks (dict | None): Dictionary of callback functions.
    """
    self.args = get_cfg(cfg, overrides)
    self.save_dir = get_save_dir(self.args)
    if self.args.conf is None:
        self.args.conf = 0.25  # default conf=0.25
    self.done_warmup = False
    if self.args.show:
        self.args.show = check_imshow(warn=True)

    # Usable if setup is done
    self.model = None
    self.data = self.args.data  # data_dict
    self.imgsz = None
    self.device = None
    self.dataset = None
    self.vid_writer = {}  # dict of {save_path: video_writer, ...}
    self.plotted_img = None
    self.source_type = None
    self.seen = 0
    self.windows = []
    self.batch = None
    self.results = None
    self.transforms = None
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    self.txt_path = None
    self._lock = threading.Lock()  # for automatic thread-safe inference
    callbacks.add_integration_callbacks(self)

postprocess

postprocess(preds_in, img, orig_imgs)

Postprocess predictions and returns a list of Results objects.

Source code in ultralytics/models/nas/predict.py
def postprocess(self, preds_in, img, orig_imgs):
    """Postprocess predictions and returns a list of Results objects."""
    # Convert boxes from xyxy to xywh format and concatenate with class scores
    boxes = ops.xyxy2xywh(preds_in[0][0])
    preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
    return super().postprocess(preds, img, orig_imgs)



📅 Created 1 year ago ✏️ Updated 6 months ago