Reference for ultralytics/models/rtdetr/train.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/rtdetr/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.rtdetr.train.RTDETRTrainer
Bases: DetectionTrainer
Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
Notes
- F.grid_sample used in RT-DETR does not support the
deterministic=True
argument. - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Example
Source code in ultralytics/models/rtdetr/train.py
build_dataset(img_path, mode='val', batch=None)
Build and return an RT-DETR dataset for training or validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img_path |
str
|
Path to the folder containing images. |
required |
mode |
str
|
Dataset mode, either 'train' or 'val'. |
'val'
|
batch |
int
|
Batch size for rectangle training. Defaults to None. |
None
|
Returns:
Type | Description |
---|---|
RTDETRDataset
|
Dataset object for the specific mode. |
Source code in ultralytics/models/rtdetr/train.py
get_model(cfg=None, weights=None, verbose=True)
Initialize and return an RT-DETR model for object detection tasks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg |
dict
|
Model configuration. Defaults to None. |
None
|
weights |
str
|
Path to pre-trained model weights. Defaults to None. |
None
|
verbose |
bool
|
Verbose logging if True. Defaults to True. |
True
|
Returns:
Type | Description |
---|---|
RTDETRDetectionModel
|
Initialized model. |
Source code in ultralytics/models/rtdetr/train.py
get_validator()
Returns a DetectionValidator suitable for RT-DETR model validation.
Returns:
Type | Description |
---|---|
RTDETRValidator
|
Validator object for model validation. |
Source code in ultralytics/models/rtdetr/train.py
preprocess_batch(batch)
Preprocess a batch of images. Scales and converts the images to float format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
dict
|
Dictionary containing a batch of images, bboxes, and labels. |
required |
Returns:
Type | Description |
---|---|
dict
|
Preprocessed batch. |