Skip to content

Reference for ultralytics/models/yolo/detect/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/detect/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.detect.train.DetectionTrainer

DetectionTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: BaseTrainer

A class extending the BaseTrainer class for training based on a detection model.

This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for object detection.

Attributes:

Name Type Description
model DetectionModel

The YOLO detection model being trained.

data dict

Dictionary containing dataset information including class names and number of classes.

loss_names Tuple[str]

Names of the loss components used in training (box_loss, cls_loss, dfl_loss).

Methods:

Name Description
build_dataset

Build YOLO dataset for training or validation.

get_dataloader

Construct and return dataloader for the specified mode.

preprocess_batch

Preprocess a batch of images by scaling and converting to float.

set_model_attributes

Set model attributes based on dataset information.

get_model

Return a YOLO detection model.

get_validator

Return a validator for model evaluation.

label_loss_items

Return a loss dictionary with labeled training loss items.

progress_string

Return a formatted string of training progress.

plot_training_samples

Plot training samples with their annotations.

plot_metrics

Plot metrics from a CSV file.

plot_training_labels

Create a labeled training plot of the YOLO model.

auto_batch

Calculate optimal batch size based on model memory requirements.

Examples:

>>> from ultralytics.models.yolo.detect import DetectionTrainer
>>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
>>> trainer = DetectionTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/engine/trainer.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize the BaseTrainer class.

    Args:
        cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
        overrides (dict, optional): Configuration overrides. Defaults to None.
        _callbacks (list, optional): List of callback functions. Defaults to None.
    """
    self.args = get_cfg(cfg, overrides)
    self.check_resume(overrides)
    self.device = select_device(self.args.device, self.args.batch)
    self.validator = None
    self.metrics = None
    self.plots = {}
    init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)

    # Dirs
    self.save_dir = get_save_dir(self.args)
    self.args.name = self.save_dir.name  # update name for loggers
    self.wdir = self.save_dir / "weights"  # weights dir
    if RANK in {-1, 0}:
        self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
        self.args.save_dir = str(self.save_dir)
        yaml_save(self.save_dir / "args.yaml", vars(self.args))  # save run args
    self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths
    self.save_period = self.args.save_period

    self.batch_size = self.args.batch
    self.epochs = self.args.epochs or 100  # in case users accidentally pass epochs=None with timed training
    self.start_epoch = 0
    if RANK == -1:
        print_args(vars(self.args))

    # Device
    if self.device.type in {"cpu", "mps"}:
        self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading

    # Model and Dataset
    self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolo11n -> yolo11n.pt
    with torch_distributed_zero_first(LOCAL_RANK):  # avoid auto-downloading dataset multiple times
        self.trainset, self.testset = self.get_dataset()
    self.ema = None

    # Optimization utils init
    self.lf = None
    self.scheduler = None

    # Epoch level metrics
    self.best_fitness = None
    self.fitness = None
    self.loss = None
    self.tloss = None
    self.loss_names = ["Loss"]
    self.csv = self.save_dir / "results.csv"
    self.plot_idx = [0, 1, 2]

    # HUB
    self.hub_session = None

    # Callbacks
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    if RANK in {-1, 0}:
        callbacks.add_integration_callbacks(self)

auto_batch

auto_batch()

Get optimal batch size by calculating memory occupation of model.

Returns:

Type Description
int

Optimal batch size.

Source code in ultralytics/models/yolo/detect/train.py
def auto_batch(self):
    """
    Get optimal batch size by calculating memory occupation of model.

    Returns:
        (int): Optimal batch size.
    """
    train_dataset = self.build_dataset(self.trainset, mode="train", batch=16)
    max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4  # 4 for mosaic augmentation
    return super().auto_batch(max_num_obj)

build_dataset

build_dataset(img_path, mode='train', batch=None)

Build YOLO Dataset for training or validation.

Parameters:

Name Type Description Default
img_path str

Path to the folder containing images.

required
mode str

train mode or val mode, users are able to customize different augmentations for each mode.

'train'
batch int

Size of batches, this is for rect.

None

Returns:

Type Description
Dataset

YOLO dataset object configured for the specified mode.

Source code in ultralytics/models/yolo/detect/train.py
def build_dataset(self, img_path, mode="train", batch=None):
    """
    Build YOLO Dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for `rect`.

    Returns:
        (Dataset): YOLO dataset object configured for the specified mode.
    """
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
    return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)

get_dataloader

get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')

Construct and return dataloader for the specified mode.

Parameters:

Name Type Description Default
dataset_path str

Path to the dataset.

required
batch_size int

Number of images per batch.

16
rank int

Process rank for distributed training.

0
mode str

'train' for training dataloader, 'val' for validation dataloader.

'train'

Returns:

Type Description
DataLoader

PyTorch dataloader object.

Source code in ultralytics/models/yolo/detect/train.py
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
    """
    Construct and return dataloader for the specified mode.

    Args:
        dataset_path (str): Path to the dataset.
        batch_size (int): Number of images per batch.
        rank (int): Process rank for distributed training.
        mode (str): 'train' for training dataloader, 'val' for validation dataloader.

    Returns:
        (DataLoader): PyTorch dataloader object.
    """
    assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = self.build_dataset(dataset_path, mode, batch_size)
    shuffle = mode == "train"
    if getattr(dataset, "rect", False) and shuffle:
        LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
        shuffle = False
    workers = self.args.workers if mode == "train" else self.args.workers * 2
    return build_dataloader(dataset, batch_size, workers, shuffle, rank)  # return dataloader

get_model

get_model(cfg=None, weights=None, verbose=True)

Return a YOLO detection model.

Parameters:

Name Type Description Default
cfg str

Path to model configuration file.

None
weights str

Path to model weights.

None
verbose bool

Whether to display model information.

True

Returns:

Type Description
DetectionModel

YOLO detection model.

Source code in ultralytics/models/yolo/detect/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """
    Return a YOLO detection model.

    Args:
        cfg (str, optional): Path to model configuration file.
        weights (str, optional): Path to model weights.
        verbose (bool): Whether to display model information.

    Returns:
        (DetectionModel): YOLO detection model.
    """
    model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator

get_validator()

Return a DetectionValidator for YOLO model validation.

Source code in ultralytics/models/yolo/detect/train.py
def get_validator(self):
    """Return a DetectionValidator for YOLO model validation."""
    self.loss_names = "box_loss", "cls_loss", "dfl_loss"
    return yolo.detect.DetectionValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

label_loss_items

label_loss_items(loss_items=None, prefix='train')

Return a loss dict with labeled training loss items tensor.

Parameters:

Name Type Description Default
loss_items List[float]

List of loss values.

None
prefix str

Prefix for keys in the returned dictionary.

'train'

Returns:

Type Description
Dict | List

Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.

Source code in ultralytics/models/yolo/detect/train.py
def label_loss_items(self, loss_items=None, prefix="train"):
    """
    Return a loss dict with labeled training loss items tensor.

    Args:
        loss_items (List[float], optional): List of loss values.
        prefix (str): Prefix for keys in the returned dictionary.

    Returns:
        (Dict | List): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
    """
    keys = [f"{prefix}/{x}" for x in self.loss_names]
    if loss_items is not None:
        loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats
        return dict(zip(keys, loss_items))
    else:
        return keys

plot_metrics

plot_metrics()

Plot metrics from a CSV file.

Source code in ultralytics/models/yolo/detect/train.py
def plot_metrics(self):
    """Plot metrics from a CSV file."""
    plot_results(file=self.csv, on_plot=self.on_plot)  # save results.png

plot_training_labels

plot_training_labels()

Create a labeled training plot of the YOLO model.

Source code in ultralytics/models/yolo/detect/train.py
def plot_training_labels(self):
    """Create a labeled training plot of the YOLO model."""
    boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
    cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
    plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)

plot_training_samples

plot_training_samples(batch, ni)

Plot training samples with their annotations.

Parameters:

Name Type Description Default
batch dict

Dictionary containing batch data.

required
ni int

Number of iterations.

required
Source code in ultralytics/models/yolo/detect/train.py
def plot_training_samples(self, batch, ni):
    """
    Plot training samples with their annotations.

    Args:
        batch (dict): Dictionary containing batch data.
        ni (int): Number of iterations.
    """
    plot_images(
        images=batch["img"],
        batch_idx=batch["batch_idx"],
        cls=batch["cls"].squeeze(-1),
        bboxes=batch["bboxes"],
        paths=batch["im_file"],
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

preprocess_batch

preprocess_batch(batch)

Preprocess a batch of images by scaling and converting to float.

Parameters:

Name Type Description Default
batch dict

Dictionary containing batch data with 'img' tensor.

required

Returns:

Type Description
dict

Preprocessed batch with normalized images.

Source code in ultralytics/models/yolo/detect/train.py
def preprocess_batch(self, batch):
    """
    Preprocess a batch of images by scaling and converting to float.

    Args:
        batch (dict): Dictionary containing batch data with 'img' tensor.

    Returns:
        (dict): Preprocessed batch with normalized images.
    """
    batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
    if self.args.multi_scale:
        imgs = batch["img"]
        sz = (
            random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
            // self.stride
            * self.stride
        )  # size
        sf = sz / max(imgs.shape[2:])  # scale factor
        if sf != 1:
            ns = [
                math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
            ]  # new shape (stretched to gs-multiple)
            imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
        batch["img"] = imgs
    return batch

progress_string

progress_string()

Return a formatted string of training progress with epoch, GPU memory, loss, instances and size.

Source code in ultralytics/models/yolo/detect/train.py
def progress_string(self):
    """Return a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
    return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
        "Epoch",
        "GPU_mem",
        *self.loss_names,
        "Instances",
        "Size",
    )

set_model_attributes

set_model_attributes()

Set model attributes based on dataset information.

Source code in ultralytics/models/yolo/detect/train.py
def set_model_attributes(self):
    """Set model attributes based on dataset information."""
    # Nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps)
    # self.args.box *= 3 / nl  # scale to layers
    # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
    # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
    self.model.nc = self.data["nc"]  # attach number of classes to model
    self.model.names = self.data["names"]  # attach class names to model
    self.model.args = self.args  # attach hyperparameters to model



📅 Created 1 year ago ✏️ Updated 6 months ago