انتقل إلى المحتوى

مرجع ل ultralytics/models/rtdetr/train.py

ملاحظه

هذا الملف متاح في https://github.com/ultralytics/ultralytics/ نقطة / الرئيسية /ultralytics/نماذج/rtdetr/train.py. إذا اكتشفت مشكلة ، فيرجى المساعدة في إصلاحها من خلال المساهمة في طلب 🛠️ سحب. شكرا لك 🙏!



ultralytics.models.rtdetr.train.RTDETRTrainer

قواعد: DetectionTrainer

فئة المدرب ل RT-DETR نموذج طورته بايدو للكشف عن الأشياء في الوقت الفعلي. يوسع المدرب الكشف فئة ل YOLO للتكيف مع الميزات والبنية المحددة ل RT-DETR. هذا النموذج يعزز الرؤية المحولات ولديه إمكانات مثل اختيار الاستعلام المدرك لإنترنت الأشياء وسرعة الاستدلال القابلة للتكيف.

تلاحظ
  • F.grid_sample المستخدمة في RT-DETR لا يدعم deterministic=True جدال.
  • يمكن أن يؤدي تدريب AMP إلى مخرجات NaN وقد ينتج عنه أخطاء أثناء مطابقة الرسم البياني الثنائي.
مثل
from ultralytics.models.rtdetr.train import RTDETRTrainer

args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
شفرة المصدر في ultralytics/models/rtdetr/train.py
class RTDETRTrainer(DetectionTrainer):
    """
    Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
    class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
    Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.

    Notes:
        - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
        - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.

    Example:
        ```python
        from ultralytics.models.rtdetr.train import RTDETRTrainer

        args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
        trainer = RTDETRTrainer(overrides=args)
        trainer.train()
        ```
    """

    def get_model(self, cfg=None, weights=None, verbose=True):
        """
        Initialize and return an RT-DETR model for object detection tasks.

        Args:
            cfg (dict, optional): Model configuration. Defaults to None.
            weights (str, optional): Path to pre-trained model weights. Defaults to None.
            verbose (bool): Verbose logging if True. Defaults to True.

        Returns:
            (RTDETRDetectionModel): Initialized model.
        """
        model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model

    def build_dataset(self, img_path, mode="val", batch=None):
        """
        Build and return an RT-DETR dataset for training or validation.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): Dataset mode, either 'train' or 'val'.
            batch (int, optional): Batch size for rectangle training. Defaults to None.

        Returns:
            (RTDETRDataset): Dataset object for the specific mode.
        """
        return RTDETRDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch,
            augment=mode == "train",
            hyp=self.args,
            rect=False,
            cache=self.args.cache or None,
            prefix=colorstr(f"{mode}: "),
            data=self.data,
        )

    def get_validator(self):
        """
        Returns a DetectionValidator suitable for RT-DETR model validation.

        Returns:
            (RTDETRValidator): Validator object for model validation.
        """
        self.loss_names = "giou_loss", "cls_loss", "l1_loss"
        return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

    def preprocess_batch(self, batch):
        """
        Preprocess a batch of images. Scales and converts the images to float format.

        Args:
            batch (dict): Dictionary containing a batch of images, bboxes, and labels.

        Returns:
            (dict): Preprocessed batch.
        """
        batch = super().preprocess_batch(batch)
        bs = len(batch["img"])
        batch_idx = batch["batch_idx"]
        gt_bbox, gt_class = [], []
        for i in range(bs):
            gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
            gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
        return batch

build_dataset(img_path, mode='val', batch=None)

بناء وإرجاع RT-DETR مجموعة بيانات للتدريب أو التحقق من الصحة.

البارامترات:

اسم نوع وصف افتراضي
img_path str

المسار إلى المجلد الذي يحتوي على الصور.

مطلوب
mode str

وضع مجموعة البيانات ، إما "قطار" أو "فال".

'val'
batch int

حجم الدفعة للتدريب المستطيل. الإعدادات الافتراضية إلى لا شيء.

None

ارجاع:

نوع وصف
RTDETRDataset

كائن مجموعة البيانات للوضع المحدد.

شفرة المصدر في ultralytics/models/rtdetr/train.py
def build_dataset(self, img_path, mode="val", batch=None):
    """
    Build and return an RT-DETR dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): Dataset mode, either 'train' or 'val'.
        batch (int, optional): Batch size for rectangle training. Defaults to None.

    Returns:
        (RTDETRDataset): Dataset object for the specific mode.
    """
    return RTDETRDataset(
        img_path=img_path,
        imgsz=self.args.imgsz,
        batch_size=batch,
        augment=mode == "train",
        hyp=self.args,
        rect=False,
        cache=self.args.cache or None,
        prefix=colorstr(f"{mode}: "),
        data=self.data,
    )

get_model(cfg=None, weights=None, verbose=True)

تهيئة وإرجاع ملف RT-DETR نموذج لمهام الكشف عن الكائنات.

البارامترات:

اسم نوع وصف افتراضي
cfg dict

تكوين النموذج. الإعدادات الافتراضية إلى لا شيء.

None
weights str

الطريق إلى أوزان النموذج المدرب مسبقا. الإعدادات الافتراضية إلى لا شيء.

None
verbose bool

التسجيل المطول إذا كان صحيحا. الإعدادات الافتراضية إلى صواب.

True

ارجاع:

نوع وصف
RTDETRDetectionModel

نموذج مهيأ.

شفرة المصدر في ultralytics/models/rtdetr/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """
    Initialize and return an RT-DETR model for object detection tasks.

    Args:
        cfg (dict, optional): Model configuration. Defaults to None.
        weights (str, optional): Path to pre-trained model weights. Defaults to None.
        verbose (bool): Verbose logging if True. Defaults to True.

    Returns:
        (RTDETRDetectionModel): Initialized model.
    """
    model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator()

إرجاع أداة التحقق من صحة الكشف المناسبة ل RT-DETR التحقق من صحة النموذج.

ارجاع:

نوع وصف
RTDETRValidator

كائن المدقق للتحقق من صحة النموذج.

شفرة المصدر في ultralytics/models/rtdetr/train.py
def get_validator(self):
    """
    Returns a DetectionValidator suitable for RT-DETR model validation.

    Returns:
        (RTDETRValidator): Validator object for model validation.
    """
    self.loss_names = "giou_loss", "cls_loss", "l1_loss"
    return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

preprocess_batch(batch)

المعالجة المسبقة لمجموعة من الصور. يقيس ويحول الصور إلى تنسيق عائم.

البارامترات:

اسم نوع وصف افتراضي
batch dict

قاموس يحتوي على مجموعة من الصور وصناديق البت والتسميات.

مطلوب

ارجاع:

نوع وصف
dict

دفعة معالجة مسبقا.

شفرة المصدر في ultralytics/models/rtdetr/train.py
def preprocess_batch(self, batch):
    """
    Preprocess a batch of images. Scales and converts the images to float format.

    Args:
        batch (dict): Dictionary containing a batch of images, bboxes, and labels.

    Returns:
        (dict): Preprocessed batch.
    """
    batch = super().preprocess_batch(batch)
    bs = len(batch["img"])
    batch_idx = batch["batch_idx"]
    gt_bbox, gt_class = [], []
    for i in range(bs):
        gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
        gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
    return batch





Created 2023-11-12, Updated 2024-06-02
Authors: glenn-jocher (5), Burhan-Q (1)