%pylab inline
import torch
from torch.utils.data import IterableDataset
from torchvision import transforms
import webdataset as wds
from itertools import islice
Populating the interactive namespace from numpy and matplotlib
Getting Started
WebDataset reads dataset that are stored as tar files, with the simple convention that files that belong together and make up a training sample share the same basename. WebDataset can read files from local disk or from any pipe, which allows it to access files using common cloud object stores.
%%bash
curl -s http://storage.googleapis.com/nvdata-openimages/openimages-train-000000.tar | tar tf - | sed 10q
e39871fd9fd74f55.jpg
e39871fd9fd74f55.json
f18b91585c4d3f3e.jpg
f18b91585c4d3f3e.json
ede6e66b2fb59aab.jpg
ede6e66b2fb59aab.json
ed600d57fcee4f94.jpg
ed600d57fcee4f94.json
ff47e649b23f446d.jpg
ff47e649b23f446d.json
url = "http://storage.googleapis.com/nvdata-openimages/openimages-train-000000.tar"
url = f"pipe:curl -L -s {url} || true"
For starters, let's use the webdataset.Dataset
class to illustrate how the webdataset
library works.
dataset = wds.WebDataset(url)
for sample in islice(dataset, 0, 3):
for key, value in sample.items():
print(key, repr(value)[:50])
print()
__key__ 'e39871fd9fd74f55'
jpg b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x01
json b'[{"ImageID": "e39871fd9fd74f55", "Source": "xcli
__key__ 'f18b91585c4d3f3e'
jpg b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00
json b'[{"ImageID": "f18b91585c4d3f3e", "Source": "acti
__key__ 'ede6e66b2fb59aab'
jpg b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00
json b'[{"ImageID": "ede6e66b2fb59aab", "Source": "acti
There are common processing stages you can add to a dataset to make it a drop-in replacement for any existing dataset. For convenience, common operations are available through a "fluent" interface (as chained method calls).
dataset = (
wds.WebDataset(url)
.shuffle(100)
.decode("rgb")
.to_tuple("jpg;png", "json")
)
for image, data in islice(dataset, 0, 3):
print(image.shape, image.dtype, type(data))
(1024, 683, 3) float32 <class 'list'>
(660, 1024, 3) float32 <class 'list'>
(701, 1024, 3) float32 <class 'list'>
The webdataset.Dataset
class has some common operations:
shuffle(n)
: shuffle the dataset with a buffer of sizen
; also shuffles shards (see below)decode(decoder, ...)
: automatically decode files (most commonly, you can just specify"pil"
,"rgb"
,"rgb8"
,"rgbtorch"
, etc.)rename(new="old1;old2", ...)
: rename fieldsmap(f)
: applyf
to each samplemap_dict(key=f, ...)
: applyf
to its corresponding keymap_tuple(f, g, ...)
: applyf
,g
, etc. to their corresponding values in the tuplepipe(f)
:f
should be a function that takes an iterator and returns a new iterator
Stages commonly take a handler=
argument, which is a function that gets called when there is an exception; you can write whatever function you want, but common functions are:
webdataset.ignore_and_stop
webdataset.ignore_and_continue
webdataset.warn_and_stop
webdataset.warn_and_continue
webdataset.reraise_exception
Data Augmentation
Here is an example that uses torchvision
data augmentation the same way you might use it with a FileDataset
.
def identity(x):
return x
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
preproc = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
dataset = (
wds.WebDataset(url)
.shuffle(100)
.decode("pil")
.to_tuple("jpg;png", "json")
.map_tuple(preproc, identity)
)
for image, data in islice(dataset, 0, 3):
print(image.shape, image.dtype, type(data))
torch.Size([3, 224, 224]) torch.float32 <class 'list'>
torch.Size([3, 224, 224]) torch.float32 <class 'list'>
torch.Size([3, 224, 224]) torch.float32 <class 'list'>
WebDataset
and DataLoader
When used with a standard Torch DataLoader, this will would perform parallel I/O and preprocessing. However, the recommended way of using IterableDataset with DataLoader is to do the batching explicitly in the Dataset:
batch_size = 20
dataloader = torch.utils.data.DataLoader(dataset.batched(batch_size), num_workers=4, batch_size=None)
images, targets = next(iter(dataloader))
images.shape
/home/tmb/proj/webdataset/docs/webdataset/dataset.py:85: UserWarning: num_workers 4 > num_shards 1
warnings.warn(f"num_workers {num_workers} > num_shards {len(urls)}")
torch.Size([20, 3, 224, 224])
You can find the full PyTorch ImageNet sample code converted to WebDataset at tmbdev/pytorch-imagenet-wds