Skip to content

SegmentationTrainer


Bases: v8.detect.DetectionTrainer

Source code in ultralytics/yolo/v8/segment/train.py
class SegmentationTrainer(v8.detect.DetectionTrainer):

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a SegmentationTrainer object with given arguments."""
        if overrides is None:
            overrides = {}
        overrides['task'] = 'segment'
        super().__init__(cfg, overrides, _callbacks)

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Return SegmentationModel initialized with specified config and weights."""
        model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)

        return model

    def get_validator(self):
        """Return an instance of SegmentationValidator for validation of YOLO model."""
        self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
        return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

    def plot_training_samples(self, batch, ni):
        """Creates a plot of training sample images with labels and box coordinates."""
        plot_images(batch['img'],
                    batch['batch_idx'],
                    batch['cls'].squeeze(-1),
                    batch['bboxes'],
                    batch['masks'],
                    paths=batch['im_file'],
                    fname=self.save_dir / f'train_batch{ni}.jpg',
                    on_plot=self.on_plot)

    def plot_metrics(self):
        """Plots training/val metrics."""
        plot_results(file=self.csv, segment=True, on_plot=self.on_plot)  # save results.png

__init__(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Initialize a SegmentationTrainer object with given arguments.

Source code in ultralytics/yolo/v8/segment/train.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """Initialize a SegmentationTrainer object with given arguments."""
    if overrides is None:
        overrides = {}
    overrides['task'] = 'segment'
    super().__init__(cfg, overrides, _callbacks)

get_model(cfg=None, weights=None, verbose=True)

Return SegmentationModel initialized with specified config and weights.

Source code in ultralytics/yolo/v8/segment/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Return SegmentationModel initialized with specified config and weights."""
    model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)

    return model

get_validator()

Return an instance of SegmentationValidator for validation of YOLO model.

Source code in ultralytics/yolo/v8/segment/train.py
def get_validator(self):
    """Return an instance of SegmentationValidator for validation of YOLO model."""
    self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
    return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

plot_metrics()

Plots training/val metrics.

Source code in ultralytics/yolo/v8/segment/train.py
def plot_metrics(self):
    """Plots training/val metrics."""
    plot_results(file=self.csv, segment=True, on_plot=self.on_plot)  # save results.png

plot_training_samples(batch, ni)

Creates a plot of training sample images with labels and box coordinates.

Source code in ultralytics/yolo/v8/segment/train.py
def plot_training_samples(self, batch, ni):
    """Creates a plot of training sample images with labels and box coordinates."""
    plot_images(batch['img'],
                batch['batch_idx'],
                batch['cls'].squeeze(-1),
                batch['bboxes'],
                batch['masks'],
                paths=batch['im_file'],
                fname=self.save_dir / f'train_batch{ni}.jpg',
                on_plot=self.on_plot)



train


Train a YOLO segmentation model based on passed arguments.

Source code in ultralytics/yolo/v8/segment/train.py
def train(cfg=DEFAULT_CFG, use_python=False):
    """Train a YOLO segmentation model based on passed arguments."""
    model = cfg.model or 'yolov8n-seg.pt'
    data = cfg.data or 'coco128-seg.yaml'  # or yolo.ClassificationDataset("mnist")
    device = cfg.device if cfg.device is not None else ''

    args = dict(model=model, data=data, device=device)
    if use_python:
        from ultralytics import YOLO
        YOLO(model).train(**args)
    else:
        trainer = SegmentationTrainer(overrides=args)
        trainer.train()




Created 2023-04-16, Updated 2023-05-30
Authors: Glenn Jocher (4)