Skip to content

Reference for ultralytics/data/base.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/base.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.data.base.BaseDataset

BaseDataset(
    img_path,
    imgsz=640,
    cache=False,
    augment=True,
    hyp=DEFAULT_CFG,
    prefix="",
    rect=False,
    batch_size=16,
    stride=32,
    pad=0.5,
    single_cls=False,
    classes=None,
    fraction=1.0,
)

Bases: Dataset

Base dataset class for loading and processing image data.

This class provides core functionality for loading images, caching, and preparing data for training and inference in object detection tasks.

Attributes:

Name Type Description
img_path str

Path to the folder containing images.

imgsz int

Target image size for resizing.

augment bool

Whether to apply data augmentation.

single_cls bool

Whether to treat all objects as a single class.

prefix str

Prefix to print in log messages.

fraction float

Fraction of dataset to utilize.

im_files List[str]

List of image file paths.

labels List[Dict]

List of label data dictionaries.

ni int

Number of images in the dataset.

rect bool

Whether to use rectangular training.

batch_size int

Size of batches.

stride int

Stride used in the model.

pad float

Padding value.

buffer list

Buffer for mosaic images.

max_buffer_length int

Maximum buffer size.

ims list

List of loaded images.

im_hw0 list

List of original image dimensions (h, w).

im_hw list

List of resized image dimensions (h, w).

npy_files List[Path]

List of numpy file paths.

cache str

Cache images to RAM or disk during training.

transforms callable

Image transformation function.

Methods:

Name Description
get_img_files

Read image files from the specified path.

update_labels

Update labels to include only specified classes.

load_image

Load an image from the dataset.

cache_images

Cache images to memory or disk.

cache_images_to_disk

Save an image as an *.npy file for faster loading.

check_cache_disk

Check image caching requirements vs available disk space.

check_cache_ram

Check image caching requirements vs available memory.

set_rectangle

Set the shape of bounding boxes as rectangles.

get_image_and_label

Get and return label information from the dataset.

update_labels_info

Custom label format method to be implemented by subclasses.

build_transforms

Build transformation pipeline to be implemented by subclasses.

get_labels

Get labels method to be implemented by subclasses.

Parameters:

Name Type Description Default
img_path str

Path to the folder containing images.

required
imgsz int

Image size for resizing.

640
cache bool | str

Cache images to RAM or disk during training.

False
augment bool

If True, data augmentation is applied.

True
hyp dict

Hyperparameters to apply data augmentation.

DEFAULT_CFG
prefix str

Prefix to print in log messages.

''
rect bool

If True, rectangular training is used.

False
batch_size int

Size of batches.

16
stride int

Stride used in the model.

32
pad float

Padding value.

0.5
single_cls bool

If True, single class training is used.

False
classes list

List of included classes.

None
fraction float

Fraction of dataset to utilize.

1.0
Source code in ultralytics/data/base.py
def __init__(
    self,
    img_path,
    imgsz=640,
    cache=False,
    augment=True,
    hyp=DEFAULT_CFG,
    prefix="",
    rect=False,
    batch_size=16,
    stride=32,
    pad=0.5,
    single_cls=False,
    classes=None,
    fraction=1.0,
):
    """
    Initialize BaseDataset with given configuration and options.

    Args:
        img_path (str): Path to the folder containing images.
        imgsz (int, optional): Image size for resizing.
        cache (bool | str, optional): Cache images to RAM or disk during training.
        augment (bool, optional): If True, data augmentation is applied.
        hyp (dict, optional): Hyperparameters to apply data augmentation.
        prefix (str, optional): Prefix to print in log messages.
        rect (bool, optional): If True, rectangular training is used.
        batch_size (int, optional): Size of batches.
        stride (int, optional): Stride used in the model.
        pad (float, optional): Padding value.
        single_cls (bool, optional): If True, single class training is used.
        classes (list, optional): List of included classes.
        fraction (float, optional): Fraction of dataset to utilize.
    """
    super().__init__()
    self.img_path = img_path
    self.imgsz = imgsz
    self.augment = augment
    self.single_cls = single_cls
    self.prefix = prefix
    self.fraction = fraction
    self.im_files = self.get_img_files(self.img_path)
    self.labels = self.get_labels()
    self.update_labels(include_class=classes)  # single_cls and include_class
    self.ni = len(self.labels)  # number of images
    self.rect = rect
    self.batch_size = batch_size
    self.stride = stride
    self.pad = pad
    if self.rect:
        assert self.batch_size is not None
        self.set_rectangle()

    # Buffer thread for mosaic images
    self.buffer = []  # buffer size = batch size
    self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0

    # Cache images (options are cache = True, False, None, "ram", "disk")
    self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
    self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
    self.cache = cache.lower() if isinstance(cache, str) else "ram" if cache is True else None
    if self.cache == "ram" and self.check_cache_ram():
        if hyp.deterministic:
            LOGGER.warning(
                "WARNING ⚠️ cache='ram' may produce non-deterministic training results. "
                "Consider cache='disk' as a deterministic alternative if your disk space allows."
            )
        self.cache_images()
    elif self.cache == "disk" and self.check_cache_disk():
        self.cache_images()

    # Transforms
    self.transforms = self.build_transforms(hyp=hyp)

__getitem__

__getitem__(index)

Return transformed label information for given index.

Source code in ultralytics/data/base.py
def __getitem__(self, index):
    """Return transformed label information for given index."""
    return self.transforms(self.get_image_and_label(index))

__len__

__len__()

Return the length of the labels list for the dataset.

Source code in ultralytics/data/base.py
def __len__(self):
    """Return the length of the labels list for the dataset."""
    return len(self.labels)

build_transforms

build_transforms(hyp=None)

Users can customize augmentations here.

Examples:

>>> if self.augment:
...     # Training transforms
...     return Compose([])
>>> else:
...    # Val transforms
...    return Compose([])
Source code in ultralytics/data/base.py
def build_transforms(self, hyp=None):
    """
    Users can customize augmentations here.

    Examples:
        >>> if self.augment:
        ...     # Training transforms
        ...     return Compose([])
        >>> else:
        ...    # Val transforms
        ...    return Compose([])
    """
    raise NotImplementedError

cache_images

cache_images()

Cache images to memory or disk for faster training.

Source code in ultralytics/data/base.py
def cache_images(self):
    """Cache images to memory or disk for faster training."""
    b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
    fcn, storage = (self.cache_images_to_disk, "Disk") if self.cache == "disk" else (self.load_image, "RAM")
    with ThreadPool(NUM_THREADS) as pool:
        results = pool.imap(fcn, range(self.ni))
        pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
        for i, x in pbar:
            if self.cache == "disk":
                b += self.npy_files[i].stat().st_size
            else:  # 'ram'
                self.ims[i], self.im_hw0[i], self.im_hw[i] = x  # im, hw_orig, hw_resized = load_image(self, i)
                b += self.ims[i].nbytes
            pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {storage})"
        pbar.close()

cache_images_to_disk

cache_images_to_disk(i)

Save an image as an *.npy file for faster loading.

Source code in ultralytics/data/base.py
def cache_images_to_disk(self, i):
    """Save an image as an *.npy file for faster loading."""
    f = self.npy_files[i]
    if not f.exists():
        np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False)

check_cache_disk

check_cache_disk(safety_margin=0.5)

Check if there's enough disk space for caching images.

Parameters:

Name Type Description Default
safety_margin float

Safety margin factor for disk space calculation.

0.5

Returns:

Type Description
bool

True if there's enough disk space, False otherwise.

Source code in ultralytics/data/base.py
def check_cache_disk(self, safety_margin=0.5):
    """
    Check if there's enough disk space for caching images.

    Args:
        safety_margin (float, optional): Safety margin factor for disk space calculation.

    Returns:
        (bool): True if there's enough disk space, False otherwise.
    """
    import shutil

    b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
    n = min(self.ni, 30)  # extrapolate from 30 random images
    for _ in range(n):
        im_file = random.choice(self.im_files)
        im = cv2.imread(im_file)
        if im is None:
            continue
        b += im.nbytes
        if not os.access(Path(im_file).parent, os.W_OK):
            self.cache = None
            LOGGER.info(f"{self.prefix}Skipping caching images to disk, directory not writeable ⚠️")
            return False
    disk_required = b * self.ni / n * (1 + safety_margin)  # bytes required to cache dataset to disk
    total, used, free = shutil.disk_usage(Path(self.im_files[0]).parent)
    if disk_required > free:
        self.cache = None
        LOGGER.info(
            f"{self.prefix}{disk_required / gb:.1f}GB disk space required, "
            f"with {int(safety_margin * 100)}% safety margin but only "
            f"{free / gb:.1f}/{total / gb:.1f}GB free, not caching images to disk ⚠️"
        )
        return False
    return True

check_cache_ram

check_cache_ram(safety_margin=0.5)

Check if there's enough RAM for caching images.

Parameters:

Name Type Description Default
safety_margin float

Safety margin factor for RAM calculation.

0.5

Returns:

Type Description
bool

True if there's enough RAM, False otherwise.

Source code in ultralytics/data/base.py
def check_cache_ram(self, safety_margin=0.5):
    """
    Check if there's enough RAM for caching images.

    Args:
        safety_margin (float, optional): Safety margin factor for RAM calculation.

    Returns:
        (bool): True if there's enough RAM, False otherwise.
    """
    b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
    n = min(self.ni, 30)  # extrapolate from 30 random images
    for _ in range(n):
        im = cv2.imread(random.choice(self.im_files))  # sample image
        if im is None:
            continue
        ratio = self.imgsz / max(im.shape[0], im.shape[1])  # max(h, w)  # ratio
        b += im.nbytes * ratio**2
    mem_required = b * self.ni / n * (1 + safety_margin)  # GB required to cache dataset into RAM
    mem = psutil.virtual_memory()
    if mem_required > mem.available:
        self.cache = None
        LOGGER.info(
            f"{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images "
            f"with {int(safety_margin * 100)}% safety margin but only "
            f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, not caching images ⚠️"
        )
        return False
    return True

get_image_and_label

get_image_and_label(index)

Get and return label information from the dataset.

Parameters:

Name Type Description Default
index int

Index of the image to retrieve.

required

Returns:

Type Description
dict

Label dictionary with image and metadata.

Source code in ultralytics/data/base.py
def get_image_and_label(self, index):
    """
    Get and return label information from the dataset.

    Args:
        index (int): Index of the image to retrieve.

    Returns:
        (dict): Label dictionary with image and metadata.
    """
    label = deepcopy(self.labels[index])  # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
    label.pop("shape", None)  # shape is for rect, remove it
    label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
    label["ratio_pad"] = (
        label["resized_shape"][0] / label["ori_shape"][0],
        label["resized_shape"][1] / label["ori_shape"][1],
    )  # for evaluation
    if self.rect:
        label["rect_shape"] = self.batch_shapes[self.batch[index]]
    return self.update_labels_info(label)

get_img_files

get_img_files(img_path)

Read image files from the specified path.

Parameters:

Name Type Description Default
img_path str | List[str]

Path or list of paths to image directories or files.

required

Returns:

Type Description
List[str]

List of image file paths.

Raises:

Type Description
FileNotFoundError

If no images are found or the path doesn't exist.

Source code in ultralytics/data/base.py
def get_img_files(self, img_path):
    """
    Read image files from the specified path.

    Args:
        img_path (str | List[str]): Path or list of paths to image directories or files.

    Returns:
        (List[str]): List of image file paths.

    Raises:
        FileNotFoundError: If no images are found or the path doesn't exist.
    """
    try:
        f = []  # image files
        for p in img_path if isinstance(img_path, list) else [img_path]:
            p = Path(p)  # os-agnostic
            if p.is_dir():  # dir
                f += glob.glob(str(p / "**" / "*.*"), recursive=True)
                # F = list(p.rglob('*.*'))  # pathlib
            elif p.is_file():  # file
                with open(p, encoding="utf-8") as t:
                    t = t.read().strip().splitlines()
                    parent = str(p.parent) + os.sep
                    f += [x.replace("./", parent) if x.startswith("./") else x for x in t]  # local to global path
                    # F += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
            else:
                raise FileNotFoundError(f"{self.prefix}{p} does not exist")
        im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
        # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlib
        assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
    except Exception as e:
        raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
    if self.fraction < 1:
        im_files = im_files[: round(len(im_files) * self.fraction)]  # retain a fraction of the dataset
    return im_files

get_labels

get_labels()

Users can customize their own format here.

Note

Ensure output is a dictionary with the following keys:

dict(
    im_file=im_file,
    shape=shape,  # format: (height, width)
    cls=cls,
    bboxes=bboxes,  # xywh
    segments=segments,  # xy
    keypoints=keypoints,  # xy
    normalized=True,  # or False
    bbox_format="xyxy",  # or xywh, ltwh
)

Source code in ultralytics/data/base.py
def get_labels(self):
    """
    Users can customize their own format here.

    Note:
        Ensure output is a dictionary with the following keys:
        ```python
        dict(
            im_file=im_file,
            shape=shape,  # format: (height, width)
            cls=cls,
            bboxes=bboxes,  # xywh
            segments=segments,  # xy
            keypoints=keypoints,  # xy
            normalized=True,  # or False
            bbox_format="xyxy",  # or xywh, ltwh
        )
        ```
    """
    raise NotImplementedError

load_image

load_image(i, rect_mode=True)

Load an image from dataset index 'i'.

Parameters:

Name Type Description Default
i int

Index of the image to load.

required
rect_mode bool

Whether to use rectangular resizing.

True

Returns:

Type Description
ndarray

Loaded image.

tuple

Original image dimensions (h, w).

tuple

Resized image dimensions (h, w).

Raises:

Type Description
FileNotFoundError

If the image file is not found.

Source code in ultralytics/data/base.py
def load_image(self, i, rect_mode=True):
    """
    Load an image from dataset index 'i'.

    Args:
        i (int): Index of the image to load.
        rect_mode (bool, optional): Whether to use rectangular resizing.

    Returns:
        (np.ndarray): Loaded image.
        (tuple): Original image dimensions (h, w).
        (tuple): Resized image dimensions (h, w).

    Raises:
        FileNotFoundError: If the image file is not found.
    """
    im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
    if im is None:  # not cached in RAM
        if fn.exists():  # load npy
            try:
                im = np.load(fn)
            except Exception as e:
                LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}")
                Path(fn).unlink(missing_ok=True)
                im = cv2.imread(f)  # BGR
        else:  # read image
            im = cv2.imread(f)  # BGR
        if im is None:
            raise FileNotFoundError(f"Image Not Found {f}")

        h0, w0 = im.shape[:2]  # orig hw
        if rect_mode:  # resize long side to imgsz while maintaining aspect ratio
            r = self.imgsz / max(h0, w0)  # ratio
            if r != 1:  # if sizes are not equal
                w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))
                im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
        elif not (h0 == w0 == self.imgsz):  # resize by stretching image to square imgsz
            im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)

        # Add to buffer if training with augmentations
        if self.augment:
            self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
            self.buffer.append(i)
            if 1 < len(self.buffer) >= self.max_buffer_length:  # prevent empty buffer
                j = self.buffer.pop(0)
                if self.cache != "ram":
                    self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None

        return im, (h0, w0), im.shape[:2]

    return self.ims[i], self.im_hw0[i], self.im_hw[i]

set_rectangle

set_rectangle()

Set the shape of bounding boxes for YOLO detections as rectangles.

Source code in ultralytics/data/base.py
def set_rectangle(self):
    """Set the shape of bounding boxes for YOLO detections as rectangles."""
    bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int)  # batch index
    nb = bi[-1] + 1  # number of batches

    s = np.array([x.pop("shape") for x in self.labels])  # hw
    ar = s[:, 0] / s[:, 1]  # aspect ratio
    irect = ar.argsort()
    self.im_files = [self.im_files[i] for i in irect]
    self.labels = [self.labels[i] for i in irect]
    ar = ar[irect]

    # Set training image shapes
    shapes = [[1, 1]] * nb
    for i in range(nb):
        ari = ar[bi == i]
        mini, maxi = ari.min(), ari.max()
        if maxi < 1:
            shapes[i] = [maxi, 1]
        elif mini > 1:
            shapes[i] = [1, 1 / mini]

    self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
    self.batch = bi  # batch index of image

update_labels

update_labels(include_class: Optional[list])

Update labels to include only specified classes.

Parameters:

Name Type Description Default
include_class list

List of classes to include. If None, all classes are included.

required
Source code in ultralytics/data/base.py
def update_labels(self, include_class: Optional[list]):
    """
    Update labels to include only specified classes.

    Args:
        include_class (list, optional): List of classes to include. If None, all classes are included.
    """
    include_class_array = np.array(include_class).reshape(1, -1)
    for i in range(len(self.labels)):
        if include_class is not None:
            cls = self.labels[i]["cls"]
            bboxes = self.labels[i]["bboxes"]
            segments = self.labels[i]["segments"]
            keypoints = self.labels[i]["keypoints"]
            j = (cls == include_class_array).any(1)
            self.labels[i]["cls"] = cls[j]
            self.labels[i]["bboxes"] = bboxes[j]
            if segments:
                self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
            if keypoints is not None:
                self.labels[i]["keypoints"] = keypoints[j]
        if self.single_cls:
            self.labels[i]["cls"][:, 0] = 0

update_labels_info

update_labels_info(label)

Custom your label format here.

Source code in ultralytics/data/base.py
def update_labels_info(self, label):
    """Custom your label format here."""
    return label



📅 Created 1 year ago ✏️ Updated 6 months ago