Skip to content

Reference for ultralytics/models/nas/model.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/nas/model.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


Summary

class ultralytics.models.nas.model.NAS

NAS(self, model: str = "yolo_nas_s.pt") -> None

Bases: Model

YOLO-NAS model for object detection.

This class provides an interface for the YOLO-NAS models and extends the Model class from Ultralytics engine. It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.

Args

NameTypeDescriptionDefault
modelstr"yolo_nas_s.pt"

Attributes

NameTypeDescription
modeltorch.nn.ModuleThe loaded YOLO-NAS model.
taskstrThe task type for the model, defaults to 'detect'.
predictorNASPredictorThe predictor instance for making predictions.
validatorNASValidatorThe validator instance for model validation.

Methods

NameDescription
task_mapReturn a dictionary mapping tasks to respective predictor and validator classes.
_loadLoad an existing NAS model weights or create a new NAS model with pretrained weights.
infoLog model information.

Examples

>>> from ultralytics import NAS
>>> model = NAS("yolo_nas_s")
>>> results = model.predict("ultralytics/assets/bus.jpg")

Notes

YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.

Source code in ultralytics/models/nas/model.pyView on GitHub
class NAS(Model):
    """YOLO-NAS model for object detection.

    This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. It
    is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.

    Attributes:
        model (torch.nn.Module): The loaded YOLO-NAS model.
        task (str): The task type for the model, defaults to 'detect'.
        predictor (NASPredictor): The predictor instance for making predictions.
        validator (NASValidator): The validator instance for model validation.

    Methods:
        info: Log model information and return model details.

    Examples:
        >>> from ultralytics import NAS
        >>> model = NAS("yolo_nas_s")
        >>> results = model.predict("ultralytics/assets/bus.jpg")

    Notes:
        YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
    """

    def __init__(self, model: str = "yolo_nas_s.pt") -> None:
        """Initialize the NAS model with the provided or default model."""
        assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
        super().__init__(model, task="detect")


property ultralytics.models.nas.model.NAS.task_map

def task_map(self) -> dict[str, dict[str, Any]]

Return a dictionary mapping tasks to respective predictor and validator classes.

Source code in ultralytics/models/nas/model.pyView on GitHub
@property
def task_map(self) -> dict[str, dict[str, Any]]:
    """Return a dictionary mapping tasks to respective predictor and validator classes."""
    return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}


method ultralytics.models.nas.model.NAS._load

def _load(self, weights: str, task = None) -> None

Load an existing NAS model weights or create a new NAS model with pretrained weights.

Args

NameTypeDescriptionDefault
weightsstrPath to the model weights file or model name.required
taskstr, optionalTask type for the model.None
Source code in ultralytics/models/nas/model.pyView on GitHub
def _load(self, weights: str, task=None) -> None:
    """Load an existing NAS model weights or create a new NAS model with pretrained weights.

    Args:
        weights (str): Path to the model weights file or model name.
        task (str, optional): Task type for the model.
    """
    import super_gradients

    suffix = Path(weights).suffix
    if suffix == ".pt":
        self.model = torch_load(attempt_download_asset(weights))
    elif suffix == "":
        self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")

    # Override the forward method to ignore additional arguments
    def new_forward(x, *args, **kwargs):
        """Ignore additional __call__ arguments."""
        return self.model._original_forward(x)

    self.model._original_forward = self.model.forward
    self.model.forward = new_forward

    # Standardize model attributes for compatibility
    self.model.fuse = lambda verbose=True: self.model
    self.model.stride = torch.tensor([32])
    self.model.names = dict(enumerate(self.model._class_names))
    self.model.is_fused = lambda: False  # for info()
    self.model.yaml = {}  # for info()
    self.model.pt_path = weights  # for export()
    self.model.task = "detect"  # for export()
    self.model.args = {**DEFAULT_CFG_DICT, **self.overrides}  # for export()
    self.model.eval()


method ultralytics.models.nas.model.NAS.info

def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]

Log model information.

Args

NameTypeDescriptionDefault
detailedboolShow detailed information about model.False
verboseboolControls verbosity.True

Returns

TypeDescription
dict[str, Any]Model information dictionary.
Source code in ultralytics/models/nas/model.pyView on GitHub
def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]:
    """Log model information.

    Args:
        detailed (bool): Show detailed information about model.
        verbose (bool): Controls verbosity.

    Returns:
        (dict[str, Any]): Model information dictionary.
    """
    return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)





📅 Created 2 years ago ✏️ Updated 2 days ago
glenn-jocherjk4eBurhan-Q