Zum Inhalt springen

Referenz fĂŒr ultralytics/utils/callbacks/wb.py

Hinweis

Diese Datei ist verfĂŒgbar unter https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/utils/callbacks/wb .py. Wenn du ein Problem entdeckst, hilf bitte mit, es zu beheben, indem du einen Pull Request đŸ› ïž einreichst. Vielen Dank 🙏!



ultralytics.utils.callbacks.wb._custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall', y_title='Precision')

Erstelle und protokolliere eine benutzerdefinierte metrische Visualisierung in wandb.plot.pr_curve.

Diese Funktion erstellt eine benutzerdefinierte metrische Visualisierung, die das Verhalten der Standard-PrĂ€zisions-RĂŒckrufkurve der wandb nachahmt Kurve nachempfunden ist und gleichzeitig eine bessere Anpassung ermöglicht. Die visuelle Metrik ist nĂŒtzlich, um die Leistung des Modells ĂŒber verschiedenen Klassen.

Parameter:

Name Typ Beschreibung Standard
x List

Werte fĂŒr die x-Achse; erwartet wird die LĂ€nge N.

erforderlich
y List

Entsprechende Werte fĂŒr die y-Achse; auch hier wird die LĂ€nge N erwartet.

erforderlich
classes List

Etiketten, die die Klasse jedes Punktes angeben; LĂ€nge N.

erforderlich
title str

Titel fĂŒr das Diagramm; Standardwert ist "Precision Recall Curve".

'Precision Recall Curve'
x_title str

Beschriftung fĂŒr die x-Achse; Standardwert ist "Recall".

'Recall'
y_title str

Beschriftung fĂŒr die y-Achse; Standardwert ist "PrĂ€zision".

'Precision'

Retouren:

Typ Beschreibung
Object

Ein wandb-Objekt, das fĂŒr die Protokollierung geeignet ist und die erstellte Metrik visualisiert.

Quellcode in ultralytics/utils/callbacks/wb.py
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
    """
    Create and log a custom metric visualization to wandb.plot.pr_curve.

    This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall
    curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
    different classes.

    Args:
        x (List): Values for the x-axis; expected to have length N.
        y (List): Corresponding values for the y-axis; also expected to have length N.
        classes (List): Labels identifying the class of each point; length N.
        title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
        x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
        y_title (str, optional): Label for the y-axis; defaults to 'Precision'.

    Returns:
        (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
    """
    df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
    fields = {"x": "x", "y": "y", "class": "class"}
    string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
    return wb.plot_table(
        "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
    )



ultralytics.utils.callbacks.wb._plot_curve(x, y, names=None, id='precision-recall', title='Precision Recall Curve', x_title='Recall', y_title='Precision', num_x=100, only_mean=False)

Protokolliere eine metrische Kurvendarstellung.

Diese Funktion erzeugt eine metrische Kurve aus den Eingabedaten und protokolliert die Visualisierung in der wandb. Die Kurve kann aggregierte Daten (Mittelwert) oder einzelne Klassendaten darstellen, je nachdem, ob das Flag "only_mean" gesetzt ist.

Parameter:

Name Typ Beschreibung Standard
x ndarray

Datenpunkte auf der x-Achse mit der LĂ€nge N.

erforderlich
y ndarray

Entsprechende Datenpunkte auf der y-Achse mit der Form CxN, wobei C die Anzahl der Klassen ist.

erforderlich
names list

Namen der Klassen, die den Daten auf der y-Achse entsprechen; LĂ€nge C. Standardwert ist [].

None
id str

Eindeutiger Bezeichner fĂŒr die protokollierten Daten in der wandb. Der Standardwert ist "precision-recall".

'precision-recall'
title str

Titel fĂŒr die Visualisierungsgrafik. Der Standardwert ist "Precision Recall Curve".

'Precision Recall Curve'
x_title str

Beschriftung fĂŒr die x-Achse. Der Standardwert ist "RĂŒckruf".

'Recall'
y_title str

Beschriftung fĂŒr die y-Achse. Der Standardwert ist "PrĂ€zision".

'Precision'
num_x int

Anzahl der interpolierten Datenpunkte fĂŒr die Visualisierung. Der Standardwert ist 100.

100
only_mean bool

Flagge, die angibt, ob nur die mittlere Kurve gezeichnet werden soll. Der Standardwert ist True.

False
Hinweis

Die Funktion nutzt die Funktion "_custom_table", um die eigentliche Visualisierung zu erstellen.

Quellcode in ultralytics/utils/callbacks/wb.py
def _plot_curve(
    x,
    y,
    names=None,
    id="precision-recall",
    title="Precision Recall Curve",
    x_title="Recall",
    y_title="Precision",
    num_x=100,
    only_mean=False,
):
    """
    Log a metric curve visualization.

    This function generates a metric curve based on input data and logs the visualization to wandb.
    The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.

    Args:
        x (np.ndarray): Data points for the x-axis with length N.
        y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
        names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
        id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
        title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
        x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
        y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
        num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
        only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.

    Note:
        The function leverages the '_custom_table' function to generate the actual visualization.
    """
    # Create new x
    if names is None:
        names = []
    x_new = np.linspace(x[0], x[-1], num_x).round(5)

    # Create arrays for logging
    x_log = x_new.tolist()
    y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()

    if only_mean:
        table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
        wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
    else:
        classes = ["mean"] * len(x_log)
        for i, yi in enumerate(y):
            x_log.extend(x_new)  # add new x
            y_log.extend(np.interp(x_new, x, yi))  # interpolate y to new x
            classes.extend([names[i]] * len(x_new))  # add class names
        wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)



ultralytics.utils.callbacks.wb._log_plots(plots, step)

Protokolliert Plots aus dem Eingabewörterbuch, wenn sie nicht bereits im angegebenen Schritt protokolliert wurden.

Quellcode in ultralytics/utils/callbacks/wb.py
def _log_plots(plots, step):
    """Logs plots from the input dictionary if they haven't been logged already at the specified step."""
    for name, params in plots.items():
        timestamp = params["timestamp"]
        if _processed_plots.get(name) != timestamp:
            wb.run.log({name.stem: wb.Image(str(name))}, step=step)
            _processed_plots[name] = timestamp



ultralytics.utils.callbacks.wb.on_pretrain_routine_start(trainer)

Initiiere und starte das Projekt, wenn das Modul vorhanden ist.

Quellcode in ultralytics/utils/callbacks/wb.py
def on_pretrain_routine_start(trainer):
    """Initiate and start project if module is present."""
    wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))



ultralytics.utils.callbacks.wb.on_fit_epoch_end(trainer)

Protokolliert Trainingsmetriken und Modellinformationen am Ende einer Epoche.

Quellcode in ultralytics/utils/callbacks/wb.py
def on_fit_epoch_end(trainer):
    """Logs training metrics and model information at the end of an epoch."""
    wb.run.log(trainer.metrics, step=trainer.epoch + 1)
    _log_plots(trainer.plots, step=trainer.epoch + 1)
    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
    if trainer.epoch == 0:
        wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)



ultralytics.utils.callbacks.wb.on_train_epoch_end(trainer)

Protokolliere Metriken und speichere Bilder am Ende jeder Trainingsepoche.

Quellcode in ultralytics/utils/callbacks/wb.py
def on_train_epoch_end(trainer):
    """Log metrics and save images at the end of each training epoch."""
    wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
    wb.run.log(trainer.lr, step=trainer.epoch + 1)
    if trainer.epoch == 1:
        _log_plots(trainer.plots, step=trainer.epoch + 1)



ultralytics.utils.callbacks.wb.on_train_end(trainer)

Speichere das beste Modell als Artefakt am Ende des Trainings.

Quellcode in ultralytics/utils/callbacks/wb.py
def on_train_end(trainer):
    """Save the best model as an artifact at end of training."""
    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
    _log_plots(trainer.plots, step=trainer.epoch + 1)
    art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
    if trainer.best.exists():
        art.add_file(trainer.best)
        wb.run.log_artifact(art, aliases=["best"])
    for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
        x, y, x_title, y_title = curve_values
        _plot_curve(
            x,
            y,
            names=list(trainer.validator.metrics.names.values()),
            id=f"curves/{curve_name}",
            title=curve_name,
            x_title=x_title,
            y_title=y_title,
        )
    wb.run.finish()  # required or run continues on dashboard





Erstellt am 2023-11-12, Aktualisiert am 2023-11-25
Autoren: glenn-jocher (3), Laughing-q (1)