import webdataset as wds
import braceexpand
from torch.utils.data import IterableDataset
from webdataset import gopen
Local and Remote Storage URLs
WebDataset refers to data sources using file paths or URLs. The following are all valid ways of referring to a data source:
dataset = wds.WebDataset("dataset-000.tar")
dataset = wds.WebDataset("file:dataset-000.tar")
dataset = wds.WebDataset("http://server/dataset-000.tar")
An additional way of referring to data is using the pipe:
scheme, so the following is also equivalent to the above references:
dataset = wds.WebDataset("pipe:cat dataset-000.tar")
You can use the same notation for accessing data in cloud storage:
dataset = wds.WebDataset("pipe:gsutil cat gs://somebucket/dataset-000.tar")
Note that access to standard web schemas are implemented using curl
. That is, http://server/dataset.tar
is internally simply treated like pipe:curl -s -L 'http://server/dataset.tar'
. The use of curl
to access Internet protocols actually is more efficient than using the built-in http
library because it results in asynchronous name resolution and downloads.
File opening is handled by webdataset.gopen.gopen
. This is a small function that just wraps standard Python file I/O and pipe capabilities.
You can define handlers for new schemes or override implementations for existing schemes by adding entries to wds.gopen_schemes
:
def gopen_gs(url, mode="rb", bufsize=8192):
...
gopen.gopen_schemes["gs"] = gopen_gs
Standard Input/Output
For the following examples, assume that we have a program called image-classifier
that takes a WebDataset containing just JPEG files as input and produces a WebDataset containing JPEG files and their corresponding classifications in JSON format:
image-classifier input-shard.tar --output=output-shard.tar --model=some-model.pth
As a special case, the string "-" refers to standard input (reading) or standard output (writing). This allows code using WebDataset to be used as part of pipes. This is useful, for example, inside Kubernetes containers with limited local storage. Assume that you store shards in Google Cloud and access it with gsutil
. Using "-", you can simply write:
gsutil cat gs://input-bucket/data-000174.tar | image-classifer - -o - | gsutil cp - gs://output-bucket/output-000174.tar
It's also useful to create shards on the fly using tar
and extract the result immediately; this lets you use shard based programs directly for operating on individual files. For example, for the image-classifier
program above, you can write:
tar cf - *.jpg | shard-classifier - -o - | tar xvf - --include '.json'
This is the rough equivalent of:
for fname in *.jpg; do
image-classifier $fname > $(basename $fname .jpg).cls
done
Multiple Shards and Mixing Datasets
The WebDataset
and ShardList
classes take either a string or a list of strings as an argument. When given a string, the string is expanded using braceexpand
. Therefore, the following three datasets are equivalent:
dataset = wds.WebDataset(["dataset-000.tar", "dataset-001.tar", "dataset-002.tar", "dataset-003.tar"])
dataset = wds.WebDataset("dataset-{000..003}.tar")
dataset = wds.WebDataset("file:dataset-{000..003}.tar")
For complex training problems, you may want to mix multiple datasets, where each dataset consists of multiple shards. A good way is to expand each shard spec individually using braceexpand
and concatenate the lists. Then you can pass the result list as an argument to WebDataset
.
urls = (
list(braceexpand.braceexpand("imagenet-{000000..000146}.tar")) +
list(braceexpand.braceexpand("openimages-{000000..000547}.tar")) +
list(braceexpand.braceexpand("custom-images-{000000..000999}.tar"))
)
print(len(urls))
dataset = wds.WebDataset(urls, shardshuffle=True).shuffle(10000).decode("torchrgb")
1695
Mixing Datsets with a Custom IterableDataset
Class
For more complex sampling problems, you can also write sample processors. For example, to sample equally from several datasets, you could write something like this (the Shorthands
and Composable
base classes just add some convenience methods):
class SampleEqually(IterableDataset, wds.Shorthands, wds.Composable):
def __init__(self, datasets):
super().__init__()
self.datasets = datasets
def __iter__(self):
sources = [iter(ds) for ds in self.datasets]
while True:
for source in sources:
try:
yield next(source)
except StopIteration:
return
Now we can mix samples from different sources in more complex ways:
dataset1 = wds.WebDataset("imagenet-{000000..000146}.tar", shardshuffle=True).shuffle(1000).decode("torchrgb")
dataset2 = wds.WebDataset("openimages-{000000..000547}.tar", shardshuffle=True).shuffle(1000).decode("torchrgb")
dataset3 = wds.WebDataset("custom-images-{000000..000999}.tar", shardshuffle=True).shuffle(1000).decode("torchrgb")
dataset = SampleEqually([dataset1, dataset2, dataset3]).shuffle(1000)