Vai al contenuto

Riferimento per ultralytics/models/rtdetr/predict.py

Nota

Questo file è disponibile su https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/models/rtdetr/predict .py. Se riscontri un problema, contribuisci a risolverlo inviando una Pull Request 🛠️. Grazie 🙏!



ultralytics.models.rtdetr.predict.RTDETRPredictor

Basi: BasePredictor

RT-DETR (Real-Time Detection Transformer) Predictor che estende la classe BasePredictor per fare previsioni utilizzando il modello di il modello RT-DETR di Baidu.

Questa classe sfrutta la potenza dei Trasformatori di Visione per fornire il rilevamento di oggetti in tempo reale mantenendo un'elevata precisione. Supporta funzioni chiave come la codifica ibrida efficiente e la selezione delle query consapevole dell'IoU.

Esempio
from ultralytics.utils import ASSETS
from ultralytics.models.rtdetr import RTDETRPredictor

args = dict(model='rtdetr-l.pt', source=ASSETS)
predictor = RTDETRPredictor(overrides=args)
predictor.predict_cli()

Attributi:

Nome Tipo Descrizione
imgsz int

Dimensione dell'immagine per l'inferenza (deve essere quadrata e in scala).

args dict

Sovrascrittura degli argomenti per il predittore.

Codice sorgente in ultralytics/models/rtdetr/predict.py
class RTDETRPredictor(BasePredictor):
    """
    RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions using
    Baidu's RT-DETR model.

    This class leverages the power of Vision Transformers to provide real-time object detection while maintaining
    high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection.

    Example:
        ```python
        from ultralytics.utils import ASSETS
        from ultralytics.models.rtdetr import RTDETRPredictor

        args = dict(model='rtdetr-l.pt', source=ASSETS)
        predictor = RTDETRPredictor(overrides=args)
        predictor.predict_cli()
        ```

    Attributes:
        imgsz (int): Image size for inference (must be square and scale-filled).
        args (dict): Argument overrides for the predictor.
    """

    def postprocess(self, preds, img, orig_imgs):
        """
        Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.

        The method filters detections based on confidence and class if specified in `self.args`.

        Args:
            preds (list): List of [predictions, extra] from the model.
            img (torch.Tensor): Processed input images.
            orig_imgs (list or torch.Tensor): Original, unprocessed images.

        Returns:
            (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
                and class labels.
        """
        if not isinstance(preds, (list, tuple)):  # list for PyTorch inference but list[0] Tensor for export inference
            preds = [preds, None]

        nd = preds[0].shape[-1]
        bboxes, scores = preds[0].split((4, nd - 4), dim=-1)

        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        for i, bbox in enumerate(bboxes):  # (300, 4)
            bbox = ops.xywh2xyxy(bbox)
            score, cls = scores[i].max(-1, keepdim=True)  # (300, 1)
            idx = score.squeeze(-1) > self.args.conf  # (300, )
            if self.args.classes is not None:
                idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
            pred = torch.cat([bbox, score, cls], dim=-1)[idx]  # filter
            orig_img = orig_imgs[i]
            oh, ow = orig_img.shape[:2]
            pred[..., [0, 2]] *= ow
            pred[..., [1, 3]] *= oh
            img_path = self.batch[0][i]
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
        return results

    def pre_transform(self, im):
        """
        Pre-transforms the input images before feeding them into the model for inference. The input images are
        letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled.

        Args:
            im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.

        Returns:
            (list): List of pre-transformed images ready for model inference.
        """
        letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
        return [letterbox(image=x) for x in im]

postprocess(preds, img, orig_imgs)

Postelaborazione delle previsioni grezze del modello per generare caselle di delimitazione e punteggi di confidenza.

Il metodo filtra i rilevamenti in base alla fiducia e alla classe se specificato in self.args.

Parametri:

Nome Tipo Descrizione Predefinito
preds list

Elenco delle [previsioni, extra] del modello.

richiesto
img Tensor

Immagini di input elaborate.

richiesto
orig_imgs list or Tensor

Immagini originali e non elaborate.

richiesto

Restituzione:

Tipo Descrizione
list[Results]

Un elenco di oggetti Risultati contenenti i riquadri di delimitazione post-elaborati, i punteggi di confidenza, e le etichette delle classi.

Codice sorgente in ultralytics/models/rtdetr/predict.py
def postprocess(self, preds, img, orig_imgs):
    """
    Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.

    The method filters detections based on confidence and class if specified in `self.args`.

    Args:
        preds (list): List of [predictions, extra] from the model.
        img (torch.Tensor): Processed input images.
        orig_imgs (list or torch.Tensor): Original, unprocessed images.

    Returns:
        (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
            and class labels.
    """
    if not isinstance(preds, (list, tuple)):  # list for PyTorch inference but list[0] Tensor for export inference
        preds = [preds, None]

    nd = preds[0].shape[-1]
    bboxes, scores = preds[0].split((4, nd - 4), dim=-1)

    if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
        orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

    results = []
    for i, bbox in enumerate(bboxes):  # (300, 4)
        bbox = ops.xywh2xyxy(bbox)
        score, cls = scores[i].max(-1, keepdim=True)  # (300, 1)
        idx = score.squeeze(-1) > self.args.conf  # (300, )
        if self.args.classes is not None:
            idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
        pred = torch.cat([bbox, score, cls], dim=-1)[idx]  # filter
        orig_img = orig_imgs[i]
        oh, ow = orig_img.shape[:2]
        pred[..., [0, 2]] *= ow
        pred[..., [1, 3]] *= oh
        img_path = self.batch[0][i]
        results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
    return results

pre_transform(im)

Pre-trasforma le immagini in ingresso prima di inserirle nel modello per l'inferenza. Le immagini in ingresso vengono letterboxate per garantire un rapporto d'aspetto quadrato e riempite in scala. Le dimensioni devono essere quadrate(640) e scaleFilled.

Parametri:

Nome Tipo Descrizione Predefinito
im list[ndarray] | Tensor

Immagini di input di forma (N,3,h,w) per tensor, [(h,w,3) x N] per l'elenco.

richiesto

Restituzione:

Tipo Descrizione
list

Elenco di immagini pre-trasformate pronte per l'inferenza del modello.

Codice sorgente in ultralytics/models/rtdetr/predict.py
def pre_transform(self, im):
    """
    Pre-transforms the input images before feeding them into the model for inference. The input images are
    letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled.

    Args:
        im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.

    Returns:
        (list): List of pre-transformed images ready for model inference.
    """
    letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
    return [letterbox(image=x) for x in im]





Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (5), Burhan-Q (1)