class BaseTrainer:
"""
BaseTrainer
A base class for creating trainers.
Attributes:
args (SimpleNamespace): Configuration for the trainer.
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
validator (BaseValidator): Validator instance.
model (nn.Module): Model instance.
callbacks (defaultdict): Dictionary of callbacks.
save_dir (Path): Directory to save results.
wdir (Path): Directory to save weights.
last (Path): Path to the last checkpoint.
best (Path): Path to the best checkpoint.
save_period (int): Save checkpoint every x epochs (disabled if < 1).
batch_size (int): Batch size for training.
epochs (int): Number of epochs to train for.
start_epoch (int): Starting epoch for training.
device (torch.device): Device to use for training.
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
scaler (amp.GradScaler): Gradient scaler for AMP.
data (str): Path to data.
trainset (torch.utils.data.Dataset): Training dataset.
testset (torch.utils.data.Dataset): Testing dataset.
ema (nn.Module): EMA (Exponential Moving Average) of the model.
lf (nn.Module): Loss function.
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
best_fitness (float): The best fitness value achieved.
fitness (float): Current fitness value.
loss (float): Current loss value.
tloss (float): Total loss value.
loss_names (list): List of loss names.
csv (Path): Path to results CSV file.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the BaseTrainer class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.check_resume(overrides)
self.device = select_device(self.args.device, self.args.batch)
self.validator = None
self.model = None
self.metrics = None
self.plots = {}
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
self.save_dir = get_save_dir(self.args)
self.wdir = self.save_dir / 'weights' # weights dir
if RANK in (-1, 0):
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.args.save_dir = str(self.save_dir)
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.save_period = self.args.save_period
self.batch_size = self.args.batch
self.epochs = self.args.epochs
self.start_epoch = 0
if RANK == -1:
print_args(vars(self.args))
# Device
if self.device.type in ('cpu', 'mps'):
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataset
self.model = self.args.model
try:
if self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'):
self.data = check_det_dataset(self.args.data)
if 'yaml_file' in self.data:
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
self.trainset, self.testset = self.get_dataset(self.data)
self.ema = None
# Optimization utils init
self.lf = None
self.scheduler = None
# Epoch level metrics
self.best_fitness = None
self.fitness = None
self.loss = None
self.tloss = None
self.loss_names = ['Loss']
self.csv = self.save_dir / 'results.csv'
self.plot_idx = [0, 1, 2]
# Callbacks
self.callbacks = _callbacks or callbacks.get_default_callbacks()
if RANK in (-1, 0):
callbacks.add_integration_callbacks(self)
def add_callback(self, event: str, callback):
"""
Appends the given callback.
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
Overrides the existing callbacks with the given callback.
"""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
"""Run all existing callbacks associated with a particular event."""
for callback in self.callbacks.get(event, []):
callback(self)
def train(self):
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
world_size = len(self.args.device.split(','))
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
world_size = len(self.args.device)
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
world_size = 1 # default to device 0
else: # i.e. device='cpu' or 'mps'
world_size = 0
# Run subprocess if DDP training, else train normally
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
# Argument checks
if self.args.rect:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
self.args.rect = False
if self.args.batch == -1:
LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
"default 'batch=16'")
self.args.batch = 16
# Command
cmd, file = generate_ddp_command(world_size, self)
try:
LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
subprocess.run(cmd, check=True)
except Exception as e:
raise e
finally:
ddp_cleanup(self, str(file))
else:
self._do_train(world_size)
def _setup_ddp(self, world_size):
"""Initializes and sets the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK)
self.device = torch.device('cuda', RANK)
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
dist.init_process_group(
'nccl' if dist.is_nccl_available() else 'gloo',
timeout=timedelta(seconds=10800), # 3 hours
rank=RANK,
world_size=world_size)
def _setup_train(self, world_size):
"""
Builds dataloaders and optimizer on correct rank process.
"""
# Model
self.run_callbacks('on_pretrain_routine_start')
ckpt = self.setup_model()
self.model = self.model.to(self.device)
self.set_model_attributes()
# Freeze layers
freeze_list = self.args.freeze if isinstance(
self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
always_freeze_names = ['.dfl'] # always freeze these layers
freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
for k, v in self.model.named_parameters():
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
if any(x in k for x in freeze_layer_names):
LOGGER.info(f"Freezing layer '{k}'")
v.requires_grad = False
elif not v.requires_grad:
LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
'See ultralytics.engine.trainer for customization of frozen layers.')
v.requires_grad = True
# Check AMP
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
self.amp = torch.tensor(check_amp(self.model), device=self.device)
callbacks.default_callbacks = callbacks_backup # restore callbacks
if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean
self.scaler = amp.GradScaler(enabled=self.amp)
if world_size > 1:
self.model = DDP(self.model, device_ids=[RANK])
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
# Batch size
if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
# Dataloaders
batch_size = self.batch_size // max(world_size, 1)
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
if RANK in (-1, 0):
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
self.ema = ModelEMA(self.model)
if self.args.plots:
self.plot_training_labels()
# Optimizer
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
self.optimizer = self.build_optimizer(model=self.model,
name=self.args.optimizer,
lr=self.args.lr0,
momentum=self.args.momentum,
decay=weight_decay,
iterations=iterations)
# Scheduler
if self.args.cos_lr:
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
else:
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks('on_pretrain_routine_end')
def _do_train(self, world_size=1):
"""Train completed, evaluate and plot if specified by arguments."""
if world_size > 1:
self._setup_ddp(world_size)
self._setup_train(world_size)
self.epoch_time = None
self.epoch_time_start = time.time()
self.train_time_start = time.time()
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
last_opt_step = -1
self.run_callbacks('on_train_start')
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f'Starting training for {self.epochs} epochs...')
if self.args.close_mosaic:
base_idx = (self.epochs - self.args.close_mosaic) * nb
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
epoch = self.epochs # predefine for resume fully trained model edge cases
for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch
self.run_callbacks('on_train_epoch_start')
self.model.train()
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic):
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
self.train_loader.reset()
if RANK in (-1, 0):
LOGGER.info(self.progress_string())
pbar = TQDM(enumerate(self.train_loader), total=nb)
self.tloss = None
self.optimizer.zero_grad()
for i, batch in pbar:
self.run_callbacks('on_train_batch_start')
# Warmup
ni = i + nb * epoch
if ni <= nw:
xi = [0, nw] # x interp
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
for j, x in enumerate(self.optimizer.param_groups):
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward
with torch.cuda.amp.autocast(self.amp):
batch = self.preprocess_batch(batch)
self.loss, self.loss_items = self.model(batch)
if RANK != -1:
self.loss *= world_size
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
else self.loss_items
# Backward
self.scaler.scale(self.loss).backward()
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
self.optimizer_step()
last_opt_step = ni
# Log
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
if RANK in (-1, 0):
pbar.set_description(
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
self.run_callbacks('on_batch_end')
if self.args.plots and ni in self.plot_idx:
self.plot_training_samples(batch, ni)
self.run_callbacks('on_train_batch_end')
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
self.scheduler.step()
self.run_callbacks('on_train_epoch_end')
if RANK in (-1, 0):
# Validation
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
if self.args.val or final_epoch:
self.metrics, self.fitness = self.validate()
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
self.stop = self.stopper(epoch + 1, self.fitness)
# Save model
if self.args.save or (epoch + 1 == self.epochs):
self.save_model()
self.run_callbacks('on_model_save')
tnow = time.time()
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
self.run_callbacks('on_fit_epoch_end')
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
# Early Stopping
if RANK != -1: # if DDP training
broadcast_list = [self.stop if RANK == 0 else None]
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
if RANK != 0:
self.stop = broadcast_list[0]
if self.stop:
break # must break all DDP ranks
if RANK in (-1, 0):
# Do final val with best.pt
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
self.final_eval()
if self.args.plots:
self.plot_metrics()
self.run_callbacks('on_train_end')
torch.cuda.empty_cache()
self.run_callbacks('teardown')
def save_model(self):
"""Save model training checkpoints with additional metadata."""
import pandas as pd # scope for faster startup
metrics = {**self.metrics, **{'fitness': self.fitness}}
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient='list').items()}
ckpt = {
'epoch': self.epoch,
'best_fitness': self.best_fitness,
'model': deepcopy(de_parallel(self.model)).half(),
'ema': deepcopy(self.ema.ema).half(),
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'train_args': vars(self.args), # save as dict
'train_metrics': metrics,
'train_results': results,
'date': datetime.now().isoformat(),
'version': __version__}
# Save last and best
torch.save(ckpt, self.last)
if self.best_fitness == self.fitness:
torch.save(ckpt, self.best)
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
@staticmethod
def get_dataset(data):
"""
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
"""
return data['train'], data.get('val') or data.get('test')
def setup_model(self):
"""
load/create/download model for any task.
"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model, weights = self.model, None
ckpt = None
if str(model).endswith('.pt'):
weights, ckpt = attempt_load_one_weight(model)
cfg = ckpt['model'].yaml
else:
cfg = model
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
return ckpt
def optimizer_step(self):
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
self.scaler.unscale_(self.optimizer) # unscale gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.ema:
self.ema.update(self.model)
def preprocess_batch(self, batch):
"""
Allows custom preprocessing model inputs and ground truths depending on task type.
"""
return batch
def validate(self):
"""
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
"""
metrics = self.validator(self)
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
return metrics, fitness
def get_model(self, cfg=None, weights=None, verbose=True):
"""Get model and raise NotImplementedError for loading cfg files."""
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self):
"""Returns a NotImplementedError when the get_validator function is called."""
raise NotImplementedError('get_validator function not implemented in trainer')
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""
Returns dataloader derived from torch.data.Dataloader.
"""
raise NotImplementedError('get_dataloader function not implemented in trainer')
def build_dataset(self, img_path, mode='train', batch=None):
"""Build dataset"""
raise NotImplementedError('build_dataset function not implemented in trainer')
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor
"""
# Not needed for classification but necessary for segmentation & detection
return {'loss': loss_items} if loss_items is not None else ['loss']
def set_model_attributes(self):
"""
To set or update model parameters before training.
"""
self.model.names = self.data['names']
def build_targets(self, preds, targets):
"""Builds target tensors for training YOLO model."""
pass
def progress_string(self):
"""Returns a string describing training progress."""
return ''
# TODO: may need to put these following functions into callback
def plot_training_samples(self, batch, ni):
"""Plots training samples during YOLOv5 training."""
pass
def plot_training_labels(self):
"""Plots training labels for YOLO model."""
pass
def save_metrics(self, metrics):
"""Saves training metrics to a CSV file."""
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
with open(self.csv, 'a') as f:
f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n')
def plot_metrics(self):
"""Plot and display metrics visually."""
pass
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
path = Path(name)
self.plots[path] = {'data': data, 'timestamp': time.time()}
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO model."""
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
LOGGER.info(f'\nValidating {f}...')
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None)
self.run_callbacks('on_fit_epoch_end')
def check_resume(self, overrides):
"""Check if resume checkpoint exists and update arguments accordingly."""
resume = self.args.resume
if resume:
try:
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
last = Path(check_file(resume) if exists else get_latest_run())
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
ckpt_args = attempt_load_weights(last).args
if not Path(ckpt_args['data']).exists():
ckpt_args['data'] = self.args.data
resume = True
self.args = get_cfg(ckpt_args)
self.args.model = str(last) # reinstate model
for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
if k in overrides:
setattr(self.args, k, overrides[k])
except Exception as e:
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
"i.e. 'yolo train resume model=path/to/last.pt'") from e
self.resume = resume
def resume_training(self, ckpt):
"""Resume YOLO training from given epoch and best fitness."""
if ckpt is None:
return
best_fitness = 0.0
start_epoch = ckpt['epoch'] + 1
if ckpt['optimizer'] is not None:
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
best_fitness = ckpt['best_fitness']
if self.ema and ckpt.get('ema'):
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
self.ema.updates = ckpt['updates']
if self.resume:
assert start_epoch > 0, \
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
LOGGER.info(
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
if self.epochs < start_epoch:
LOGGER.info(
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
self.epochs += ckpt['epoch'] # finetune additional epochs
self.best_fitness = best_fitness
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
"""
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
momentum, weight decay, and number of iterations.
Args:
model (torch.nn.Module): The model for which to build an optimizer.
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
based on the number of iterations. Default: 'auto'.
lr (float, optional): The learning rate for the optimizer. Default: 0.001.
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
iterations (float, optional): The number of iterations, which determines the optimizer if
name is 'auto'. Default: 1e5.
Returns:
(torch.optim.Optimizer): The constructed optimizer.
"""
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
if name == 'auto':
LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
nc = getattr(model, 'nc', 10) # number of classes
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
fullname = f'{module_name}.{param_name}' if module_name else param_name
if 'bias' in fullname: # bias (no decay)
g[2].append(param)
elif isinstance(module, bn): # weight (no decay)
g[1].append(param)
else: # weight (with decay)
g[0].append(param)
if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
elif name == 'RMSProp':
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
elif name == 'SGD':
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
else:
raise NotImplementedError(
f"Optimizer '{name}' not found in list of available optimizers "
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
return optimizer