Skip to content

Reference for ultralytics/models/rtdetr/train.py

Improvements

This page is sourced from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/rtdetr/train.py. Have an improvement or example to add? Open a Pull Request — thank you! 🙏


class ultralytics.models.rtdetr.train.RTDETRTrainer

RTDETRTrainer()

Bases: DetectionTrainer

Trainer class for the RT-DETR model developed by Baidu for real-time object detection.

This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.

Attributes

NameTypeDescription
loss_namestupleNames of the loss components used for training.
datadictDataset configuration containing class count and other parameters.
argsdictTraining arguments and hyperparameters.
save_dirPathDirectory to save training results.
test_loaderDataLoaderDataLoader for validation/testing data.

Methods

NameDescription
build_datasetBuild and return an RT-DETR dataset for training or validation.
get_modelInitialize and return an RT-DETR model for object detection tasks.
get_validatorReturn a DetectionValidator suitable for RT-DETR model validation.

Examples

>>> from ultralytics.models.rtdetr.train import RTDETRTrainer
>>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
>>> trainer = RTDETRTrainer(overrides=args)
>>> trainer.train()

Notes

  • F.grid_sample used in RT-DETR does not support the deterministic=True argument.
  • AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Source code in ultralytics/models/rtdetr/train.pyView on GitHub
class RTDETRTrainer(DetectionTrainer):


method ultralytics.models.rtdetr.train.RTDETRTrainer.build_dataset

def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None)

Build and return an RT-DETR dataset for training or validation.

Args

NameTypeDescriptionDefault
img_pathstrPath to the folder containing images.required
modestrDataset mode, either 'train' or 'val'."val"
batchint, optionalBatch size for rectangle training.None

Returns

TypeDescription
RTDETRDatasetDataset object for the specific mode.
Source code in ultralytics/models/rtdetr/train.pyView on GitHub
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
    """Build and return an RT-DETR dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): Dataset mode, either 'train' or 'val'.
        batch (int, optional): Batch size for rectangle training.

    Returns:
        (RTDETRDataset): Dataset object for the specific mode.
    """
    return RTDETRDataset(
        img_path=img_path,
        imgsz=self.args.imgsz,
        batch_size=batch,
        augment=mode == "train",
        hyp=self.args,
        rect=False,
        cache=self.args.cache or None,
        single_cls=self.args.single_cls or False,
        prefix=colorstr(f"{mode}: "),
        classes=self.args.classes,
        data=self.data,
        fraction=self.args.fraction if mode == "train" else 1.0,
    )


method ultralytics.models.rtdetr.train.RTDETRTrainer.get_model

def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True)

Initialize and return an RT-DETR model for object detection tasks.

Args

NameTypeDescriptionDefault
cfgdict, optionalModel configuration.None
weightsstr, optionalPath to pre-trained model weights.None
verboseboolVerbose logging if True.True

Returns

TypeDescription
RTDETRDetectionModelInitialized model.
Source code in ultralytics/models/rtdetr/train.pyView on GitHub
def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
    """Initialize and return an RT-DETR model for object detection tasks.

    Args:
        cfg (dict, optional): Model configuration.
        weights (str, optional): Path to pre-trained model weights.
        verbose (bool): Verbose logging if True.

    Returns:
        (RTDETRDetectionModel): Initialized model.
    """
    model = RTDETRDetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model


method ultralytics.models.rtdetr.train.RTDETRTrainer.get_validator

def get_validator(self)

Return a DetectionValidator suitable for RT-DETR model validation.

Source code in ultralytics/models/rtdetr/train.pyView on GitHub
def get_validator(self):
    """Return a DetectionValidator suitable for RT-DETR model validation."""
    self.loss_names = "giou_loss", "cls_loss", "l1_loss"
    return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))





📅 Created 2 years ago ✏️ Updated 2 days ago
glenn-jocherjk4eBurhan-Q