Skip to content

Référence pour ultralytics/utils/callbacks/wb.py

Note

Ce fichier est disponible à l'adresse https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/utils/callbacks/wb .py. Si tu repères un problème, aide à le corriger en contribuant à une Pull Request 🛠️. Merci 🙏 !



ultralytics.utils.callbacks.wb._custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall', y_title='Precision')

Crée et enregistre une visualisation métrique personnalisée dans wandb.plot.pr_curve.

This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across different classes.

Paramètres :

Nom Type Description DĂ©faut
x List

Valeurs pour l'axe des x ; on s'attend Ă  ce qu'elles aient une longueur N.

requis
y List

Valeurs correspondantes pour l'axe des ordonnées ; on s'attend également à ce qu'elles aient une longueur N.

requis
classes List

Étiquettes identifiant la classe de chaque point ; longueur N.

requis
title str

Titre du graphique ; la valeur par défaut est "Courbe de précision et de rappel".

'Precision Recall Curve'
x_title str

Étiquette pour l'axe des x ; la valeur par défaut est "Rappel".

'Recall'
y_title str

Étiquette pour l'axe des ordonnées ; la valeur par défaut est "Précision".

'Precision'

Retourne :

Type Description
Object

Un objet wandb adapté à la journalisation, présentant la visualisation de la métrique élaborée.

Code source dans ultralytics/utils/callbacks/wb.py
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
    """
    Create and log a custom metric visualization to wandb.plot.pr_curve.

    This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
    curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
    different classes.

    Args:
        x (List): Values for the x-axis; expected to have length N.
        y (List): Corresponding values for the y-axis; also expected to have length N.
        classes (List): Labels identifying the class of each point; length N.
        title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
        x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
        y_title (str, optional): Label for the y-axis; defaults to 'Precision'.

    Returns:
        (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
    """
    import pandas  # scope for faster 'import ultralytics'

    df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
    fields = {"x": "x", "y": "y", "class": "class"}
    string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
    return wb.plot_table(
        "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
    )



ultralytics.utils.callbacks.wb._plot_curve(x, y, names=None, id='precision-recall', title='Precision Recall Curve', x_title='Recall', y_title='Precision', num_x=100, only_mean=False)

Enregistre une visualisation de courbe métrique.

Cette fonction génère une courbe métrique basée sur les données d'entrée et enregistre la visualisation dans wandb. La courbe peut représenter des données agrégées (moyenne) ou des données de classes individuelles, en fonction de l'indicateur 'only_mean'.

Paramètres :

Nom Type Description DĂ©faut
x ndarray

Points de données pour l'axe des x de longueur N.

requis
y ndarray

Points de données correspondants pour l'axe des y avec la forme CxN, où C est le nombre de classes.

requis
names list

Noms des classes correspondant aux données de l'axe des ordonnées ; longueur C. La valeur par défaut est [].

None
id str

Identifiant unique pour les données enregistrées dans wandb. La valeur par défaut est "precision-recall".

'precision-recall'
title str

Titre du graphique de visualisation. La valeur par défaut est "Courbe de précision et de rappel".

'Precision Recall Curve'
x_title str

Étiquette pour l'axe des x. La valeur par défaut est "Rappel".

'Recall'
y_title str

Étiquette pour l'axe des ordonnées. La valeur par défaut est "Précision".

'Precision'
num_x int

Nombre de points de données interpolés pour la visualisation. La valeur par défaut est 100.

100
only_mean bool

Drapeau indiquant si seule la courbe de la moyenne doit être tracée. La valeur par défaut est True.

False
Note

Cette fonction s'appuie sur la fonction '_custom_table' pour générer la visualisation proprement dite.

Code source dans ultralytics/utils/callbacks/wb.py
def _plot_curve(
    x,
    y,
    names=None,
    id="precision-recall",
    title="Precision Recall Curve",
    x_title="Recall",
    y_title="Precision",
    num_x=100,
    only_mean=False,
):
    """
    Log a metric curve visualization.

    This function generates a metric curve based on input data and logs the visualization to wandb.
    The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.

    Args:
        x (np.ndarray): Data points for the x-axis with length N.
        y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
        names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
        id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
        title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
        x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
        y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
        num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
        only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.

    Note:
        The function leverages the '_custom_table' function to generate the actual visualization.
    """
    import numpy as np

    # Create new x
    if names is None:
        names = []
    x_new = np.linspace(x[0], x[-1], num_x).round(5)

    # Create arrays for logging
    x_log = x_new.tolist()
    y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()

    if only_mean:
        table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
        wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
    else:
        classes = ["mean"] * len(x_log)
        for i, yi in enumerate(y):
            x_log.extend(x_new)  # add new x
            y_log.extend(np.interp(x_new, x, yi))  # interpolate y to new x
            classes.extend([names[i]] * len(x_new))  # add class names
        wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)



ultralytics.utils.callbacks.wb._log_plots(plots, step)

Enregistre les parcelles du dictionnaire d'entrée si elles n'ont pas déjà été enregistrées à l'étape spécifiée.

Code source dans ultralytics/utils/callbacks/wb.py
def _log_plots(plots, step):
    """Logs plots from the input dictionary if they haven't been logged already at the specified step."""
    for name, params in plots.copy().items():  # shallow copy to prevent plots dict changing during iteration
        timestamp = params["timestamp"]
        if _processed_plots.get(name) != timestamp:
            wb.run.log({name.stem: wb.Image(str(name))}, step=step)
            _processed_plots[name] = timestamp



ultralytics.utils.callbacks.wb.on_pretrain_routine_start(trainer)

Initie et démarre le projet si le module est présent.

Code source dans ultralytics/utils/callbacks/wb.py
def on_pretrain_routine_start(trainer):
    """Initiate and start project if module is present."""
    wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))



ultralytics.utils.callbacks.wb.on_fit_epoch_end(trainer)

Enregistre les métriques d'entraînement et les informations sur le modèle à la fin d'une époque.

Code source dans ultralytics/utils/callbacks/wb.py
def on_fit_epoch_end(trainer):
    """Logs training metrics and model information at the end of an epoch."""
    wb.run.log(trainer.metrics, step=trainer.epoch + 1)
    _log_plots(trainer.plots, step=trainer.epoch + 1)
    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
    if trainer.epoch == 0:
        wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)



ultralytics.utils.callbacks.wb.on_train_epoch_end(trainer)

Enregistre les métriques et sauvegarde les images à la fin de chaque période d'apprentissage.

Code source dans ultralytics/utils/callbacks/wb.py
def on_train_epoch_end(trainer):
    """Log metrics and save images at the end of each training epoch."""
    wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
    wb.run.log(trainer.lr, step=trainer.epoch + 1)
    if trainer.epoch == 1:
        _log_plots(trainer.plots, step=trainer.epoch + 1)



ultralytics.utils.callbacks.wb.on_train_end(trainer)

Sauvegarde le meilleur modèle en tant qu'artefact à la fin de la formation.

Code source dans ultralytics/utils/callbacks/wb.py
def on_train_end(trainer):
    """Save the best model as an artifact at end of training."""
    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
    _log_plots(trainer.plots, step=trainer.epoch + 1)
    art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
    if trainer.best.exists():
        art.add_file(trainer.best)
        wb.run.log_artifact(art, aliases=["best"])
    for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
        x, y, x_title, y_title = curve_values
        _plot_curve(
            x,
            y,
            names=list(trainer.validator.metrics.names.values()),
            id=f"curves/{curve_name}",
            title=curve_name,
            x_title=x_title,
            y_title=y_title,
        )
    wb.run.finish()  # required or run continues on dashboard





Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (5), Burhan-Q (1), Laughing-q (1)