%pylab inline

import torch
from torch.utils.data import IterableDataset
from torchvision import transforms
import webdataset as wds
from itertools import islice
import numpy as np
Populating the interactive namespace from numpy and matplotlib

Writing Filters and Offline Augmentation

Webdataset can be used for filters and offline augmentation of datasets. Here is a complete example that pre-augments a shard and extracts class labels.

def extract_class(data):
    # mock implementation
    return 0

def augment(a):
    a += torch.randn_like(a) * 0.01
    return a

def augment_wds(url, output, maxcount=999999999):
    src = (
        wds.WebDataset(url)
        .decode("torchrgb")
        .to_tuple("__key__", "jpg;png", "json")
        .map_tuple(lambda x: x, augment)
    )
    with wds.TarWriter(output) as dst:
        for key, image, data in islice(src, 0, maxcount):
            print(key)
            image = image.numpy().transpose(1, 2, 0)
            image -= amin(image)
            image /= amax(image)
            sample = {
                "__key__": key,
                "png": image,
                "cls": extract_class(data)
            }
            dst.write(sample)

Now run the augmentation pipeline:

url = "http://storage.googleapis.com/nvdata-openimages/openimages-train-000000.tar"
url = f"pipe:curl -L -s {url} || true"
augment_wds(url, "_temp.tar", maxcount=5)
e39871fd9fd74f55
f18b91585c4d3f3e
ede6e66b2fb59aab
ed600d57fcee4f94
ff47e649b23f446d

To verify that things worked correctly, let's look at the output file:

%%bash
tar tf _temp.tar
e39871fd9fd74f55.cls
e39871fd9fd74f55.png
f18b91585c4d3f3e.cls
f18b91585c4d3f3e.png
ede6e66b2fb59aab.cls
ede6e66b2fb59aab.png
ed600d57fcee4f94.cls
ed600d57fcee4f94.png
ff47e649b23f446d.cls
ff47e649b23f446d.png

If you want to preprocess the entire OpenImages dataset with a process like this, you can use your favorite job queueing or worflow system.

For example, using Dask, you could process all 554 shards in parallel using code like this:

shards = braceexpand.braceexpand("{000000..000554}")
inputs = [f"gs://bucket/openimages-{shard}.tar" for shard in shards]
outputs = [f"gs://bucket2/openimages-augmented-{shard}.tar" for shard in shards]
results = [dask.delayed(augment_wds)(args) for args in zip(inputs, outputs)]
dask.compute(*results)

Note that the data is streaming from and to Google Cloud Storage buckets, so very little local storage is required on each worker.

For very large scale processing, it's easiest to submit separate jobs to a Kubernetes cluster using the Kubernetes Job template, or using a workflow engine like Argo.