%pylab inline
import torch
import webdataset as wds
import braceexpand
Populating the interactive namespace from numpy and matplotlib
Splitting Shards across Nodes and Workers
Unlike traditional PyTorch Dataset
instances, WebDataset
splits data across nodes at the shard level, not at the sample level.
This functionality is handled inside the ShardList
class. Recall that dataset = webdataset.Webdataset(urls)
is just a shorthand for:
urls = list(braceexpand.braceexpand("dataset-{000000..000999}.tar"))
dataset = wds.ShardList(urls, splitter=wds.split_by_worker, nodesplitter=wds.split_by_node, shuffle=False)
dataset = wds.Processor(dataset, wds.url_opener)
dataset = wds.Processor(dataset, wds.tar_file_expander)
dataset = wds.Processor(dataset, wds.group_by_keys)
Here, nodesplitter
and splitter
are functions that are called inside ShardList
to split up the URLs in urls
by node and worker. You can use any functions you like there, all they need to do is take a list of URLs and return a subset of those URLs as a result.
The default split_by_worker
looks roughly like:
def my_split_by_worker(urls):
wi = torch.utils.data.get_worker_info()
if wi is None:
return urls
else:
return urls[wi.id::wi.num_workers]
The same approach works for multiple worker nodes:
def my_split_by_node(urls):
node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
return urls[node_id::node_count]
dataset = wds.WebDataset(urls, splitter=my_split_by_worker, nodesplitter=my_split_by_node)
Of course, you can also create more complex splitting strategies if necessary.
DistributedDataParallel
DistributedDataParallel training requires that each participating node receive exactly the same number of training batches as all others. The ddp_equalize
method ensures this:
urls = "./shards/imagenet-train-{000000..001281}.tar"
dataset_size, batch_size = 1282000, 64
dataset = wds.WebDataset(urls).decode("pil").shuffle(5000).batched(batch_size, partial=False)
loader = wds.WebLoader(dataset, num_workers=4)
loader = loader.ddp_equalize(dataset_size // batch_size)
You need to give the total number of batches in your dataset to ddp_equalize
; it will compute the batches per node from this and equalize batches accordingly.
You need to apply ddp_equalize
to the WebLoader
rather than the Dataset
.