Skip to content

Reference for ultralytics/models/rtdetr/predict.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/rtdetr/predict.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.rtdetr.predict.RTDETRPredictor

RTDETRPredictor()

Bases: BasePredictor

RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.

This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection.

Attributes

NameTypeDescription
imgszintImage size for inference (must be square and scale-filled).
argsdictArgument overrides for the predictor.
modeltorch.nn.ModuleThe loaded RT-DETR model.
batchlistCurrent batch of processed inputs.

Methods

NameDescription
postprocessPostprocess the raw predictions from the model to generate bounding boxes and confidence scores.
pre_transformPre-transform input images before feeding them into the model for inference.

Examples

>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.rtdetr import RTDETRPredictor
>>> args = dict(model="rtdetr-l.pt", source=ASSETS)
>>> predictor = RTDETRPredictor(overrides=args)
>>> predictor.predict_cli()
Source code in ultralytics/models/rtdetr/predict.pyView on GitHub
class RTDETRPredictor(BasePredictor):


method ultralytics.models.rtdetr.predict.RTDETRPredictor.postprocess

def postprocess(self, preds, img, orig_imgs)

Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.

The method filters detections based on confidence and class if specified in self.args. It converts model predictions to Results objects containing properly scaled bounding boxes.

Args

NameTypeDescriptionDefault
predslist | tupleList of [predictions, extra] from the model, where predictions contain bounding boxes and scores.required
imgtorch.TensorProcessed input images with shape (N, 3, H, W).required
orig_imgslist | torch.TensorOriginal, unprocessed images.required

Returns

TypeDescription
results (list[Results])A list of Results objects containing the post-processed bounding boxes, confidence
Source code in ultralytics/models/rtdetr/predict.pyView on GitHub
def postprocess(self, preds, img, orig_imgs):
    """Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.

    The method filters detections based on confidence and class if specified in `self.args`. It converts model
    predictions to Results objects containing properly scaled bounding boxes.

    Args:
        preds (list | tuple): List of [predictions, extra] from the model, where predictions contain bounding boxes
            and scores.
        img (torch.Tensor): Processed input images with shape (N, 3, H, W).
        orig_imgs (list | torch.Tensor): Original, unprocessed images.

    Returns:
        results (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence
            scores, and class labels.
    """
    if not isinstance(preds, (list, tuple)):  # list for PyTorch inference but list[0] Tensor for export inference
        preds = [preds, None]

    nd = preds[0].shape[-1]
    bboxes, scores = preds[0].split((4, nd - 4), dim=-1)

    if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
        orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]

    results = []
    for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]):  # (300, 4)
        bbox = ops.xywh2xyxy(bbox)
        max_score, cls = score.max(-1, keepdim=True)  # (300, 1)
        idx = max_score.squeeze(-1) > self.args.conf  # (300, )
        if self.args.classes is not None:
            idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
        pred = torch.cat([bbox, max_score, cls], dim=-1)[idx]  # filter
        pred = pred[pred[:, 4].argsort(descending=True)][: self.args.max_det]
        oh, ow = orig_img.shape[:2]
        pred[..., [0, 2]] *= ow  # scale x coordinates to original width
        pred[..., [1, 3]] *= oh  # scale y coordinates to original height
        results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
    return results


method ultralytics.models.rtdetr.predict.RTDETRPredictor.pre_transform

def pre_transform(self, im)

Pre-transform input images before feeding them into the model for inference.

The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square (640) and scale_filled.

Args

NameTypeDescriptionDefault
imlist[np.ndarray] | torch.TensorInput images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.required

Returns

TypeDescription
listList of pre-transformed images ready for model inference.
Source code in ultralytics/models/rtdetr/predict.pyView on GitHub
def pre_transform(self, im):
    """Pre-transform input images before feeding them into the model for inference.

    The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square (640)
    and scale_filled.

    Args:
        im (list[np.ndarray]  | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for
            list.

    Returns:
        (list): List of pre-transformed images ready for model inference.
    """
    letterbox = LetterBox(self.imgsz, auto=False, scale_fill=True)
    return [letterbox(image=x) for x in im]





📅 Created 2 years ago ✏️ Updated 18 days ago
glenn-jocherjk4eBurhan-Q