def_log_images(path,prefix=""):"""Logs images at specified path with an optional prefix using DVCLive."""iflive:name=path.name# Group images by batch to enable sliders in UIifm:=re.search(r"_batch(\d+)",name):ni=m[1]new_stem=re.sub(r"_batch(\d+)","_batch",path.stem)name=(Path(new_stem)/ni).with_suffix(path.suffix)live.log_image(os.path.join(prefix,name),path)
def_log_plots(plots,prefix=""):"""Logs plot images for training progress if they have not been previously processed."""forname,paramsinplots.items():timestamp=params["timestamp"]if_processed_plots.get(name)!=timestamp:_log_images(name,prefix)_processed_plots[name]=timestamp
def_log_confusion_matrix(validator):"""Logs the confusion matrix for the given validator using DVCLive."""targets=[]preds=[]matrix=validator.confusion_matrix.matrixnames=list(validator.names.values())ifvalidator.confusion_matrix.task=="detect":names+=["background"]forti,predinenumerate(matrix.T.astype(int)):forpi,numinenumerate(pred):targets.extend([names[ti]]*num)preds.extend([names[pi]]*num)live.log_sklearn_plot("confusion_matrix",targets,preds,name="cf.json",normalized=True)
defon_pretrain_routine_start(trainer):"""Initializes DVCLive logger for training metadata during pre-training routine."""try:globallivelive=dvclive.Live(save_dvc_exp=True,cache_images=True)LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")exceptExceptionase:LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
defon_pretrain_routine_end(trainer):"""Logs plots related to the training process at the end of the pretraining routine."""_log_plots(trainer.plots,"train")
defon_train_epoch_start(trainer):"""Sets the global variable _training_epoch value to True at the start of training each epoch."""global_training_epoch_training_epoch=True
defon_fit_epoch_end(trainer):"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""global_training_epochifliveand_training_epoch:all_metrics={**trainer.label_loss_items(trainer.tloss,prefix="train"),**trainer.metrics,**trainer.lr}formetric,valueinall_metrics.items():live.log_metric(metric,value)iftrainer.epoch==0:fromultralytics.utils.torch_utilsimportmodel_info_for_loggersformetric,valueinmodel_info_for_loggers(trainer).items():live.log_metric(metric,value,plot=False)_log_plots(trainer.plots,"train")_log_plots(trainer.validator.plots,"val")live.next_step()_training_epoch=False
defon_train_end(trainer):"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""iflive:# At the end log the best metrics. It runs validator on the best model internally.all_metrics={**trainer.label_loss_items(trainer.tloss,prefix="train"),**trainer.metrics,**trainer.lr}formetric,valueinall_metrics.items():live.log_metric(metric,value,plot=False)_log_plots(trainer.plots,"val")_log_plots(trainer.validator.plots,"val")_log_confusion_matrix(trainer.validator)iftrainer.best.exists():live.log_artifact(trainer.best,copy=True,type="model")live.end()