Skip to content

Reference for ultralytics/models/yolo/detect/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/detect/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.detect.train.DetectionTrainer

DetectionTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: BaseTrainer

A class extending the BaseTrainer class for training based on a detection model.

Example
from ultralytics.models.yolo.detect import DetectionTrainer

args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
trainer = DetectionTrainer(overrides=args)
trainer.train()

Parameters:

NameTypeDescriptionDefault
cfgstr

Path to a configuration file. Defaults to DEFAULT_CFG.

DEFAULT_CFG
overridesdict

Configuration overrides. Defaults to None.

None
Source code in ultralytics/engine/trainer.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initializes the BaseTrainer class.

    Args:
        cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
        overrides (dict, optional): Configuration overrides. Defaults to None.
    """
    self.args = get_cfg(cfg, overrides)
    self.check_resume(overrides)
    self.device = select_device(self.args.device, self.args.batch)
    self.validator = None
    self.metrics = None
    self.plots = {}
    init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)

    # Dirs
    self.save_dir = get_save_dir(self.args)
    self.args.name = self.save_dir.name  # update name for loggers
    self.wdir = self.save_dir / "weights"  # weights dir
    if RANK in {-1, 0}:
        self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
        self.args.save_dir = str(self.save_dir)
        yaml_save(self.save_dir / "args.yaml", vars(self.args))  # save run args
    self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths
    self.save_period = self.args.save_period

    self.batch_size = self.args.batch
    self.epochs = self.args.epochs
    self.start_epoch = 0
    if RANK == -1:
        print_args(vars(self.args))

    # Device
    if self.device.type in {"cpu", "mps"}:
        self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading

    # Model and Dataset
    self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolov8n -> yolov8n.pt
    with torch_distributed_zero_first(LOCAL_RANK):  # avoid auto-downloading dataset multiple times
        self.trainset, self.testset = self.get_dataset()
    self.ema = None

    # Optimization utils init
    self.lf = None
    self.scheduler = None

    # Epoch level metrics
    self.best_fitness = None
    self.fitness = None
    self.loss = None
    self.tloss = None
    self.loss_names = ["Loss"]
    self.csv = self.save_dir / "results.csv"
    self.plot_idx = [0, 1, 2]

    # HUB
    self.hub_session = None

    # Callbacks
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    if RANK in {-1, 0}:
        callbacks.add_integration_callbacks(self)

build_dataset

build_dataset(img_path, mode='train', batch=None)

Build YOLO Dataset.

Parameters:

NameTypeDescriptionDefault
img_pathstr

Path to the folder containing images.

required
modestr

train mode or val mode, users are able to customize different augmentations for each mode.

'train'
batchint

Size of batches, this is for rect. Defaults to None.

None
Source code in ultralytics/models/yolo/detect/train.py
def build_dataset(self, img_path, mode="train", batch=None):
    """
    Build YOLO Dataset.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
    """
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
    return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)

get_dataloader

get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')

Construct and return dataloader.

Source code in ultralytics/models/yolo/detect/train.py
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
    """Construct and return dataloader."""
    assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = self.build_dataset(dataset_path, mode, batch_size)
    shuffle = mode == "train"
    if getattr(dataset, "rect", False) and shuffle:
        LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
        shuffle = False
    workers = self.args.workers if mode == "train" else self.args.workers * 2
    return build_dataloader(dataset, batch_size, workers, shuffle, rank)  # return dataloader

get_model

get_model(cfg=None, weights=None, verbose=True)

Return a YOLO detection model.

Source code in ultralytics/models/yolo/detect/train.py
def get_model(self, cfg=None, weights=None, verbose=True):
    """Return a YOLO detection model."""
    model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator

get_validator()

Returns a DetectionValidator for YOLO model validation.

Source code in ultralytics/models/yolo/detect/train.py
def get_validator(self):
    """Returns a DetectionValidator for YOLO model validation."""
    self.loss_names = "box_loss", "cls_loss", "dfl_loss"
    return yolo.detect.DetectionValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

label_loss_items

label_loss_items(loss_items=None, prefix='train')

Returns a loss dict with labelled training loss items tensor.

Not needed for classification but necessary for segmentation & detection

Source code in ultralytics/models/yolo/detect/train.py
def label_loss_items(self, loss_items=None, prefix="train"):
    """
    Returns a loss dict with labelled training loss items tensor.

    Not needed for classification but necessary for segmentation & detection
    """
    keys = [f"{prefix}/{x}" for x in self.loss_names]
    if loss_items is not None:
        loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats
        return dict(zip(keys, loss_items))
    else:
        return keys

plot_metrics

plot_metrics()

Plots metrics from a CSV file.

Source code in ultralytics/models/yolo/detect/train.py
def plot_metrics(self):
    """Plots metrics from a CSV file."""
    plot_results(file=self.csv, on_plot=self.on_plot)  # save results.png

plot_training_labels

plot_training_labels()

Create a labeled training plot of the YOLO model.

Source code in ultralytics/models/yolo/detect/train.py
def plot_training_labels(self):
    """Create a labeled training plot of the YOLO model."""
    boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
    cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
    plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)

plot_training_samples

plot_training_samples(batch, ni)

Plots training samples with their annotations.

Source code in ultralytics/models/yolo/detect/train.py
def plot_training_samples(self, batch, ni):
    """Plots training samples with their annotations."""
    plot_images(
        images=batch["img"],
        batch_idx=batch["batch_idx"],
        cls=batch["cls"].squeeze(-1),
        bboxes=batch["bboxes"],
        paths=batch["im_file"],
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

preprocess_batch

preprocess_batch(batch)

Preprocesses a batch of images by scaling and converting to float.

Source code in ultralytics/models/yolo/detect/train.py
def preprocess_batch(self, batch):
    """Preprocesses a batch of images by scaling and converting to float."""
    batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
    if self.args.multi_scale:
        imgs = batch["img"]
        sz = (
            random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
            // self.stride
            * self.stride
        )  # size
        sf = sz / max(imgs.shape[2:])  # scale factor
        if sf != 1:
            ns = [
                math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
            ]  # new shape (stretched to gs-multiple)
            imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
        batch["img"] = imgs
    return batch

progress_string

progress_string()

Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size.

Source code in ultralytics/models/yolo/detect/train.py
def progress_string(self):
    """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
    return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
        "Epoch",
        "GPU_mem",
        *self.loss_names,
        "Instances",
        "Size",
    )

set_model_attributes

set_model_attributes()

Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps).

Source code in ultralytics/models/yolo/detect/train.py
def set_model_attributes(self):
    """Nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps)."""
    # self.args.box *= 3 / nl  # scale to layers
    # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
    # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
    self.model.nc = self.data["nc"]  # attach number of classes to model
    self.model.names = self.data["names"]  # attach class names to model
    self.model.args = self.args  # attach hyperparameters to model



📅 Created 11 months ago ✏️ Updated 1 month ago