跳至内容

参考资料 ultralytics/models/nas/val.py

备注

该文件可在https://github.com/ultralytics/ultralytics/blob/main/ ultralytics/models/nas/val .py 上获取。如果您发现问题,请通过提交 Pull Request🛠️ 帮助修复。谢谢🙏!



ultralytics.models.nas.val.NASValidator

垒球 DetectionValidator

Ultralytics YOLO 用于对象检测的 NAS 验证器。

扩展 DetectionValidator Ultralytics 模型软件包,旨在对 NAS 模型生成的原始预测结果进行后处理。 YOLO NAS 模型生成的原始预测。它执行非最大抑制,以去除重叠和低置信度的方框、 最终产生最终探测结果。

属性

名称 类型 说明
args Namespace

名称空间包含用于后处理的各种配置,如置信度和 IoU 阈值。

lb Tensor

多标签 NMS 的可选tensor 。

示例
from ultralytics import NAS

model = NAS('yolo_nas_s')
validator = model.validator
# Assumes that raw_preds are available
final_preds = validator.postprocess(raw_preds)
备注

该类通常不会直接实例化,而是在 NAS 类。

源代码 ultralytics/models/nas/val.py
class NASValidator(DetectionValidator):
    """
    Ultralytics YOLO NAS Validator for object detection.

    Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
    generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
    ultimately producing the final detections.

    Attributes:
        args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU thresholds.
        lb (torch.Tensor): Optional tensor for multilabel NMS.

    Example:
        ```python
        from ultralytics import NAS

        model = NAS('yolo_nas_s')
        validator = model.validator
        # Assumes that raw_preds are available
        final_preds = validator.postprocess(raw_preds)
        ```

    Note:
        This class is generally not instantiated directly but is used internally within the `NAS` class.
    """

    def postprocess(self, preds_in):
        """Apply Non-maximum suppression to prediction outputs."""
        boxes = ops.xyxy2xywh(preds_in[0][0])
        preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
        return ops.non_max_suppression(
            preds,
            self.args.conf,
            self.args.iou,
            labels=self.lb,
            multi_label=False,
            agnostic=self.args.single_cls,
            max_det=self.args.max_det,
            max_time_img=0.5,
        )

postprocess(preds_in)

对预测输出进行非最大抑制。

源代码 ultralytics/models/nas/val.py
def postprocess(self, preds_in):
    """Apply Non-maximum suppression to prediction outputs."""
    boxes = ops.xyxy2xywh(preds_in[0][0])
    preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
    return ops.non_max_suppression(
        preds,
        self.args.conf,
        self.args.iou,
        labels=self.lb,
        multi_label=False,
        agnostic=self.args.single_cls,
        max_det=self.args.max_det,
        max_time_img=0.5,
    )





创建于 2023-11-12,更新于 2023-11-25
作者:glenn-jocher(3)