Bases: BasePredictor
Ultralytics YOLO NAS Predictor for object detection.
This class extends the BasePredictor
from Ultralytics engine and is responsible for post-processing the
raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
scaling the bounding boxes to fit the original image dimensions.
Attributes:
Name |
Type |
Description |
args |
Namespace
|
Namespace containing various configurations for post-processing including confidence threshold,
IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
|
model |
Module
|
The YOLO NAS model used for inference.
|
batch |
List
|
Batch of inputs for processing.
|
Examples:
>>> from ultralytics import NAS
>>> model = NAS("yolo_nas_s")
>>> predictor = model.predictor
Assume that raw_preds, img, orig_imgs are available
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
Notes
Typically, this class is not instantiated directly. It is used internally within the NAS
class.
Source code in ultralytics/engine/predictor.py
| def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the BasePredictor class.
Args:
cfg (str | dict): Path to a configuration file or a configuration dictionary.
overrides (dict | None): Configuration overrides.
_callbacks (dict | None): Dictionary of callback functions.
"""
self.args = get_cfg(cfg, overrides)
self.save_dir = get_save_dir(self.args)
if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25
self.done_warmup = False
if self.args.show:
self.args.show = check_imshow(warn=True)
# Usable if setup is done
self.model = None
self.data = self.args.data # data_dict
self.imgsz = None
self.device = None
self.dataset = None
self.vid_writer = {} # dict of {save_path: video_writer, ...}
self.plotted_img = None
self.source_type = None
self.seen = 0
self.windows = []
self.batch = None
self.results = None
self.transforms = None
self.callbacks = _callbacks or callbacks.get_default_callbacks()
self.txt_path = None
self._lock = threading.Lock() # for automatic thread-safe inference
callbacks.add_integration_callbacks(self)
|
postprocess
postprocess(preds_in, img, orig_imgs)
Postprocess predictions and returns a list of Results objects.
Source code in ultralytics/models/nas/predict.py
| def postprocess(self, preds_in, img, orig_imgs):
"""Postprocess predictions and returns a list of Results objects."""
# Convert boxes from xyxy to xywh format and concatenate with class scores
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
# Apply non-maximum suppression to filter overlapping detections
preds = ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes,
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
# Scale bounding boxes to match original image dimensions
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
return results
|