Skip to content

Reference for ultralytics/engine/validator.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/validator.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.engine.validator.BaseValidator

BaseValidator(
    dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None
)

BaseValidator.

A base class for creating validators.

Attributes:

NameTypeDescription
argsSimpleNamespace

Configuration for the validator.

dataloaderDataLoader

Dataloader to use for validation.

pbartqdm

Progress bar to update during validation.

modelModule

Model to validate.

datadict

Data dictionary.

devicedevice

Device to use for validation.

batch_iint

Current batch index.

trainingbool

Whether the model is in training mode.

namesdict

Class names.

seen

Records the number of images seen so far during validation.

stats

Placeholder for statistics during validation.

confusion_matrix

Placeholder for a confusion matrix.

nc

Number of classes.

iouv

(torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.

jdictdict

Dictionary to store JSON validation results.

speeddict

Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch processing times in milliseconds.

save_dirPath

Directory to save results.

plotsdict

Dictionary to store plots for visualization.

callbacksdict

Dictionary to store various callback functions.

Parameters:

NameTypeDescriptionDefault
dataloaderDataLoader

Dataloader to be used for validation.

None
save_dirPath

Directory to save results.

None
pbartqdm

Progress bar for displaying progress.

None
argsSimpleNamespace

Configuration for the validator.

None
_callbacksdict

Dictionary to store various callback functions.

None
Source code in ultralytics/engine/validator.py
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
    """
    Initializes a BaseValidator instance.

    Args:
        dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
        save_dir (Path, optional): Directory to save results.
        pbar (tqdm.tqdm): Progress bar for displaying progress.
        args (SimpleNamespace): Configuration for the validator.
        _callbacks (dict): Dictionary to store various callback functions.
    """
    self.args = get_cfg(overrides=args)
    self.dataloader = dataloader
    self.pbar = pbar
    self.stride = None
    self.data = None
    self.device = None
    self.batch_i = None
    self.training = True
    self.names = None
    self.seen = None
    self.stats = None
    self.confusion_matrix = None
    self.nc = None
    self.iouv = None
    self.jdict = None
    self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}

    self.save_dir = save_dir or get_save_dir(self.args)
    (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
    if self.args.conf is None:
        self.args.conf = 0.001  # default conf=0.001
    self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)

    self.plots = {}
    self.callbacks = _callbacks or callbacks.get_default_callbacks()

metric_keys property

metric_keys

Returns the metric keys used in YOLO training/validation.

__call__

__call__(trainer=None, model=None)

Executes validation process, running inference on dataloader and computing performance metrics.

Source code in ultralytics/engine/validator.py
@smart_inference_mode()
def __call__(self, trainer=None, model=None):
    """Executes validation process, running inference on dataloader and computing performance metrics."""
    self.training = trainer is not None
    augment = self.args.augment and (not self.training)
    if self.training:
        self.device = trainer.device
        self.data = trainer.data
        # force FP16 val during training
        self.args.half = self.device.type != "cpu" and trainer.amp
        model = trainer.ema.ema or trainer.model
        model = model.half() if self.args.half else model.float()
        # self.model = model
        self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
        self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
        model.eval()
    else:
        if str(self.args.model).endswith(".yaml"):
            LOGGER.warning("WARNING ⚠️ validating an untrained model YAML will result in 0 mAP.")
        callbacks.add_integration_callbacks(self)
        model = AutoBackend(
            weights=model or self.args.model,
            device=select_device(self.args.device, self.args.batch),
            dnn=self.args.dnn,
            data=self.args.data,
            fp16=self.args.half,
        )
        # self.model = model
        self.device = model.device  # update device
        self.args.half = model.fp16  # update half
        stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
        imgsz = check_imgsz(self.args.imgsz, stride=stride)
        if engine:
            self.args.batch = model.batch_size
        elif not pt and not jit:
            self.args.batch = model.metadata.get("batch", 1)  # export.py models default to batch-size 1
            LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")

        if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
            self.data = check_det_dataset(self.args.data)
        elif self.args.task == "classify":
            self.data = check_cls_dataset(self.args.data, split=self.args.split)
        else:
            raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))

        if self.device.type in {"cpu", "mps"}:
            self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading
        if not pt:
            self.args.rect = False
        self.stride = model.stride  # used in get_dataloader() for padding
        self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)

        model.eval()
        model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))  # warmup

    self.run_callbacks("on_val_start")
    dt = (
        Profile(device=self.device),
        Profile(device=self.device),
        Profile(device=self.device),
        Profile(device=self.device),
    )
    bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
    self.init_metrics(de_parallel(model))
    self.jdict = []  # empty before each val
    for batch_i, batch in enumerate(bar):
        self.run_callbacks("on_val_batch_start")
        self.batch_i = batch_i
        # Preprocess
        with dt[0]:
            batch = self.preprocess(batch)

        # Inference
        with dt[1]:
            preds = model(batch["img"], augment=augment)

        # Loss
        with dt[2]:
            if self.training:
                self.loss += model.loss(batch, preds)[1]

        # Postprocess
        with dt[3]:
            preds = self.postprocess(preds)

        self.update_metrics(preds, batch)
        if self.args.plots and batch_i < 3:
            self.plot_val_samples(batch, batch_i)
            self.plot_predictions(batch, preds, batch_i)

        self.run_callbacks("on_val_batch_end")
    stats = self.get_stats()
    self.check_stats(stats)
    self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
    self.finalize_metrics()
    self.print_results()
    self.run_callbacks("on_val_end")
    if self.training:
        model.float()
        results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
        return {k: round(float(v), 5) for k, v in results.items()}  # return results as 5 decimal place floats
    else:
        LOGGER.info(
            "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
                *tuple(self.speed.values())
            )
        )
        if self.args.save_json and self.jdict:
            with open(str(self.save_dir / "predictions.json"), "w") as f:
                LOGGER.info(f"Saving {f.name}...")
                json.dump(self.jdict, f)  # flatten and save
            stats = self.eval_json(stats)  # update stats
        if self.args.plots or self.args.save_json:
            LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
        return stats

add_callback

add_callback(event: str, callback)

Appends the given callback.

Source code in ultralytics/engine/validator.py
def add_callback(self, event: str, callback):
    """Appends the given callback."""
    self.callbacks[event].append(callback)

build_dataset

build_dataset(img_path)

Build dataset.

Source code in ultralytics/engine/validator.py
def build_dataset(self, img_path):
    """Build dataset."""
    raise NotImplementedError("build_dataset function not implemented in validator")

check_stats

check_stats(stats)

Checks statistics.

Source code in ultralytics/engine/validator.py
def check_stats(self, stats):
    """Checks statistics."""
    pass

eval_json

eval_json(stats)

Evaluate and return JSON format of prediction statistics.

Source code in ultralytics/engine/validator.py
def eval_json(self, stats):
    """Evaluate and return JSON format of prediction statistics."""
    pass

finalize_metrics

finalize_metrics(*args, **kwargs)

Finalizes and returns all metrics.

Source code in ultralytics/engine/validator.py
def finalize_metrics(self, *args, **kwargs):
    """Finalizes and returns all metrics."""
    pass

get_dataloader

get_dataloader(dataset_path, batch_size)

Get data loader from dataset path and batch size.

Source code in ultralytics/engine/validator.py
def get_dataloader(self, dataset_path, batch_size):
    """Get data loader from dataset path and batch size."""
    raise NotImplementedError("get_dataloader function not implemented for this validator")

get_desc

get_desc()

Get description of the YOLO model.

Source code in ultralytics/engine/validator.py
def get_desc(self):
    """Get description of the YOLO model."""
    pass

get_stats

get_stats()

Returns statistics about the model's performance.

Source code in ultralytics/engine/validator.py
def get_stats(self):
    """Returns statistics about the model's performance."""
    return {}

init_metrics

init_metrics(model)

Initialize performance metrics for the YOLO model.

Source code in ultralytics/engine/validator.py
def init_metrics(self, model):
    """Initialize performance metrics for the YOLO model."""
    pass

match_predictions

match_predictions(pred_classes, true_classes, iou, use_scipy=False)

Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.

Parameters:

NameTypeDescriptionDefault
pred_classesTensor

Predicted class indices of shape(N,).

required
true_classesTensor

Target class indices of shape(M,).

required
iouTensor

An NxM tensor containing the pairwise IoU values for predictions and ground of truth

required
use_scipybool

Whether to use scipy for matching (more precise).

False

Returns:

TypeDescription
Tensor

Correct tensor of shape(N,10) for 10 IoU thresholds.

Source code in ultralytics/engine/validator.py
def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
    """
    Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.

    Args:
        pred_classes (torch.Tensor): Predicted class indices of shape(N,).
        true_classes (torch.Tensor): Target class indices of shape(M,).
        iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
        use_scipy (bool): Whether to use scipy for matching (more precise).

    Returns:
        (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
    """
    # Dx10 matrix, where D - detections, 10 - IoU thresholds
    correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
    # LxD matrix where L - labels (rows), D - detections (columns)
    correct_class = true_classes[:, None] == pred_classes
    iou = iou * correct_class  # zero out the wrong classes
    iou = iou.cpu().numpy()
    for i, threshold in enumerate(self.iouv.cpu().tolist()):
        if use_scipy:
            # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
            import scipy  # scope import to avoid importing for all commands

            cost_matrix = iou * (iou >= threshold)
            if cost_matrix.any():
                labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
                valid = cost_matrix[labels_idx, detections_idx] > 0
                if valid.any():
                    correct[detections_idx[valid], i] = True
        else:
            matches = np.nonzero(iou >= threshold)  # IoU > threshold and classes match
            matches = np.array(matches).T
            if matches.shape[0]:
                if matches.shape[0] > 1:
                    matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
                    matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                    # matches = matches[matches[:, 2].argsort()[::-1]]
                    matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
                correct[matches[:, 1].astype(int), i] = True
    return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)

on_plot

on_plot(name, data=None)

Registers plots (e.g. to be consumed in callbacks).

Source code in ultralytics/engine/validator.py
def on_plot(self, name, data=None):
    """Registers plots (e.g. to be consumed in callbacks)."""
    self.plots[Path(name)] = {"data": data, "timestamp": time.time()}

plot_predictions

plot_predictions(batch, preds, ni)

Plots YOLO model predictions on batch images.

Source code in ultralytics/engine/validator.py
def plot_predictions(self, batch, preds, ni):
    """Plots YOLO model predictions on batch images."""
    pass

plot_val_samples

plot_val_samples(batch, ni)

Plots validation samples during training.

Source code in ultralytics/engine/validator.py
def plot_val_samples(self, batch, ni):
    """Plots validation samples during training."""
    pass

postprocess

postprocess(preds)

Preprocesses the predictions.

Source code in ultralytics/engine/validator.py
def postprocess(self, preds):
    """Preprocesses the predictions."""
    return preds

pred_to_json

pred_to_json(preds, batch)

Convert predictions to JSON format.

Source code in ultralytics/engine/validator.py
def pred_to_json(self, preds, batch):
    """Convert predictions to JSON format."""
    pass

preprocess

preprocess(batch)

Preprocesses an input batch.

Source code in ultralytics/engine/validator.py
def preprocess(self, batch):
    """Preprocesses an input batch."""
    return batch

print_results

print_results()

Prints the results of the model's predictions.

Source code in ultralytics/engine/validator.py
def print_results(self):
    """Prints the results of the model's predictions."""
    pass

run_callbacks

run_callbacks(event: str)

Runs all callbacks associated with a specified event.

Source code in ultralytics/engine/validator.py
def run_callbacks(self, event: str):
    """Runs all callbacks associated with a specified event."""
    for callback in self.callbacks.get(event, []):
        callback(self)

update_metrics

update_metrics(preds, batch)

Updates metrics based on predictions and batch.

Source code in ultralytics/engine/validator.py
def update_metrics(self, preds, batch):
    """Updates metrics based on predictions and batch."""
    pass



📅 Created 11 months ago ✏️ Updated 1 month ago