%pylab inline

import torch
from torch.utils.data import IterableDataset
from torchvision import transforms
import webdataset as wds
from itertools import islice
Populating the interactive namespace from numpy and matplotlib

Getting Started

WebDataset reads dataset that are stored as tar files, with the simple convention that files that belong together and make up a training sample share the same basename. WebDataset can read files from local disk or from any pipe, which allows it to access files using common cloud object stores.

%%bash
curl -s http://storage.googleapis.com/nvdata-openimages/openimages-train-000000.tar | tar tf - | sed 10q
e39871fd9fd74f55.jpg
e39871fd9fd74f55.json
f18b91585c4d3f3e.jpg
f18b91585c4d3f3e.json
ede6e66b2fb59aab.jpg
ede6e66b2fb59aab.json
ed600d57fcee4f94.jpg
ed600d57fcee4f94.json
ff47e649b23f446d.jpg
ff47e649b23f446d.json
url = "http://storage.googleapis.com/nvdata-openimages/openimages-train-000000.tar"
url = f"pipe:curl -L -s {url} || true"

For starters, let's use the webdataset.Dataset class to illustrate how the webdataset library works.

dataset = wds.WebDataset(url)

for sample in islice(dataset, 0, 3):
    for key, value in sample.items():
        print(key, repr(value)[:50])
    print()
__key__ 'e39871fd9fd74f55'
jpg b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x01
json b'[{"ImageID": "e39871fd9fd74f55", "Source": "xcli

__key__ 'f18b91585c4d3f3e'
jpg b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00
json b'[{"ImageID": "f18b91585c4d3f3e", "Source": "acti

__key__ 'ede6e66b2fb59aab'
jpg b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00
json b'[{"ImageID": "ede6e66b2fb59aab", "Source": "acti

There are common processing stages you can add to a dataset to make it a drop-in replacement for any existing dataset. For convenience, common operations are available through a "fluent" interface (as chained method calls).

dataset = (
    wds.WebDataset(url)
    .shuffle(100)
    .decode("rgb")
    .to_tuple("jpg;png", "json")
)

for image, data in islice(dataset, 0, 3):
    print(image.shape, image.dtype, type(data))
(1024, 683, 3) float32 <class 'list'>
(660, 1024, 3) float32 <class 'list'>
(701, 1024, 3) float32 <class 'list'>

The webdataset.Dataset class has some common operations:

  • shuffle(n): shuffle the dataset with a buffer of size n; also shuffles shards (see below)
  • decode(decoder, ...): automatically decode files (most commonly, you can just specify "pil", "rgb", "rgb8", "rgbtorch", etc.)
  • rename(new="old1;old2", ...): rename fields
  • map(f): apply f to each sample
  • map_dict(key=f, ...): apply f to its corresponding key
  • map_tuple(f, g, ...): apply f, g, etc. to their corresponding values in the tuple
  • pipe(f): f should be a function that takes an iterator and returns a new iterator

Stages commonly take a handler= argument, which is a function that gets called when there is an exception; you can write whatever function you want, but common functions are:

  • webdataset.ignore_and_stop
  • webdataset.ignore_and_continue
  • webdataset.warn_and_stop
  • webdataset.warn_and_continue
  • webdataset.reraise_exception

Data Augmentation

Here is an example that uses torchvision data augmentation the same way you might use it with a FileDataset.

def identity(x):
    return x

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

preproc = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

dataset = (
    wds.WebDataset(url)
    .shuffle(100)
    .decode("pil")
    .to_tuple("jpg;png", "json")
    .map_tuple(preproc, identity)
)

for image, data in islice(dataset, 0, 3):
    print(image.shape, image.dtype, type(data))
torch.Size([3, 224, 224]) torch.float32 <class 'list'>
torch.Size([3, 224, 224]) torch.float32 <class 'list'>
torch.Size([3, 224, 224]) torch.float32 <class 'list'>

WebDataset and DataLoader

When used with a standard Torch DataLoader, this will would perform parallel I/O and preprocessing. However, the recommended way of using IterableDataset with DataLoader is to do the batching explicitly in the Dataset:

batch_size = 20
dataloader = torch.utils.data.DataLoader(dataset.batched(batch_size), num_workers=4, batch_size=None)
images, targets = next(iter(dataloader))
images.shape
/home/tmb/proj/webdataset/docs/webdataset/dataset.py:85: UserWarning: num_workers 4 > num_shards 1
  warnings.warn(f"num_workers {num_workers} > num_shards {len(urls)}")





torch.Size([20, 3, 224, 224])

You can find the full PyTorch ImageNet sample code converted to WebDataset at tmbdev/pytorch-imagenet-wds