Reference for ultralytics/models/rtdetr/train.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/rtdetr/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.rtdetr.train.RTDETRTrainer
RTDETRTrainer(
cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None
)
Bases: DetectionTrainer
Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
Attributes:
Name | Type | Description |
---|---|---|
loss_names |
tuple
|
Names of the loss components used for training. |
data |
dict
|
Dataset configuration containing class count and other parameters. |
args |
dict
|
Training arguments and hyperparameters. |
save_dir |
Path
|
Directory to save training results. |
test_loader |
DataLoader
|
DataLoader for validation/testing data. |
Methods:
Name | Description |
---|---|
get_model |
Initialize and return an RT-DETR model for object detection tasks. |
build_dataset |
Build and return an RT-DETR dataset for training or validation. |
get_validator |
Return a DetectionValidator suitable for RT-DETR model validation. |
Notes
- F.grid_sample used in RT-DETR does not support the
deterministic=True
argument. - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Examples:
>>> from ultralytics.models.rtdetr.train import RTDETRTrainer
>>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
>>> trainer = RTDETRTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/models/yolo/detect/train.py
56 57 58 59 60 61 62 63 64 65 |
|
build_dataset
build_dataset(img_path: str, mode: str = 'val', batch: int | None = None)
Build and return an RT-DETR dataset for training or validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img_path
|
str
|
Path to the folder containing images. |
required |
mode
|
str
|
Dataset mode, either 'train' or 'val'. |
'val'
|
batch
|
int
|
Batch size for rectangle training. |
None
|
Returns:
Type | Description |
---|---|
RTDETRDataset
|
Dataset object for the specific mode. |
Source code in ultralytics/models/rtdetr/train.py
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
|
get_model
get_model(
cfg: dict | None = None, weights: str | None = None, verbose: bool = True
)
Initialize and return an RT-DETR model for object detection tasks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
dict
|
Model configuration. |
None
|
weights
|
str
|
Path to pre-trained model weights. |
None
|
verbose
|
bool
|
Verbose logging if True. |
True
|
Returns:
Type | Description |
---|---|
RTDETRDetectionModel
|
Initialized model. |
Source code in ultralytics/models/rtdetr/train.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|
get_validator
get_validator()
Return a DetectionValidator suitable for RT-DETR model validation.
Source code in ultralytics/models/rtdetr/train.py
89 90 91 92 |
|