Skip to content

Reference for ultralytics/utils/callbacks/wb.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/callbacks/wb.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.utils.callbacks.wb._custom_table

_custom_table(
    x,
    y,
    classes,
    title="Precision Recall Curve",
    x_title="Recall",
    y_title="Precision",
)

Create and log a custom metric visualization to wandb.plot.pr_curve.

This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across different classes.

Parameters:

Name Type Description Default
x list

Values for the x-axis; expected to have length N.

required
y list

Corresponding values for the y-axis; also expected to have length N.

required
classes list

Labels identifying the class of each point; length N.

required
title str

Title for the plot; defaults to 'Precision Recall Curve'.

'Precision Recall Curve'
x_title str

Label for the x-axis; defaults to 'Recall'.

'Recall'
y_title str

Label for the y-axis; defaults to 'Precision'.

'Precision'

Returns:

Type Description
Object

A wandb object suitable for logging, showcasing the crafted metric visualization.

Source code in ultralytics/utils/callbacks/wb.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
    """
    Create and log a custom metric visualization to wandb.plot.pr_curve.

    This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
    curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
    different classes.

    Args:
        x (list): Values for the x-axis; expected to have length N.
        y (list): Corresponding values for the y-axis; also expected to have length N.
        classes (list): Labels identifying the class of each point; length N.
        title (str): Title for the plot; defaults to 'Precision Recall Curve'.
        x_title (str): Label for the x-axis; defaults to 'Recall'.
        y_title (str): Label for the y-axis; defaults to 'Precision'.

    Returns:
        (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
    """
    import pandas  # scope for faster 'import ultralytics'

    df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
    fields = {"x": "x", "y": "y", "class": "class"}
    string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
    return wb.plot_table(
        "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
    )





ultralytics.utils.callbacks.wb._plot_curve

_plot_curve(
    x,
    y,
    names=None,
    id="precision-recall",
    title="Precision Recall Curve",
    x_title="Recall",
    y_title="Precision",
    num_x=100,
    only_mean=False,
)

Log a metric curve visualization.

This function generates a metric curve based on input data and logs the visualization to wandb. The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.

Parameters:

Name Type Description Default
x ndarray

Data points for the x-axis with length N.

required
y ndarray

Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.

required
names list

Names of the classes corresponding to the y-axis data; length C.

None
id str

Unique identifier for the logged data in wandb.

'precision-recall'
title str

Title for the visualization plot.

'Precision Recall Curve'
x_title str

Label for the x-axis.

'Recall'
y_title str

Label for the y-axis.

'Precision'
num_x int

Number of interpolated data points for visualization.

100
only_mean bool

Flag to indicate if only the mean curve should be plotted.

False
Notes

The function leverages the '_custom_table' function to generate the actual visualization.

Source code in ultralytics/utils/callbacks/wb.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def _plot_curve(
    x,
    y,
    names=None,
    id="precision-recall",
    title="Precision Recall Curve",
    x_title="Recall",
    y_title="Precision",
    num_x=100,
    only_mean=False,
):
    """
    Log a metric curve visualization.

    This function generates a metric curve based on input data and logs the visualization to wandb.
    The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.

    Args:
        x (np.ndarray): Data points for the x-axis with length N.
        y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.
        names (list): Names of the classes corresponding to the y-axis data; length C.
        id (str): Unique identifier for the logged data in wandb.
        title (str): Title for the visualization plot.
        x_title (str): Label for the x-axis.
        y_title (str): Label for the y-axis.
        num_x (int): Number of interpolated data points for visualization.
        only_mean (bool): Flag to indicate if only the mean curve should be plotted.

    Notes:
        The function leverages the '_custom_table' function to generate the actual visualization.
    """
    import numpy as np

    # Create new x
    if names is None:
        names = []
    x_new = np.linspace(x[0], x[-1], num_x).round(5)

    # Create arrays for logging
    x_log = x_new.tolist()
    y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()

    if only_mean:
        table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
        wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
    else:
        classes = ["mean"] * len(x_log)
        for i, yi in enumerate(y):
            x_log.extend(x_new)  # add new x
            y_log.extend(np.interp(x_new, x, yi))  # interpolate y to new x
            classes.extend([names[i]] * len(x_new))  # add class names
        wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)





ultralytics.utils.callbacks.wb._log_plots

_log_plots(plots, step)

Log plots to WandB at a specific step if they haven't been logged already.

This function checks each plot in the input dictionary against previously processed plots and logs new or updated plots to WandB at the specified step.

Parameters:

Name Type Description Default
plots dict

Dictionary of plots to log, where keys are plot names and values are dictionaries containing plot metadata including timestamps.

required
step int

The step/epoch at which to log the plots in the WandB run.

required
Notes
  • The function uses a shallow copy of the plots dictionary to prevent modification during iteration
  • Plots are identified by their stem name (filename without extension)
  • Each plot is logged as a WandB Image object
Source code in ultralytics/utils/callbacks/wb.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def _log_plots(plots, step):
    """
    Log plots to WandB at a specific step if they haven't been logged already.

    This function checks each plot in the input dictionary against previously processed plots and logs
    new or updated plots to WandB at the specified step.

    Args:
        plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries
            containing plot metadata including timestamps.
        step (int): The step/epoch at which to log the plots in the WandB run.

    Notes:
        - The function uses a shallow copy of the plots dictionary to prevent modification during iteration
        - Plots are identified by their stem name (filename without extension)
        - Each plot is logged as a WandB Image object
    """
    for name, params in plots.copy().items():  # shallow copy to prevent plots dict changing during iteration
        timestamp = params["timestamp"]
        if _processed_plots.get(name) != timestamp:
            wb.run.log({name.stem: wb.Image(str(name))}, step=step)
            _processed_plots[name] = timestamp





ultralytics.utils.callbacks.wb.on_pretrain_routine_start

on_pretrain_routine_start(trainer)

Initiate and start wandb project if module is present.

Source code in ultralytics/utils/callbacks/wb.py
125
126
127
128
129
130
131
132
def on_pretrain_routine_start(trainer):
    """Initiate and start wandb project if module is present."""
    if not wb.run:
        wb.init(
            project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
            name=str(trainer.args.name).replace("/", "-"),
            config=vars(trainer.args),
        )





ultralytics.utils.callbacks.wb.on_fit_epoch_end

on_fit_epoch_end(trainer)

Log training metrics and model information at the end of an epoch.

Source code in ultralytics/utils/callbacks/wb.py
135
136
137
138
139
140
141
def on_fit_epoch_end(trainer):
    """Log training metrics and model information at the end of an epoch."""
    wb.run.log(trainer.metrics, step=trainer.epoch + 1)
    _log_plots(trainer.plots, step=trainer.epoch + 1)
    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
    if trainer.epoch == 0:
        wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)





ultralytics.utils.callbacks.wb.on_train_epoch_end

on_train_epoch_end(trainer)

Log metrics and save images at the end of each training epoch.

Source code in ultralytics/utils/callbacks/wb.py
144
145
146
147
148
149
def on_train_epoch_end(trainer):
    """Log metrics and save images at the end of each training epoch."""
    wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
    wb.run.log(trainer.lr, step=trainer.epoch + 1)
    if trainer.epoch == 1:
        _log_plots(trainer.plots, step=trainer.epoch + 1)





ultralytics.utils.callbacks.wb.on_train_end

on_train_end(trainer)

Save the best model as an artifact and log final plots at the end of training.

Source code in ultralytics/utils/callbacks/wb.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def on_train_end(trainer):
    """Save the best model as an artifact and log final plots at the end of training."""
    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
    _log_plots(trainer.plots, step=trainer.epoch + 1)
    art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
    if trainer.best.exists():
        art.add_file(trainer.best)
        wb.run.log_artifact(art, aliases=["best"])
    # Check if we actually have plots to save
    if trainer.args.plots and hasattr(trainer.validator.metrics, "curves_results"):
        for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
            x, y, x_title, y_title = curve_values
            _plot_curve(
                x,
                y,
                names=list(trainer.validator.metrics.names.values()),
                id=f"curves/{curve_name}",
                title=curve_name,
                x_title=x_title,
                y_title=y_title,
            )
    wb.run.finish()  # required or run continues on dashboard





📅 Created 1 year ago ✏️ Updated 8 months ago