%pylab inline
import torch
from torch.utils.data import IterableDataset
from torchvision import transforms
import webdataset as wds
from itertools import islice
Populating the interactive namespace from numpy and matplotlib
Data Decoding
Data decoding is a special kind of transformations of samples. You could simply write a decoding function like this:
def my_sample_decoder(sample):
result = dict(__key__=sample["__key__"])
for key, value in sample.items():
if key == "png" or key.endswith(".png"):
result[key] = mageio.imread(io.BytesIO(value))
elif ...:
...
return result
dataset = wds.Processor(dataset, wds.map, my_sample_decoder)
This gets tedious, though, and it also unnecessarily hardcodes the sample's keys into the processing pipeline. To help with this, there is a helper class that simplifies this kind of code. The primary use of Decoder
is for decoding compressed image, video, and audio formats, as well as unzipping .gz
files.
Here is an example of automatically decoding .png
images with imread
and using the default torch_video
and torch_audio
decoders for video and audio:
def my_png_decoder(key, value):
if not key.endswith(".png"):
return None
assert isinstance(value, bytes)
return imageio.imread(io.BytesIO(value))
dataset = wds.Decoder(my_png_decoder, wds.torch_video, wds.torch_audio)(dataset)
You can use whatever criteria you like for deciding how to decode values in samples. When used with standard WebDataset
format files, the keys are the full extensions of the file names inside a .tar
file. For consistency, it's recommended that you primarily rely on the extensions (e.g., .png
, .mp4
) to decide which decoders to use. There is a special helper function that simplifies this:
def my_decoder(value):
return imageio.imread(io.BytesIO(value))
dataset = wds.Decoder(wds.handle_extension(".png", my_decoder))(dataset)
If you want to "decode everyting" automatically and even override some extensions, you can use something like:
url = "http://storage.googleapis.com/nvdata-openimages/openimages-train-000000.tar"
url = f"pipe:curl -L -s {url} || true"
def png_decoder_16bpp(key, data):
...
dataset = wds.WebDataset(url).decode(
wds.handle_extension("left.png", png_decoder_16bpp),
wds.handle_extension("right.png", png_decoder_16bpp),
wds.imagehandler("torchrgb"),
wds.torch_audio,
wds.torch_video
)
This code would...
- handle any file with a ".left.png" or ".right.png" extension using a special 16bpp PNG decoder function
- decode all other image extensions to three channel Torch tensors
- decode audio files using the
torchaudio
library - decode video files using the
torchvideo
library
In order to decode images, audio, and video, it would dynamically load the Pillow
, torchaudio
, and torchvideo
libraries.
Automatic Decompression
The default decoder handles compressed files automatically. That is .json.gz
is decompressed first using the gzip
library and then treated as if it had been called .json
.
In other words, you can store compressed files directly in a WebDataset
and decompression is handled for you automatically.
If you want to add your own decompressors, look at the implementation of webdataset.autodecode.gzfilter
.