Zum Inhalt springen

Referenz fĂŒr ultralytics/models/rtdetr/train.py

Hinweis

Diese Datei ist verfĂŒgbar unter https://github.com/ultralytics/ ultralytics/blob/main/ ultralytics/models/rtdetr/train .py. Wenn du ein Problem entdeckst, hilf bitte mit, es zu beheben, indem du einen Pull Request đŸ› ïž einreichst. Vielen Dank 🙏!



ultralytics.models.rtdetr.train.RTDETRTrainer

Basen: DetectionTrainer

Trainerklasse fĂŒr das von Baidu entwickelte RT-DETR Modell zur Objekterkennung in Echtzeit. Erweitert die DetectionTrainer Klasse fĂŒr YOLO , um sie an die spezifischen Merkmale und die Architektur von RT-DETR anzupassen. Dieses Modell nutzt die Vision Transformers und verfĂŒgt ĂŒber Funktionen wie IoU-bewusste Abfrageauswahl und anpassbare Inferenzgeschwindigkeit.

Anmerkungen
  • F.grid_sample, das in RT-DETR verwendet wird, unterstĂŒtzt nicht die deterministic=True Argument.
  • AMP-Training kann zu NaN-Ausgaben fĂŒhren und Fehler beim Matching von bipartiten Graphen verursachen.
Beispiel
from ultralytics.models.rtdetr.train import RTDETRTrainer

args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
Quellcode in ultralytics/models/rtdetr/train.py
class RTDETRTrainer(DetectionTrainer):
    """
    Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
    class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
    Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.

    Notes:
        - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
        - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.

    Example:
        ```python
        from ultralytics.models.rtdetr.train import RTDETRTrainer

        args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
        trainer = RTDETRTrainer(overrides=args)
        trainer.train()
        ```
    """

    def get_model(self, cfg=None, weights=None, verbose=True):
        """
        Initialize and return an RT-DETR model for object detection tasks.

        Args:
            cfg (dict, optional): Model configuration. Defaults to None.
            weights (str, optional): Path to pre-trained model weights. Defaults to None.
            verbose (bool): Verbose logging if True. Defaults to True.

        Returns:
            (RTDETRDetectionModel): Initialized model.
        """
        model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model

    def build_dataset(self, img_path, mode="val", batch=None):
        """
        Build and return an RT-DETR dataset for training or validation.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): Dataset mode, either 'train' or 'val'.
            batch (int, optional): Batch size for rectangle training. Defaults to None.

        Returns:
            (RTDETRDataset): Dataset object for the specific mode.
        """
        return RTDETRDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch,
            augment=mode == "train",
            hyp=self.args,
            rect=False,
            cache=self.args.cache or None,
            prefix=colorstr(f"{mode}: "),
            data=self.data,
        )

    def get_validator(self):
        """
        Returns a DetectionValidator suitable for RT-DETR model validation.

        Returns:
            (RTDETRValidator): Validator object for model validation.
        """
        self.loss_names = "giou_loss", "cls_loss", "l1_loss"
        return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

    def preprocess_batch(self, batch):
        """
        Preprocess a batch of images. Scales and converts the images to float format.

        Args:
            batch (dict): Dictionary containing a batch of images, bboxes, and labels.

        Returns:
            (dict): Preprocessed batch.
        """
        batch = super().preprocess_batch(batch)
        bs = len(batch["img"])
        batch_idx = batch["batch_idx"]
        gt_bbox, gt_class = [], []
        for i in range(bs):
            gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
            gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
        return batch

build_dataset(img_path, mode='val', batch=None)

Erstelle einen RT-DETR -Datensatz fĂŒr das Training oder die Validierung und gib ihn zurĂŒck.

Parameter:

Name Typ Beschreibung Standard
img_path str

Pfad zu dem Ordner, der die Bilder enthÀlt.

erforderlich
mode str

Datensatzmodus, entweder "train" oder "val".

'val'
batch int

StapelgrĂ¶ĂŸe fĂŒr das Rechtecktraining. Der Standardwert ist Keine.

None

Retouren:

Typ Beschreibung
RTDETRDataset

Dataset-Objekt fĂŒr den spezifischen Modus.

Quellcode in ultralytics/models/rtdetr/train.py
def build_dataset(self, img_path, mode="val", batch=None):
    """
    Build and return an RT-DETR dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): Dataset mode, either 'train' or 'val'.
        batch (int, optional): Batch size for rectangle training. Defaults to None.

    Returns:
        (RTDETRDataset): Dataset object for the specific mode.
    """
    return RTDETRDataset(
        img_path=img_path,
        imgsz=self.args.imgsz,
        batch_size=batch,
        augment=mode == "train",
        hyp=self.args,
        rect=False,
        cache=self.args.cache or None,
        prefix=colorstr(f"{mode}: "),
        data=self.data,
    )

get_model(cfg=None, weights=None, verbose=True)

Initialisiere ein RT-DETR Modell fĂŒr Objekterkennungsaufgaben und gib es zurĂŒck.

Parameter:

Name Typ Beschreibung Standard
cfg dict

Modellkonfiguration. Der Standardwert ist Keine.

None
weights str

Pfad zu den vortrainierten Modellgewichten. Der Standardwert ist Keine.

None
verbose bool

AusfĂŒhrliche Protokollierung bei True. StandardmĂ€ĂŸig ist True eingestellt.

True

Retouren:

Typ Beschreibung
RTDETRDetectionModel

Initialisiertes Modell.

Quellcode in ultralytics/models/rtdetr/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """
    Initialize and return an RT-DETR model for object detection tasks.

    Args:
        cfg (dict, optional): Model configuration. Defaults to None.
        weights (str, optional): Path to pre-trained model weights. Defaults to None.
        verbose (bool): Verbose logging if True. Defaults to True.

    Returns:
        (RTDETRDetectionModel): Initialized model.
    """
    model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator()

Gibt einen DetectionValidator zurĂŒck, der fĂŒr die RT-DETR Modellvalidierung geeignet ist.

Retouren:

Typ Beschreibung
RTDETRValidator

Validator-Objekt fĂŒr die Modellvalidierung.

Quellcode in ultralytics/models/rtdetr/train.py
def get_validator(self):
    """
    Returns a DetectionValidator suitable for RT-DETR model validation.

    Returns:
        (RTDETRValidator): Validator object for model validation.
    """
    self.loss_names = "giou_loss", "cls_loss", "l1_loss"
    return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

preprocess_batch(batch)

Verarbeitet einen Stapel von Bildern vor. Skaliert und konvertiert die Bilder in das Float-Format.

Parameter:

Name Typ Beschreibung Standard
batch dict

Wörterbuch mit einem Stapel von Bildern, bboxes und Labels.

erforderlich

Retouren:

Typ Beschreibung
dict

Vorverarbeiteter Stapel.

Quellcode in ultralytics/models/rtdetr/train.py
def preprocess_batch(self, batch):
    """
    Preprocess a batch of images. Scales and converts the images to float format.

    Args:
        batch (dict): Dictionary containing a batch of images, bboxes, and labels.

    Returns:
        (dict): Preprocessed batch.
    """
    batch = super().preprocess_batch(batch)
    bs = len(batch["img"])
    batch_idx = batch["batch_idx"]
    gt_bbox, gt_class = [], []
    for i in range(bs):
        gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
        gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
    return batch





Erstellt 2023-11-12, aktualisiert 2024-06-02
Autoren: glenn-jocher (5), Burhan-Q (1)