Skip to content

Reference for ultralytics/hub/session.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/hub/session.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.hub.session.HUBTrainingSession

HUBTrainingSession(identifier)

HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.

Attributes:

Name Type Description
model_id str

Identifier for the YOLO model being trained.

model_url str

URL for the model in Ultralytics HUB.

rate_limits dict

Rate limits for different API calls (in seconds).

timers dict

Timers for rate limiting.

metrics_queue dict

Queue for the model's metrics.

model dict

Model data fetched from Ultralytics HUB.

Parameters:

Name Type Description Default
identifier str

Model identifier used to initialize the HUB training session. It can be a URL string or a model key with specific format.

required

Raises:

Type Description
ValueError

If the provided model identifier is invalid.

ConnectionError

If connecting with global API key is not supported.

ModuleNotFoundError

If hub-sdk package is not installed.

Source code in ultralytics/hub/session.py
def __init__(self, identifier):
    """
    Initialize the HUBTrainingSession with the provided model identifier.

    Args:
        identifier (str): Model identifier used to initialize the HUB training session.
            It can be a URL string or a model key with specific format.

    Raises:
        ValueError: If the provided model identifier is invalid.
        ConnectionError: If connecting with global API key is not supported.
        ModuleNotFoundError: If hub-sdk package is not installed.
    """
    from hub_sdk import HUBClient

    self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300}  # rate limits (seconds)
    self.metrics_queue = {}  # holds metrics for each epoch until upload
    self.metrics_upload_failed_queue = {}  # holds metrics for each epoch if upload failed
    self.timers = {}  # holds timers in ultralytics/utils/callbacks/hub.py
    self.model = None
    self.model_url = None
    self.model_file = None
    self.train_args = None

    # Parse input
    api_key, model_id, self.filename = self._parse_identifier(identifier)

    # Get credentials
    active_key = api_key or SETTINGS.get("api_key")
    credentials = {"api_key": active_key} if active_key else None  # set credentials

    # Initialize client
    self.client = HUBClient(credentials)

    # Load models
    try:
        if model_id:
            self.load_model(model_id)  # load existing model
        else:
            self.model = self.client.model()  # load empty model
    except Exception:
        if identifier.startswith(f"{HUB_WEB_ROOT}/models/") and not self.client.authenticated:
            LOGGER.warning(
                f"{PREFIX}WARNING ⚠️ Please log in using 'yolo login API_KEY'. "
                "You can find your API Key at: https://hub.ultralytics.com/settings?tab=api+keys."
            )

_get_failure_message

_get_failure_message(response: requests.Response, retry: int, timeout: int)

Generate a retry message based on the response status code.

Parameters:

Name Type Description Default
response Response

The HTTP response object.

required
retry int

The number of retry attempts allowed.

required
timeout int

The maximum timeout duration.

required

Returns:

Type Description
str

The retry message.

Source code in ultralytics/hub/session.py
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
    """
    Generate a retry message based on the response status code.

    Args:
        response: The HTTP response object.
        retry: The number of retry attempts allowed.
        timeout: The maximum timeout duration.

    Returns:
        (str): The retry message.
    """
    if self._should_retry(response.status_code):
        return f"Retrying {retry}x for {timeout}s." if retry else ""
    elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS:  # rate limit
        headers = response.headers
        return (
            f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
            f"Please retry after {headers['Retry-After']}s."
        )
    else:
        try:
            return response.json().get("message", "No JSON message.")
        except AttributeError:
            return "Unable to read JSON."

_iterate_content staticmethod

_iterate_content(response: requests.Response) -> None

Process the streamed HTTP response data.

Parameters:

Name Type Description Default
response Response

The response object from the file download request.

required

Returns:

Type Description
None

None

Source code in ultralytics/hub/session.py
@staticmethod
def _iterate_content(response: requests.Response) -> None:
    """
    Process the streamed HTTP response data.

    Args:
        response (requests.Response): The response object from the file download request.

    Returns:
        None
    """
    for _ in response.iter_content(chunk_size=1024):
        pass  # Do nothing with data chunks

_parse_identifier staticmethod

_parse_identifier(identifier)

Parses the given identifier to determine the type of identifier and extract relevant components.

The method supports different identifier formats

Parameters:

Name Type Description Default
identifier str

The identifier string to be parsed.

required

Returns:

Type Description
tuple

A tuple containing the API key, model ID, and filename as applicable.

Raises:

Type Description
HUBModelError

If the identifier format is not recognized.

Source code in ultralytics/hub/session.py
@staticmethod
def _parse_identifier(identifier):
    """
    Parses the given identifier to determine the type of identifier and extract relevant components.

    The method supports different identifier formats:
        - A HUB model URL https://hub.ultralytics.com/models/MODEL
        - A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY
        - A local filename that ends with '.pt' or '.yaml'

    Args:
        identifier (str): The identifier string to be parsed.

    Returns:
        (tuple): A tuple containing the API key, model ID, and filename as applicable.

    Raises:
        HUBModelError: If the identifier format is not recognized.
    """
    api_key, model_id, filename = None, None, None
    if Path(identifier).suffix in {".pt", ".yaml"}:
        filename = identifier
    elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
        parsed_url = urlparse(identifier)
        model_id = Path(parsed_url.path).stem  # handle possible final backslash robustly
        query_params = parse_qs(parsed_url.query)  # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
        api_key = query_params.get("api_key", [None])[0]
    else:
        raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
    return api_key, model_id, filename

_set_train_args

_set_train_args()

Initializes training arguments and creates a model entry on the Ultralytics HUB.

This method sets up training arguments based on the model's state and updates them with any additional arguments provided. It handles different states of the model, such as whether it's resumable, pretrained, or requires specific file setup.

Raises:

Type Description
ValueError

If the model is already trained, if required dataset information is missing, or if there are issues with the provided training arguments.

Source code in ultralytics/hub/session.py
def _set_train_args(self):
    """
    Initializes training arguments and creates a model entry on the Ultralytics HUB.

    This method sets up training arguments based on the model's state and updates them with any additional
    arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
    or requires specific file setup.

    Raises:
        ValueError: If the model is already trained, if required dataset information is missing, or if there are
            issues with the provided training arguments.
    """
    if self.model.is_resumable():
        # Model has saved weights
        self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
        self.model_file = self.model.get_weights_url("last")
    else:
        # Model has no saved weights
        self.train_args = self.model.data.get("train_args")  # new response

        # Set the model file as either a *.pt or *.yaml file
        self.model_file = (
            self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
        )

    if "data" not in self.train_args:
        # RF bug - datasets are sometimes not exported
        raise ValueError("Dataset may still be processing. Please wait a minute and try again.")

    self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False)  # YOLOv5->YOLOv5u
    self.model_id = self.model.id

_should_retry staticmethod

_should_retry(status_code)

Determines if a request should be retried based on the HTTP status code.

Source code in ultralytics/hub/session.py
@staticmethod
def _should_retry(status_code):
    """Determines if a request should be retried based on the HTTP status code."""
    retry_codes = {
        HTTPStatus.REQUEST_TIMEOUT,
        HTTPStatus.BAD_GATEWAY,
        HTTPStatus.GATEWAY_TIMEOUT,
    }
    return status_code in retry_codes

_show_upload_progress staticmethod

_show_upload_progress(content_length: int, response: requests.Response) -> None

Display a progress bar to track the upload progress of a file download.

Parameters:

Name Type Description Default
content_length int

The total size of the content to be downloaded in bytes.

required
response Response

The response object from the file download request.

required

Returns:

Type Description
None

None

Source code in ultralytics/hub/session.py
@staticmethod
def _show_upload_progress(content_length: int, response: requests.Response) -> None:
    """
    Display a progress bar to track the upload progress of a file download.

    Args:
        content_length (int): The total size of the content to be downloaded in bytes.
        response (requests.Response): The response object from the file download request.

    Returns:
        None
    """
    with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
        for data in response.iter_content(chunk_size=1024):
            pbar.update(len(data))

create_model

create_model(model_args)

Initializes a HUB training session with the specified model identifier.

Source code in ultralytics/hub/session.py
def create_model(self, model_args):
    """Initializes a HUB training session with the specified model identifier."""
    payload = {
        "config": {
            "batchSize": model_args.get("batch", -1),
            "epochs": model_args.get("epochs", 300),
            "imageSize": model_args.get("imgsz", 640),
            "patience": model_args.get("patience", 100),
            "device": str(model_args.get("device", "")),  # convert None to string
            "cache": str(model_args.get("cache", "ram")),  # convert True, False, None to string
        },
        "dataset": {"name": model_args.get("data")},
        "lineage": {
            "architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")},
            "parent": {},
        },
        "meta": {"name": self.filename},
    }

    if self.filename.endswith(".pt"):
        payload["lineage"]["parent"]["name"] = self.filename

    self.model.create_model(payload)

    # Model could not be created
    # TODO: improve error handling
    if not self.model.id:
        return None

    self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"

    # Start heartbeats for HUB to monitor agent
    self.model.start_heartbeat(self.rate_limits["heartbeat"])

    LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")

create_session classmethod

create_session(identifier, args=None)

Class method to create an authenticated HUBTrainingSession or return None.

Source code in ultralytics/hub/session.py
@classmethod
def create_session(cls, identifier, args=None):
    """Class method to create an authenticated HUBTrainingSession or return None."""
    try:
        session = cls(identifier)
        if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"):  # not a HUB model URL
            session.create_model(args)
            assert session.model.id, "HUB model not loaded correctly"
        return session
    # PermissionError and ModuleNotFoundError indicate hub-sdk not installed
    except (PermissionError, ModuleNotFoundError, AssertionError):
        return None

load_model

load_model(model_id)

Loads an existing model from Ultralytics HUB using the provided model identifier.

Source code in ultralytics/hub/session.py
def load_model(self, model_id):
    """Loads an existing model from Ultralytics HUB using the provided model identifier."""
    self.model = self.client.model(model_id)
    if not self.model.data:  # then model does not exist
        raise ValueError(emojis("❌ The specified HUB model does not exist"))  # TODO: improve error handling

    self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
    if self.model.is_trained():
        print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
        url = self.model.get_weights_url("best")  # download URL with auth
        self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
        return

    # Set training args and start heartbeats for HUB to monitor agent
    self._set_train_args()
    self.model.start_heartbeat(self.rate_limits["heartbeat"])
    LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")

request_queue

request_queue(
    request_func,
    retry=3,
    timeout=30,
    thread=True,
    verbose=True,
    progress_total=None,
    stream_response=None,
    *args,
    **kwargs
)

Attempts to execute request_func with retries, timeout handling, optional threading, and progress.

Source code in ultralytics/hub/session.py
def request_queue(
    self,
    request_func,
    retry=3,
    timeout=30,
    thread=True,
    verbose=True,
    progress_total=None,
    stream_response=None,
    *args,
    **kwargs,
):
    """Attempts to execute `request_func` with retries, timeout handling, optional threading, and progress."""

    def retry_request():
        """Attempts to call `request_func` with retries, timeout, and optional threading."""
        t0 = time.time()  # Record the start time for the timeout
        response = None
        for i in range(retry + 1):
            if (time.time() - t0) > timeout:
                LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
                break  # Timeout reached, exit loop

            response = request_func(*args, **kwargs)
            if response is None:
                LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
                time.sleep(2**i)  # Exponential backoff before retrying
                continue  # Skip further processing and retry

            if progress_total:
                self._show_upload_progress(progress_total, response)
            elif stream_response:
                self._iterate_content(response)

            if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
                # if request related to metrics upload
                if kwargs.get("metrics"):
                    self.metrics_upload_failed_queue = {}
                return response  # Success, no need to retry

            if i == 0:
                # Initial attempt, check status code and provide messages
                message = self._get_failure_message(response, retry, timeout)

                if verbose:
                    LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")

            if not self._should_retry(response.status_code):
                LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
                break  # Not an error that should be retried, exit loop

            time.sleep(2**i)  # Exponential backoff for retries

        # if request related to metrics upload and exceed retries
        if response is None and kwargs.get("metrics"):
            self.metrics_upload_failed_queue.update(kwargs.get("metrics"))

        return response

    if thread:
        # Start a new thread to run the retry_request function
        threading.Thread(target=retry_request, daemon=True).start()
    else:
        # If running in the main thread, call retry_request directly
        return retry_request()

upload_metrics

upload_metrics()

Upload model metrics to Ultralytics HUB.

Source code in ultralytics/hub/session.py
def upload_metrics(self):
    """Upload model metrics to Ultralytics HUB."""
    return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)

upload_model

upload_model(
    epoch: int,
    weights: str,
    is_best: bool = False,
    map: float = 0.0,
    final: bool = False,
) -> None

Upload a model checkpoint to Ultralytics HUB.

Parameters:

Name Type Description Default
epoch int

The current training epoch.

required
weights str

Path to the model weights file.

required
is_best bool

Indicates if the current model is the best one so far.

False
map float

Mean average precision of the model.

0.0
final bool

Indicates if the model is the final model after training.

False
Source code in ultralytics/hub/session.py
def upload_model(
    self,
    epoch: int,
    weights: str,
    is_best: bool = False,
    map: float = 0.0,
    final: bool = False,
) -> None:
    """
    Upload a model checkpoint to Ultralytics HUB.

    Args:
        epoch (int): The current training epoch.
        weights (str): Path to the model weights file.
        is_best (bool): Indicates if the current model is the best one so far.
        map (float): Mean average precision of the model.
        final (bool): Indicates if the model is the final model after training.
    """
    weights = Path(weights)
    if not weights.is_file():
        last = weights.with_name(f"last{weights.suffix}")
        if final and last.is_file():
            LOGGER.warning(
                f"{PREFIX} WARNING ⚠️ Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. "
                "This often happens when resuming training in transient environments like Google Colab. "
                "For more reliable training, consider using Ultralytics HUB Cloud. "
                "Learn more at https://docs.ultralytics.com/hub/cloud-training."
            )
            shutil.copy(last, weights)  # copy last.pt to best.pt
        else:
            LOGGER.warning(f"{PREFIX} WARNING ⚠️ Model upload issue. Missing model {weights}.")
            return

    self.request_queue(
        self.model.upload_model,
        epoch=epoch,
        weights=str(weights),
        is_best=is_best,
        map=map,
        final=final,
        retry=10,
        timeout=3600,
        thread=not final,
        progress_total=weights.stat().st_size if final else None,  # only show progress if final
        stream_response=True,
    )



📅 Created 1 year ago ✏️ Updated 5 months ago