Skip to content

Reference for ultralytics/engine/validator.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/validator.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.engine.validator.BaseValidator

BaseValidator(dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None)

BaseValidator.

A base class for creating validators.

Attributes:

Name Type Description
args SimpleNamespace

Configuration for the validator.

dataloader DataLoader

Dataloader to use for validation.

pbar tqdm

Progress bar to update during validation.

model Module

Model to validate.

data dict

Data dictionary.

device device

Device to use for validation.

batch_i int

Current batch index.

training bool

Whether the model is in training mode.

names dict

Class names.

seen

Records the number of images seen so far during validation.

stats

Placeholder for statistics during validation.

confusion_matrix

Placeholder for a confusion matrix.

nc

Number of classes.

iouv

(torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.

jdict dict

Dictionary to store JSON validation results.

speed dict

Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch processing times in milliseconds.

save_dir Path

Directory to save results.

plots dict

Dictionary to store plots for visualization.

callbacks dict

Dictionary to store various callback functions.

Parameters:

Name Type Description Default
dataloader DataLoader

Dataloader to be used for validation.

None
save_dir Path

Directory to save results.

None
pbar tqdm

Progress bar for displaying progress.

None
args SimpleNamespace

Configuration for the validator.

None
_callbacks dict

Dictionary to store various callback functions.

None
Source code in ultralytics/engine/validator.py
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
    """
    Initializes a BaseValidator instance.

    Args:
        dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
        save_dir (Path, optional): Directory to save results.
        pbar (tqdm.tqdm): Progress bar for displaying progress.
        args (SimpleNamespace): Configuration for the validator.
        _callbacks (dict): Dictionary to store various callback functions.
    """
    self.args = get_cfg(overrides=args)
    self.dataloader = dataloader
    self.pbar = pbar
    self.stride = None
    self.data = None
    self.device = None
    self.batch_i = None
    self.training = True
    self.names = None
    self.seen = None
    self.stats = None
    self.confusion_matrix = None
    self.nc = None
    self.iouv = None
    self.jdict = None
    self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}

    self.save_dir = save_dir or get_save_dir(self.args)
    (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
    if self.args.conf is None:
        self.args.conf = 0.001  # default conf=0.001
    self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)

    self.plots = {}
    self.callbacks = _callbacks or callbacks.get_default_callbacks()

metric_keys property

metric_keys

Returns the metric keys used in YOLO training/validation.

__call__

__call__(trainer=None, model=None)

Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer gets priority).

Source code in ultralytics/engine/validator.py
@smart_inference_mode()
def __call__(self, trainer=None, model=None):
    """Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
    gets priority).
    """
    self.training = trainer is not None
    augment = self.args.augment and (not self.training)
    if self.training:
        self.device = trainer.device
        self.data = trainer.data
        self.args.half = self.device.type != "cpu"  # force FP16 val during training
        model = trainer.ema.ema or trainer.model
        model = model.half() if self.args.half else model.float()
        # self.model = model
        self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
        self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
        model.eval()
    else:
        callbacks.add_integration_callbacks(self)
        model = AutoBackend(
            weights=model or self.args.model,
            device=select_device(self.args.device, self.args.batch),
            dnn=self.args.dnn,
            data=self.args.data,
            fp16=self.args.half,
        )
        # self.model = model
        self.device = model.device  # update device
        self.args.half = model.fp16  # update half
        stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
        imgsz = check_imgsz(self.args.imgsz, stride=stride)
        if engine:
            self.args.batch = model.batch_size
        elif not pt and not jit:
            self.args.batch = 1  # export.py models default to batch-size 1
            LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")

        if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
            self.data = check_det_dataset(self.args.data)
        elif self.args.task == "classify":
            self.data = check_cls_dataset(self.args.data, split=self.args.split)
        else:
            raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))

        if self.device.type in {"cpu", "mps"}:
            self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading
        if not pt:
            self.args.rect = False
        self.stride = model.stride  # used in get_dataloader() for padding
        self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)

        model.eval()
        model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))  # warmup

    self.run_callbacks("on_val_start")
    dt = (
        Profile(device=self.device),
        Profile(device=self.device),
        Profile(device=self.device),
        Profile(device=self.device),
    )
    bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
    self.init_metrics(de_parallel(model))
    self.jdict = []  # empty before each val
    for batch_i, batch in enumerate(bar):
        self.run_callbacks("on_val_batch_start")
        self.batch_i = batch_i
        # Preprocess
        with dt[0]:
            batch = self.preprocess(batch)

        # Inference
        with dt[1]:
            preds = model(batch["img"], augment=augment)

        # Loss
        with dt[2]:
            if self.training:
                self.loss += model.loss(batch, preds)[1]

        # Postprocess
        with dt[3]:
            preds = self.postprocess(preds)

        self.update_metrics(preds, batch)
        if self.args.plots and batch_i < 3:
            self.plot_val_samples(batch, batch_i)
            self.plot_predictions(batch, preds, batch_i)

        self.run_callbacks("on_val_batch_end")
    stats = self.get_stats()
    self.check_stats(stats)
    self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
    self.finalize_metrics()
    self.print_results()
    self.run_callbacks("on_val_end")
    if self.training:
        model.float()
        results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
        return {k: round(float(v), 5) for k, v in results.items()}  # return results as 5 decimal place floats
    else:
        LOGGER.info(
            "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
            % tuple(self.speed.values())
        )
        if self.args.save_json and self.jdict:
            with open(str(self.save_dir / "predictions.json"), "w") as f:
                LOGGER.info(f"Saving {f.name}...")
                json.dump(self.jdict, f)  # flatten and save
            stats = self.eval_json(stats)  # update stats
        if self.args.plots or self.args.save_json:
            LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
        return stats

add_callback

add_callback(event: str, callback)

Appends the given callback.

Source code in ultralytics/engine/validator.py
def add_callback(self, event: str, callback):
    """Appends the given callback."""
    self.callbacks[event].append(callback)

build_dataset

build_dataset(img_path)

Build dataset.

Source code in ultralytics/engine/validator.py
def build_dataset(self, img_path):
    """Build dataset."""
    raise NotImplementedError("build_dataset function not implemented in validator")

check_stats

check_stats(stats)

Checks statistics.

Source code in ultralytics/engine/validator.py
def check_stats(self, stats):
    """Checks statistics."""
    pass

eval_json

eval_json(stats)

Evaluate and return JSON format of prediction statistics.

Source code in ultralytics/engine/validator.py
def eval_json(self, stats):
    """Evaluate and return JSON format of prediction statistics."""
    pass

finalize_metrics

finalize_metrics(*args, **kwargs)

Finalizes and returns all metrics.

Source code in ultralytics/engine/validator.py
def finalize_metrics(self, *args, **kwargs):
    """Finalizes and returns all metrics."""
    pass

get_dataloader

get_dataloader(dataset_path, batch_size)

Get data loader from dataset path and batch size.

Source code in ultralytics/engine/validator.py
def get_dataloader(self, dataset_path, batch_size):
    """Get data loader from dataset path and batch size."""
    raise NotImplementedError("get_dataloader function not implemented for this validator")

get_desc

get_desc()

Get description of the YOLO model.

Source code in ultralytics/engine/validator.py
def get_desc(self):
    """Get description of the YOLO model."""
    pass

get_stats

get_stats()

Returns statistics about the model's performance.

Source code in ultralytics/engine/validator.py
def get_stats(self):
    """Returns statistics about the model's performance."""
    return {}

init_metrics

init_metrics(model)

Initialize performance metrics for the YOLO model.

Source code in ultralytics/engine/validator.py
def init_metrics(self, model):
    """Initialize performance metrics for the YOLO model."""
    pass

match_predictions

match_predictions(pred_classes, true_classes, iou, use_scipy=False)

Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.

Parameters:

Name Type Description Default
pred_classes Tensor

Predicted class indices of shape(N,).

required
true_classes Tensor

Target class indices of shape(M,).

required
iou Tensor

An NxM tensor containing the pairwise IoU values for predictions and ground of truth

required
use_scipy bool

Whether to use scipy for matching (more precise).

False

Returns:

Type Description
Tensor

Correct tensor of shape(N,10) for 10 IoU thresholds.

Source code in ultralytics/engine/validator.py
def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
    """
    Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.

    Args:
        pred_classes (torch.Tensor): Predicted class indices of shape(N,).
        true_classes (torch.Tensor): Target class indices of shape(M,).
        iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
        use_scipy (bool): Whether to use scipy for matching (more precise).

    Returns:
        (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
    """
    # Dx10 matrix, where D - detections, 10 - IoU thresholds
    correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
    # LxD matrix where L - labels (rows), D - detections (columns)
    correct_class = true_classes[:, None] == pred_classes
    iou = iou * correct_class  # zero out the wrong classes
    iou = iou.cpu().numpy()
    for i, threshold in enumerate(self.iouv.cpu().tolist()):
        if use_scipy:
            # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
            import scipy  # scope import to avoid importing for all commands

            cost_matrix = iou * (iou >= threshold)
            if cost_matrix.any():
                labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
                valid = cost_matrix[labels_idx, detections_idx] > 0
                if valid.any():
                    correct[detections_idx[valid], i] = True
        else:
            matches = np.nonzero(iou >= threshold)  # IoU > threshold and classes match
            matches = np.array(matches).T
            if matches.shape[0]:
                if matches.shape[0] > 1:
                    matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
                    matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                    # matches = matches[matches[:, 2].argsort()[::-1]]
                    matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
                correct[matches[:, 1].astype(int), i] = True
    return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)

on_plot

on_plot(name, data=None)

Registers plots (e.g. to be consumed in callbacks)

Source code in ultralytics/engine/validator.py
def on_plot(self, name, data=None):
    """Registers plots (e.g. to be consumed in callbacks)"""
    self.plots[Path(name)] = {"data": data, "timestamp": time.time()}

plot_predictions

plot_predictions(batch, preds, ni)

Plots YOLO model predictions on batch images.

Source code in ultralytics/engine/validator.py
def plot_predictions(self, batch, preds, ni):
    """Plots YOLO model predictions on batch images."""
    pass

plot_val_samples

plot_val_samples(batch, ni)

Plots validation samples during training.

Source code in ultralytics/engine/validator.py
def plot_val_samples(self, batch, ni):
    """Plots validation samples during training."""
    pass

postprocess

postprocess(preds)

Describes and summarizes the purpose of 'postprocess()' but no details mentioned.

Source code in ultralytics/engine/validator.py
def postprocess(self, preds):
    """Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
    return preds

pred_to_json

pred_to_json(preds, batch)

Convert predictions to JSON format.

Source code in ultralytics/engine/validator.py
def pred_to_json(self, preds, batch):
    """Convert predictions to JSON format."""
    pass

preprocess

preprocess(batch)

Preprocesses an input batch.

Source code in ultralytics/engine/validator.py
def preprocess(self, batch):
    """Preprocesses an input batch."""
    return batch

print_results

print_results()

Prints the results of the model's predictions.

Source code in ultralytics/engine/validator.py
def print_results(self):
    """Prints the results of the model's predictions."""
    pass

run_callbacks

run_callbacks(event: str)

Runs all callbacks associated with a specified event.

Source code in ultralytics/engine/validator.py
def run_callbacks(self, event: str):
    """Runs all callbacks associated with a specified event."""
    for callback in self.callbacks.get(event, []):
        callback(self)

update_metrics

update_metrics(preds, batch)

Updates metrics based on predictions and batch.

Source code in ultralytics/engine/validator.py
def update_metrics(self, preds, batch):
    """Updates metrics based on predictions and batch."""
    pass





Created 2023-11-12, Updated 2024-07-21
Authors: glenn-jocher (6), Burhan-Q (1)