Bỏ để qua phần nội dung

Tài liệu tham khảo cho ultralytics/models/rtdetr/train.py

Ghi

Tệp này có sẵn tại https://github.com/ultralytics/ultralytics/blob/main/ultralytics/mô hình/rtdetr/train.py. Nếu bạn phát hiện ra một vấn đề, vui lòng giúp khắc phục nó bằng cách đóng góp Yêu cầu 🛠️ kéo. Cảm ơn bạn 🙏 !



ultralytics.models.rtdetr.train.RTDETRTrainer

Căn cứ: DetectionTrainer

Lớp huấn luyện viên cho RT-DETR mô hình được phát triển bởi Baidu để phát hiện đối tượng thời gian thực. Mở rộng DetectionTrainer lớp học cho YOLO để thích ứng với các tính năng và kiến trúc cụ thể của RT-DETR. Mô hình này tận dụng Tầm nhìn Transformers và có các khả năng như lựa chọn truy vấn nhận biết IoU và tốc độ suy luận thích ứng.

Ghi chú
  • F.grid_sample sử dụng trong RT-DETR không hỗ trợ deterministic=True lý lẽ.
  • Đào tạo AMP có thể dẫn đến đầu ra NaN và có thể tạo ra lỗi trong quá trình khớp đồ thị hai bên.
Ví dụ
from ultralytics.models.rtdetr.train import RTDETRTrainer

args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
Mã nguồn trong ultralytics/models/rtdetr/train.py
class RTDETRTrainer(DetectionTrainer):
    """
    Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
    class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
    Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.

    Notes:
        - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
        - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.

    Example:
        ```python
        from ultralytics.models.rtdetr.train import RTDETRTrainer

        args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
        trainer = RTDETRTrainer(overrides=args)
        trainer.train()
        ```
    """

    def get_model(self, cfg=None, weights=None, verbose=True):
        """
        Initialize and return an RT-DETR model for object detection tasks.

        Args:
            cfg (dict, optional): Model configuration. Defaults to None.
            weights (str, optional): Path to pre-trained model weights. Defaults to None.
            verbose (bool): Verbose logging if True. Defaults to True.

        Returns:
            (RTDETRDetectionModel): Initialized model.
        """
        model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model

    def build_dataset(self, img_path, mode="val", batch=None):
        """
        Build and return an RT-DETR dataset for training or validation.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): Dataset mode, either 'train' or 'val'.
            batch (int, optional): Batch size for rectangle training. Defaults to None.

        Returns:
            (RTDETRDataset): Dataset object for the specific mode.
        """
        return RTDETRDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch,
            augment=mode == "train",
            hyp=self.args,
            rect=False,
            cache=self.args.cache or None,
            prefix=colorstr(f"{mode}: "),
            data=self.data,
        )

    def get_validator(self):
        """
        Returns a DetectionValidator suitable for RT-DETR model validation.

        Returns:
            (RTDETRValidator): Validator object for model validation.
        """
        self.loss_names = "giou_loss", "cls_loss", "l1_loss"
        return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

    def preprocess_batch(self, batch):
        """
        Preprocess a batch of images. Scales and converts the images to float format.

        Args:
            batch (dict): Dictionary containing a batch of images, bboxes, and labels.

        Returns:
            (dict): Preprocessed batch.
        """
        batch = super().preprocess_batch(batch)
        bs = len(batch["img"])
        batch_idx = batch["batch_idx"]
        gt_bbox, gt_class = [], []
        for i in range(bs):
            gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
            gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
        return batch

build_dataset(img_path, mode='val', batch=None)

Xây dựng và trả về một RT-DETR tập dữ liệu để đào tạo hoặc xác nhận.

Thông số:

Tên Kiểu Sự miêu tả Mặc định
img_path str

Đường dẫn đến thư mục chứa hình ảnh.

bắt buộc
mode str

Chế độ tập dữ liệu, 'tàu' hoặc 'val'.

'val'
batch int

Kích thước hàng loạt để đào tạo hình chữ nhật. Mặc định là Không có.

None

Trở lại:

Kiểu Sự miêu tả
RTDETRDataset

Đối tượng tập dữ liệu cho chế độ cụ thể.

Mã nguồn trong ultralytics/models/rtdetr/train.py
def build_dataset(self, img_path, mode="val", batch=None):
    """
    Build and return an RT-DETR dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): Dataset mode, either 'train' or 'val'.
        batch (int, optional): Batch size for rectangle training. Defaults to None.

    Returns:
        (RTDETRDataset): Dataset object for the specific mode.
    """
    return RTDETRDataset(
        img_path=img_path,
        imgsz=self.args.imgsz,
        batch_size=batch,
        augment=mode == "train",
        hyp=self.args,
        rect=False,
        cache=self.args.cache or None,
        prefix=colorstr(f"{mode}: "),
        data=self.data,
    )

get_model(cfg=None, weights=None, verbose=True)

Khởi tạo và trả về một RT-DETR mô hình cho các nhiệm vụ phát hiện đối tượng.

Thông số:

Tên Kiểu Sự miêu tả Mặc định
cfg dict

Cấu hình mô hình. Mặc định là Không có.

None
weights str

Đường dẫn đến trọng lượng mô hình được đào tạo trước. Mặc định là Không có.

None
verbose bool

Ghi nhật ký chi tiết nếu True. Mặc định là True.

True

Trở lại:

Kiểu Sự miêu tả
RTDETRDetectionModel

Mô hình khởi tạo.

Mã nguồn trong ultralytics/models/rtdetr/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """
    Initialize and return an RT-DETR model for object detection tasks.

    Args:
        cfg (dict, optional): Model configuration. Defaults to None.
        weights (str, optional): Path to pre-trained model weights. Defaults to None.
        verbose (bool): Verbose logging if True. Defaults to True.

    Returns:
        (RTDETRDetectionModel): Initialized model.
    """
    model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator()

Trả về một DetectionValidator phù hợp với RT-DETR Xác thực mô hình.

Trở lại:

Kiểu Sự miêu tả
RTDETRValidator

Đối tượng xác thực để xác thực mô hình.

Mã nguồn trong ultralytics/models/rtdetr/train.py
def get_validator(self):
    """
    Returns a DetectionValidator suitable for RT-DETR model validation.

    Returns:
        (RTDETRValidator): Validator object for model validation.
    """
    self.loss_names = "giou_loss", "cls_loss", "l1_loss"
    return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

preprocess_batch(batch)

Xử lý trước một loạt hình ảnh. Chia tỷ lệ và chuyển đổi hình ảnh sang định dạng nổi.

Thông số:

Tên Kiểu Sự miêu tả Mặc định
batch dict

Từ điển chứa một loạt hình ảnh, hộp thư và nhãn.

bắt buộc

Trở lại:

Kiểu Sự miêu tả
dict

Lô tiền xử lý.

Mã nguồn trong ultralytics/models/rtdetr/train.py
def preprocess_batch(self, batch):
    """
    Preprocess a batch of images. Scales and converts the images to float format.

    Args:
        batch (dict): Dictionary containing a batch of images, bboxes, and labels.

    Returns:
        (dict): Preprocessed batch.
    """
    batch = super().preprocess_batch(batch)
    bs = len(batch["img"])
    batch_idx = batch["batch_idx"]
    gt_bbox, gt_class = [], []
    for i in range(bs):
        gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
        gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
    return batch





Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (5), Burhan-Q (1)