Skip to content

Reference for ultralytics/engine/trainer.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/trainer.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.engine.trainer.BaseTrainer

BaseTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

A base class for creating trainers.

Attributes:

Name Type Description
args SimpleNamespace

Configuration for the trainer.

validator BaseValidator

Validator instance.

model Module

Model instance.

callbacks defaultdict

Dictionary of callbacks.

save_dir Path

Directory to save results.

wdir Path

Directory to save weights.

last Path

Path to the last checkpoint.

best Path

Path to the best checkpoint.

save_period int

Save checkpoint every x epochs (disabled if < 1).

batch_size int

Batch size for training.

epochs int

Number of epochs to train for.

start_epoch int

Starting epoch for training.

device device

Device to use for training.

amp bool

Flag to enable AMP (Automatic Mixed Precision).

scaler GradScaler

Gradient scaler for AMP.

data str

Path to data.

ema Module

EMA (Exponential Moving Average) of the model.

resume bool

Resume training from a checkpoint.

lf Module

Loss function.

scheduler _LRScheduler

Learning rate scheduler.

best_fitness float

The best fitness value achieved.

fitness float

Current fitness value.

loss float

Current loss value.

tloss float

Total loss value.

loss_names list

List of loss names.

csv Path

Path to results CSV file.

metrics dict

Dictionary of metrics.

plots dict

Dictionary of plots.

Parameters:

Name Type Description Default
cfg str

Path to a configuration file. Defaults to DEFAULT_CFG.

DEFAULT_CFG
overrides dict

Configuration overrides. Defaults to None.

None
_callbacks list

List of callback functions. Defaults to None.

None
Source code in ultralytics/engine/trainer.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize the BaseTrainer class.

    Args:
        cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
        overrides (dict, optional): Configuration overrides. Defaults to None.
        _callbacks (list, optional): List of callback functions. Defaults to None.
    """
    self.args = get_cfg(cfg, overrides)
    self.check_resume(overrides)
    self.device = select_device(self.args.device, self.args.batch)
    # update "-1" devices so post-training val does not repeat search
    self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
    self.validator = None
    self.metrics = None
    self.plots = {}
    init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)

    # Dirs
    self.save_dir = get_save_dir(self.args)
    self.args.name = self.save_dir.name  # update name for loggers
    self.wdir = self.save_dir / "weights"  # weights dir
    if RANK in {-1, 0}:
        self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
        self.args.save_dir = str(self.save_dir)
        YAML.save(self.save_dir / "args.yaml", vars(self.args))  # save run args
    self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths
    self.save_period = self.args.save_period

    self.batch_size = self.args.batch
    self.epochs = self.args.epochs or 100  # in case users accidentally pass epochs=None with timed training
    self.start_epoch = 0
    if RANK == -1:
        print_args(vars(self.args))

    # Device
    if self.device.type in {"cpu", "mps"}:
        self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading

    # Model and Dataset
    self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolo11n -> yolo11n.pt
    with torch_distributed_zero_first(LOCAL_RANK):  # avoid auto-downloading dataset multiple times
        self.data = self.get_dataset()

    self.ema = None

    # Optimization utils init
    self.lf = None
    self.scheduler = None

    # Epoch level metrics
    self.best_fitness = None
    self.fitness = None
    self.loss = None
    self.tloss = None
    self.loss_names = ["Loss"]
    self.csv = self.save_dir / "results.csv"
    self.plot_idx = [0, 1, 2]

    # HUB
    self.hub_session = None

    # Callbacks
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    if RANK in {-1, 0}:
        callbacks.add_integration_callbacks(self)

add_callback

add_callback(event: str, callback)

Append the given callback to the event's callback list.

Source code in ultralytics/engine/trainer.py
162
163
164
def add_callback(self, event: str, callback):
    """Append the given callback to the event's callback list."""
    self.callbacks[event].append(callback)

auto_batch

auto_batch(max_num_obj=0)

Calculate optimal batch size based on model and device memory constraints.

Source code in ultralytics/engine/trainer.py
489
490
491
492
493
494
495
496
497
def auto_batch(self, max_num_obj=0):
    """Calculate optimal batch size based on model and device memory constraints."""
    return check_train_batch_size(
        model=self.model,
        imgsz=self.args.imgsz,
        amp=self.amp,
        batch=self.batch_size,
        max_num_obj=max_num_obj,
    )  # returns batch size

build_dataset

build_dataset(img_path, mode='train', batch=None)

Build dataset.

Source code in ultralytics/engine/trainer.py
657
658
659
def build_dataset(self, img_path, mode="train", batch=None):
    """Build dataset."""
    raise NotImplementedError("build_dataset function not implemented in trainer")

build_optimizer

build_optimizer(
    model, name="auto", lr=0.001, momentum=0.9, decay=1e-05, iterations=100000.0
)

Construct an optimizer for the given model.

Parameters:

Name Type Description Default
model Module

The model for which to build an optimizer.

required
name str

The name of the optimizer to use. If 'auto', the optimizer is selected based on the number of iterations. Default: 'auto'.

'auto'
lr float

The learning rate for the optimizer. Default: 0.001.

0.001
momentum float

The momentum factor for the optimizer. Default: 0.9.

0.9
decay float

The weight decay for the optimizer. Default: 1e-5.

1e-05
iterations float

The number of iterations, which determines the optimizer if name is 'auto'. Default: 1e5.

100000.0

Returns:

Type Description
Optimizer

The constructed optimizer.

Source code in ultralytics/engine/trainer.py
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
    """
    Construct an optimizer for the given model.

    Args:
        model (torch.nn.Module): The model for which to build an optimizer.
        name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
            based on the number of iterations. Default: 'auto'.
        lr (float, optional): The learning rate for the optimizer. Default: 0.001.
        momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
        decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
        iterations (float, optional): The number of iterations, which determines the optimizer if
            name is 'auto'. Default: 1e5.

    Returns:
        (torch.optim.Optimizer): The constructed optimizer.
    """
    g = [], [], []  # optimizer parameter groups
    bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()
    if name == "auto":
        LOGGER.info(
            f"{colorstr('optimizer:')} 'optimizer=auto' found, "
            f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
            f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
        )
        nc = self.data.get("nc", 10)  # number of classes
        lr_fit = round(0.002 * 5 / (4 + nc), 6)  # lr0 fit equation to 6 decimal places
        name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
        self.args.warmup_bias_lr = 0.0  # no higher than 0.01 for Adam

    for module_name, module in model.named_modules():
        for param_name, param in module.named_parameters(recurse=False):
            fullname = f"{module_name}.{param_name}" if module_name else param_name
            if "bias" in fullname:  # bias (no decay)
                g[2].append(param)
            elif isinstance(module, bn) or "logit_scale" in fullname:  # weight (no decay)
                # ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
                g[1].append(param)
            else:  # weight (with decay)
                g[0].append(param)

    optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
    name = {x.lower(): x for x in optimizers}.get(name.lower())
    if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
        optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
    elif name == "RMSProp":
        optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
    elif name == "SGD":
        optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
    else:
        raise NotImplementedError(
            f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
            "Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
        )

    optimizer.add_param_group({"params": g[0], "weight_decay": decay})  # add g0 with weight_decay
    optimizer.add_param_group({"params": g[1], "weight_decay": 0.0})  # add g1 (BatchNorm2d weights)
    LOGGER.info(
        f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
        f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
    )
    return optimizer

build_targets

build_targets(preds, targets)

Builds target tensors for training YOLO model.

Source code in ultralytics/engine/trainer.py
674
675
676
def build_targets(self, preds, targets):
    """Builds target tensors for training YOLO model."""
    pass

check_resume

check_resume(overrides)

Check if resume checkpoint exists and update arguments accordingly.

Source code in ultralytics/engine/trainer.py
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
def check_resume(self, overrides):
    """Check if resume checkpoint exists and update arguments accordingly."""
    resume = self.args.resume
    if resume:
        try:
            exists = isinstance(resume, (str, Path)) and Path(resume).exists()
            last = Path(check_file(resume) if exists else get_latest_run())

            # Check that resume data YAML exists, otherwise strip to force re-download of dataset
            ckpt_args = attempt_load_weights(last).args
            if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
                ckpt_args["data"] = self.args.data

            resume = True
            self.args = get_cfg(ckpt_args)
            self.args.model = self.args.resume = str(last)  # reinstate model
            for k in (
                "imgsz",
                "batch",
                "device",
                "close_mosaic",
            ):  # allow arg updates to reduce memory or update device on resume
                if k in overrides:
                    setattr(self.args, k, overrides[k])

        except Exception as e:
            raise FileNotFoundError(
                "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
                "i.e. 'yolo train resume model=path/to/last.pt'"
            ) from e
    self.resume = resume

final_eval

final_eval()

Perform final evaluation and validation for object detection YOLO model.

Source code in ultralytics/engine/trainer.py
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
def final_eval(self):
    """Perform final evaluation and validation for object detection YOLO model."""
    ckpt = {}
    for f in self.last, self.best:
        if f.exists():
            if f is self.last:
                ckpt = strip_optimizer(f)
            elif f is self.best:
                k = "train_results"  # update best.pt train_metrics from last.pt
                strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
                LOGGER.info(f"\nValidating {f}...")
                self.validator.args.plots = self.args.plots
                self.metrics = self.validator(model=f)
                self.metrics.pop("fitness", None)
                self.run_callbacks("on_fit_epoch_end")

get_dataloader

get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')

Returns dataloader derived from torch.data.Dataloader.

Source code in ultralytics/engine/trainer.py
653
654
655
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
    """Returns dataloader derived from torch.data.Dataloader."""
    raise NotImplementedError("get_dataloader function not implemented in trainer")

get_dataset

get_dataset()

Get train and validation datasets from data dictionary.

Returns:

Type Description
dict

A dictionary containing the training/validation/test dataset and category names.

Source code in ultralytics/engine/trainer.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
def get_dataset(self):
    """
    Get train and validation datasets from data dictionary.

    Returns:
        (dict): A dictionary containing the training/validation/test dataset and category names.
    """
    try:
        if self.args.task == "classify":
            data = check_cls_dataset(self.args.data)
        elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
            "detect",
            "segment",
            "pose",
            "obb",
        }:
            data = check_det_dataset(self.args.data)
            if "yaml_file" in data:
                self.args.data = data["yaml_file"]  # for validating 'yolo train data=url.zip' usage
    except Exception as e:
        raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
    if self.args.single_cls:
        LOGGER.info("Overriding class names with single class.")
        data["names"] = {0: "item"}
        data["nc"] = 1
    return data

get_model

get_model(cfg=None, weights=None, verbose=True)

Get model and raise NotImplementedError for loading cfg files.

Source code in ultralytics/engine/trainer.py
645
646
647
def get_model(self, cfg=None, weights=None, verbose=True):
    """Get model and raise NotImplementedError for loading cfg files."""
    raise NotImplementedError("This task trainer doesn't support loading cfg files")

get_validator

get_validator()

Returns a NotImplementedError when the get_validator function is called.

Source code in ultralytics/engine/trainer.py
649
650
651
def get_validator(self):
    """Returns a NotImplementedError when the get_validator function is called."""
    raise NotImplementedError("get_validator function not implemented in trainer")

label_loss_items

label_loss_items(loss_items=None, prefix='train')

Returns a loss dict with labelled training loss items tensor.

Note

This is not needed for classification but necessary for segmentation & detection

Source code in ultralytics/engine/trainer.py
661
662
663
664
665
666
667
668
def label_loss_items(self, loss_items=None, prefix="train"):
    """
    Returns a loss dict with labelled training loss items tensor.

    Note:
        This is not needed for classification but necessary for segmentation & detection
    """
    return {"loss": loss_items} if loss_items is not None else ["loss"]

on_plot

on_plot(name, data=None)

Registers plots (e.g. to be consumed in callbacks).

Source code in ultralytics/engine/trainer.py
704
705
706
707
def on_plot(self, name, data=None):
    """Registers plots (e.g. to be consumed in callbacks)."""
    path = Path(name)
    self.plots[path] = {"data": data, "timestamp": time.time()}

optimizer_step

optimizer_step()

Perform a single step of the training optimizer with gradient clipping and EMA update.

Source code in ultralytics/engine/trainer.py
618
619
620
621
622
623
624
625
626
def optimizer_step(self):
    """Perform a single step of the training optimizer with gradient clipping and EMA update."""
    self.scaler.unscale_(self.optimizer)  # unscale gradients
    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)  # clip gradients
    self.scaler.step(self.optimizer)
    self.scaler.update()
    self.optimizer.zero_grad()
    if self.ema:
        self.ema.update(self.model)

plot_metrics

plot_metrics()

Plot and display metrics visually.

Source code in ultralytics/engine/trainer.py
700
701
702
def plot_metrics(self):
    """Plot and display metrics visually."""
    pass

plot_training_labels

plot_training_labels()

Plots training labels for YOLO model.

Source code in ultralytics/engine/trainer.py
687
688
689
def plot_training_labels(self):
    """Plots training labels for YOLO model."""
    pass

plot_training_samples

plot_training_samples(batch, ni)

Plots training samples during YOLO training.

Source code in ultralytics/engine/trainer.py
683
684
685
def plot_training_samples(self, batch, ni):
    """Plots training samples during YOLO training."""
    pass

preprocess_batch

preprocess_batch(batch)

Allows custom preprocessing model inputs and ground truths depending on task type.

Source code in ultralytics/engine/trainer.py
628
629
630
def preprocess_batch(self, batch):
    """Allows custom preprocessing model inputs and ground truths depending on task type."""
    return batch

progress_string

progress_string()

Returns a string describing training progress.

Source code in ultralytics/engine/trainer.py
678
679
680
def progress_string(self):
    """Returns a string describing training progress."""
    return ""

read_results_csv

read_results_csv()

Read results.csv into a dictionary using pandas.

Source code in ultralytics/engine/trainer.py
522
523
524
525
526
def read_results_csv(self):
    """Read results.csv into a dictionary using pandas."""
    import pandas as pd  # scope for faster 'import ultralytics'

    return pd.read_csv(self.csv).to_dict(orient="list")

resume_training

resume_training(ckpt)

Resume YOLO training from given epoch and best fitness.

Source code in ultralytics/engine/trainer.py
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
def resume_training(self, ckpt):
    """Resume YOLO training from given epoch and best fitness."""
    if ckpt is None or not self.resume:
        return
    best_fitness = 0.0
    start_epoch = ckpt.get("epoch", -1) + 1
    if ckpt.get("optimizer", None) is not None:
        self.optimizer.load_state_dict(ckpt["optimizer"])  # optimizer
        best_fitness = ckpt["best_fitness"]
    if self.ema and ckpt.get("ema"):
        self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())  # EMA
        self.ema.updates = ckpt["updates"]
    assert start_epoch > 0, (
        f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
        f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
    )
    LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
    if self.epochs < start_epoch:
        LOGGER.info(
            f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
        )
        self.epochs += ckpt["epoch"]  # finetune additional epochs
    self.best_fitness = best_fitness
    self.start_epoch = start_epoch
    if start_epoch > (self.epochs - self.args.close_mosaic):
        self._close_dataloader_mosaic()

run_callbacks

run_callbacks(event: str)

Run all existing callbacks associated with a particular event.

Source code in ultralytics/engine/trainer.py
170
171
172
173
def run_callbacks(self, event: str):
    """Run all existing callbacks associated with a particular event."""
    for callback in self.callbacks.get(event, []):
        callback(self)

save_metrics

save_metrics(metrics)

Save training metrics to a CSV file.

Source code in ultralytics/engine/trainer.py
691
692
693
694
695
696
697
698
def save_metrics(self, metrics):
    """Save training metrics to a CSV file."""
    keys, vals = list(metrics.keys()), list(metrics.values())
    n = len(metrics) + 2  # number of cols
    s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n")  # header
    t = time.time() - self.train_time_start
    with open(self.csv, "a", encoding="utf-8") as f:
        f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")

save_model

save_model()

Save model training checkpoints with additional metadata.

Source code in ultralytics/engine/trainer.py
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def save_model(self):
    """Save model training checkpoints with additional metadata."""
    import io

    # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
    buffer = io.BytesIO()
    torch.save(
        {
            "epoch": self.epoch,
            "best_fitness": self.best_fitness,
            "model": None,  # resume and final checkpoints derive from EMA
            "ema": deepcopy(self.ema.ema).half(),
            "updates": self.ema.updates,
            "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
            "train_args": vars(self.args),  # save as dict
            "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
            "train_results": self.read_results_csv(),
            "date": datetime.now().isoformat(),
            "version": __version__,
            "license": "AGPL-3.0 (https://ultralytics.com/license)",
            "docs": "https://docs.ultralytics.com",
        },
        buffer,
    )
    serialized_ckpt = buffer.getvalue()  # get the serialized content to save

    # Save checkpoints
    self.last.write_bytes(serialized_ckpt)  # save last.pt
    if self.best_fitness == self.fitness:
        self.best.write_bytes(serialized_ckpt)  # save best.pt
    if (self.save_period > 0) and (self.epoch % self.save_period == 0):
        (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt)  # save epoch, i.e. 'epoch3.pt'

set_callback

set_callback(event: str, callback)

Override the existing callbacks with the given callback for the specified event.

Source code in ultralytics/engine/trainer.py
166
167
168
def set_callback(self, event: str, callback):
    """Override the existing callbacks with the given callback for the specified event."""
    self.callbacks[event] = [callback]

set_model_attributes

set_model_attributes()

Set or update model parameters before training.

Source code in ultralytics/engine/trainer.py
670
671
672
def set_model_attributes(self):
    """Set or update model parameters before training."""
    self.model.names = self.data["names"]

setup_model

setup_model()

Load, create, or download model for any task.

Returns:

Type Description
dict

Optional checkpoint to resume training from.

Source code in ultralytics/engine/trainer.py
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
def setup_model(self):
    """
    Load, create, or download model for any task.

    Returns:
        (dict): Optional checkpoint to resume training from.
    """
    if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
        return

    cfg, weights = self.model, None
    ckpt = None
    if str(self.model).endswith(".pt"):
        weights, ckpt = attempt_load_one_weight(self.model)
        cfg = weights.yaml
    elif isinstance(self.args.pretrained, (str, Path)):
        weights, _ = attempt_load_one_weight(self.args.pretrained)
    self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
    return ckpt

train

train()

Allow device='', device=None on Multi-GPU systems to default to device=0.

Source code in ultralytics/engine/trainer.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def train(self):
    """Allow device='', device=None on Multi-GPU systems to default to device=0."""
    if isinstance(self.args.device, str) and len(self.args.device):  # i.e. device='0' or device='0,1,2,3'
        world_size = len(self.args.device.split(","))
    elif isinstance(self.args.device, (tuple, list)):  # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
        world_size = len(self.args.device)
    elif self.args.device in {"cpu", "mps"}:  # i.e. device='cpu' or 'mps'
        world_size = 0
    elif torch.cuda.is_available():  # i.e. device=None or device='' or device=number
        world_size = 1  # default to device 0
    else:  # i.e. device=None or device=''
        world_size = 0

    # Run subprocess if DDP training, else train normally
    if world_size > 1 and "LOCAL_RANK" not in os.environ:
        # Argument checks
        if self.args.rect:
            LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
            self.args.rect = False
        if self.args.batch < 1.0:
            LOGGER.warning(
                "'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting default 'batch=16'"
            )
            self.args.batch = 16

        # Command
        cmd, file = generate_ddp_command(world_size, self)
        try:
            LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
            subprocess.run(cmd, check=True)
        except Exception as e:
            raise e
        finally:
            ddp_cleanup(self, str(file))

    else:
        self._do_train(world_size)

validate

validate()

Run validation on test set using self.validator.

Returns:

Type Description
tuple

A tuple containing metrics dictionary and fitness score.

Source code in ultralytics/engine/trainer.py
632
633
634
635
636
637
638
639
640
641
642
643
def validate(self):
    """
    Run validation on test set using self.validator.

    Returns:
        (tuple): A tuple containing metrics dictionary and fitness score.
    """
    metrics = self.validator(self)
    fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
    if not self.best_fitness or self.best_fitness < fitness:
        self.best_fitness = fitness
    return metrics, fitness





📅 Created 1 year ago ✏️ Updated 8 months ago