Vai al contenuto

Riferimento per ultralytics/models/fastsam/predict.py

Nota

Questo file è disponibile su https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/models/ fastsam/predict .py. Se noti un problema, contribuisci a risolverlo inviando una Pull Request 🛠️. Grazie 🙏!



ultralytics.models.fastsam.predict.FastSAMPredictor

Basi: DetectionPredictor

FastSAMPredictor è specializzato in compiti di predizione della segmentazione veloce SAM (Segment Anything Model) nel framework Ultralytics YOLO .

Questa classe estende il DetectionPredictor, personalizzando la pipeline di predizione in modo specifico per il veloce SAM. Regola le fasi di post-elaborazione per incorporare la predizione della maschera e la soppressione non-max, ottimizzando al tempo stesso per la segmentazione a classe singola.

Attributi:

Nome Tipo Descrizione
cfg dict

Parametri di configurazione per la previsione.

overrides dict

Parametri opzionali che sovrascrivono il comportamento personalizzato.

_callbacks dict

Elenco opzionale di funzioni di callback da invocare durante la predizione.

Codice sorgente in ultralytics/models/fastsam/predict.py
class FastSAMPredictor(DetectionPredictor):
    """
    FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
    YOLO framework.

    This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM.
    It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing
    for single-class segmentation.

    Attributes:
        cfg (dict): Configuration parameters for prediction.
        overrides (dict, optional): Optional parameter overrides for custom behavior.
        _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.

        Args:
            cfg (dict): Configuration parameters for prediction.
            overrides (dict, optional): Optional parameter overrides for custom behavior.
            _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
        """
        super().__init__(cfg, overrides, _callbacks)
        self.args.task = "segment"

    def postprocess(self, preds, img, orig_imgs):
        """
        Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
        size, and returns the final results.

        Args:
            preds (list): The raw output predictions from the model.
            img (torch.Tensor): The processed image tensor.
            orig_imgs (list | torch.Tensor): The original image or list of images.

        Returns:
            (list): A list of Results objects, each containing processed boxes, masks, and other metadata.
        """
        p = ops.non_max_suppression(
            preds[0],
            self.args.conf,
            self.args.iou,
            agnostic=self.args.agnostic_nms,
            max_det=self.args.max_det,
            nc=1,  # set to 1 class since SAM has no class predictions
            classes=self.args.classes,
        )
        full_box = torch.zeros(p[0].shape[1], device=p[0].device)
        full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
        full_box = full_box.view(1, -1)
        critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
        if critical_iou_index.numel() != 0:
            full_box[0][4] = p[0][critical_iou_index][:, 4]
            full_box[0][6:] = p[0][critical_iou_index][:, 6:]
            p[0][critical_iou_index] = full_box

        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported
        for i, pred in enumerate(p):
            orig_img = orig_imgs[i]
            img_path = self.batch[0][i]
            if not len(pred):  # save empty boxes
                masks = None
            elif self.args.retina_masks:
                pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
                masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
            else:
                masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
                pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
        return results

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Inizializza la classe FastSAMPredictor, ereditando da DetectionPredictor e impostando il compito su 'segment'.

Parametri:

Nome Tipo Descrizione Predefinito
cfg dict

Parametri di configurazione per la previsione.

DEFAULT_CFG
overrides dict

Parametri opzionali che sovrascrivono il comportamento personalizzato.

None
_callbacks dict

Elenco opzionale di funzioni di callback da invocare durante la predizione.

None
Codice sorgente in ultralytics/models/fastsam/predict.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.

    Args:
        cfg (dict): Configuration parameters for prediction.
        overrides (dict, optional): Optional parameter overrides for custom behavior.
        _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
    """
    super().__init__(cfg, overrides, _callbacks)
    self.args.task = "segment"

postprocess(preds, img, orig_imgs)

Esegue le fasi di post-elaborazione delle previsioni, tra cui la soppressione di non-max e il ridimensionamento delle caselle alle dimensioni dell'immagine originale. e restituisce i risultati finali.

Parametri:

Nome Tipo Descrizione Predefinito
preds list

Le previsioni di output grezze del modello.

richiesto
img Tensor

L'immagine elaborata tensor.

richiesto
orig_imgs list | Tensor

L'immagine originale o l'elenco di immagini.

richiesto

Restituzione:

Tipo Descrizione
list

Un elenco di oggetti Risultati, ciascuno contenente caselle elaborate, maschere e altri metadati.

Codice sorgente in ultralytics/models/fastsam/predict.py
def postprocess(self, preds, img, orig_imgs):
    """
    Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
    size, and returns the final results.

    Args:
        preds (list): The raw output predictions from the model.
        img (torch.Tensor): The processed image tensor.
        orig_imgs (list | torch.Tensor): The original image or list of images.

    Returns:
        (list): A list of Results objects, each containing processed boxes, masks, and other metadata.
    """
    p = ops.non_max_suppression(
        preds[0],
        self.args.conf,
        self.args.iou,
        agnostic=self.args.agnostic_nms,
        max_det=self.args.max_det,
        nc=1,  # set to 1 class since SAM has no class predictions
        classes=self.args.classes,
    )
    full_box = torch.zeros(p[0].shape[1], device=p[0].device)
    full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
    full_box = full_box.view(1, -1)
    critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
    if critical_iou_index.numel() != 0:
        full_box[0][4] = p[0][critical_iou_index][:, 4]
        full_box[0][6:] = p[0][critical_iou_index][:, 6:]
        p[0][critical_iou_index] = full_box

    if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
        orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

    results = []
    proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported
    for i, pred in enumerate(p):
        orig_img = orig_imgs[i]
        img_path = self.batch[0][i]
        if not len(pred):  # save empty boxes
            masks = None
        elif self.args.retina_masks:
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
        else:
            masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
        results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
    return results





Creato 2023-11-12, Aggiornato 2023-11-25
Autori: glenn-jocher (3)