Bỏ để qua phần nội dung

Tài liệu tham khảo cho ultralytics/utils/callbacks/wb.py

Ghi

Tệp này có sẵn tại https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/callbacks/wb.py. Nếu bạn phát hiện ra một vấn đề, vui lòng giúp khắc phục nó bằng cách đóng góp Yêu cầu 🛠️ kéo. Cảm ơn bạn 🙏 !



ultralytics.utils.callbacks.wb._custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall', y_title='Precision')

Tạo và ghi nhật ký trực quan hóa số liệu tùy chỉnh để wandb.plot.pr_curve.

This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across different classes.

Thông số:

Tên Kiểu Sự miêu tả Mặc định
x List

Giá trị cho trục x; dự kiến có chiều dài N.

bắt buộc
y List

Các giá trị tương ứng cho trục y; cũng dự kiến có chiều dài N.

bắt buộc
classes List

Nhãn xác định lớp của từng điểm; chiều dài N.

bắt buộc
title str

Tiêu đề cho cốt truyện; mặc định là 'Đường cong thu hồi chính xác'.

'Precision Recall Curve'
x_title str

Nhãn cho trục x; mặc định là "Thu hồi".

'Recall'
y_title str

Nhãn cho trục y; mặc định là "Độ chính xác".

'Precision'

Trở lại:

Kiểu Sự miêu tả
Object

Một đối tượng đũa phép thích hợp để ghi nhật ký, thể hiện trực quan hóa số liệu được chế tạo.

Mã nguồn trong ultralytics/utils/callbacks/wb.py
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
    """
    Create and log a custom metric visualization to wandb.plot.pr_curve.

    This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
    curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
    different classes.

    Args:
        x (List): Values for the x-axis; expected to have length N.
        y (List): Corresponding values for the y-axis; also expected to have length N.
        classes (List): Labels identifying the class of each point; length N.
        title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
        x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
        y_title (str, optional): Label for the y-axis; defaults to 'Precision'.

    Returns:
        (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
    """
    import pandas  # scope for faster 'import ultralytics'

    df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
    fields = {"x": "x", "y": "y", "class": "class"}
    string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
    return wb.plot_table(
        "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
    )



ultralytics.utils.callbacks.wb._plot_curve(x, y, names=None, id='precision-recall', title='Precision Recall Curve', x_title='Recall', y_title='Precision', num_x=100, only_mean=False)

Ghi nhật ký trực quan hóa đường cong số liệu.

Hàm này tạo ra một đường cong số liệu dựa trên dữ liệu đầu vào và ghi nhật ký trực quan hóa thành wandb. Đường cong có thể đại diện cho dữ liệu tổng hợp (trung bình) hoặc dữ liệu lớp riêng lẻ, tùy thuộc vào cờ 'only_mean'.

Thông số:

Tên Kiểu Sự miêu tả Mặc định
x ndarray

Các điểm dữ liệu cho trục x có độ dài N.

bắt buộc
y ndarray

Các điểm dữ liệu tương ứng cho trục y có hình CxN, trong đó C là số lớp.

bắt buộc
names list

Tên của các lớp tương ứng với dữ liệu trục y; chiều dài C. Mặc định là [].

None
id str

Mã định danh duy nhất cho dữ liệu đã đăng nhập trong wandb. Mặc định là "thu hồi chính xác".

'precision-recall'
title str

Tiêu đề cho cốt truyện trực quan. Mặc định là 'Đường cong thu hồi chính xác'.

'Precision Recall Curve'
x_title str

Nhãn cho trục x. Mặc định là "Thu hồi".

'Recall'
y_title str

Nhãn cho trục y. Mặc định là 'Độ chính xác'.

'Precision'
num_x int

Số điểm dữ liệu nội suy để trực quan hóa. Mặc định là 100.

100
only_mean bool

Gắn cờ để cho biết nếu chỉ nên vẽ đường cong trung bình. Mặc định là True.

False
Ghi

Chức năng này tận dụng chức năng '_custom_table' để tạo ra hình ảnh trực quan thực tế.

Mã nguồn trong ultralytics/utils/callbacks/wb.py
def _plot_curve(
    x,
    y,
    names=None,
    id="precision-recall",
    title="Precision Recall Curve",
    x_title="Recall",
    y_title="Precision",
    num_x=100,
    only_mean=False,
):
    """
    Log a metric curve visualization.

    This function generates a metric curve based on input data and logs the visualization to wandb.
    The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.

    Args:
        x (np.ndarray): Data points for the x-axis with length N.
        y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
        names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
        id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
        title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
        x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
        y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
        num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
        only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.

    Note:
        The function leverages the '_custom_table' function to generate the actual visualization.
    """
    import numpy as np

    # Create new x
    if names is None:
        names = []
    x_new = np.linspace(x[0], x[-1], num_x).round(5)

    # Create arrays for logging
    x_log = x_new.tolist()
    y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()

    if only_mean:
        table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
        wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
    else:
        classes = ["mean"] * len(x_log)
        for i, yi in enumerate(y):
            x_log.extend(x_new)  # add new x
            y_log.extend(np.interp(x_new, x, yi))  # interpolate y to new x
            classes.extend([names[i]] * len(x_new))  # add class names
        wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)



ultralytics.utils.callbacks.wb._log_plots(plots, step)

Ghi nhật ký các biểu đồ từ từ điển nhập liệu nếu chúng chưa được ghi lại ở bước được chỉ định.

Mã nguồn trong ultralytics/utils/callbacks/wb.py
def _log_plots(plots, step):
    """Logs plots from the input dictionary if they haven't been logged already at the specified step."""
    for name, params in plots.copy().items():  # shallow copy to prevent plots dict changing during iteration
        timestamp = params["timestamp"]
        if _processed_plots.get(name) != timestamp:
            wb.run.log({name.stem: wb.Image(str(name))}, step=step)
            _processed_plots[name] = timestamp



ultralytics.utils.callbacks.wb.on_pretrain_routine_start(trainer)

Bắt đầu và bắt đầu dự án nếu có mô-đun.

Mã nguồn trong ultralytics/utils/callbacks/wb.py
def on_pretrain_routine_start(trainer):
    """Initiate and start project if module is present."""
    wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))



ultralytics.utils.callbacks.wb.on_fit_epoch_end(trainer)

Ghi nhật ký số liệu đào tạo và thông tin mô hình vào cuối kỷ nguyên.

Mã nguồn trong ultralytics/utils/callbacks/wb.py
def on_fit_epoch_end(trainer):
    """Logs training metrics and model information at the end of an epoch."""
    wb.run.log(trainer.metrics, step=trainer.epoch + 1)
    _log_plots(trainer.plots, step=trainer.epoch + 1)
    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
    if trainer.epoch == 0:
        wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)



ultralytics.utils.callbacks.wb.on_train_epoch_end(trainer)

Ghi nhật ký số liệu và lưu hình ảnh vào cuối mỗi kỷ nguyên đào tạo.

Mã nguồn trong ultralytics/utils/callbacks/wb.py
def on_train_epoch_end(trainer):
    """Log metrics and save images at the end of each training epoch."""
    wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
    wb.run.log(trainer.lr, step=trainer.epoch + 1)
    if trainer.epoch == 1:
        _log_plots(trainer.plots, step=trainer.epoch + 1)



ultralytics.utils.callbacks.wb.on_train_end(trainer)

Lưu mô hình tốt nhất làm hiện vật khi kết thúc đào tạo.

Mã nguồn trong ultralytics/utils/callbacks/wb.py
def on_train_end(trainer):
    """Save the best model as an artifact at end of training."""
    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
    _log_plots(trainer.plots, step=trainer.epoch + 1)
    art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
    if trainer.best.exists():
        art.add_file(trainer.best)
        wb.run.log_artifact(art, aliases=["best"])
    for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
        x, y, x_title, y_title = curve_values
        _plot_curve(
            x,
            y,
            names=list(trainer.validator.metrics.names.values()),
            id=f"curves/{curve_name}",
            title=curve_name,
            x_title=x_title,
            y_title=y_title,
        )
    wb.run.finish()  # required or run continues on dashboard





Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (5), Burhan-Q (1), Laughing-q (1)