Skip to content

Référence pour ultralytics/models/rtdetr/train.py

Note

Ce fichier est disponible à l'adresse https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/models/rtdetr/train .py. Si tu repères un problème, aide à le corriger en contribuant à une Pull Request 🛠️. Merci 🙏 !



ultralytics.models.rtdetr.train.RTDETRTrainer

Bases : DetectionTrainer

Classe d'entraînement pour le modèle RT-DETR développé par Baidu pour la détection d'objets en temps réel. Étend la classe DetectionTrainer pour YOLO afin de l'adapter aux caractéristiques et à l'architecture spécifiques de RT-DETR. Ce modèle s'appuie sur Vision Transformers et possède des capacités telles que la sélection de requête consciente de l'interface utilisateur et la vitesse d'inférence adaptable.

Notes
  • F.grid_sample utilisĂ© dans RT-DETR ne prend pas en charge la fonction deterministic=True argument.
  • La formation AMP peut conduire Ă  des sorties NaN et peut produire des erreurs lors de l'appariement des graphes bipartites.
Exemple
from ultralytics.models.rtdetr.train import RTDETRTrainer

args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
Code source dans ultralytics/models/rtdetr/train.py
class RTDETRTrainer(DetectionTrainer):
    """
    Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
    class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
    Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.

    Notes:
        - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
        - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.

    Example:
        ```python
        from ultralytics.models.rtdetr.train import RTDETRTrainer

        args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
        trainer = RTDETRTrainer(overrides=args)
        trainer.train()
        ```
    """

    def get_model(self, cfg=None, weights=None, verbose=True):
        """
        Initialize and return an RT-DETR model for object detection tasks.

        Args:
            cfg (dict, optional): Model configuration. Defaults to None.
            weights (str, optional): Path to pre-trained model weights. Defaults to None.
            verbose (bool): Verbose logging if True. Defaults to True.

        Returns:
            (RTDETRDetectionModel): Initialized model.
        """
        model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model

    def build_dataset(self, img_path, mode="val", batch=None):
        """
        Build and return an RT-DETR dataset for training or validation.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): Dataset mode, either 'train' or 'val'.
            batch (int, optional): Batch size for rectangle training. Defaults to None.

        Returns:
            (RTDETRDataset): Dataset object for the specific mode.
        """
        return RTDETRDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch,
            augment=mode == "train",
            hyp=self.args,
            rect=False,
            cache=self.args.cache or None,
            prefix=colorstr(f"{mode}: "),
            data=self.data,
        )

    def get_validator(self):
        """
        Returns a DetectionValidator suitable for RT-DETR model validation.

        Returns:
            (RTDETRValidator): Validator object for model validation.
        """
        self.loss_names = "giou_loss", "cls_loss", "l1_loss"
        return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

    def preprocess_batch(self, batch):
        """
        Preprocess a batch of images. Scales and converts the images to float format.

        Args:
            batch (dict): Dictionary containing a batch of images, bboxes, and labels.

        Returns:
            (dict): Preprocessed batch.
        """
        batch = super().preprocess_batch(batch)
        bs = len(batch["img"])
        batch_idx = batch["batch_idx"]
        gt_bbox, gt_class = [], []
        for i in range(bs):
            gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
            gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
        return batch

build_dataset(img_path, mode='val', batch=None)

Construis et renvoie un ensemble de données RT-DETR pour la formation ou la validation.

Paramètres :

Nom Type Description DĂ©faut
img_path str

Chemin d'accès au dossier contenant les images.

requis
mode str

Mode de l'ensemble de données, soit "train", soit "val".

'val'
batch int

Taille du lot pour la formation des rectangles. La valeur par défaut est Aucun.

None

Retourne :

Type Description
RTDETRDataset

Objet de jeu de données pour le mode spécifique.

Code source dans ultralytics/models/rtdetr/train.py
def build_dataset(self, img_path, mode="val", batch=None):
    """
    Build and return an RT-DETR dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): Dataset mode, either 'train' or 'val'.
        batch (int, optional): Batch size for rectangle training. Defaults to None.

    Returns:
        (RTDETRDataset): Dataset object for the specific mode.
    """
    return RTDETRDataset(
        img_path=img_path,
        imgsz=self.args.imgsz,
        batch_size=batch,
        augment=mode == "train",
        hyp=self.args,
        rect=False,
        cache=self.args.cache or None,
        prefix=colorstr(f"{mode}: "),
        data=self.data,
    )

get_model(cfg=None, weights=None, verbose=True)

Initialise et renvoie un modèle RT-DETR pour les tâches de détection d'objets.

Paramètres :

Nom Type Description DĂ©faut
cfg dict

Configuration du modèle. La valeur par défaut est Aucun.

None
weights str

Chemin d'accès aux poids du modèle pré-entraîné. La valeur par défaut est Aucun.

None
verbose bool

Journalisation verbeuse si True. La valeur par défaut est True.

True

Retourne :

Type Description
RTDETRDetectionModel

Modèle initialisé.

Code source dans ultralytics/models/rtdetr/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """
    Initialize and return an RT-DETR model for object detection tasks.

    Args:
        cfg (dict, optional): Model configuration. Defaults to None.
        weights (str, optional): Path to pre-trained model weights. Defaults to None.
        verbose (bool): Verbose logging if True. Defaults to True.

    Returns:
        (RTDETRDetectionModel): Initialized model.
    """
    model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator()

Renvoie un DetectionValidator adapté à la validation du modèle RT-DETR .

Retourne :

Type Description
RTDETRValidator

Objet validateur pour la validation du modèle.

Code source dans ultralytics/models/rtdetr/train.py
def get_validator(self):
    """
    Returns a DetectionValidator suitable for RT-DETR model validation.

    Returns:
        (RTDETRValidator): Validator object for model validation.
    """
    self.loss_names = "giou_loss", "cls_loss", "l1_loss"
    return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

preprocess_batch(batch)

Prétraite un lot d'images. Met les images à l'échelle et les convertit au format flottant.

Paramètres :

Nom Type Description DĂ©faut
batch dict

Dictionnaire contenant un lot d'images, de bboxes et d'Ă©tiquettes.

requis

Retourne :

Type Description
dict

Lot prétraité.

Code source dans ultralytics/models/rtdetr/train.py
def preprocess_batch(self, batch):
    """
    Preprocess a batch of images. Scales and converts the images to float format.

    Args:
        batch (dict): Dictionary containing a batch of images, bboxes, and labels.

    Returns:
        (dict): Preprocessed batch.
    """
    batch = super().preprocess_batch(batch)
    bs = len(batch["img"])
    batch_idx = batch["batch_idx"]
    gt_bbox, gt_class = [], []
    for i in range(bs):
        gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
        gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
    return batch





Créé le 2023-11-12, Mis à jour le 2023-11-25
Auteurs : glenn-jocher (3)