Skip to content

Reference for ultralytics/engine/validator.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/validator.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.engine.validator.BaseValidator

BaseValidator(dataloader=None, save_dir=None, args=None, _callbacks=None)

A base class for creating validators.

This class provides the foundation for validation processes, including model evaluation, metric computation, and result visualization.

Attributes:

Name Type Description
args SimpleNamespace

Configuration for the validator.

dataloader DataLoader

Dataloader to use for validation.

model Module

Model to validate.

data dict

Data dictionary containing dataset information.

device device

Device to use for validation.

batch_i int

Current batch index.

training bool

Whether the model is in training mode.

names dict

Class names mapping.

seen int

Number of images seen so far during validation.

stats dict

Statistics collected during validation.

confusion_matrix

Confusion matrix for classification evaluation.

nc int

Number of classes.

iouv Tensor

IoU thresholds from 0.50 to 0.95 in spaces of 0.05.

jdict list

List to store JSON validation results.

speed dict

Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch processing times in milliseconds.

save_dir Path

Directory to save results.

plots dict

Dictionary to store plots for visualization.

callbacks dict

Dictionary to store various callback functions.

stride int

Model stride for padding calculations.

loss Tensor

Accumulated loss during training validation.

Methods:

Name Description
__call__

Execute validation process, running inference on dataloader and computing performance metrics.

match_predictions

Match predictions to ground truth objects using IoU.

add_callback

Append the given callback to the specified event.

run_callbacks

Run all callbacks associated with a specified event.

get_dataloader

Get data loader from dataset path and batch size.

build_dataset

Build dataset from image path.

preprocess

Preprocess an input batch.

postprocess

Postprocess the predictions.

init_metrics

Initialize performance metrics for the YOLO model.

update_metrics

Update metrics based on predictions and batch.

finalize_metrics

Finalize and return all metrics.

get_stats

Return statistics about the model's performance.

print_results

Print the results of the model's predictions.

get_desc

Get description of the YOLO model.

on_plot

Register plots for visualization.

plot_val_samples

Plot validation samples during training.

plot_predictions

Plot YOLO model predictions on batch images.

pred_to_json

Convert predictions to JSON format.

eval_json

Evaluate and return JSON format of prediction statistics.

Parameters:

Name Type Description Default
dataloader DataLoader

Dataloader to be used for validation.

None
save_dir Path

Directory to save results.

None
args SimpleNamespace

Configuration for the validator.

None
_callbacks dict

Dictionary to store various callback functions.

None
Source code in ultralytics/engine/validator.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
    """
    Initialize a BaseValidator instance.

    Args:
        dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
        save_dir (Path, optional): Directory to save results.
        args (SimpleNamespace, optional): Configuration for the validator.
        _callbacks (dict, optional): Dictionary to store various callback functions.
    """
    import torchvision  # noqa (import here so torchvision import time not recorded in postprocess time)

    self.args = get_cfg(overrides=args)
    self.dataloader = dataloader
    self.stride = None
    self.data = None
    self.device = None
    self.batch_i = None
    self.training = True
    self.names = None
    self.seen = None
    self.stats = None
    self.confusion_matrix = None
    self.nc = None
    self.iouv = None
    self.jdict = None
    self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}

    self.save_dir = save_dir or get_save_dir(self.args)
    (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
    if self.args.conf is None:
        self.args.conf = 0.01 if self.args.task == "obb" else 0.001  # reduce OBB val memory usage
    self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)

    self.plots = {}
    self.callbacks = _callbacks or callbacks.get_default_callbacks()

metric_keys property

metric_keys

Return the metric keys used in YOLO training/validation.

__call__

__call__(trainer=None, model=None)

Execute validation process, running inference on dataloader and computing performance metrics.

Parameters:

Name Type Description Default
trainer object

Trainer object that contains the model to validate.

None
model Module

Model to validate if not using a trainer.

None

Returns:

Type Description
dict

Dictionary containing validation statistics.

Source code in ultralytics/engine/validator.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
@smart_inference_mode()
def __call__(self, trainer=None, model=None):
    """
    Execute validation process, running inference on dataloader and computing performance metrics.

    Args:
        trainer (object, optional): Trainer object that contains the model to validate.
        model (nn.Module, optional): Model to validate if not using a trainer.

    Returns:
        (dict): Dictionary containing validation statistics.
    """
    self.training = trainer is not None
    augment = self.args.augment and (not self.training)
    if self.training:
        self.device = trainer.device
        self.data = trainer.data
        # Force FP16 val during training
        self.args.half = self.device.type != "cpu" and trainer.amp
        model = trainer.ema.ema or trainer.model
        if trainer.args.compile and hasattr(model, "_orig_mod"):
            model = model._orig_mod  # validate non-compiled original model to avoid issues
        model = model.half() if self.args.half else model.float()
        self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
        self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
        model.eval()
    else:
        if str(self.args.model).endswith(".yaml") and model is None:
            LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
        callbacks.add_integration_callbacks(self)
        model = AutoBackend(
            model=model or self.args.model,
            device=select_device(self.args.device) if RANK == -1 else torch.device("cuda", RANK),
            dnn=self.args.dnn,
            data=self.args.data,
            fp16=self.args.half,
        )
        self.device = model.device  # update device
        self.args.half = model.fp16  # update half
        stride, pt, jit = model.stride, model.pt, model.jit
        imgsz = check_imgsz(self.args.imgsz, stride=stride)
        if not (pt or jit or getattr(model, "dynamic", False)):
            self.args.batch = model.metadata.get("batch", 1)  # export.py models default to batch-size 1
            LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")

        if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
            self.data = check_det_dataset(self.args.data)
        elif self.args.task == "classify":
            self.data = check_cls_dataset(self.args.data, split=self.args.split)
        else:
            raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))

        if self.device.type in {"cpu", "mps"}:
            self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading
        if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
            self.args.rect = False
        self.stride = model.stride  # used in get_dataloader() for padding
        self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)

        model.eval()
        if self.args.compile:
            model = attempt_compile(model, device=self.device)
        model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz))  # warmup

    self.run_callbacks("on_val_start")
    dt = (
        Profile(device=self.device),
        Profile(device=self.device),
        Profile(device=self.device),
        Profile(device=self.device),
    )
    bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
    self.init_metrics(unwrap_model(model))
    self.jdict = []  # empty before each val
    for batch_i, batch in enumerate(bar):
        self.run_callbacks("on_val_batch_start")
        self.batch_i = batch_i
        # Preprocess
        with dt[0]:
            batch = self.preprocess(batch)

        # Inference
        with dt[1]:
            preds = model(batch["img"], augment=augment)

        # Loss
        with dt[2]:
            if self.training:
                self.loss += model.loss(batch, preds)[1]

        # Postprocess
        with dt[3]:
            preds = self.postprocess(preds)

        self.update_metrics(preds, batch)
        if self.args.plots and batch_i < 3 and RANK in {-1, 0}:
            self.plot_val_samples(batch, batch_i)
            self.plot_predictions(batch, preds, batch_i)

        self.run_callbacks("on_val_batch_end")

    stats = {}
    self.gather_stats()
    if RANK in {-1, 0}:
        stats = self.get_stats()
        self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
        self.finalize_metrics()
        self.print_results()
        self.run_callbacks("on_val_end")

    if self.training:
        model.float()
        # Reduce loss across all GPUs
        loss = self.loss.clone().detach()
        if trainer.world_size > 1:
            dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
        if RANK > 0:
            return
        results = {**stats, **trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val")}
        return {k: round(float(v), 5) for k, v in results.items()}  # return results as 5 decimal place floats
    else:
        if RANK > 0:
            return stats
        LOGGER.info(
            "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
                *tuple(self.speed.values())
            )
        )
        if self.args.save_json and self.jdict:
            with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f:
                LOGGER.info(f"Saving {f.name}...")
                json.dump(self.jdict, f)  # flatten and save
            stats = self.eval_json(stats)  # update stats
        if self.args.plots or self.args.save_json:
            LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
        return stats

add_callback

add_callback(event: str, callback)

Append the given callback to the specified event.

Source code in ultralytics/engine/validator.py
312
313
314
def add_callback(self, event: str, callback):
    """Append the given callback to the specified event."""
    self.callbacks[event].append(callback)

build_dataset

build_dataset(img_path)

Build dataset from image path.

Source code in ultralytics/engine/validator.py
325
326
327
def build_dataset(self, img_path):
    """Build dataset from image path."""
    raise NotImplementedError("build_dataset function not implemented in validator")

eval_json

eval_json(stats)

Evaluate and return JSON format of prediction statistics.

Source code in ultralytics/engine/validator.py
386
387
388
def eval_json(self, stats):
    """Evaluate and return JSON format of prediction statistics."""
    pass

finalize_metrics

finalize_metrics()

Finalize and return all metrics.

Source code in ultralytics/engine/validator.py
345
346
347
def finalize_metrics(self):
    """Finalize and return all metrics."""
    pass

gather_stats

gather_stats()

Gather statistics from all the GPUs during DDP training to GPU 0.

Source code in ultralytics/engine/validator.py
353
354
355
def gather_stats(self):
    """Gather statistics from all the GPUs during DDP training to GPU 0."""
    pass

get_dataloader

get_dataloader(dataset_path, batch_size)

Get data loader from dataset path and batch size.

Source code in ultralytics/engine/validator.py
321
322
323
def get_dataloader(self, dataset_path, batch_size):
    """Get data loader from dataset path and batch size."""
    raise NotImplementedError("get_dataloader function not implemented for this validator")

get_desc

get_desc()

Get description of the YOLO model.

Source code in ultralytics/engine/validator.py
361
362
363
def get_desc(self):
    """Get description of the YOLO model."""
    pass

get_stats

get_stats()

Return statistics about the model's performance.

Source code in ultralytics/engine/validator.py
349
350
351
def get_stats(self):
    """Return statistics about the model's performance."""
    return {}

init_metrics

init_metrics(model)

Initialize performance metrics for the YOLO model.

Source code in ultralytics/engine/validator.py
337
338
339
def init_metrics(self, model):
    """Initialize performance metrics for the YOLO model."""
    pass

match_predictions

match_predictions(
    pred_classes: Tensor,
    true_classes: Tensor,
    iou: Tensor,
    use_scipy: bool = False,
) -> torch.Tensor

Match predictions to ground truth objects using IoU.

Parameters:

Name Type Description Default
pred_classes Tensor

Predicted class indices of shape (N,).

required
true_classes Tensor

Target class indices of shape (M,).

required
iou Tensor

An NxM tensor containing the pairwise IoU values for predictions and ground truth.

required
use_scipy bool

Whether to use scipy for matching (more precise).

False

Returns:

Type Description
Tensor

Correct tensor of shape (N, 10) for 10 IoU thresholds.

Source code in ultralytics/engine/validator.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def match_predictions(
    self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
) -> torch.Tensor:
    """
    Match predictions to ground truth objects using IoU.

    Args:
        pred_classes (torch.Tensor): Predicted class indices of shape (N,).
        true_classes (torch.Tensor): Target class indices of shape (M,).
        iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
        use_scipy (bool, optional): Whether to use scipy for matching (more precise).

    Returns:
        (torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
    """
    # Dx10 matrix, where D - detections, 10 - IoU thresholds
    correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
    # LxD matrix where L - labels (rows), D - detections (columns)
    correct_class = true_classes[:, None] == pred_classes
    iou = iou * correct_class  # zero out the wrong classes
    iou = iou.cpu().numpy()
    for i, threshold in enumerate(self.iouv.cpu().tolist()):
        if use_scipy:
            # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
            import scipy  # scope import to avoid importing for all commands

            cost_matrix = iou * (iou >= threshold)
            if cost_matrix.any():
                labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix)
                valid = cost_matrix[labels_idx, detections_idx] > 0
                if valid.any():
                    correct[detections_idx[valid], i] = True
        else:
            matches = np.nonzero(iou >= threshold)  # IoU > threshold and classes match
            matches = np.array(matches).T
            if matches.shape[0]:
                if matches.shape[0] > 1:
                    matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
                    matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                    matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
                correct[matches[:, 1].astype(int), i] = True
    return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)

on_plot

on_plot(name, data=None)

Register plots for visualization.

Source code in ultralytics/engine/validator.py
370
371
372
def on_plot(self, name, data=None):
    """Register plots for visualization."""
    self.plots[Path(name)] = {"data": data, "timestamp": time.time()}

plot_predictions

plot_predictions(batch, preds, ni)

Plot YOLO model predictions on batch images.

Source code in ultralytics/engine/validator.py
378
379
380
def plot_predictions(self, batch, preds, ni):
    """Plot YOLO model predictions on batch images."""
    pass

plot_val_samples

plot_val_samples(batch, ni)

Plot validation samples during training.

Source code in ultralytics/engine/validator.py
374
375
376
def plot_val_samples(self, batch, ni):
    """Plot validation samples during training."""
    pass

postprocess

postprocess(preds)

Postprocess the predictions.

Source code in ultralytics/engine/validator.py
333
334
335
def postprocess(self, preds):
    """Postprocess the predictions."""
    return preds

pred_to_json

pred_to_json(preds, batch)

Convert predictions to JSON format.

Source code in ultralytics/engine/validator.py
382
383
384
def pred_to_json(self, preds, batch):
    """Convert predictions to JSON format."""
    pass

preprocess

preprocess(batch)

Preprocess an input batch.

Source code in ultralytics/engine/validator.py
329
330
331
def preprocess(self, batch):
    """Preprocess an input batch."""
    return batch

print_results

print_results()

Print the results of the model's predictions.

Source code in ultralytics/engine/validator.py
357
358
359
def print_results(self):
    """Print the results of the model's predictions."""
    pass

run_callbacks

run_callbacks(event: str)

Run all callbacks associated with a specified event.

Source code in ultralytics/engine/validator.py
316
317
318
319
def run_callbacks(self, event: str):
    """Run all callbacks associated with a specified event."""
    for callback in self.callbacks.get(event, []):
        callback(self)

update_metrics

update_metrics(preds, batch)

Update metrics based on predictions and batch.

Source code in ultralytics/engine/validator.py
341
342
343
def update_metrics(self, preds, batch):
    """Update metrics based on predictions and batch."""
    pass





📅 Created 1 year ago ✏️ Updated 1 year ago