WebDataset + Distributed PyTorch Training

This notebook illustrates how to use the Web Indexed Dataset (wids) library for distributed PyTorch training using DistributedDataParallel.

Using webdataset results in training code that is almost identical to plain PyTorch except for the dataset creation. Since WebDataset is an iterable dataset, you need to account for that when creating the DataLoader. Furthermore, for distributed training, easy restarts, etc., it is convenient to use a resampled dataset; this is in contrast to sampling without replacement for each epoch as used more commonly for small, local training. (If you want to use sampling without replacement with webdataset format datasets, see the companion wids-based training notebooks.)

Training with WebDataset can be carried out completely without local storage; this is the usual setup in the cloud and on high speed compute clusters. When running locally on a desktop, you may want to cache the data, and for that, you set a cache_dir directory.

import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torchvision.models import resnet50
from torchvision import datasets, transforms
import ray
import webdataset as wds
import dataclasses
import time
from collections import deque
from typing import Optional


def enumerate_report(seq, delta, growth=1.0):
    last = 0
    count = 0
    for count, item in enumerate(seq):
        now = time.time()
        if now - last > delta:
            last = now
            yield count, item, True
        else:
            yield count, item, False
        delta *= growth
# Parameters
epochs = 10
maxsteps = int(1e12)
batch_size = 32

Data Loading for Distributed Training

# Datasets are just collections of shards in the cloud. We usually specify
# them using {lo..hi} brace notation (there is also a YAML spec for more complex
# datasets).

bucket = "https://storage.googleapis.com/webdataset/fake-imagenet"
trainset_url = bucket + "/imagenet-train-{000000..001281}.tar"
valset_url = bucket + "/imagenet-val-{000000..000049}.tar"
batch_size = 32
# If running in the cloud or with a fast network storage system, we don't
# need any local storage.

if "google.colab" in sys.modules:
    cache_dir = None
    print("running on colab, streaming data directly from storage")
else:
    cache_dir = "./_cache"
    print(f"not running in colab, caching data locally in {cache_dir}")
# The dataloader pipeline is a fairly typical `IterableDataset` pipeline
# for PyTorch


def make_dataloader_train():
    """Create a DataLoader for training on the ImageNet dataset using WebDataset."""

    transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )

    def make_sample(sample):
        return transform(sample["jpg"]), sample["cls"]

    # This is the basic WebDataset definition: it starts with a URL and add shuffling,
    # decoding, and augmentation. Note `resampled=True`; this is essential for
    # distributed training to work correctly.
    trainset = wds.WebDataset(trainset_url, resampled=True, shardshuffle=True, cache_dir=cache_dir, nodesplitter=wds.split_by_node)
    trainset = trainset.shuffle(1000).decode("pil").map(make_sample)

    # For IterableDataset objects, the batching needs to happen in the dataset.
    trainset = trainset.batched(64)
    trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)

    # We unbatch, shuffle, and rebatch to mix samples from different workers.
    trainloader = trainloader.unbatched().shuffle(1000).batched(batch_size)

    # A resampled dataset is infinite size, but we can recreate a fixed epoch length.
    trainloader = trainloader.with_epoch(1282 * 100 // 64)

    return trainloader
# Let's try it out


def make_dataloader(split="train"):
    """Make a dataloader for training or validation."""
    if split == "train":
        return make_dataloader_train()
    elif split == "val":
        return make_dataloader_val()  # not implemented for this notebook
    else:
        raise ValueError(f"unknown split {split}")


# Try it out.
os.environ["GOPEN_VERBOSE"] = "1"
sample = next(iter(make_dataloader()))
print(sample[0].shape, sample[1].shape)
os.environ["GOPEN_VERBOSE"] = "0"

Standard PyTorch Training

This is completely standard PyTorch training; nothing changes by using WebDataset.

# We gather all the configuration info into a single typed dataclass.


@dataclasses.dataclass
class Config:
    epochs: int = 1
    max_steps: int = int(1e18)
    lr: float = 0.001
    momentum: float = 0.9
    rank: Optional[int] = None
    world_size: int = 2
    backend: str = "nccl"
    master_addr: str = "localhost"
    master_port: str = "12355"
    report_s: float = 15.0
    report_growth: float = 1.1
def train(config):
    # Define the model, loss function, and optimizer
    model = resnet50(pretrained=False).cuda()
    if config.rank is not None:
        model = DistributedDataParallel(model)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)

    # Data loading code
    trainloader = make_dataloader(split="train")

    losses, accuracies, steps = deque(maxlen=100), deque(maxlen=100), 0

    # Training loop
    for epoch in range(config.epochs):
        for i, data, verbose in enumerate_report(trainloader, config.report_s):
            inputs, labels = data[0].cuda(), data[1].cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)

            # update statistics
            loss = loss_fn(outputs, labels)
            accuracy = (
                (outputs.argmax(1) == labels).float().mean()
            )  # calculate accuracy
            losses.append(loss.item())
            accuracies.append(accuracy.item())

            if verbose and len(losses) > 0:
                avgloss = sum(losses) / len(losses)
                avgaccuracy = sum(accuracies) / len(accuracies)
                print(
                    f"rank {config.rank} epoch {epoch:5d}/{i:9d} loss {avgloss:8.3f} acc {avgaccuracy:8.3f} {steps:9d}",
                    file=sys.stderr,
                )
            loss.backward()
            optimizer.step()
            steps += len(labels)
            if steps > config.max_steps:
                print(
                    "finished training (max_steps)",
                    steps,
                    config.max_steps,
                    file=sys.stderr,
                )
                return

    print("finished Training", steps)
# A quick smoke test of the training function.

config = Config()
config.epochs = 1
config.max_steps = 1000
train(config)

Setting up Distributed Training with Ray

Ray is a convenient distributed computing framework. We are using it here to start up the training jobs on multiple GPUs. You can use torch.distributed.launch or other such tools as well with the above code. Ray has the advantage that it is runtime environment independent; you set up your Ray cluster in whatever way works for your environment, and afterwards, this code will run in it without change.

@ray.remote(num_gpus=1)
def train_on_ray(rank, config):
    """Set up distributed torch env and train the model on this node."""
    # Set up distributed PyTorch.
    if rank is not None:
        os.environ["MASTER_ADDR"] = config.master_addr
        os.environ["MASTER_PORT"] = config.master_port
        dist.init_process_group(
            backend=config.backend, rank=rank, world_size=config.world_size
        )
        config.rank = rank
        # Ray will automatically set CUDA_VISIBLE_DEVICES for each task.
    train(config)
if not ray.is_initialized():
    ray.init()

ray.available_resources()["GPU"]


def distributed_training(config):
    """Perform distributed training with the given config."""
    num_gpus = ray.available_resources()["GPU"]
    config.world_size = min(config.world_size, num_gpus)
    results = ray.get(
        [train_on_ray.remote(i, config) for i in range(config.world_size)]
    )
    print(results)


config = Config()
config.epochs = epochs
config.max_steps = max_steps
config.batch_size = batch_size
print(config)
distributed_training(config)