Sharding, Parallel I/O, and DataLoader

WebDataset datasets are usually split into many shards; this is both to achieve parallel I/O and to shuffle data.

%pylab inline

import torch
from torch.utils.data import IterableDataset
from torchvision import transforms
import webdataset as wds
from itertools import islice
Populating the interactive namespace from numpy and matplotlib

Sets of shards can be given as a list of files, or they can be written using the brace notation, as in openimages-train-{000000..000554}.tar. For example, the OpenImages dataset consists of 554 shards, each containing about 1 Gbyte of images. You can open the entire dataset as follows (note the explicit use of both shardshuffle=True (for shuffling the shards and the .shuffle processor for shuffling samples inline).

url = "http://storage.googleapis.com/nvdata-openimages/openimages-train-{000000..000554}.tar"
url = f"pipe:curl -L -s {url} || true"

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

preproc = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

dataset = (
    wds.WebDataset(url, shardshuffle=True)
    .shuffle(100)
    .decode("pil")
    .to_tuple("jpg;png", "json")
    .map_tuple(preproc)
)

x, y = next(iter(dataset))
print(x.shape, str(y)[:50])
torch.Size([3, 224, 224]) [{'ImageID': '19a7594f418fe39e', 'Source': 'xclick

When used with a standard Torch DataLoader, this will would perform parallel I/O and preprocessing. However, the recommended way of using IterableDataset with DataLoader is to do the batching explicitly in the Dataset:

batch_size = 20
dataloader = torch.utils.data.DataLoader(dataset.batched(batch_size), num_workers=4, batch_size=None)
images, targets = next(iter(dataloader))
images.shape
torch.Size([20, 3, 224, 224])

Explicit Dataset Sizes

Ideally, you shouldn't use len(dataset) or len(loader) at all in your training loop. However, some code may use calls to the len(.) function. WebDataset generally propagates such calls back through the chain of dataset processors. Generally, IterableDataset implementations don't have a size, but you can specify an explicit size using the length= argument to WebDataset.

You can also use the ResizedDataset class to force an IterableDataset to have a specific epoch length and (if desired) set a separate nominal epoch length.