跳至内容

参考资料 ultralytics/models/nas/predict.py

备注

该文件可在https://github.com/ultralytics/ultralytics/blob/main/ ultralytics/models/nas/predict .py 上获取。如果您发现问题,请通过提交 Pull Request🛠️ 帮助修复。谢谢🙏!



ultralytics.models.nas.predict.NASPredictor

垒球 BasePredictor

Ultralytics YOLO 用于物体检测的 NAS 预测器。

该类扩展了 BasePredictor Ultralytics 引擎,并负责对 NAS 模型生成的原始预测结果进行后处理。 YOLO NAS 模型生成的原始预测。它应用的操作包括非最大值抑制和 缩放边界框以适应原始图像尺寸等操作。

属性

名称 类型 说明
args Namespace

包含各种后处理配置的命名空间。

示例
from ultralytics import NAS

model = NAS('yolo_nas_s')
predictor = model.predictor
# Assumes that raw_preds, img, orig_imgs are available
results = predictor.postprocess(raw_preds, img, orig_imgs)
备注

通常情况下,该类不会被直接实例化。它在 NAS 类。

源代码 ultralytics/models/nas/predict.py
class NASPredictor(BasePredictor):
    """
    Ultralytics YOLO NAS Predictor for object detection.

    This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the
    raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
    scaling the bounding boxes to fit the original image dimensions.

    Attributes:
        args (Namespace): Namespace containing various configurations for post-processing.

    Example:
        ```python
        from ultralytics import NAS

        model = NAS('yolo_nas_s')
        predictor = model.predictor
        # Assumes that raw_preds, img, orig_imgs are available
        results = predictor.postprocess(raw_preds, img, orig_imgs)
        ```

    Note:
        Typically, this class is not instantiated directly. It is used internally within the `NAS` class.
    """

    def postprocess(self, preds_in, img, orig_imgs):
        """Postprocess predictions and returns a list of Results objects."""

        # Cat boxes and class scores
        boxes = ops.xyxy2xywh(preds_in[0][0])
        preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)

        preds = ops.non_max_suppression(
            preds,
            self.args.conf,
            self.args.iou,
            agnostic=self.args.agnostic_nms,
            max_det=self.args.max_det,
            classes=self.args.classes,
        )

        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        for i, pred in enumerate(preds):
            orig_img = orig_imgs[i]
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            img_path = self.batch[0][i]
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
        return results

postprocess(preds_in, img, orig_imgs)

对预测进行后处理,并返回一个结果对象列表。

源代码 ultralytics/models/nas/predict.py
def postprocess(self, preds_in, img, orig_imgs):
    """Postprocess predictions and returns a list of Results objects."""

    # Cat boxes and class scores
    boxes = ops.xyxy2xywh(preds_in[0][0])
    preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)

    preds = ops.non_max_suppression(
        preds,
        self.args.conf,
        self.args.iou,
        agnostic=self.args.agnostic_nms,
        max_det=self.args.max_det,
        classes=self.args.classes,
    )

    if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
        orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

    results = []
    for i, pred in enumerate(preds):
        orig_img = orig_imgs[i]
        pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
        img_path = self.batch[0][i]
        results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
    return results





创建于 2023-11-12,更新于 2023-11-25
作者:glenn-jocher(3)