सामग्री पर जाएं

के लिए संदर्भ ultralytics/models/rtdetr/train.py

नोट

यह फ़ाइल यहाँ उपलब्ध है https://github.com/ultralytics/ultralytics/बूँद/मुख्य/ultralytics/models/rtdetr/train.py. यदि आप कोई समस्या देखते हैं तो कृपया पुल अनुरोध का योगदान करके इसे ठीक करने में मदद करें 🛠️। 🙏 धन्यवाद !



ultralytics.models.rtdetr.train.RTDETRTrainer

का रूप: DetectionTrainer

के लिए ट्रेनर वर्ग RT-DETR वास्तविक समय वस्तु का पता लगाने के लिए Baidu द्वारा विकसित मॉडल। डिटेक्शनट्रेनर का विस्तार करता है के लिए वर्ग YOLO की विशिष्ट विशेषताओं और वास्तुकला के अनुकूल होने के लिए RT-DETR. यह मॉडल विजन का लाभ उठाता है ट्रांसफॉर्मर और IoU-जागरूक क्वेरी चयन और अनुकूलनीय अनुमान गति जैसी क्षमताएं हैं।

नोट्स
  • F.grid_sample में इस्तेमाल किया RT-DETR का समर्थन नहीं करता है deterministic=True युक्ति।
  • एएमपी प्रशिक्षण एनएएन आउटपुट का कारण बन सकता है और द्विपक्षीय ग्राफ मिलान के दौरान त्रुटियां उत्पन्न कर सकता है।
उदाहरण
from ultralytics.models.rtdetr.train import RTDETRTrainer

args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
में स्रोत कोड ultralytics/models/rtdetr/train.py
 13 बांग्लादेश 13 बांग्लादेश बांग्लादेश 13 बांग्लादेश 13 बांग्लादेश बांग्लादेश 13 बांग्लादेश बांग्लादेश 13 बांग्लादेश बांग्लादेश 13 बांग्लादेश बांग्लादेश 13 बांग्लादेश बांग्लादेश 13 बांग्लादेश बांग्लादेश 13 4 5            29       30    30 13   13           44  45 46 47  48 49 50  51 52 53 54  55 56 57 58  59 60 61 62 63 64 65 66 67   68  69  70 71 72 73 74 75    76 77   78           79 80 81 82 83 84 85 86 87 88 89 90 91  92 93 94 95 96 97 98         99 100  101 
class RTDETRTrainer(DetectionTrainer):
    """
    Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
    class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
    Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.

    Notes:
        - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
        - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.

    Example:
        ```python
        from ultralytics.models.rtdetr.train import RTDETRTrainer

        args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
        trainer = RTDETRTrainer(overrides=args)
        trainer.train()
        ```
    """

    def get_model(self, cfg=None, weights=None, verbose=True):
        """
        Initialize and return an RT-DETR model for object detection tasks.

        Args:
            cfg (dict, optional): Model configuration. Defaults to None.
            weights (str, optional): Path to pre-trained model weights. Defaults to None.
            verbose (bool): Verbose logging if True. Defaults to True.

        Returns:
            (RTDETRDetectionModel): Initialized model.
        """
        model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model

    def build_dataset(self, img_path, mode="val", batch=None):
        """
        Build and return an RT-DETR dataset for training or validation.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): Dataset mode, either 'train' or 'val'.
            batch (int, optional): Batch size for rectangle training. Defaults to None.

        Returns:
            (RTDETRDataset): Dataset object for the specific mode.
        """
        return RTDETRDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch,
            augment=mode == "train",
            hyp=self.args,
            rect=False,
            cache=self.args.cache or None,
            prefix=colorstr(f"{mode}: "),
            data=self.data,
        )

    def get_validator(self):
        """
        Returns a DetectionValidator suitable for RT-DETR model validation.

        Returns:
            (RTDETRValidator): Validator object for model validation.
        """
        self.loss_names = "giou_loss", "cls_loss", "l1_loss"
        return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

    def preprocess_batch(self, batch):
        """
        Preprocess a batch of images. Scales and converts the images to float format.

        Args:
            batch (dict): Dictionary containing a batch of images, bboxes, and labels.

        Returns:
            (dict): Preprocessed batch.
        """
        batch = super().preprocess_batch(batch)
        bs = len(batch["img"])
        batch_idx = batch["batch_idx"]
        gt_bbox, gt_class = [], []
        for i in range(bs):
            gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
            gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
        return batch

build_dataset(img_path, mode='val', batch=None)

बनाएँ और एक लौटाएं RT-DETR प्रशिक्षण या सत्यापन के लिए डेटासेट।

पैरामीटर:

नाम प्रकार विवरण: __________ चूक
img_path str

छवियों वाले फ़ोल्डर का पथ।

आवश्यक
mode str

डेटासेट मोड, या तो 'ट्रेन' या 'वैल'।

'val'
batch int

आयत प्रशिक्षण के लिए बैच आकार। कोई नहीं करने के लिए डिफ़ॉल्ट।

None

देता:

प्रकार विवरण: __________
RTDETRDataset

विशिष्ट मोड के लिए डेटासेट ऑब्जेक्ट।

में स्रोत कोड ultralytics/models/rtdetr/train.py
50 51 52 53 54 55 56 57 58 59 60 61 62 63646566676869 7071 72
def build_dataset(self, img_path, mode="val", batch=None):
    """
    Build and return an RT-DETR dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): Dataset mode, either 'train' or 'val'.
        batch (int, optional): Batch size for rectangle training. Defaults to None.

    Returns:
        (RTDETRDataset): Dataset object for the specific mode.
    """
    return RTDETRDataset(
        img_path=img_path,
        imgsz=self.args.imgsz,
        batch_size=batch,
        augment=mode == "train",
        hyp=self.args,
        rect=False,
        cache=self.args.cache or None,
        prefix=colorstr(f"{mode}: "),
        data=self.data,
    )

get_model(cfg=None, weights=None, verbose=True)

प्रारंभ करें और एक लौटाएं RT-DETR ऑब्जेक्ट डिटेक्शन कार्यों के लिए मॉडल।

पैरामीटर:

नाम प्रकार विवरण: __________ चूक
cfg dict

मॉडल कॉन्फ़िगरेशन। कोई नहीं करने के लिए डिफ़ॉल्ट।

None
weights str

पूर्व-प्रशिक्षित मॉडल वजन का मार्ग। कोई नहीं करने के लिए डिफ़ॉल्ट।

None
verbose bool

वर्बोज़ लॉगिंग अगर सही है। सही करने के लिए डिफ़ॉल्ट।

True

देता:

प्रकार विवरण: __________
RTDETRDetectionModel

प्रारंभ किया गया मॉडल।

में स्रोत कोड ultralytics/models/rtdetr/train.py
33 बांग्लादेश 34 35 36 37 38 3940 41 42 43 44454647 48
def get_model(self, cfg=None, weights=None, verbose=True):
    """
    Initialize and return an RT-DETR model for object detection tasks.

    Args:
        cfg (dict, optional): Model configuration. Defaults to None.
        weights (str, optional): Path to pre-trained model weights. Defaults to None.
        verbose (bool): Verbose logging if True. Defaults to True.

    Returns:
        (RTDETRDetectionModel): Initialized model.
    """
    model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator()

के लिए उपयुक्त एक DetectionValidator देता है RT-DETR मॉडल सत्यापन।

देता:

प्रकार विवरण: __________
RTDETRValidator

मॉडल सत्यापन के लिए सत्यापनकर्ता ऑब्जेक्ट।

में स्रोत कोड ultralytics/models/rtdetr/train.py
def get_validator(self):
    """
    Returns a DetectionValidator suitable for RT-DETR model validation.

    Returns:
        (RTDETRValidator): Validator object for model validation.
    """
    self.loss_names = "giou_loss", "cls_loss", "l1_loss"
    return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

preprocess_batch(batch)

छवियों के एक बैच को प्रीप्रोसेस करें। छवियों को स्केल करता है और फ्लोट प्रारूप में परिवर्तित करता है।

पैरामीटर:

नाम प्रकार विवरण: __________ चूक
batch dict

शब्दकोश जिसमें छवियों, bboxes, और लेबल्स का एक बैच होता है.

आवश्यक

देता:

प्रकार विवरण: __________
dict

प्रीप्रोसेस्ड बैच।

में स्रोत कोड ultralytics/models/rtdetr/train.py
def preprocess_batch(self, batch):
    """
    Preprocess a batch of images. Scales and converts the images to float format.

    Args:
        batch (dict): Dictionary containing a batch of images, bboxes, and labels.

    Returns:
        (dict): Preprocessed batch.
    """
    batch = super().preprocess_batch(batch)
    bs = len(batch["img"])
    batch_idx = batch["batch_idx"]
    gt_bbox, gt_class = [], []
    for i in range(bs):
        gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
        gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
    return batch





2023-11-12 बनाया गया, अपडेट किया गया 2023-11-25
लेखक: ग्लेन-जोचर (3)