跳至内容

参考资料 ultralytics/utils/callbacks/mlflow.py

备注

该文件可在https://github.com/ultralytics/ultralytics/blob/main/ ultralytics/utils/callbacks/mlflow .py 下找到。如果您发现问题,请通过提交 Pull Request🛠️ 帮助修复。谢谢🙏!



ultralytics.utils.callbacks.mlflow.on_pretrain_routine_end(trainer)

在预训练程序结束时,将训练参数记录到 MLflow 中。

此函数根据环境变量和训练参数设置 MLflow 日志。它会设置跟踪 URI、 实验名称和运行名称,然后启动尚未激活的 MLflow 运行。最后记录 来自训练器的参数。

参数

名称 类型 说明 默认值
trainer BaseTrainer

包含要记录的参数的训练对象。

所需
全球

mlflow:用于记录日志的导入 mlflow 模块。

环境变量

mlflow_tracking_uri:MLflow 跟踪的 URI。如果未设置,默认为 "runs/mlflow"。 mlflow_experiment_name:MLflow 实验的名称。如果未设置,默认为 trainer.args.project。 MLFLOW_RUN:MLflow 运行的名称。如果未设置,则默认为 trainer.args.name。

源代码 ultralytics/utils/callbacks/mlflow.py
def on_pretrain_routine_end(trainer):
    """
    Log training parameters to MLflow at the end of the pretraining routine.

    This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
    experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
    from the trainer.

    Args:
        trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.

    Global:
        mlflow: The imported mlflow module to use for logging.

    Environment Variables:
        MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
        MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
        MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
    """
    global mlflow

    uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
    LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
    mlflow.set_tracking_uri(uri)

    # Set experiment and run names
    experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
    run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
    mlflow.set_experiment(experiment_name)

    mlflow.autolog()
    try:
        active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
        LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
        if Path(uri).is_dir():
            LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
        LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
        mlflow.log_params(dict(trainer.args))
    except Exception as e:
        LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")



ultralytics.utils.callbacks.mlflow.on_train_epoch_end(trainer)

在每个训练历元结束时,将训练指标记录到 MLflow。

源代码 ultralytics/utils/callbacks/mlflow.py
def on_train_epoch_end(trainer):
    """Log training metrics at the end of each train epoch to MLflow."""
    if mlflow:
        mlflow.log_metrics(
            metrics={
                **SANITIZE(trainer.lr),
                **SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")),
            },
            step=trainer.epoch,
        )



ultralytics.utils.callbacks.mlflow.on_fit_epoch_end(trainer)

将每个拟合历元结束时的训练指标记录到 MLflow 中。

源代码 ultralytics/utils/callbacks/mlflow.py
def on_fit_epoch_end(trainer):
    """Log training metrics at the end of each fit epoch to MLflow."""
    if mlflow:
        mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch)



ultralytics.utils.callbacks.mlflow.on_train_end(trainer)

在培训结束时记录模型工件。

源代码 ultralytics/utils/callbacks/mlflow.py
def on_train_end(trainer):
    """Log model artifacts at the end of the training."""
    if mlflow:
        mlflow.log_artifact(str(trainer.best.parent))  # log save_dir/weights directory with best.pt and last.pt
        for f in trainer.save_dir.glob("*"):  # log all other files in save_dir
            if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
                mlflow.log_artifact(str(f))

        mlflow.end_run()
        LOGGER.info(
            f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
            f"{PREFIX}disable with 'yolo settings mlflow=False'"
        )





创建于 2023-11-12,更新于 2023-12-01
作者:glenn-jocher(4),Laughing-q(1)