Skip to content

Reference for ultralytics/models/rtdetr/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/rtdetr/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.rtdetr.train.RTDETRTrainer

RTDETRTrainer(
    cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None
)

Bases: DetectionTrainer

Trainer class for the RT-DETR model developed by Baidu for real-time object detection.

This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.

Attributes:

Name Type Description
loss_names tuple

Names of the loss components used for training.

data dict

Dataset configuration containing class count and other parameters.

args dict

Training arguments and hyperparameters.

save_dir Path

Directory to save training results.

test_loader DataLoader

DataLoader for validation/testing data.

Methods:

Name Description
get_model

Initialize and return an RT-DETR model for object detection tasks.

build_dataset

Build and return an RT-DETR dataset for training or validation.

get_validator

Return a DetectionValidator suitable for RT-DETR model validation.

Notes
  • F.grid_sample used in RT-DETR does not support the deterministic=True argument.
  • AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.

Examples:

>>> from ultralytics.models.rtdetr.train import RTDETRTrainer
>>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
>>> trainer = RTDETRTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/models/yolo/detect/train.py
56
57
58
59
60
61
62
63
64
65
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
    """
    Initialize a DetectionTrainer object for training YOLO object detection model training.

    Args:
        cfg (dict, optional): Default configuration dictionary containing training parameters.
        overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
        _callbacks (list, optional): List of callback functions to be executed during training.
    """
    super().__init__(cfg, overrides, _callbacks)

build_dataset

build_dataset(img_path: str, mode: str = 'val', batch: int | None = None)

Build and return an RT-DETR dataset for training or validation.

Parameters:

Name Type Description Default
img_path str

Path to the folder containing images.

required
mode str

Dataset mode, either 'train' or 'val'.

'val'
batch int

Batch size for rectangle training.

None

Returns:

Type Description
RTDETRDataset

Dataset object for the specific mode.

Source code in ultralytics/models/rtdetr/train.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
    """
    Build and return an RT-DETR dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): Dataset mode, either 'train' or 'val'.
        batch (int, optional): Batch size for rectangle training.

    Returns:
        (RTDETRDataset): Dataset object for the specific mode.
    """
    return RTDETRDataset(
        img_path=img_path,
        imgsz=self.args.imgsz,
        batch_size=batch,
        augment=mode == "train",
        hyp=self.args,
        rect=False,
        cache=self.args.cache or None,
        single_cls=self.args.single_cls or False,
        prefix=colorstr(f"{mode}: "),
        classes=self.args.classes,
        data=self.data,
        fraction=self.args.fraction if mode == "train" else 1.0,
    )

get_model

get_model(
    cfg: dict | None = None, weights: str | None = None, verbose: bool = True
)

Initialize and return an RT-DETR model for object detection tasks.

Parameters:

Name Type Description Default
cfg dict

Model configuration.

None
weights str

Path to pre-trained model weights.

None
verbose bool

Verbose logging if True.

True

Returns:

Type Description
RTDETRDetectionModel

Initialized model.

Source code in ultralytics/models/rtdetr/train.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
    """
    Initialize and return an RT-DETR model for object detection tasks.

    Args:
        cfg (dict, optional): Model configuration.
        weights (str, optional): Path to pre-trained model weights.
        verbose (bool): Verbose logging if True.

    Returns:
        (RTDETRDetectionModel): Initialized model.
    """
    model = RTDETRDetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator

get_validator()

Return a DetectionValidator suitable for RT-DETR model validation.

Source code in ultralytics/models/rtdetr/train.py
89
90
91
92
def get_validator(self):
    """Return a DetectionValidator suitable for RT-DETR model validation."""
    self.loss_names = "giou_loss", "cls_loss", "l1_loss"
    return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))





📅 Created 1 year ago ✏️ Updated 1 year ago