跳至内容

参考资料 ultralytics/engine/predictor.py

备注

该文件可在https://github.com/ultralytics/ultralytics/blob/main/ ultralytics/engine/predictor .py。如果您发现问题,请通过提交 Pull Request🛠️ 帮助修复。谢谢🙏!



ultralytics.engine.predictor.BasePredictor

BasePredictor.

用于创建预测器的基类。

属性

名称 类型 说明
args SimpleNamespace

预测器的配置。

save_dir Path

保存结果的目录。

done_warmup bool

预测器是否已完成设置。

model Module

用于预测的模型。

data dict

数据配置。

device device

用于预测的设备。

dataset Dataset

用于预测的数据集。

vid_writer dict

Dictionary of {save_path: video_writer, ...} writer for saving video output.

源代码 ultralytics/engine/predictor.py
class BasePredictor:
    """
    BasePredictor.

    A base class for creating predictors.

    Attributes:
        args (SimpleNamespace): Configuration for the predictor.
        save_dir (Path): Directory to save results.
        done_warmup (bool): Whether the predictor has finished setup.
        model (nn.Module): Model used for prediction.
        data (dict): Data configuration.
        device (torch.device): Device used for prediction.
        dataset (Dataset): Dataset used for prediction.
        vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output.
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initializes the BasePredictor class.

        Args:
            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
            overrides (dict, optional): Configuration overrides. Defaults to None.
        """
        self.args = get_cfg(cfg, overrides)
        self.save_dir = get_save_dir(self.args)
        if self.args.conf is None:
            self.args.conf = 0.25  # default conf=0.25
        self.done_warmup = False
        if self.args.show:
            self.args.show = check_imshow(warn=True)

        # Usable if setup is done
        self.model = None
        self.data = self.args.data  # data_dict
        self.imgsz = None
        self.device = None
        self.dataset = None
        self.vid_writer = {}  # dict of {save_path: video_writer, ...}
        self.plotted_img = None
        self.source_type = None
        self.seen = 0
        self.windows = []
        self.batch = None
        self.results = None
        self.transforms = None
        self.callbacks = _callbacks or callbacks.get_default_callbacks()
        self.txt_path = None
        self._lock = threading.Lock()  # for automatic thread-safe inference
        callbacks.add_integration_callbacks(self)

    def preprocess(self, im):
        """
        Prepares input image before inference.

        Args:
            im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
        """
        not_tensor = not isinstance(im, torch.Tensor)
        if not_tensor:
            im = np.stack(self.pre_transform(im))
            im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
            im = np.ascontiguousarray(im)  # contiguous
            im = torch.from_numpy(im)

        im = im.to(self.device)
        im = im.half() if self.model.fp16 else im.float()  # uint8 to fp16/32
        if not_tensor:
            im /= 255  # 0 - 255 to 0.0 - 1.0
        return im

    def inference(self, im, *args, **kwargs):
        """Runs inference on a given image using the specified model and arguments."""
        visualize = (
            increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
            if self.args.visualize and (not self.source_type.tensor)
            else False
        )
        return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)

    def pre_transform(self, im):
        """
        Pre-transform input image before inference.

        Args:
            im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.

        Returns:
            (list): A list of transformed images.
        """
        same_shapes = len({x.shape for x in im}) == 1
        letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
        return [letterbox(image=x) for x in im]

    def postprocess(self, preds, img, orig_imgs):
        """Post-processes predictions for an image and returns them."""
        return preds

    def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
        """Performs inference on an image or stream."""
        self.stream = stream
        if stream:
            return self.stream_inference(source, model, *args, **kwargs)
        else:
            return list(self.stream_inference(source, model, *args, **kwargs))  # merge list of Result into one

    def predict_cli(self, source=None, model=None):
        """
        Method used for Command Line Interface (CLI) prediction.

        This function is designed to run predictions using the CLI. It sets up the source and model, then processes
        the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
        generator without storing results.

        Note:
            Do not modify this function or remove the generator. The generator ensures that no outputs are
            accumulated in memory, which is critical for preventing memory issues during long-running predictions.
        """
        gen = self.stream_inference(source, model)
        for _ in gen:  # sourcery skip: remove-empty-nested-block, noqa
            pass

    def setup_source(self, source):
        """Sets up source and inference mode."""
        self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2)  # check image size
        self.transforms = (
            getattr(
                self.model.model,
                "transforms",
                classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
            )
            if self.args.task == "classify"
            else None
        )
        self.dataset = load_inference_source(
            source=source,
            batch=self.args.batch,
            vid_stride=self.args.vid_stride,
            buffer=self.args.stream_buffer,
        )
        self.source_type = self.dataset.source_type
        if not getattr(self, "stream", True) and (
            self.source_type.stream
            or self.source_type.screenshot
            or len(self.dataset) > 1000  # many images
            or any(getattr(self.dataset, "video_flag", [False]))
        ):  # videos
            LOGGER.warning(STREAM_WARNING)
        self.vid_writer = {}

    @smart_inference_mode()
    def stream_inference(self, source=None, model=None, *args, **kwargs):
        """Streams real-time inference on camera feed and saves results to file."""
        if self.args.verbose:
            LOGGER.info("")

        # Setup model
        if not self.model:
            self.setup_model(model)

        with self._lock:  # for thread-safe inference
            # Setup source every time predict is called
            self.setup_source(source if source is not None else self.args.source)

            # Check if save_dir/ label file exists
            if self.args.save or self.args.save_txt:
                (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)

            # Warmup model
            if not self.done_warmup:
                self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
                self.done_warmup = True

            self.seen, self.windows, self.batch = 0, [], None
            profilers = (
                ops.Profile(device=self.device),
                ops.Profile(device=self.device),
                ops.Profile(device=self.device),
            )
            self.run_callbacks("on_predict_start")
            for self.batch in self.dataset:
                self.run_callbacks("on_predict_batch_start")
                paths, im0s, s = self.batch

                # Preprocess
                with profilers[0]:
                    im = self.preprocess(im0s)

                # Inference
                with profilers[1]:
                    preds = self.inference(im, *args, **kwargs)
                    if self.args.embed:
                        yield from [preds] if isinstance(preds, torch.Tensor) else preds  # yield embedding tensors
                        continue

                # Postprocess
                with profilers[2]:
                    self.results = self.postprocess(preds, im, im0s)
                self.run_callbacks("on_predict_postprocess_end")

                # Visualize, save, write results
                n = len(im0s)
                for i in range(n):
                    self.seen += 1
                    self.results[i].speed = {
                        "preprocess": profilers[0].dt * 1e3 / n,
                        "inference": profilers[1].dt * 1e3 / n,
                        "postprocess": profilers[2].dt * 1e3 / n,
                    }
                    if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
                        s[i] += self.write_results(i, Path(paths[i]), im, s)

                # Print batch results
                if self.args.verbose:
                    LOGGER.info("\n".join(s))

                self.run_callbacks("on_predict_batch_end")
                yield from self.results

        # Release assets
        for v in self.vid_writer.values():
            if isinstance(v, cv2.VideoWriter):
                v.release()

        # Print final results
        if self.args.verbose and self.seen:
            t = tuple(x.t / self.seen * 1e3 for x in profilers)  # speeds per image
            LOGGER.info(
                f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
                f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t
            )
        if self.args.save or self.args.save_txt or self.args.save_crop:
            nl = len(list(self.save_dir.glob("labels/*.txt")))  # number of labels
            s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
            LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
        self.run_callbacks("on_predict_end")

    def setup_model(self, model, verbose=True):
        """Initialize YOLO model with given parameters and set it to evaluation mode."""
        self.model = AutoBackend(
            weights=model or self.args.model,
            device=select_device(self.args.device, verbose=verbose),
            dnn=self.args.dnn,
            data=self.args.data,
            fp16=self.args.half,
            batch=self.args.batch,
            fuse=True,
            verbose=verbose,
        )

        self.device = self.model.device  # update device
        self.args.half = self.model.fp16  # update half
        self.model.eval()

    def write_results(self, i, p, im, s):
        """Write inference results to a file or directory."""
        string = ""  # print string
        if len(im.shape) == 3:
            im = im[None]  # expand for batch dim
        if self.source_type.stream or self.source_type.from_img or self.source_type.tensor:  # batch_size >= 1
            string += f"{i}: "
            frame = self.dataset.count
        else:
            match = re.search(r"frame (\d+)/", s[i])
            frame = int(match[1]) if match else None  # 0 if frame undetermined

        self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
        string += "%gx%g " % im.shape[2:]
        result = self.results[i]
        result.save_dir = self.save_dir.__str__()  # used in other locations
        string += f"{result.verbose()}{result.speed['inference']:.1f}ms"

        # Add predictions to image
        if self.args.save or self.args.show:
            self.plotted_img = result.plot(
                line_width=self.args.line_width,
                boxes=self.args.show_boxes,
                conf=self.args.show_conf,
                labels=self.args.show_labels,
                im_gpu=None if self.args.retina_masks else im[i],
            )

        # Save results
        if self.args.save_txt:
            result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
        if self.args.save_crop:
            result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
        if self.args.show:
            self.show(str(p))
        if self.args.save:
            self.save_predicted_images(str(self.save_dir / p.name), frame)

        return string

    def save_predicted_images(self, save_path="", frame=0):
        """Save video predictions as mp4 at specified path."""
        im = self.plotted_img

        # Save videos and streams
        if self.dataset.mode in {"stream", "video"}:
            fps = self.dataset.fps if self.dataset.mode == "video" else 30
            frames_path = f'{save_path.split(".", 1)[0]}_frames/'
            if save_path not in self.vid_writer:  # new video
                if self.args.save_frames:
                    Path(frames_path).mkdir(parents=True, exist_ok=True)
                suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
                self.vid_writer[save_path] = cv2.VideoWriter(
                    filename=str(Path(save_path).with_suffix(suffix)),
                    fourcc=cv2.VideoWriter_fourcc(*fourcc),
                    fps=fps,  # integer required, floats produce error in MP4 codec
                    frameSize=(im.shape[1], im.shape[0]),  # (width, height)
                )

            # Save video
            self.vid_writer[save_path].write(im)
            if self.args.save_frames:
                cv2.imwrite(f"{frames_path}{frame}.jpg", im)

        # Save images
        else:
            cv2.imwrite(save_path, im)

    def show(self, p=""):
        """Display an image in a window using OpenCV imshow()."""
        im = self.plotted_img
        if platform.system() == "Linux" and p not in self.windows:
            self.windows.append(p)
            cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux)
            cv2.resizeWindow(p, im.shape[1], im.shape[0])  # (width, height)
        cv2.imshow(p, im)
        cv2.waitKey(300 if self.dataset.mode == "image" else 1)  # 1 millisecond

    def run_callbacks(self, event: str):
        """Runs all registered callbacks for a specific event."""
        for callback in self.callbacks.get(event, []):
            callback(self)

    def add_callback(self, event: str, func):
        """Add callback."""
        self.callbacks[event].append(func)

__call__(source=None, model=None, stream=False, *args, **kwargs)

对图像或数据流执行推理。

源代码 ultralytics/engine/predictor.py
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
    """Performs inference on an image or stream."""
    self.stream = stream
    if stream:
        return self.stream_inference(source, model, *args, **kwargs)
    else:
        return list(self.stream_inference(source, model, *args, **kwargs))  # merge list of Result into one

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

初始化 BasePredictor 类。

参数

名称 类型 说明 默认值
cfg str

配置文件的路径。默认为 DEFAULT_CFG。

DEFAULT_CFG
overrides dict

配置覆盖。默认为 "无"。

None
源代码 ultralytics/engine/predictor.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initializes the BasePredictor class.

    Args:
        cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
        overrides (dict, optional): Configuration overrides. Defaults to None.
    """
    self.args = get_cfg(cfg, overrides)
    self.save_dir = get_save_dir(self.args)
    if self.args.conf is None:
        self.args.conf = 0.25  # default conf=0.25
    self.done_warmup = False
    if self.args.show:
        self.args.show = check_imshow(warn=True)

    # Usable if setup is done
    self.model = None
    self.data = self.args.data  # data_dict
    self.imgsz = None
    self.device = None
    self.dataset = None
    self.vid_writer = {}  # dict of {save_path: video_writer, ...}
    self.plotted_img = None
    self.source_type = None
    self.seen = 0
    self.windows = []
    self.batch = None
    self.results = None
    self.transforms = None
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    self.txt_path = None
    self._lock = threading.Lock()  # for automatic thread-safe inference
    callbacks.add_integration_callbacks(self)

add_callback(event, func)

添加回调。

源代码 ultralytics/engine/predictor.py
def add_callback(self, event: str, func):
    """Add callback."""
    self.callbacks[event].append(func)

inference(im, *args, **kwargs)

使用指定的模型和参数对给定图像进行推理。

源代码 ultralytics/engine/predictor.py
def inference(self, im, *args, **kwargs):
    """Runs inference on a given image using the specified model and arguments."""
    visualize = (
        increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
        if self.args.visualize and (not self.source_type.tensor)
        else False
    )
    return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)

postprocess(preds, img, orig_imgs)

对图像进行后处理并返回预测结果。

源代码 ultralytics/engine/predictor.py
def postprocess(self, preds, img, orig_imgs):
    """Post-processes predictions for an image and returns them."""
    return preds

pre_transform(im)

在推理之前对输入图像进行预变换。

参数

名称 类型 说明 默认值
im List(np.ndarray

(N, 3, h, w) 代表tensor, [(h, w, 3) x N] 代表列表。

所需

返回:

类型 说明
list

转换后的图像列表。

源代码 ultralytics/engine/predictor.py
def pre_transform(self, im):
    """
    Pre-transform input image before inference.

    Args:
        im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.

    Returns:
        (list): A list of transformed images.
    """
    same_shapes = len({x.shape for x in im}) == 1
    letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
    return [letterbox(image=x) for x in im]

predict_cli(source=None, model=None)

Method used for Command Line Interface (CLI) prediction.

This function is designed to run predictions using the CLI. It sets up the source and model, then processes the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the generator without storing results.

备注

Do not modify this function or remove the generator. The generator ensures that no outputs are accumulated in memory, which is critical for preventing memory issues during long-running predictions.

源代码 ultralytics/engine/predictor.py
def predict_cli(self, source=None, model=None):
    """
    Method used for Command Line Interface (CLI) prediction.

    This function is designed to run predictions using the CLI. It sets up the source and model, then processes
    the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
    generator without storing results.

    Note:
        Do not modify this function or remove the generator. The generator ensures that no outputs are
        accumulated in memory, which is critical for preventing memory issues during long-running predictions.
    """
    gen = self.stream_inference(source, model)
    for _ in gen:  # sourcery skip: remove-empty-nested-block, noqa
        pass

preprocess(im)

在推理前准备输入图像。

参数

名称 类型 说明 默认值
im torch.Tensor | List(np.ndarray

BCHW 代表tensor, [(HWC) x B] 代表列表。

所需
源代码 ultralytics/engine/predictor.py
def preprocess(self, im):
    """
    Prepares input image before inference.

    Args:
        im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
    """
    not_tensor = not isinstance(im, torch.Tensor)
    if not_tensor:
        im = np.stack(self.pre_transform(im))
        im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
        im = np.ascontiguousarray(im)  # contiguous
        im = torch.from_numpy(im)

    im = im.to(self.device)
    im = im.half() if self.model.fp16 else im.float()  # uint8 to fp16/32
    if not_tensor:
        im /= 255  # 0 - 255 to 0.0 - 1.0
    return im

run_callbacks(event)

运行特定事件的所有已注册回调。

源代码 ultralytics/engine/predictor.py
def run_callbacks(self, event: str):
    """Runs all registered callbacks for a specific event."""
    for callback in self.callbacks.get(event, []):
        callback(self)

save_predicted_images(save_path='', frame=0)

将预测的视频以 mp4 格式保存在指定路径下。

源代码 ultralytics/engine/predictor.py
def save_predicted_images(self, save_path="", frame=0):
    """Save video predictions as mp4 at specified path."""
    im = self.plotted_img

    # Save videos and streams
    if self.dataset.mode in {"stream", "video"}:
        fps = self.dataset.fps if self.dataset.mode == "video" else 30
        frames_path = f'{save_path.split(".", 1)[0]}_frames/'
        if save_path not in self.vid_writer:  # new video
            if self.args.save_frames:
                Path(frames_path).mkdir(parents=True, exist_ok=True)
            suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
            self.vid_writer[save_path] = cv2.VideoWriter(
                filename=str(Path(save_path).with_suffix(suffix)),
                fourcc=cv2.VideoWriter_fourcc(*fourcc),
                fps=fps,  # integer required, floats produce error in MP4 codec
                frameSize=(im.shape[1], im.shape[0]),  # (width, height)
            )

        # Save video
        self.vid_writer[save_path].write(im)
        if self.args.save_frames:
            cv2.imwrite(f"{frames_path}{frame}.jpg", im)

    # Save images
    else:
        cv2.imwrite(save_path, im)

setup_model(model, verbose=True)

使用给定参数初始化YOLO 模型,并将其设置为评估模式。

源代码 ultralytics/engine/predictor.py
def setup_model(self, model, verbose=True):
    """Initialize YOLO model with given parameters and set it to evaluation mode."""
    self.model = AutoBackend(
        weights=model or self.args.model,
        device=select_device(self.args.device, verbose=verbose),
        dnn=self.args.dnn,
        data=self.args.data,
        fp16=self.args.half,
        batch=self.args.batch,
        fuse=True,
        verbose=verbose,
    )

    self.device = self.model.device  # update device
    self.args.half = self.model.fp16  # update half
    self.model.eval()

setup_source(source)

设置信号源和推理模式。

源代码 ultralytics/engine/predictor.py
def setup_source(self, source):
    """Sets up source and inference mode."""
    self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2)  # check image size
    self.transforms = (
        getattr(
            self.model.model,
            "transforms",
            classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
        )
        if self.args.task == "classify"
        else None
    )
    self.dataset = load_inference_source(
        source=source,
        batch=self.args.batch,
        vid_stride=self.args.vid_stride,
        buffer=self.args.stream_buffer,
    )
    self.source_type = self.dataset.source_type
    if not getattr(self, "stream", True) and (
        self.source_type.stream
        or self.source_type.screenshot
        or len(self.dataset) > 1000  # many images
        or any(getattr(self.dataset, "video_flag", [False]))
    ):  # videos
        LOGGER.warning(STREAM_WARNING)
    self.vid_writer = {}

show(p='')

使用 OpenCV imshow() 在窗口中显示图像。

源代码 ultralytics/engine/predictor.py
def show(self, p=""):
    """Display an image in a window using OpenCV imshow()."""
    im = self.plotted_img
    if platform.system() == "Linux" and p not in self.windows:
        self.windows.append(p)
        cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux)
        cv2.resizeWindow(p, im.shape[1], im.shape[0])  # (width, height)
    cv2.imshow(p, im)
    cv2.waitKey(300 if self.dataset.mode == "image" else 1)  # 1 millisecond

stream_inference(source=None, model=None, *args, **kwargs)

实时推理摄像机画面,并将结果保存到文件中。

源代码 ultralytics/engine/predictor.py
@smart_inference_mode()
def stream_inference(self, source=None, model=None, *args, **kwargs):
    """Streams real-time inference on camera feed and saves results to file."""
    if self.args.verbose:
        LOGGER.info("")

    # Setup model
    if not self.model:
        self.setup_model(model)

    with self._lock:  # for thread-safe inference
        # Setup source every time predict is called
        self.setup_source(source if source is not None else self.args.source)

        # Check if save_dir/ label file exists
        if self.args.save or self.args.save_txt:
            (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)

        # Warmup model
        if not self.done_warmup:
            self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
            self.done_warmup = True

        self.seen, self.windows, self.batch = 0, [], None
        profilers = (
            ops.Profile(device=self.device),
            ops.Profile(device=self.device),
            ops.Profile(device=self.device),
        )
        self.run_callbacks("on_predict_start")
        for self.batch in self.dataset:
            self.run_callbacks("on_predict_batch_start")
            paths, im0s, s = self.batch

            # Preprocess
            with profilers[0]:
                im = self.preprocess(im0s)

            # Inference
            with profilers[1]:
                preds = self.inference(im, *args, **kwargs)
                if self.args.embed:
                    yield from [preds] if isinstance(preds, torch.Tensor) else preds  # yield embedding tensors
                    continue

            # Postprocess
            with profilers[2]:
                self.results = self.postprocess(preds, im, im0s)
            self.run_callbacks("on_predict_postprocess_end")

            # Visualize, save, write results
            n = len(im0s)
            for i in range(n):
                self.seen += 1
                self.results[i].speed = {
                    "preprocess": profilers[0].dt * 1e3 / n,
                    "inference": profilers[1].dt * 1e3 / n,
                    "postprocess": profilers[2].dt * 1e3 / n,
                }
                if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
                    s[i] += self.write_results(i, Path(paths[i]), im, s)

            # Print batch results
            if self.args.verbose:
                LOGGER.info("\n".join(s))

            self.run_callbacks("on_predict_batch_end")
            yield from self.results

    # Release assets
    for v in self.vid_writer.values():
        if isinstance(v, cv2.VideoWriter):
            v.release()

    # Print final results
    if self.args.verbose and self.seen:
        t = tuple(x.t / self.seen * 1e3 for x in profilers)  # speeds per image
        LOGGER.info(
            f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
            f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t
        )
    if self.args.save or self.args.save_txt or self.args.save_crop:
        nl = len(list(self.save_dir.glob("labels/*.txt")))  # number of labels
        s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
        LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
    self.run_callbacks("on_predict_end")

write_results(i, p, im, s)

将推理结果写入文件或目录。

源代码 ultralytics/engine/predictor.py
def write_results(self, i, p, im, s):
    """Write inference results to a file or directory."""
    string = ""  # print string
    if len(im.shape) == 3:
        im = im[None]  # expand for batch dim
    if self.source_type.stream or self.source_type.from_img or self.source_type.tensor:  # batch_size >= 1
        string += f"{i}: "
        frame = self.dataset.count
    else:
        match = re.search(r"frame (\d+)/", s[i])
        frame = int(match[1]) if match else None  # 0 if frame undetermined

    self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
    string += "%gx%g " % im.shape[2:]
    result = self.results[i]
    result.save_dir = self.save_dir.__str__()  # used in other locations
    string += f"{result.verbose()}{result.speed['inference']:.1f}ms"

    # Add predictions to image
    if self.args.save or self.args.show:
        self.plotted_img = result.plot(
            line_width=self.args.line_width,
            boxes=self.args.show_boxes,
            conf=self.args.show_conf,
            labels=self.args.show_labels,
            im_gpu=None if self.args.retina_masks else im[i],
        )

    # Save results
    if self.args.save_txt:
        result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
    if self.args.save_crop:
        result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
    if self.args.show:
        self.show(str(p))
    if self.args.save:
        self.save_predicted_images(str(self.save_dir / p.name), frame)

    return string





Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (5), Burhan-Q (1)