Skip to content

Reference for ultralytics/data/build.py

Note

Full source code for this file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/build.py. Help us fix any issues you see by submitting a Pull Request 🛠️. Thank you 🙏!


ultralytics.data.build.InfiniteDataLoader

Bases: DataLoader

Dataloader that reuses workers. Uses same syntax as vanilla DataLoader.

Source code in ultralytics/data/build.py
class InfiniteDataLoader(dataloader.DataLoader):
    """Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""

    def __init__(self, *args, **kwargs):
        """Dataloader that infinitely recycles workers, inherits from DataLoader."""
        super().__init__(*args, **kwargs)
        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
        self.iterator = super().__iter__()

    def __len__(self):
        """Returns the length of the batch sampler's sampler."""
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        """Creates a sampler that repeats indefinitely."""
        for _ in range(len(self)):
            yield next(self.iterator)

    def reset(self):
        """Reset iterator.
        This is useful when we want to modify settings of dataset while training.
        """
        self.iterator = self._get_iterator()

__init__(*args, **kwargs)

Dataloader that infinitely recycles workers, inherits from DataLoader.

Source code in ultralytics/data/build.py
def __init__(self, *args, **kwargs):
    """Dataloader that infinitely recycles workers, inherits from DataLoader."""
    super().__init__(*args, **kwargs)
    object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
    self.iterator = super().__iter__()

__iter__()

Creates a sampler that repeats indefinitely.

Source code in ultralytics/data/build.py
def __iter__(self):
    """Creates a sampler that repeats indefinitely."""
    for _ in range(len(self)):
        yield next(self.iterator)

__len__()

Returns the length of the batch sampler's sampler.

Source code in ultralytics/data/build.py
def __len__(self):
    """Returns the length of the batch sampler's sampler."""
    return len(self.batch_sampler.sampler)

reset()

Reset iterator. This is useful when we want to modify settings of dataset while training.

Source code in ultralytics/data/build.py
def reset(self):
    """Reset iterator.
    This is useful when we want to modify settings of dataset while training.
    """
    self.iterator = self._get_iterator()




ultralytics.data.build._RepeatSampler

Sampler that repeats forever.

Parameters:

Name Type Description Default
sampler sampler

The sampler to repeat.

required
Source code in ultralytics/data/build.py
class _RepeatSampler:
    """
    Sampler that repeats forever.

    Args:
        sampler (Dataset.sampler): The sampler to repeat.
    """

    def __init__(self, sampler):
        """Initializes an object that repeats a given sampler indefinitely."""
        self.sampler = sampler

    def __iter__(self):
        """Iterates over the 'sampler' and yields its contents."""
        while True:
            yield from iter(self.sampler)

__init__(sampler)

Initializes an object that repeats a given sampler indefinitely.

Source code in ultralytics/data/build.py
def __init__(self, sampler):
    """Initializes an object that repeats a given sampler indefinitely."""
    self.sampler = sampler

__iter__()

Iterates over the 'sampler' and yields its contents.

Source code in ultralytics/data/build.py
def __iter__(self):
    """Iterates over the 'sampler' and yields its contents."""
    while True:
        yield from iter(self.sampler)




ultralytics.data.build.seed_worker(worker_id)

Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader.

Source code in ultralytics/data/build.py
def seed_worker(worker_id):  # noqa
    """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)




ultralytics.data.build.build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32)

Build YOLO Dataset

Source code in ultralytics/data/build.py
def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
    """Build YOLO Dataset"""
    return YOLODataset(
        img_path=img_path,
        imgsz=cfg.imgsz,
        batch_size=batch,
        augment=mode == 'train',  # augmentation
        hyp=cfg,  # TODO: probably add a get_hyps_from_cfg function
        rect=cfg.rect or rect,  # rectangular batches
        cache=cfg.cache or None,
        single_cls=cfg.single_cls or False,
        stride=int(stride),
        pad=0.0 if mode == 'train' else 0.5,
        prefix=colorstr(f'{mode}: '),
        use_segments=cfg.task == 'segment',
        use_keypoints=cfg.task == 'pose',
        classes=cfg.classes,
        data=data,
        fraction=cfg.fraction if mode == 'train' else 1.0)




ultralytics.data.build.build_dataloader(dataset, batch, workers, shuffle=True, rank=-1)

Return an InfiniteDataLoader or DataLoader for training or validation set.

Source code in ultralytics/data/build.py
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
    """Return an InfiniteDataLoader or DataLoader for training or validation set."""
    batch = min(batch, len(dataset))
    nd = torch.cuda.device_count()  # number of CUDA devices
    nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers])  # number of workers
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
    generator = torch.Generator()
    generator.manual_seed(6148914691236517205 + RANK)
    return InfiniteDataLoader(dataset=dataset,
                              batch_size=batch,
                              shuffle=shuffle and sampler is None,
                              num_workers=nw,
                              sampler=sampler,
                              pin_memory=PIN_MEMORY,
                              collate_fn=getattr(dataset, 'collate_fn', None),
                              worker_init_fn=seed_worker,
                              generator=generator)




ultralytics.data.build.check_source(source)

Check source type and return corresponding flag values.

Source code in ultralytics/data/build.py
def check_source(source):
    """Check source type and return corresponding flag values."""
    webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
    if isinstance(source, (str, int, Path)):  # int for local usb camera
        source = str(source)
        is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
        is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
        webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
        screenshot = source.lower() == 'screen'
        if is_url and is_file:
            source = check_file(source)  # download
    elif isinstance(source, LOADERS):
        in_memory = True
    elif isinstance(source, (list, tuple)):
        source = autocast_list(source)  # convert all list elements to PIL or np arrays
        from_img = True
    elif isinstance(source, (Image.Image, np.ndarray)):
        from_img = True
    elif isinstance(source, torch.Tensor):
        tensor = True
    else:
        raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict')

    return source, webcam, screenshot, from_img, in_memory, tensor




ultralytics.data.build.load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False)

Loads an inference source for object detection and applies necessary transformations.

Parameters:

Name Type Description Default
source (str, Path, Tensor, Image, ndarray)

The input source for inference.

None
imgsz int

The size of the image for inference. Default is 640.

640
vid_stride int

The frame interval for video sources. Default is 1.

1
buffer bool

Determined whether stream frames will be buffered. Default is False.

False

Returns:

Name Type Description
dataset Dataset

A dataset object for the specified input source.

Source code in ultralytics/data/build.py
def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
    """
    Loads an inference source for object detection and applies necessary transformations.

    Args:
        source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
        imgsz (int, optional): The size of the image for inference. Default is 640.
        vid_stride (int, optional): The frame interval for video sources. Default is 1.
        buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.

    Returns:
        dataset (Dataset): A dataset object for the specified input source.
    """
    source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
    source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)

    # Dataloader
    if tensor:
        dataset = LoadTensor(source)
    elif in_memory:
        dataset = source
    elif webcam:
        dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, buffer=buffer)
    elif screenshot:
        dataset = LoadScreenshots(source, imgsz=imgsz)
    elif from_img:
        dataset = LoadPilAndNumpy(source, imgsz=imgsz)
    else:
        dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)

    # Attach source types to the dataset
    setattr(dataset, 'source_type', source_type)

    return dataset




Created 2023-07-16, Updated 2023-08-07
Authors: glenn-jocher (5), Laughing-q (1)