Skip to content

Reference for ultralytics/models/rtdetr/train.py

Note

Full source code for this file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/rtdetr/train.py. Help us fix any issues you see by submitting a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.rtdetr.train.RTDETRTrainer

Bases: DetectionTrainer

A class extending the DetectionTrainer class for training based on an RT-DETR detection model.

Notes
  • F.grid_sample used in rt-detr does not support the deterministic=True argument.
  • AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Example
from ultralytics.models.rtdetr.train import RTDETRTrainer

args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
Source code in ultralytics/models/rtdetr/train.py
class RTDETRTrainer(DetectionTrainer):
    """
    A class extending the DetectionTrainer class for training based on an RT-DETR detection model.

    Notes:
        - F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
        - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.

    Example:
        ```python
        from ultralytics.models.rtdetr.train import RTDETRTrainer

        args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
        trainer = RTDETRTrainer(overrides=args)
        trainer.train()
        ```
    """

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Return a YOLO detection model."""
        model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model

    def build_dataset(self, img_path, mode='val', batch=None):
        """Build RTDETR Dataset

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
        """
        return RTDETRDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch,
            augment=mode == 'train',  # no augmentation
            hyp=self.args,
            rect=False,  # no rect
            cache=self.args.cache or None,
            prefix=colorstr(f'{mode}: '),
            data=self.data)

    def get_validator(self):
        """Returns a DetectionValidator for RTDETR model validation."""
        self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
        return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

    def preprocess_batch(self, batch):
        """Preprocesses a batch of images by scaling and converting to float."""
        batch = super().preprocess_batch(batch)
        bs = len(batch['img'])
        batch_idx = batch['batch_idx']
        gt_bbox, gt_class = [], []
        for i in range(bs):
            gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
            gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
        return batch

build_dataset(img_path, mode='val', batch=None)

Build RTDETR Dataset

Parameters:

Name Type Description Default
img_path str

Path to the folder containing images.

required
mode str

train mode or val mode, users are able to customize different augmentations for each mode.

'val'
batch int

Size of batches, this is for rect. Defaults to None.

None
Source code in ultralytics/models/rtdetr/train.py
def build_dataset(self, img_path, mode='val', batch=None):
    """Build RTDETR Dataset

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
    """
    return RTDETRDataset(
        img_path=img_path,
        imgsz=self.args.imgsz,
        batch_size=batch,
        augment=mode == 'train',  # no augmentation
        hyp=self.args,
        rect=False,  # no rect
        cache=self.args.cache or None,
        prefix=colorstr(f'{mode}: '),
        data=self.data)

get_model(cfg=None, weights=None, verbose=True)

Return a YOLO detection model.

Source code in ultralytics/models/rtdetr/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Return a YOLO detection model."""
    model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator()

Returns a DetectionValidator for RTDETR model validation.

Source code in ultralytics/models/rtdetr/train.py
def get_validator(self):
    """Returns a DetectionValidator for RTDETR model validation."""
    self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
    return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

preprocess_batch(batch)

Preprocesses a batch of images by scaling and converting to float.

Source code in ultralytics/models/rtdetr/train.py
def preprocess_batch(self, batch):
    """Preprocesses a batch of images by scaling and converting to float."""
    batch = super().preprocess_batch(batch)
    bs = len(batch['img'])
    batch_idx = batch['batch_idx']
    gt_bbox, gt_class = [], []
    for i in range(bs):
        gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
        gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
    return batch




Created 2023-07-16, Updated 2023-08-20
Authors: glenn-jocher (6), Laughing-q (1)