Skip to content

Reference for ultralytics/utils/tuner.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/tuner.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


function ultralytics.utils.tuner._sanitize_tune_value

def _sanitize_tune_value(value: dict)

Convert NumPy-backed Tune values into native Python types for YAML serialization.

Args

NameTypeDescriptionDefault
valuedictThe value to convert. Can be a dict, list, tuple, NumPy scalar, or NumPy array.required

Returns

TypeDescription
The converted value with NumPy types replaced by native Python types.
Source code in ultralytics/utils/tuner.pyView on GitHub
def _sanitize_tune_value(value: dict):
    """Convert NumPy-backed Tune values into native Python types for YAML serialization.

    Args:
        value (dict): The value to convert. Can be a dict, list, tuple, NumPy scalar, or NumPy array.

    Returns:
        The converted value with NumPy types replaced by native Python types.
    """
    if isinstance(value, dict):
        return {k: _sanitize_tune_value(v) for k, v in value.items()}
    if isinstance(value, list):
        return [_sanitize_tune_value(v) for v in value]
    if isinstance(value, tuple):
        return tuple(_sanitize_tune_value(v) for v in value)
    if isinstance(value, np.generic):
        return value.item()
    if isinstance(value, np.ndarray):
        return value.tolist()
    return value





function ultralytics.utils.tuner._get_ray_search_alg_kind

def _get_ray_search_alg_kind(search_alg)

Return the normalized Ray Tune search algorithm kind for known searcher objects.

Args

NameTypeDescriptionDefault
search_algstr | ray.tune.search.SearcherThe search algorithm to identify. Can be None, a string, or a Ray
Tune searcher object.
required

Returns

TypeDescription
str | NoneThe normalized search algorithm name, or None if not recognized.
Source code in ultralytics/utils/tuner.pyView on GitHub
def _get_ray_search_alg_kind(search_alg):
    """Return the normalized Ray Tune search algorithm kind for known searcher objects.

    Args:
        search_alg (str | ray.tune.search.Searcher): The search algorithm to identify. Can be None, a string, or a Ray
            Tune searcher object.

    Returns:
        str | None: The normalized search algorithm name, or None if not recognized.
    """
    if search_alg is None:
        return None
    if isinstance(search_alg, str):
        normalized = search_alg.strip().lower()
        return normalized or None

    cls = search_alg.__class__
    module, name = cls.__module__, cls.__name__
    if name == "AxSearch" and module.startswith("ray.tune.search.ax"):
        return "ax"
    if name == "TuneBOHB" and module.startswith("ray.tune.search.bohb"):
        return "bohb"
    if name == "ZOOptSearch" and module.startswith("ray.tune.search.zoopt"):
        return "zoopt"
    return None





function ultralytics.utils.tuner._validate_ax_search_space

def _validate_ax_search_space(space)

Validate that a Tune search space can be consumed by Ax.

Args

NameTypeDescriptionDefault
spacedictThe hyperparameter search space to validate.required

Returns

TypeDescription
listThe converted Ax parameters.

Raises

TypeDescription
ImportErrorIf the required 'ax-platform' package is not installed.
Source code in ultralytics/utils/tuner.pyView on GitHub
def _validate_ax_search_space(space):
    """Validate that a Tune search space can be consumed by Ax.

    Args:
        space (dict): The hyperparameter search space to validate.

    Returns:
        list: The converted Ax parameters.

    Raises:
        ImportError: If the required 'ax-platform' package is not installed.
    """
    checks.check_requirements(RAY_SEARCH_ALG_REQUIREMENTS["ax"])

    from ray.tune.search.ax.ax_search import AxSearch

    return AxSearch.convert_search_space(space)





def _create_ax_search(space, task)

Create an Ax searcher with an initialized experiment.

Args

NameTypeDescriptionDefault
spacedictThe hyperparameter search space.required
taskstrThe task type (e.g., 'detect', 'segment', 'classify').required

Returns

TypeDescription
AxSearch (ray.tune.search.Searcher)The configured Ax search algorithm.

Raises

TypeDescription
ImportErrorIf required Ax packages are not installed.
Source code in ultralytics/utils/tuner.pyView on GitHub
def _create_ax_search(space, task):
    """Create an Ax searcher with an initialized experiment.

    Args:
        space (dict): The hyperparameter search space.
        task (str): The task type (e.g., 'detect', 'segment', 'classify').

    Returns:
        AxSearch (ray.tune.search.Searcher): The configured Ax search algorithm.

    Raises:
        ImportError: If required Ax packages are not installed.
    """
    parameters = _validate_ax_search_space(space)

    from ax.service.ax_client import AxClient
    from ax.service.utils.instantiation import ObjectiveProperties
    from ray.tune.search.ax.ax_search import AxSearch

    ax_client = AxClient()
    ax_client.create_experiment(
        parameters=parameters,
        objectives={TASK2METRIC[task]: ObjectiveProperties(minimize=False)},
    )
    return AxSearch(ax_client=ax_client)





function ultralytics.utils.tuner._convert_bohb_search_space

def _convert_bohb_search_space(space)

Convert a Tune search space into BOHB-compatible ConfigSpace and fixed-only Tune param_space.

Args

NameTypeDescriptionDefault
spacedictThe hyperparameter search space.required

Returns

TypeDescription
tupleA tuple containing the ConfigSpace object and a dict of fixed parameters.

Raises

TypeDescription
ValueErrorIf the search space contains grid search parameters or unsupported samplers.
ImportErrorIf required BOHB packages are not installed.
Source code in ultralytics/utils/tuner.pyView on GitHub
def _convert_bohb_search_space(space):
    """Convert a Tune search space into BOHB-compatible ConfigSpace and fixed-only Tune param_space.

    Args:
        space (dict): The hyperparameter search space.

    Returns:
        (tuple): A tuple containing the ConfigSpace object and a dict of fixed parameters.

    Raises:
        ValueError: If the search space contains grid search parameters or unsupported samplers.
        ImportError: If required BOHB packages are not installed.
    """
    checks.check_requirements(RAY_SEARCH_ALG_REQUIREMENTS["bohb"])

    import ConfigSpace
    from ray.tune.search.sample import Categorical, Float, Integer, LogUniform, Quantized, Uniform
    from ray.tune.search.variant_generator import parse_spec_vars
    from ray.tune.utils import flatten_dict

    resolved_space = flatten_dict(space, prevent_delimiter=True)
    resolved_vars, domain_vars, grid_vars = parse_spec_vars(resolved_space)
    if grid_vars:
        raise ValueError("Grid search parameters cannot be automatically converted to a TuneBOHB search space.")

    cs = ConfigSpace.ConfigurationSpace()
    for path, domain in domain_vars:
        par = "/".join(str(p) for p in path)
        sampler = domain.get_sampler()
        if isinstance(sampler, Quantized):
            raise ValueError("TuneBOHB does not support quantized search spaces with the current ConfigSpace version.")

        if isinstance(domain, Float) and isinstance(sampler, (Uniform, LogUniform)):
            cs.add(
                ConfigSpace.UniformFloatHyperparameter(
                    par, lower=domain.lower, upper=domain.upper, log=isinstance(sampler, LogUniform)
                )
            )
        elif isinstance(domain, Integer) and isinstance(sampler, (Uniform, LogUniform)):
            upper = domain.upper - 1  # Tune integer search spaces are exclusive on the upper bound
            cs.add(
                ConfigSpace.UniformIntegerHyperparameter(
                    par, lower=domain.lower, upper=upper, log=isinstance(sampler, LogUniform)
                )
            )
        elif isinstance(domain, Categorical) and isinstance(sampler, Uniform):
            cs.add(ConfigSpace.CategoricalHyperparameter(par, choices=domain.categories))
        else:
            raise ValueError(
                f"TuneBOHB does not support parameters of type {type(domain).__name__} "
                f"with sampler type {type(domain.sampler).__name__}."
            )

    fixed_param_space = {"/".join(str(p) for p in path): value for path, value in resolved_vars}
    return cs, fixed_param_space





def _create_bohb_search(space, task)

Create a BOHB searcher using a ConfigSpace definition compatible with current ConfigSpace versions.

Args

NameTypeDescriptionDefault
spacedictThe hyperparameter search space.required
taskstrThe task type (e.g., 'detect', 'segment', 'classify').required

Returns

TypeDescription
tupleA tuple containing the TuneBOHB searcher and fixed parameter space dict.

Raises

TypeDescription
ImportErrorIf required BOHB packages are not installed.
Source code in ultralytics/utils/tuner.pyView on GitHub
def _create_bohb_search(space, task):
    """Create a BOHB searcher using a ConfigSpace definition compatible with current ConfigSpace versions.

    Args:
        space (dict): The hyperparameter search space.
        task (str): The task type (e.g., 'detect', 'segment', 'classify').

    Returns:
        (tuple): A tuple containing the TuneBOHB searcher and fixed parameter space dict.

    Raises:
        ImportError: If required BOHB packages are not installed.
    """
    cs, fixed_param_space = _convert_bohb_search_space(space)

    from ray.tune.search.bohb.bohb_search import TuneBOHB

    return TuneBOHB(space=cs, metric=TASK2METRIC[task], mode="max"), fixed_param_space





def _create_nevergrad_search(task)

Create a Nevergrad searcher with a default optimizer.

Args

NameTypeDescriptionDefault
taskstrThe task type (e.g., 'detect', 'segment', 'classify').required

Returns

TypeDescription
NevergradSearchThe configured Nevergrad search algorithm.

Raises

TypeDescription
ImportErrorIf the 'nevergrad' package is not installed.
Source code in ultralytics/utils/tuner.pyView on GitHub
def _create_nevergrad_search(task):
    """Create a Nevergrad searcher with a default optimizer.

    Args:
        task (str): The task type (e.g., 'detect', 'segment', 'classify').

    Returns:
        (NevergradSearch): The configured Nevergrad search algorithm.

    Raises:
        ImportError: If the 'nevergrad' package is not installed.
    """
    checks.check_requirements(RAY_SEARCH_ALG_REQUIREMENTS["nevergrad"])

    import nevergrad as ng
    from ray.tune.search.nevergrad import NevergradSearch

    return NevergradSearch(optimizer=ng.optimizers.OnePlusOne, metric=TASK2METRIC[task], mode="max")





function ultralytics.utils.tuner._convert_zoopt_search_space

def _convert_zoopt_search_space(space)

Convert a Tune search space into ZOOpt-compatible dimensions and fixed-only Tune param_space.

Args

NameTypeDescriptionDefault
spacedictThe hyperparameter search space.required

Returns

TypeDescription
tupleA tuple containing the ZOOpt dimension dict and fixed parameter space dict.

Raises

TypeDescription
ImportErrorIf the 'zoopt' package is not installed.
Source code in ultralytics/utils/tuner.pyView on GitHub
def _convert_zoopt_search_space(space):
    """Convert a Tune search space into ZOOpt-compatible dimensions and fixed-only Tune param_space.

    Args:
        space (dict): The hyperparameter search space.

    Returns:
        (tuple): A tuple containing the ZOOpt dimension dict and fixed parameter space dict.

    Raises:
        ImportError: If the 'zoopt' package is not installed.
    """
    checks.check_requirements(RAY_SEARCH_ALG_REQUIREMENTS["zoopt"])

    from ray.tune.search.variant_generator import parse_spec_vars
    from ray.tune.search.zoopt import ZOOptSearch
    from ray.tune.utils import flatten_dict

    resolved_space = flatten_dict(space, prevent_delimiter=True)
    resolved_vars, _, _ = parse_spec_vars(resolved_space)
    fixed_param_space = {"/".join(str(p) for p in path): value for path, value in resolved_vars}
    dim_dict = ZOOptSearch.convert_search_space(space)
    return dim_dict, fixed_param_space





def _create_zoopt_search(space, task, iterations)

Create a ZOOpt searcher with required budget and converted search space.

Args

NameTypeDescriptionDefault
spacedictThe hyperparameter search space.required
taskstrThe task type (e.g., 'detect', 'segment', 'classify').required
iterationsintThe maximum number of trials (budget) for ZOOpt.required

Returns

TypeDescription
tupleA tuple containing the ZOOptSearch searcher and fixed parameter space dict.

Raises

TypeDescription
ImportErrorIf the 'zoopt' package is not installed.
Source code in ultralytics/utils/tuner.pyView on GitHub
def _create_zoopt_search(space, task, iterations):
    """Create a ZOOpt searcher with required budget and converted search space.

    Args:
        space (dict): The hyperparameter search space.
        task (str): The task type (e.g., 'detect', 'segment', 'classify').
        iterations (int): The maximum number of trials (budget) for ZOOpt.

    Returns:
        (tuple): A tuple containing the ZOOptSearch searcher and fixed parameter space dict.

    Raises:
        ImportError: If the 'zoopt' package is not installed.
    """
    dim_dict, fixed_param_space = _convert_zoopt_search_space(space)

    from ray.tune.search.zoopt import ZOOptSearch

    return ZOOptSearch(
        algo="asracos", budget=iterations, dim_dict=dim_dict, metric=TASK2METRIC[task], mode="max"
    ), fixed_param_space





function ultralytics.utils.tuner._resolve_ray_search_alg

def _resolve_ray_search_alg(search_alg, task, space, iterations)

Resolve search algorithms and normalize Tune param_space for known Ray Tune searchers.

Args

NameTypeDescriptionDefault
search_algstr | object | NoneThe search algorithm to use. Can be a string name, a pre-instantiated Ray Tune
searcher object, or None for default behavior.
required
taskstrThe task type (e.g., 'detect', 'segment', 'classify').required
spacedictThe hyperparameter search space.required
iterationsintThe maximum number of trials to run.required

Returns

TypeDescription
tupleA tuple containing (resolved_search_alg, tuner_param_space, resolved_search_alg_kind).

Raises

TypeDescription
ValueErrorIf an unsupported search_alg string is provided.
ModuleNotFoundErrorIf required dependencies for the chosen algorithm are not installed.
Source code in ultralytics/utils/tuner.pyView on GitHub
def _resolve_ray_search_alg(search_alg, task, space, iterations):
    """Resolve search algorithms and normalize Tune param_space for known Ray Tune searchers.

    Args:
        search_alg (str | object | None): The search algorithm to use. Can be a string name, a pre-instantiated Ray Tune
            searcher object, or None for default behavior.
        task (str): The task type (e.g., 'detect', 'segment', 'classify').
        space (dict): The hyperparameter search space.
        iterations (int): The maximum number of trials to run.

    Returns:
        (tuple): A tuple containing (resolved_search_alg, tuner_param_space, resolved_search_alg_kind).
            - resolved_search_alg: The configured searcher or None.
            - tuner_param_space: The normalized parameter space for the tuner.
            - resolved_search_alg_kind: The normalized algorithm name or None.

    Raises:
        ValueError: If an unsupported search_alg string is provided.
        ModuleNotFoundError: If required dependencies for the chosen algorithm are not installed.
    """
    if search_alg is None:
        return None, space, None

    normalized = _get_ray_search_alg_kind(search_alg)
    if isinstance(search_alg, str):
        if not normalized:
            return None, space, None
        if normalized not in RAY_SEARCH_ALG_REQUIREMENTS:
            supported = ", ".join(sorted(RAY_SEARCH_ALG_REQUIREMENTS))
            raise ValueError(f"Unsupported Ray Tune search_alg '{search_alg}'. Supported values: {supported}.")
        if normalized == "random":
            return None, space, normalized

    try:
        if normalized == "ax":
            if isinstance(search_alg, str):
                return _create_ax_search(space, task), {}, normalized
            _validate_ax_search_space(space)
            return search_alg, {}, normalized
        if normalized == "bohb":
            if isinstance(search_alg, str):
                resolved_search_alg, tuner_param_space = _create_bohb_search(space, task)
            else:
                _, tuner_param_space = _convert_bohb_search_space(space)
                resolved_search_alg = search_alg
            return resolved_search_alg, tuner_param_space, normalized
        if normalized == "nevergrad":
            return _create_nevergrad_search(task), space, normalized
        if normalized == "zoopt":
            if isinstance(search_alg, str):
                resolved_search_alg, tuner_param_space = _create_zoopt_search(space, task, iterations)
            else:
                _, tuner_param_space = _convert_zoopt_search_space(space)
                resolved_search_alg = search_alg
            return resolved_search_alg, tuner_param_space, normalized
        if not isinstance(search_alg, str):
            return search_alg, space, None

        requirements = RAY_SEARCH_ALG_REQUIREMENTS[normalized]
        if requirements:
            checks.check_requirements(requirements)

        from ray.tune.search import create_searcher

        return create_searcher(normalized, metric=TASK2METRIC[task], mode="max"), space, normalized
    except (ImportError, ModuleNotFoundError) as e:
        raise ModuleNotFoundError(
            f"Ray Tune search_alg '{search_alg}' requires additional dependencies. Original error: {e}"
        ) from e





function ultralytics.utils.tuner.run_ray_tune

def run_ray_tune(
    model,
    space: dict | None = None,
    grace_period: int = 10,
    gpu_per_trial: int | None = None,
    iterations: int = 10,
    search_alg=None,
    **train_args,
)

Run hyperparameter tuning using Ray Tune.

Args

NameTypeDescriptionDefault
modelYOLOModel to run the tuner on.required
spacedict, optionalThe hyperparameter search space. If not provided, uses default space.None
grace_periodint, optionalThe grace period in epochs of the ASHA scheduler.10
gpu_per_trialint, optionalThe number of GPUs to allocate per trial.None
iterationsint, optionalThe maximum number of trials to run.10
search_algstr | ray.tune.search.Searcher | ray.tune.search.SearchAlgorithm, optionalSearch algorithm to
use. Strings are resolved to supported Ray Tune searchers. Pre-instantiated objects are reused, and known
searchers with special Tune param_space requirements are normalized automatically.
None
**train_argsAnyAdditional arguments to pass to the train() method.required

Returns

TypeDescription
ray.tune.ResultGridA ResultGrid containing the results of the hyperparameter search.

Examples

>>> from ultralytics import YOLO
>>> model = YOLO("yolo26n.pt")  # Load a YOLO26n model

Start tuning hyperparameters for YOLO26n training on the COCO8 dataset
>>> result_grid = model.tune(data="coco8.yaml", use_ray=True)
Source code in ultralytics/utils/tuner.pyView on GitHub
def run_ray_tune(
    model,
    space: dict | None = None,
    grace_period: int = 10,
    gpu_per_trial: int | None = None,
    iterations: int = 10,
    search_alg=None,
    **train_args,
):
    """Run hyperparameter tuning using Ray Tune.

    Args:
        model (YOLO): Model to run the tuner on.
        space (dict, optional): The hyperparameter search space. If not provided, uses default space.
        grace_period (int, optional): The grace period in epochs of the ASHA scheduler.
        gpu_per_trial (int, optional): The number of GPUs to allocate per trial.
        iterations (int, optional): The maximum number of trials to run.
        search_alg (str | ray.tune.search.Searcher | ray.tune.search.SearchAlgorithm, optional): Search algorithm to
            use. Strings are resolved to supported Ray Tune searchers. Pre-instantiated objects are reused, and known
            searchers with special Tune param_space requirements are normalized automatically.
        **train_args (Any): Additional arguments to pass to the `train()` method.

    Returns:
        (ray.tune.ResultGrid): A ResultGrid containing the results of the hyperparameter search.

    Examples:
        >>> from ultralytics import YOLO
        >>> model = YOLO("yolo26n.pt")  # Load a YOLO26n model

        Start tuning hyperparameters for YOLO26n training on the COCO8 dataset
        >>> result_grid = model.tune(data="coco8.yaml", use_ray=True)
    """
    LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
    try:
        checks.check_requirements("ray[tune]")

        import ray
        from ray import tune
        from ray.tune import RunConfig
        from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB
    except ImportError:
        raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"')

    try:
        import wandb

        assert hasattr(wandb, "__version__")
    except (ImportError, AssertionError):
        wandb = False

    checks.check_version(ray.__version__, ">=2.0.0", "ray")
    default_space = {
        # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
        "lr0": tune.uniform(1e-5, 1e-2),  # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
        "lrf": tune.uniform(0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)
        "momentum": tune.uniform(0.7, 0.98),  # SGD momentum/Adam beta1
        "weight_decay": tune.uniform(0.0, 0.001),  # optimizer weight decay
        "warmup_epochs": tune.uniform(0.0, 5.0),  # warmup epochs (fractions ok)
        "warmup_momentum": tune.uniform(0.0, 0.95),  # warmup initial momentum
        "box": tune.uniform(1.0, 20.0),  # box loss gain
        "cls": tune.uniform(0.1, 4.0),  # cls loss gain (scale with pixels)
        "dfl": tune.uniform(0.4, 12.0),  # dfl loss gain
        "hsv_h": tune.uniform(0.0, 0.1),  # image HSV-Hue augmentation (fraction)
        "hsv_s": tune.uniform(0.0, 0.9),  # image HSV-Saturation augmentation (fraction)
        "hsv_v": tune.uniform(0.0, 0.9),  # image HSV-Value augmentation (fraction)
        "degrees": tune.uniform(0.0, 45.0),  # image rotation (+/- deg)
        "translate": tune.uniform(0.0, 0.9),  # image translation (+/- fraction)
        "scale": tune.uniform(0.0, 0.95),  # image scale (+/- gain)
        "shear": tune.uniform(0.0, 10.0),  # image shear (+/- deg)
        "perspective": tune.uniform(0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001
        "flipud": tune.uniform(0.0, 1.0),  # image flip up-down (probability)
        "fliplr": tune.uniform(0.0, 1.0),  # image flip left-right (probability)
        "bgr": tune.uniform(0.0, 1.0),  # swap RGB↔BGR channels (probability)
        "mosaic": tune.uniform(0.0, 1.0),  # image mosaic (probability)
        "mixup": tune.uniform(0.0, 1.0),  # image mixup (probability)
        "cutmix": tune.uniform(0.0, 1.0),  # image cutmix (probability)
        "copy_paste": tune.uniform(0.0, 1.0),  # segment copy-paste (probability)
        "close_mosaic": tune.randint(0, 11),  # close dataloader mosaic (epochs)
    }

    # Put the model in ray store
    task = model.task
    model_in_store = ray.put(model)
    base_name = train_args.get("name", "tune")

    def _tune(config):
        """Train the YOLO model with the specified hyperparameters and return results."""
        model_to_train = ray.get(model_in_store)  # get the model from ray store for tuning
        model_to_train.trainer = None
        model_to_train.reset_callbacks()
        config = _sanitize_tune_value(dict(config))
        config.update(train_args)

        # Set trial-specific name for W&B logging
        try:
            trial_id = tune.get_trial_id()  # Get current trial ID (e.g., "2c2fc_00000")
            trial_suffix = trial_id.split("_")[-1] if "_" in trial_id else trial_id
            config["name"] = f"{base_name}_{trial_suffix}"
        except Exception:
            # Not in Ray Tune context or error getting trial ID, use base name
            config["name"] = base_name

        results = model_to_train.train(**config)
        return results.results_dict

    # Get search space
    if not space and not train_args.get("resume"):
        space = default_space
        LOGGER.warning("Search space not provided, using default search space.")

    # Get dataset
    data = train_args.get("data", TASK2DATA[task])
    space["data"] = data
    if "data" not in train_args:
        LOGGER.warning(f'Data not provided, using default "data={data}".')

    resolved_search_alg, tuner_param_space, resolved_search_alg_kind = _resolve_ray_search_alg(
        search_alg, task, space, iterations
    )

    # Define the trainable function with allocated resources
    trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})

    # Define the scheduler for hyperparameter search
    max_t = train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100
    scheduler = ASHAScheduler(
        time_attr="epoch",
        metric=TASK2METRIC[task],
        mode="max",
        max_t=max_t,
        grace_period=min(grace_period, max_t),
        reduction_factor=3,
    )
    if resolved_search_alg_kind == "bohb":
        scheduler = HyperBandForBOHB(
            time_attr="epoch",
            metric=TASK2METRIC[task],
            mode="max",
            max_t=max_t,
            reduction_factor=3,
        )

    # Create the Ray Tune hyperparameter search tuner
    tune_dir = get_save_dir(
        get_cfg(
            DEFAULT_CFG,
            {**train_args, **{"exist_ok": train_args.pop("resume", False)}},  # resume w/ same tune_dir
        ),
        name=train_args.pop("name", "tune"),  # runs/{task}/{tune_dir}
    )  # must be absolute dir
    tune_dir.mkdir(parents=True, exist_ok=True)
    if tune.Tuner.can_restore(tune_dir):
        LOGGER.info(f"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...")
        tuner = tune.Tuner.restore(str(tune_dir), trainable=trainable_with_resources, resume_errored=True)
    else:
        tuner = tune.Tuner(
            trainable_with_resources,
            param_space=tuner_param_space,
            tune_config=tune.TuneConfig(
                search_alg=resolved_search_alg,
                scheduler=scheduler,
                num_samples=iterations,
                trial_name_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
                trial_dirname_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
            ),
            run_config=RunConfig(storage_path=tune_dir.parent, name=tune_dir.name),
        )

    # Run the hyperparameter search
    tuner.fit()

    # Get the results of the hyperparameter search
    results = tuner.get_results()

    # Shut down Ray to clean up workers
    ray.shutdown()

    return results





📅 Created 2 years ago ✏️ Updated 1 day ago
glenn-jocherlmycrossjk4eBurhan-Q