Bỏ để qua phần nội dung

Tài liệu tham khảo cho ultralytics/models/fastsam/predict.py

Ghi

Tệp này có sẵn tại https://github.com/ultralytics/ultralytics/blob/main/ultralytics/Mô hình/fastsam/predict.py. Nếu bạn phát hiện ra một vấn đề, vui lòng giúp khắc phục nó bằng cách đóng góp Yêu cầu 🛠️ kéo. Cảm ơn bạn 🙏 !



ultralytics.models.fastsam.predict.FastSAMPredictor

Căn cứ: DetectionPredictor

FastSAMPredictor chuyên dùng để nhanh chóng SAM (Mô hình phân đoạn bất cứ điều gì) nhiệm vụ dự đoán phân đoạn trong Ultralytics YOLO khuôn khổ.

Lớp này mở rộng DetectionPredictor, tùy chỉnh quy trình dự đoán đặc biệt để nhanh chóng SAM. Nó điều chỉnh các bước xử lý hậu kỳ để kết hợp dự đoán mặt nạ và triệt tiêu không tối đa trong khi tối ưu hóa cho phân khúc một lớp.

Thuộc tính:

Tên Kiểu Sự miêu tả
cfg dict

Thông số cấu hình để dự đoán.

overrides dict

Ghi đè thông số tùy chọn cho hành vi tùy chỉnh.

_callbacks dict

Danh sách tùy chọn các hàm gọi lại sẽ được gọi trong quá trình dự đoán.

Mã nguồn trong ultralytics/models/fastsam/predict.py
class FastSAMPredictor(DetectionPredictor):
    """
    FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
    YOLO framework.

    This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM.
    It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing
    for single-class segmentation.

    Attributes:
        cfg (dict): Configuration parameters for prediction.
        overrides (dict, optional): Optional parameter overrides for custom behavior.
        _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.

        Args:
            cfg (dict): Configuration parameters for prediction.
            overrides (dict, optional): Optional parameter overrides for custom behavior.
            _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
        """
        super().__init__(cfg, overrides, _callbacks)
        self.args.task = "segment"

    def postprocess(self, preds, img, orig_imgs):
        """
        Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
        size, and returns the final results.

        Args:
            preds (list): The raw output predictions from the model.
            img (torch.Tensor): The processed image tensor.
            orig_imgs (list | torch.Tensor): The original image or list of images.

        Returns:
            (list): A list of Results objects, each containing processed boxes, masks, and other metadata.
        """
        p = ops.non_max_suppression(
            preds[0],
            self.args.conf,
            self.args.iou,
            agnostic=self.args.agnostic_nms,
            max_det=self.args.max_det,
            nc=1,  # set to 1 class since SAM has no class predictions
            classes=self.args.classes,
        )
        full_box = torch.zeros(p[0].shape[1], device=p[0].device)
        full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
        full_box = full_box.view(1, -1)
        critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
        if critical_iou_index.numel() != 0:
            full_box[0][4] = p[0][critical_iou_index][:, 4]
            full_box[0][6:] = p[0][critical_iou_index][:, 6:]
            p[0][critical_iou_index] = full_box

        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported
        for i, pred in enumerate(p):
            orig_img = orig_imgs[i]
            img_path = self.batch[0][i]
            if not len(pred):  # save empty boxes
                masks = None
            elif self.args.retina_masks:
                pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
                masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
            else:
                masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
                pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
        return results

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Khởi tạo lớp FastSAMPredictor, kế thừa từ DetectionPredictor và đặt tác vụ thành 'segment'.

Thông số:

Tên Kiểu Sự miêu tả Mặc định
cfg dict

Thông số cấu hình để dự đoán.

DEFAULT_CFG
overrides dict

Ghi đè thông số tùy chọn cho hành vi tùy chỉnh.

None
_callbacks dict

Danh sách tùy chọn các hàm gọi lại sẽ được gọi trong quá trình dự đoán.

None
Mã nguồn trong ultralytics/models/fastsam/predict.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.

    Args:
        cfg (dict): Configuration parameters for prediction.
        overrides (dict, optional): Optional parameter overrides for custom behavior.
        _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
    """
    super().__init__(cfg, overrides, _callbacks)
    self.args.task = "segment"

postprocess(preds, img, orig_imgs)

Thực hiện các bước xử lý hậu kỳ trên các dự đoán, bao gồm các hộp triệt tiêu và chia tỷ lệ không tối đa thành hình ảnh gốc kích thước và trả về kết quả cuối cùng.

Thông số:

Tên Kiểu Sự miêu tả Mặc định
preds list

Các dự đoán đầu ra thô từ mô hình.

bắt buộc
img Tensor

Hình ảnh đã xử lý tensor.

bắt buộc
orig_imgs list | Tensor

Hình ảnh gốc hoặc danh sách hình ảnh.

bắt buộc

Trở lại:

Kiểu Sự miêu tả
list

Danh sách các đối tượng Kết quả, mỗi đối tượng chứa các hộp đã xử lý, mặt nạ và siêu dữ liệu khác.

Mã nguồn trong ultralytics/models/fastsam/predict.py
def postprocess(self, preds, img, orig_imgs):
    """
    Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
    size, and returns the final results.

    Args:
        preds (list): The raw output predictions from the model.
        img (torch.Tensor): The processed image tensor.
        orig_imgs (list | torch.Tensor): The original image or list of images.

    Returns:
        (list): A list of Results objects, each containing processed boxes, masks, and other metadata.
    """
    p = ops.non_max_suppression(
        preds[0],
        self.args.conf,
        self.args.iou,
        agnostic=self.args.agnostic_nms,
        max_det=self.args.max_det,
        nc=1,  # set to 1 class since SAM has no class predictions
        classes=self.args.classes,
    )
    full_box = torch.zeros(p[0].shape[1], device=p[0].device)
    full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
    full_box = full_box.view(1, -1)
    critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
    if critical_iou_index.numel() != 0:
        full_box[0][4] = p[0][critical_iou_index][:, 4]
        full_box[0][6:] = p[0][critical_iou_index][:, 6:]
        p[0][critical_iou_index] = full_box

    if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
        orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

    results = []
    proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported
    for i, pred in enumerate(p):
        orig_img = orig_imgs[i]
        img_path = self.batch[0][i]
        if not len(pred):  # save empty boxes
            masks = None
        elif self.args.retina_masks:
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
        else:
            masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
        results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
    return results





Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (5), Burhan-Q (1)