Reference for hub_sdk/modules/models.py
Improvements
This page is sourced from https://github.com/ultralytics/hub-sdk/blob/main/hub_sdk/modules/models.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏
Summary
Models._reconstruct_dataModels.get_dataModels.create_modelModels.is_resumableModels.has_best_weightsModels.is_pretrainedModels.is_trainedModels.is_customModels.get_architectureModels.get_dataset_urlModels.get_weights_urlModels.deleteModels.updateModels.get_metricsModels.upload_modelModels.upload_metricsModels.start_heartbeatModels.stop_heartbeatModels.exportModels.predict
class hub_sdk.modules.models.Models
Models(self, model_id: str | None = None, headers: dict[str, Any] | None = None)
Bases: CRUDClient
A class representing a client for interacting with Models through CRUD operations.
This class extends the CRUDClient class and provides specific methods for working with Models, including creating, retrieving, updating, and deleting model resources, as well as uploading model weights and metrics.
Args
| Name | Type | Description | Default |
|---|---|---|---|
model_id | str, optional | The unique identifier of the model. | None |
headers | Dict[str, Any], optional | Headers to be included in API requests. | None |
Attributes
| Name | Type | Description |
|---|---|---|
base_endpoint | str | The base endpoint URL for the API, set to "models". |
hub_client | ModelUpload | An instance of ModelUpload used for interacting with model uploads. |
id | str | None | The unique identifier of the model, if available. |
data | Dict | A dictionary to store model data. |
metrics | List[Dict] | None | Model metrics data, populated after retrieval. |
Methods
| Name | Description |
|---|---|
_reconstruct_data | Reconstruct format of model data supported by ultralytics. |
create_model | Create a new model with the provided data and set the model ID for the current instance. |
delete | Delete the model resource represented by this instance. |
export | Export model to specified format via Ultralytics HUB. |
get_architecture | Get the architecture name of the model. |
get_data | Retrieve data for the current model instance. |
get_dataset_url | Get the dataset URL associated with the model. |
get_metrics | Get metrics of the model. |
get_weights_url | Get the URL of the model weights. |
has_best_weights | Check if the model has best weights saved from previous training. |
is_custom | Check if the model is a custom model rather than a standard one. |
is_pretrained | Check if the model is pretrained with initial weights. |
is_resumable | Check if the model training can be resumed based on the presence of last weights. |
is_trained | Check if the model has completed training and is in 'trained' status. |
predict | Run prediction using the model via Ultralytics HUB. |
start_heartbeat | Start sending heartbeat signals to a remote hub server. |
stop_heartbeat | Stop sending heartbeat signals to a remote hub server. |
update | Update the model resource represented by this instance. |
upload_metrics | Upload model metrics to Ultralytics HUB. |
upload_model | Upload a model checkpoint to Ultralytics HUB. |
Notes
The 'id' attribute is set during initialization and can be used to uniquely identify a model. The 'data' attribute is used to store model data fetched from the API.
Source code in hub_sdk/modules/models.py
View on GitHubclass Models(CRUDClient):
"""A class representing a client for interacting with Models through CRUD operations.
This class extends the CRUDClient class and provides specific methods for working with Models, including creating,
retrieving, updating, and deleting model resources, as well as uploading model weights and metrics.
Attributes:
base_endpoint (str): The base endpoint URL for the API, set to "models".
hub_client (ModelUpload): An instance of ModelUpload used for interacting with model uploads.
id (str | None): The unique identifier of the model, if available.
data (Dict): A dictionary to store model data.
metrics (List[Dict] | None): Model metrics data, populated after retrieval.
Notes:
The 'id' attribute is set during initialization and can be used to uniquely identify a model.
The 'data' attribute is used to store model data fetched from the API.
"""
def __init__(self, model_id: str | None = None, headers: dict[str, Any] | None = None):
"""Initialize a Models instance.
Args:
model_id (str, optional): The unique identifier of the model.
headers (Dict[str, Any], optional): Headers to be included in API requests.
"""
self.base_endpoint = "models"
super().__init__(self.base_endpoint, "model", headers)
self.hub_client = ModelUpload(headers)
self.id = model_id
self.data = {}
self.metrics = None
if model_id:
self.get_data()
method hub_sdk.modules.models.Models._reconstruct_data
def _reconstruct_data(data: dict) -> dict
Reconstruct format of model data supported by ultralytics.
Args
| Name | Type | Description | Default |
|---|---|---|---|
data | Dict | Original model data dictionary. | required |
Returns
| Type | Description |
|---|---|
Dict | Reconstructed data format with reorganized configuration. |
Source code in hub_sdk/modules/models.py
View on GitHub@staticmethod
def _reconstruct_data(data: dict) -> dict:
"""Reconstruct format of model data supported by ultralytics.
Args:
data (Dict): Original model data dictionary.
Returns:
(Dict): Reconstructed data format with reorganized configuration.
"""
if not data:
return data
data["config"] = {
"batchSize": data.pop("batch_size", None),
"epochs": data.pop("epochs", None),
"imageSize": data.pop("imgsz", None),
"patience": data.pop("patience", None),
"device": data.pop("device", None),
"cache": data.pop("cache", None),
}
return data
method hub_sdk.modules.models.Models.create_model
def create_model(self, model_data: dict) -> None
Create a new model with the provided data and set the model ID for the current instance.
Args
| Name | Type | Description | Default |
|---|---|---|---|
model_data | Dict | A dictionary containing the data for creating the model. | required |
Source code in hub_sdk/modules/models.py
View on GitHubdef create_model(self, model_data: dict) -> None:
"""Create a new model with the provided data and set the model ID for the current instance.
Args:
model_data (Dict): A dictionary containing the data for creating the model.
"""
try:
response = super().create(model_data)
if response is None:
self.logger.error("Received no response from the server while creating the model.")
return
# Ensuring the response object has the .json() method
if not hasattr(response, "json"):
self.logger.error("Invalid response object received while creating the model.")
return
resp_data = response.json()
if resp_data is None:
self.logger.error("No data received in the response while creating the model.")
return
self.id = resp_data.get("data", {}).get("id")
# Check if the ID was successfully retrieved
if not self.id:
self.logger.error("Model ID not found in the response data.")
return
self.get_data()
except Exception as e:
self.logger.error(f"An error occurred while creating the model: {e!s}")
method hub_sdk.modules.models.Models.delete
def delete(self, hard: bool = False) -> Response | None
Delete the model resource represented by this instance.
Args
| Name | Type | Description | Default |
|---|---|---|---|
hard | bool, optional | If True, perform a hard (permanent) delete. | False |
Returns
| Type | Description |
|---|---|
Optional[Response] | Response object from the delete request, or None if delete fails. |
Notes
The 'hard' parameter determines whether to perform a soft delete (default) or a hard delete. In a soft delete, the model might be marked as deleted but retained in the system. In a hard delete, the model is permanently removed from the system.
Source code in hub_sdk/modules/models.py
View on GitHubdef delete(self, hard: bool = False) -> Response | None:
"""Delete the model resource represented by this instance.
Args:
hard (bool, optional): If True, perform a hard (permanent) delete.
Returns:
(Optional[Response]): Response object from the delete request, or None if delete fails.
Notes:
The 'hard' parameter determines whether to perform a soft delete (default) or a hard delete.
In a soft delete, the model might be marked as deleted but retained in the system.
In a hard delete, the model is permanently removed from the system.
"""
return super().delete(self.id, hard)
method hub_sdk.modules.models.Models.export
def export(self, format: str) -> Response | None
Export model to specified format via Ultralytics HUB.
Args
| Name | Type | Description | Default |
|---|---|---|---|
format | str | Export format. Supported formats are available at https://docs.ultralytics.com/modes/export/#export-formats | required |
Returns
| Type | Description |
|---|---|
Optional[Response] | Response object from the export request, or None if export fails. |
Source code in hub_sdk/modules/models.py
View on GitHubdef export(self, format: str) -> Response | None:
"""Export model to specified format via Ultralytics HUB.
Args:
format (str): Export format. Supported formats are available at
https://docs.ultralytics.com/modes/export/#export-formats
Returns:
(Optional[Response]): Response object from the export request, or None if export fails.
"""
return self.hub_client.export(self.id, format) # response
method hub_sdk.modules.models.Models.get_architecture
def get_architecture(self) -> str | None
Get the architecture name of the model.
Returns
| Type | Description |
|---|---|
Optional[str] | The architecture configuration path or None if not available. |
Source code in hub_sdk/modules/models.py
View on GitHubdef get_architecture(self) -> str | None:
"""Get the architecture name of the model.
Returns:
(Optional[str]): The architecture configuration path or None if not available.
"""
return self.data.get("cfg")
method hub_sdk.modules.models.Models.get_data
def get_data(self) -> None
Retrieve data for the current model instance.
Fetches model data from the API if a valid model ID has been set and stores it in the instance. Logs appropriate error messages if the retrieval fails at any step.
Source code in hub_sdk/modules/models.py
View on GitHubdef get_data(self) -> None:
"""Retrieve data for the current model instance.
Fetches model data from the API if a valid model ID has been set and stores it in the instance. Logs appropriate
error messages if the retrieval fails at any step.
"""
if not self.id:
self.logger.error("No model id has been set. Update the model id or create a model.")
return
try:
response = super().read(self.id)
if response is None:
self.logger.error(f"Received no response from the server for model ID: {self.id}")
return
# Check if the response has a .json() method (it should if it's a response object)
if not hasattr(response, "json"):
self.logger.error(f"Invalid response object received for model ID: {self.id}")
return
resp_data = response.json()
if resp_data is None:
self.logger.error(f"No data received in the response for model ID: {self.id}")
return
data = resp_data.get("data", {})
self.data = self._reconstruct_data(data)
self.logger.debug(f"Model data retrieved for ID: {self.id}")
except Exception as e:
self.logger.error(f"An error occurred while retrieving data for model ID: {self.id}, {e!s}")
method hub_sdk.modules.models.Models.get_dataset_url
def get_dataset_url(self) -> str | None
Get the dataset URL associated with the model.
Returns
| Type | Description |
|---|---|
Optional[str] | The URL of the dataset or None if not available. |
Source code in hub_sdk/modules/models.py
View on GitHubdef get_dataset_url(self) -> str | None:
"""Get the dataset URL associated with the model.
Returns:
(Optional[str]): The URL of the dataset or None if not available.
"""
return self.data.get("data")
method hub_sdk.modules.models.Models.get_metrics
def get_metrics(self) -> list[dict[str, Any]] | None
Get metrics of the model.
Returns
| Type | Description |
|---|---|
Optional[List[Dict[str, Any]]] | The list of metrics objects, or None if retrieval fails. |
Source code in hub_sdk/modules/models.py
View on GitHubdef get_metrics(self) -> list[dict[str, Any]] | None:
"""Get metrics of the model.
Returns:
(Optional[List[Dict[str, Any]]]): The list of metrics objects, or None if retrieval fails.
"""
if self.metrics:
return self.metrics
endpoint = f"{HUB_API_ROOT}/v1/{self.base_endpoint}/{self.id}/metrics"
try:
results = self.get(endpoint)
self.metrics = results.json().get("data")
return self.metrics
except Exception as e:
self.logger.error(f"Model Metrics not found: {e}")
return None
method hub_sdk.modules.models.Models.get_weights_url
def get_weights_url(self, weight: str = "best") -> str | None
Get the URL of the model weights.
Args
| Name | Type | Description | Default |
|---|---|---|---|
weight | str, optional | Type of weights to retrieve, either "best" or "last". | "best" |
Returns
| Type | Description |
|---|---|
Optional[str] | The URL of the specified weights or None if not available. |
Source code in hub_sdk/modules/models.py
View on GitHubdef get_weights_url(self, weight: str = "best") -> str | None:
"""Get the URL of the model weights.
Args:
weight (str, optional): Type of weights to retrieve, either "best" or "last".
Returns:
(Optional[str]): The URL of the specified weights or None if not available.
"""
if weight == "last":
return self.data.get("resume")
return self.data.get("weights")
method hub_sdk.modules.models.Models.has_best_weights
def has_best_weights(self) -> bool
Check if the model has best weights saved from previous training.
Source code in hub_sdk/modules/models.py
View on GitHubdef has_best_weights(self) -> bool:
"""Check if the model has best weights saved from previous training."""
return self.data.get("has_best_weights", False)
method hub_sdk.modules.models.Models.is_custom
def is_custom(self) -> bool
Check if the model is a custom model rather than a standard one.
Source code in hub_sdk/modules/models.py
View on GitHubdef is_custom(self) -> bool:
"""Check if the model is a custom model rather than a standard one."""
return self.data.get("is_custom", False)
method hub_sdk.modules.models.Models.is_pretrained
def is_pretrained(self) -> bool
Check if the model is pretrained with initial weights.
Source code in hub_sdk/modules/models.py
View on GitHubdef is_pretrained(self) -> bool:
"""Check if the model is pretrained with initial weights."""
return self.data.get("is_pretrained", False)
method hub_sdk.modules.models.Models.is_resumable
def is_resumable(self) -> bool
Check if the model training can be resumed based on the presence of last weights.
Source code in hub_sdk/modules/models.py
View on GitHubdef is_resumable(self) -> bool:
"""Check if the model training can be resumed based on the presence of last weights."""
return self.data.get("has_last_weights", False)
method hub_sdk.modules.models.Models.is_trained
def is_trained(self) -> bool
Check if the model has completed training and is in 'trained' status.
Source code in hub_sdk/modules/models.py
View on GitHubdef is_trained(self) -> bool:
"""Check if the model has completed training and is in 'trained' status."""
return self.data.get("status") == "trained"
method hub_sdk.modules.models.Models.predict
def predict(self, image: str, config: dict[str, Any]) -> Response | None
Run prediction using the model via Ultralytics HUB.
Args
| Name | Type | Description | Default |
|---|---|---|---|
image | str | The path to the image file. | required |
config | Dict[str, Any] | A configuration dictionary for the prediction. | required |
Returns
| Type | Description |
|---|---|
Optional[Response] | Response object from the predict request, or None if prediction fails. |
Source code in hub_sdk/modules/models.py
View on GitHubdef predict(self, image: str, config: dict[str, Any]) -> Response | None:
"""Run prediction using the model via Ultralytics HUB.
Args:
image (str): The path to the image file.
config (Dict[str, Any]): A configuration dictionary for the prediction.
Returns:
(Optional[Response]): Response object from the predict request, or None if prediction fails.
"""
return self.hub_client.predict(self.id, image, config) # response
method hub_sdk.modules.models.Models.start_heartbeat
def start_heartbeat(self, interval: int = 60)
Start sending heartbeat signals to a remote hub server.
This method initiates the sending of heartbeat signals to a hub server to indicate the continued availability and health of the client.
Args
| Name | Type | Description | Default |
|---|---|---|---|
interval | int, optional | The time interval, in seconds, between consecutive heartbeats. | 60 |
Notes
Heartbeats are essential for maintaining a connection with the hub server and ensuring that the client remains active and responsive.
Source code in hub_sdk/modules/models.py
View on GitHubdef start_heartbeat(self, interval: int = 60):
"""Start sending heartbeat signals to a remote hub server.
This method initiates the sending of heartbeat signals to a hub server
to indicate the continued availability and health of the client.
Args:
interval (int, optional): The time interval, in seconds, between consecutive heartbeats.
Notes:
Heartbeats are essential for maintaining a connection with the hub server
and ensuring that the client remains active and responsive.
"""
self.hub_client._register_signal_handlers()
self.hub_client._start_heartbeats(self.id, interval)
method hub_sdk.modules.models.Models.stop_heartbeat
def stop_heartbeat(self) -> None
Stop sending heartbeat signals to a remote hub server.
This method terminates the sending of heartbeat signals to the hub server, effectively signaling that the client is no longer available or active.
Notes
Stopping heartbeats should be done carefully, as it may result in the hub server considering the client as disconnected or unavailable.
Source code in hub_sdk/modules/models.py
View on GitHubdef stop_heartbeat(self) -> None:
"""Stop sending heartbeat signals to a remote hub server.
This method terminates the sending of heartbeat signals to the hub server,
effectively signaling that the client is no longer available or active.
Notes:
Stopping heartbeats should be done carefully, as it may result in the hub server
considering the client as disconnected or unavailable.
"""
self.hub_client._stop_heartbeats()
method hub_sdk.modules.models.Models.update
def update(self, data: dict) -> Response | None
Update the model resource represented by this instance.
Args
| Name | Type | Description | Default |
|---|---|---|---|
data | Dict | The updated data for the model resource. | required |
Returns
| Type | Description |
|---|---|
Optional[Response] | Response object from the update request, or None if update fails. |
Source code in hub_sdk/modules/models.py
View on GitHubdef update(self, data: dict) -> Response | None:
"""Update the model resource represented by this instance.
Args:
data (Dict): The updated data for the model resource.
Returns:
(Optional[Response]): Response object from the update request, or None if update fails.
"""
return super().update(self.id, data)
method hub_sdk.modules.models.Models.upload_metrics
def upload_metrics(self, metrics: dict) -> Response | None
Upload model metrics to Ultralytics HUB.
Args
| Name | Type | Description | Default |
|---|---|---|---|
metrics | Dict | Dictionary containing model metrics data. | required |
Returns
| Type | Description |
|---|---|
Optional[Response] | Response object from the upload metrics request, or None if it fails. |
Source code in hub_sdk/modules/models.py
View on GitHubdef upload_metrics(self, metrics: dict) -> Response | None:
"""Upload model metrics to Ultralytics HUB.
Args:
metrics (Dict): Dictionary containing model metrics data.
Returns:
(Optional[Response]): Response object from the upload metrics request, or None if it fails.
"""
return self.hub_client.upload_metrics(self.id, metrics) # response
method hub_sdk.modules.models.Models.upload_model
def upload_model(
self,
epoch: int,
weights: str,
is_best: bool = False,
map: float = 0.0,
final: bool = False,
) -> Response | None
Upload a model checkpoint to Ultralytics HUB.
Args
| Name | Type | Description | Default |
|---|---|---|---|
epoch | int | The current training epoch. | required |
weights | str | Path to the model weights file. | required |
is_best | bool, optional | Indicates if the current model is the best one so far. | False |
map | float, optional | Mean average precision of the model. | 0.0 |
final | bool, optional | Indicates if the model is the final model after training. | False |
Returns
| Type | Description |
|---|---|
Optional[Response] | Response object from the upload request, or None if upload fails. |
Source code in hub_sdk/modules/models.py
View on GitHubdef upload_model(
self,
epoch: int,
weights: str,
is_best: bool = False,
map: float = 0.0,
final: bool = False,
) -> Response | None:
"""Upload a model checkpoint to Ultralytics HUB.
Args:
epoch (int): The current training epoch.
weights (str): Path to the model weights file.
is_best (bool, optional): Indicates if the current model is the best one so far.
map (float, optional): Mean average precision of the model.
final (bool, optional): Indicates if the model is the final model after training.
Returns:
(Optional[Response]): Response object from the upload request, or None if upload fails.
"""
return self.hub_client.upload_model(self.id, epoch, weights, is_best=is_best, map=map, final=final)
class hub_sdk.modules.models.ModelList
ModelList(self, page_size = None, public = None, headers = None)
Bases: PaginatedList
Provides a paginated list interface for managing and querying models from the Ultralytics HUB API.
Args
| Name | Type | Description | Default |
|---|---|---|---|
page_size | int, optional | The number of items to request per page. | None |
public | bool, optional | Whether the items should be publicly accessible. | None |
headers | Dict, optional | Headers to be included in API requests. | None |
Source code in hub_sdk/modules/models.py
View on GitHubclass ModelList(PaginatedList):
"""Provides a paginated list interface for managing and querying models from the Ultralytics HUB API."""
def __init__(self, page_size=None, public=None, headers=None):
"""Initialize a ModelList instance.
Args:
page_size (int, optional): The number of items to request per page.
public (bool, optional): Whether the items should be publicly accessible.
headers (Dict, optional): Headers to be included in API requests.
"""
base_endpoint = "models"
super().__init__(base_endpoint, "model", page_size, public, headers)