WebDataset + Distributed PyTorch Training
This notebook illustrates how to use the Web Indexed Dataset (wids
) library for distributed PyTorch training using DistributedDataParallel
.
Using webdataset
results in training code that is almost identical to plain PyTorch except for the dataset creation.
Since WebDataset
is an iterable dataset, you need to account for that when creating the DataLoader
. Furthermore, for
distributed training, easy restarts, etc., it is convenient to use a resampled dataset; this is in contrast to
sampling without replacement for each epoch as used more commonly for small, local training. (If you want to use
sampling without replacement with webdataset format datasets, see the companion wids
-based training notebooks.)
Training with WebDataset
can be carried out completely without local storage; this is the usual setup in the cloud
and on high speed compute clusters. When running locally on a desktop, you may want to cache the data, and for that,
you set a cache_dir
directory.
import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torchvision.models import resnet50
from torchvision import datasets, transforms
import ray
import webdataset as wds
import dataclasses
import time
from collections import deque
from typing import Optional
def enumerate_report(seq, delta, growth=1.0):
last = 0
count = 0
for count, item in enumerate(seq):
now = time.time()
if now - last > delta:
last = now
yield count, item, True
else:
yield count, item, False
delta *= growth
# Parameters
epochs = 10
maxsteps = int(1e12)
batch_size = 32
Data Loading for Distributed Training
# Datasets are just collections of shards in the cloud. We usually specify
# them using {lo..hi} brace notation (there is also a YAML spec for more complex
# datasets).
bucket = "https://storage.googleapis.com/webdataset/fake-imagenet"
trainset_url = bucket + "/imagenet-train-{000000..001281}.tar"
valset_url = bucket + "/imagenet-val-{000000..000049}.tar"
batch_size = 32
# If running in the cloud or with a fast network storage system, we don't
# need any local storage.
if "google.colab" in sys.modules:
cache_dir = None
print("running on colab, streaming data directly from storage")
else:
cache_dir = "./_cache"
print(f"not running in colab, caching data locally in {cache_dir}")
# The dataloader pipeline is a fairly typical `IterableDataset` pipeline
# for PyTorch
def make_dataloader_train():
"""Create a DataLoader for training on the ImageNet dataset using WebDataset."""
transform = transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
def make_sample(sample):
return transform(sample["jpg"]), sample["cls"]
# This is the basic WebDataset definition: it starts with a URL and add shuffling,
# decoding, and augmentation. Note `resampled=True`; this is essential for
# distributed training to work correctly.
trainset = wds.WebDataset(trainset_url, resampled=True, shardshuffle=True, cache_dir=cache_dir, nodesplitter=wds.split_by_node)
trainset = trainset.shuffle(1000).decode("pil").map(make_sample)
# For IterableDataset objects, the batching needs to happen in the dataset.
trainset = trainset.batched(64)
trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)
# We unbatch, shuffle, and rebatch to mix samples from different workers.
trainloader = trainloader.unbatched().shuffle(1000).batched(batch_size)
# A resampled dataset is infinite size, but we can recreate a fixed epoch length.
trainloader = trainloader.with_epoch(1282 * 100 // 64)
return trainloader
# Let's try it out
def make_dataloader(split="train"):
"""Make a dataloader for training or validation."""
if split == "train":
return make_dataloader_train()
elif split == "val":
return make_dataloader_val() # not implemented for this notebook
else:
raise ValueError(f"unknown split {split}")
# Try it out.
os.environ["GOPEN_VERBOSE"] = "1"
sample = next(iter(make_dataloader()))
print(sample[0].shape, sample[1].shape)
os.environ["GOPEN_VERBOSE"] = "0"
Standard PyTorch Training
This is completely standard PyTorch training; nothing changes by using WebDataset.
# We gather all the configuration info into a single typed dataclass.
@dataclasses.dataclass
class Config:
epochs: int = 1
max_steps: int = int(1e18)
lr: float = 0.001
momentum: float = 0.9
rank: Optional[int] = None
world_size: int = 2
backend: str = "nccl"
master_addr: str = "localhost"
master_port: str = "12355"
report_s: float = 15.0
report_growth: float = 1.1
def train(config):
# Define the model, loss function, and optimizer
model = resnet50(pretrained=False).cuda()
if config.rank is not None:
model = DistributedDataParallel(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)
# Data loading code
trainloader = make_dataloader(split="train")
losses, accuracies, steps = deque(maxlen=100), deque(maxlen=100), 0
# Training loop
for epoch in range(config.epochs):
for i, data, verbose in enumerate_report(trainloader, config.report_s):
inputs, labels = data[0].cuda(), data[1].cuda()
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
# update statistics
loss = loss_fn(outputs, labels)
accuracy = (
(outputs.argmax(1) == labels).float().mean()
) # calculate accuracy
losses.append(loss.item())
accuracies.append(accuracy.item())
if verbose and len(losses) > 0:
avgloss = sum(losses) / len(losses)
avgaccuracy = sum(accuracies) / len(accuracies)
print(
f"rank {config.rank} epoch {epoch:5d}/{i:9d} loss {avgloss:8.3f} acc {avgaccuracy:8.3f} {steps:9d}",
file=sys.stderr,
)
loss.backward()
optimizer.step()
steps += len(labels)
if steps > config.max_steps:
print(
"finished training (max_steps)",
steps,
config.max_steps,
file=sys.stderr,
)
return
print("finished Training", steps)
# A quick smoke test of the training function.
config = Config()
config.epochs = 1
config.max_steps = 1000
train(config)
Setting up Distributed Training with Ray
Ray is a convenient distributed computing framework. We are using it here to start up the training
jobs on multiple GPUs. You can use torch.distributed.launch
or other such tools as well with the above
code. Ray has the advantage that it is runtime environment independent; you set up your Ray cluster
in whatever way works for your environment, and afterwards, this code will run in it without change.
@ray.remote(num_gpus=1)
def train_on_ray(rank, config):
"""Set up distributed torch env and train the model on this node."""
# Set up distributed PyTorch.
if rank is not None:
os.environ["MASTER_ADDR"] = config.master_addr
os.environ["MASTER_PORT"] = config.master_port
dist.init_process_group(
backend=config.backend, rank=rank, world_size=config.world_size
)
config.rank = rank
# Ray will automatically set CUDA_VISIBLE_DEVICES for each task.
train(config)
if not ray.is_initialized():
ray.init()
ray.available_resources()["GPU"]
def distributed_training(config):
"""Perform distributed training with the given config."""
num_gpus = ray.available_resources()["GPU"]
config.world_size = min(config.world_size, num_gpus)
results = ray.get(
[train_on_ray.remote(i, config) for i in range(config.world_size)]
)
print(results)
config = Config()
config.epochs = epochs
config.max_steps = max_steps
config.batch_size = batch_size
print(config)
distributed_training(config)