Skip to content

DetectionTrainer


Bases: BaseTrainer

Source code in ultralytics/yolo/v8/detect/train.py
class DetectionTrainer(BaseTrainer):

    def build_dataset(self, img_path, mode='train', batch=None):
        """Build YOLO Dataset

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
        """
        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)

    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
        """TODO: manage splits differently."""
        # Calculate stride - check if model is initialized
        if self.args.v5loader:
            LOGGER.warning("WARNING ⚠️ 'v5loader' feature is deprecated and will be removed soon. You can train using "
                           'the default YOLOv8 dataloader instead, no argument is needed.')
            gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
            return create_dataloader(path=dataset_path,
                                     imgsz=self.args.imgsz,
                                     batch_size=batch_size,
                                     stride=gs,
                                     hyp=vars(self.args),
                                     augment=mode == 'train',
                                     cache=self.args.cache,
                                     pad=0 if mode == 'train' else 0.5,
                                     rect=self.args.rect or mode == 'val',
                                     rank=rank,
                                     workers=self.args.workers,
                                     close_mosaic=self.args.close_mosaic != 0,
                                     prefix=colorstr(f'{mode}: '),
                                     shuffle=mode == 'train',
                                     seed=self.args.seed)[0]
        assert mode in ['train', 'val']
        with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
            dataset = self.build_dataset(dataset_path, mode, batch_size)
        shuffle = mode == 'train'
        if getattr(dataset, 'rect', False) and shuffle:
            LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
            shuffle = False
        workers = self.args.workers if mode == 'train' else self.args.workers * 2
        return build_dataloader(dataset, batch_size, workers, shuffle, rank)  # return dataloader

    def preprocess_batch(self, batch):
        """Preprocesses a batch of images by scaling and converting to float."""
        batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
        return batch

    def set_model_attributes(self):
        """nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps)."""
        # self.args.box *= 3 / nl  # scale to layers
        # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
        # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
        self.model.nc = self.data['nc']  # attach number of classes to model
        self.model.names = self.data['names']  # attach class names to model
        self.model.args = self.args  # attach hyperparameters to model
        # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Return a YOLO detection model."""
        model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model

    def get_validator(self):
        """Returns a DetectionValidator for YOLO model validation."""
        self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
        return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

    def label_loss_items(self, loss_items=None, prefix='train'):
        """
        Returns a loss dict with labelled training loss items tensor
        """
        # Not needed for classification but necessary for segmentation & detection
        keys = [f'{prefix}/{x}' for x in self.loss_names]
        if loss_items is not None:
            loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats
            return dict(zip(keys, loss_items))
        else:
            return keys

    def progress_string(self):
        """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
        return ('\n' + '%11s' *
                (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')

    def plot_training_samples(self, batch, ni):
        """Plots training samples with their annotations."""
        plot_images(images=batch['img'],
                    batch_idx=batch['batch_idx'],
                    cls=batch['cls'].squeeze(-1),
                    bboxes=batch['bboxes'],
                    paths=batch['im_file'],
                    fname=self.save_dir / f'train_batch{ni}.jpg',
                    on_plot=self.on_plot)

    def plot_metrics(self):
        """Plots metrics from a CSV file."""
        plot_results(file=self.csv, on_plot=self.on_plot)  # save results.png

    def plot_training_labels(self):
        """Create a labeled training plot of the YOLO model."""
        boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
        cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
        plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)

build_dataset(img_path, mode='train', batch=None)

Build YOLO Dataset

Parameters:

Name Type Description Default
img_path str

Path to the folder containing images.

required
mode str

train mode or val mode, users are able to customize different augmentations for each mode.

'train'
batch int

Size of batches, this is for rect. Defaults to None.

None
Source code in ultralytics/yolo/v8/detect/train.py
def build_dataset(self, img_path, mode='train', batch=None):
    """Build YOLO Dataset

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
    """
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
    return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)

get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')

TODO: manage splits differently.

Source code in ultralytics/yolo/v8/detect/train.py
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
    """TODO: manage splits differently."""
    # Calculate stride - check if model is initialized
    if self.args.v5loader:
        LOGGER.warning("WARNING ⚠️ 'v5loader' feature is deprecated and will be removed soon. You can train using "
                       'the default YOLOv8 dataloader instead, no argument is needed.')
        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
        return create_dataloader(path=dataset_path,
                                 imgsz=self.args.imgsz,
                                 batch_size=batch_size,
                                 stride=gs,
                                 hyp=vars(self.args),
                                 augment=mode == 'train',
                                 cache=self.args.cache,
                                 pad=0 if mode == 'train' else 0.5,
                                 rect=self.args.rect or mode == 'val',
                                 rank=rank,
                                 workers=self.args.workers,
                                 close_mosaic=self.args.close_mosaic != 0,
                                 prefix=colorstr(f'{mode}: '),
                                 shuffle=mode == 'train',
                                 seed=self.args.seed)[0]
    assert mode in ['train', 'val']
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = self.build_dataset(dataset_path, mode, batch_size)
    shuffle = mode == 'train'
    if getattr(dataset, 'rect', False) and shuffle:
        LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
        shuffle = False
    workers = self.args.workers if mode == 'train' else self.args.workers * 2
    return build_dataloader(dataset, batch_size, workers, shuffle, rank)  # return dataloader

get_model(cfg=None, weights=None, verbose=True)

Return a YOLO detection model.

Source code in ultralytics/yolo/v8/detect/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Return a YOLO detection model."""
    model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator()

Returns a DetectionValidator for YOLO model validation.

Source code in ultralytics/yolo/v8/detect/train.py
def get_validator(self):
    """Returns a DetectionValidator for YOLO model validation."""
    self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
    return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

label_loss_items(loss_items=None, prefix='train')

Returns a loss dict with labelled training loss items tensor

Source code in ultralytics/yolo/v8/detect/train.py
def label_loss_items(self, loss_items=None, prefix='train'):
    """
    Returns a loss dict with labelled training loss items tensor
    """
    # Not needed for classification but necessary for segmentation & detection
    keys = [f'{prefix}/{x}' for x in self.loss_names]
    if loss_items is not None:
        loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats
        return dict(zip(keys, loss_items))
    else:
        return keys

plot_metrics()

Plots metrics from a CSV file.

Source code in ultralytics/yolo/v8/detect/train.py
def plot_metrics(self):
    """Plots metrics from a CSV file."""
    plot_results(file=self.csv, on_plot=self.on_plot)  # save results.png

plot_training_labels()

Create a labeled training plot of the YOLO model.

Source code in ultralytics/yolo/v8/detect/train.py
def plot_training_labels(self):
    """Create a labeled training plot of the YOLO model."""
    boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
    cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
    plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)

plot_training_samples(batch, ni)

Plots training samples with their annotations.

Source code in ultralytics/yolo/v8/detect/train.py
def plot_training_samples(self, batch, ni):
    """Plots training samples with their annotations."""
    plot_images(images=batch['img'],
                batch_idx=batch['batch_idx'],
                cls=batch['cls'].squeeze(-1),
                bboxes=batch['bboxes'],
                paths=batch['im_file'],
                fname=self.save_dir / f'train_batch{ni}.jpg',
                on_plot=self.on_plot)

preprocess_batch(batch)

Preprocesses a batch of images by scaling and converting to float.

Source code in ultralytics/yolo/v8/detect/train.py
def preprocess_batch(self, batch):
    """Preprocesses a batch of images by scaling and converting to float."""
    batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
    return batch

progress_string()

Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size.

Source code in ultralytics/yolo/v8/detect/train.py
def progress_string(self):
    """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
    return ('\n' + '%11s' *
            (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')

set_model_attributes()

nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps).

Source code in ultralytics/yolo/v8/detect/train.py
def set_model_attributes(self):
    """nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps)."""
    # self.args.box *= 3 / nl  # scale to layers
    # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
    # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
    self.model.nc = self.data['nc']  # attach number of classes to model
    self.model.names = self.data['names']  # attach class names to model
    self.model.args = self.args  # attach hyperparameters to model



train


Train and optimize YOLO model given training data and device.

Source code in ultralytics/yolo/v8/detect/train.py
def train(cfg=DEFAULT_CFG, use_python=False):
    """Train and optimize YOLO model given training data and device."""
    model = cfg.model or 'yolov8n.pt'
    data = cfg.data or 'coco128.yaml'  # or yolo.ClassificationDataset("mnist")
    device = cfg.device if cfg.device is not None else ''

    args = dict(model=model, data=data, device=device)
    if use_python:
        from ultralytics import YOLO
        YOLO(model).train(**args)
    else:
        trainer = DetectionTrainer(overrides=args)
        trainer.train()




Created 2023-04-16, Updated 2023-05-30
Authors: Glenn Jocher (4)