Bases: DetectionTrainer
Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
Notes
- F.grid_sample used in RT-DETR does not support the
deterministic=True
argument.
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Example
from ultralytics.models.rtdetr.train import RTDETRTrainer
args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
Parameters:
Name |
Type |
Description |
Default |
cfg
|
str
|
Path to a configuration file. Defaults to DEFAULT_CFG.
|
DEFAULT_CFG
|
overrides
|
dict
|
Configuration overrides. Defaults to None.
|
None
|
Source code in ultralytics/engine/trainer.py
| def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the BaseTrainer class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.check_resume(overrides)
self.device = select_device(self.args.device, self.args.batch)
self.validator = None
self.metrics = None
self.plots = {}
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
self.save_dir = get_save_dir(self.args)
self.args.name = self.save_dir.name # update name for loggers
self.wdir = self.save_dir / "weights" # weights dir
if RANK in {-1, 0}:
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.args.save_dir = str(self.save_dir)
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
self.save_period = self.args.save_period
self.batch_size = self.args.batch
self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training
self.start_epoch = 0
if RANK == -1:
print_args(vars(self.args))
# Device
if self.device.type in {"cpu", "mps"}:
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataset
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
self.trainset, self.testset = self.get_dataset()
self.ema = None
# Optimization utils init
self.lf = None
self.scheduler = None
# Epoch level metrics
self.best_fitness = None
self.fitness = None
self.loss = None
self.tloss = None
self.loss_names = ["Loss"]
self.csv = self.save_dir / "results.csv"
self.plot_idx = [0, 1, 2]
# HUB
self.hub_session = None
# Callbacks
self.callbacks = _callbacks or callbacks.get_default_callbacks()
if RANK in {-1, 0}:
callbacks.add_integration_callbacks(self)
|
build_dataset
build_dataset(img_path, mode='val', batch=None)
Build and return an RT-DETR dataset for training or validation.
Parameters:
Name |
Type |
Description |
Default |
img_path
|
str
|
Path to the folder containing images.
|
required
|
mode
|
str
|
Dataset mode, either 'train' or 'val'.
|
'val'
|
batch
|
int
|
Batch size for rectangle training. Defaults to None.
|
None
|
Returns:
Type |
Description |
RTDETRDataset
|
Dataset object for the specific mode.
|
Source code in ultralytics/models/rtdetr/train.py
| def build_dataset(self, img_path, mode="val", batch=None):
"""
Build and return an RT-DETR dataset for training or validation.
Args:
img_path (str): Path to the folder containing images.
mode (str): Dataset mode, either 'train' or 'val'.
batch (int, optional): Batch size for rectangle training. Defaults to None.
Returns:
(RTDETRDataset): Dataset object for the specific mode.
"""
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=mode == "train",
hyp=self.args,
rect=False,
cache=self.args.cache or None,
single_cls=self.args.single_cls or False,
prefix=colorstr(f"{mode}: "),
classes=self.args.classes,
data=self.data,
fraction=self.args.fraction if mode == "train" else 1.0,
)
|
get_model
get_model(cfg=None, weights=None, verbose=True)
Initialize and return an RT-DETR model for object detection tasks.
Parameters:
Name |
Type |
Description |
Default |
cfg
|
dict
|
Model configuration. Defaults to None.
|
None
|
weights
|
str
|
Path to pre-trained model weights. Defaults to None.
|
None
|
verbose
|
bool
|
Verbose logging if True. Defaults to True.
|
True
|
Returns:
Source code in ultralytics/models/rtdetr/train.py
| def get_model(self, cfg=None, weights=None, verbose=True):
"""
Initialize and return an RT-DETR model for object detection tasks.
Args:
cfg (dict, optional): Model configuration. Defaults to None.
weights (str, optional): Path to pre-trained model weights. Defaults to None.
verbose (bool): Verbose logging if True. Defaults to True.
Returns:
(RTDETRDetectionModel): Initialized model.
"""
model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
|
get_validator
Returns a DetectionValidator suitable for RT-DETR model validation.
Returns:
Source code in ultralytics/models/rtdetr/train.py
| def get_validator(self):
"""
Returns a DetectionValidator suitable for RT-DETR model validation.
Returns:
(RTDETRValidator): Validator object for model validation.
"""
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
preprocess_batch
Preprocess a batch of images. Scales and converts the images to float format.
Parameters:
Name |
Type |
Description |
Default |
batch
|
dict
|
Dictionary containing a batch of images, bboxes, and labels.
|
required
|
Returns:
Source code in ultralytics/models/rtdetr/train.py
| def preprocess_batch(self, batch):
"""
Preprocess a batch of images. Scales and converts the images to float format.
Args:
batch (dict): Dictionary containing a batch of images, bboxes, and labels.
Returns:
(dict): Preprocessed batch.
"""
batch = super().preprocess_batch(batch)
bs = len(batch["img"])
batch_idx = batch["batch_idx"]
gt_bbox, gt_class = [], []
for i in range(bs):
gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch
|