コンテンツへスキップ

参考 ultralytics/utils/callbacks/mlflow.py

備考

このファイルはhttps://github.com/ultralytics/ultralytics/blob/main/ ultralytics/utils/callbacks/mlflow .py にあります。もし問題を発見したら、Pull Request🛠️ を投稿して修正にご協力ください。ありがとうございます🙏!



ultralytics.utils.callbacks.mlflow.on_pretrain_routine_end(trainer)

プレトレーニングルーチンの最後にトレーニングパラメータをMLflowに記録する。

この関数は環境変数とトレーナの引数に基づいてMLflowのロギングを設定します。トラッキングURIを設定します、 実験名、実行名を設定し、まだアクティブでなければMLflowの実行を開始します。最後に、トレーナーからのパラメータ を記録します。

パラメーター

名称 タイプ 説明 デフォルト
trainer BaseTrainer

ログに記録する引数とパラメータを持つトレーニングオブジェクト。

必須
グローバル

mlflow:ロギングに使用するインポートされたmlflowモジュール。

環境変数

mlflow_tracking_uri:MLflowトラッキングのURI。設定されていない場合、デフォルトは 'runs/mlflow' です。 mlflow_experiment_name: MLflow の実験名:MLflow実験の名前。設定されていない場合、デフォルトは trainer.args.project です。 MLFLOW_RUN: MLflow ランの名前。設定されていない場合、デフォルトはtrainer.args.nameです。

ソースコード ultralytics/utils/callbacks/mlflow.py
def on_pretrain_routine_end(trainer):
    """
    Log training parameters to MLflow at the end of the pretraining routine.

    This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
    experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
    from the trainer.

    Args:
        trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.

    Global:
        mlflow: The imported mlflow module to use for logging.

    Environment Variables:
        MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
        MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
        MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
    """
    global mlflow

    uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
    LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
    mlflow.set_tracking_uri(uri)

    # Set experiment and run names
    experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
    run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
    mlflow.set_experiment(experiment_name)

    mlflow.autolog()
    try:
        active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
        LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
        if Path(uri).is_dir():
            LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
        LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
        mlflow.log_params(dict(trainer.args))
    except Exception as e:
        LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")



ultralytics.utils.callbacks.mlflow.on_train_epoch_end(trainer)

各訓練エポック終了時に訓練メトリクスをMLflowに記録する。

ソースコード ultralytics/utils/callbacks/mlflow.py
def on_train_epoch_end(trainer):
    """Log training metrics at the end of each train epoch to MLflow."""
    if mlflow:
        mlflow.log_metrics(
            metrics={
                **SANITIZE(trainer.lr),
                **SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")),
            },
            step=trainer.epoch,
        )



ultralytics.utils.callbacks.mlflow.on_fit_epoch_end(trainer)

各フィット・エポックの終了時にトレーニング・メトリクスをMLflowに記録する。

ソースコード ultralytics/utils/callbacks/mlflow.py
def on_fit_epoch_end(trainer):
    """Log training metrics at the end of each fit epoch to MLflow."""
    if mlflow:
        mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch)



ultralytics.utils.callbacks.mlflow.on_train_end(trainer)

トレーニング終了時にモデルの成果物を記録する。

ソースコード ultralytics/utils/callbacks/mlflow.py
def on_train_end(trainer):
    """Log model artifacts at the end of the training."""
    if mlflow:
        mlflow.log_artifact(str(trainer.best.parent))  # log save_dir/weights directory with best.pt and last.pt
        for f in trainer.save_dir.glob("*"):  # log all other files in save_dir
            if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
                mlflow.log_artifact(str(f))

        mlflow.end_run()
        LOGGER.info(
            f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
            f"{PREFIX}disable with 'yolo settings mlflow=False'"
        )





作成日:2023-11-12 更新日:2023-12-01
作成者:glenn-jocher(4),Laughing-q(1)