WebIndexedDataset + Distributed PyTorch Training

This notebook illustrates how to use the Web Indexed Dataset (wids) library for distributed PyTorch training using DistributedDataParallel.

Using wids results in training code that is almost identical to plain PyTorch, with the only changes being the use of ShardListDataset for the dataset construction, and the use of the DistributedChunkedSampler for generating random samples from the dataset.

ShardListDataset requires some local storage. By default, that local storage just grows as shards are downloaded, but if you have limited space, you can run create_cleanup_background_process to clean up the cache; shards will be re-downloaded as necessary.

import os
import sys
from typing import (
    List,
    Tuple,
    Dict,
    Optional,
    Any,
    Union,
    Callable,
    Iterable,
    Iterator,
    NamedTuple,
    Set,
    Sequence,
)
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torchvision.models import resnet50
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import ray
import wids
import dataclasses
import time
from collections import deque
from pprint import pprint


def enumerate_report(seq, delta, growth=1.0):
    last = 0
    count = 0
    for count, item in enumerate(seq):
        now = time.time()
        if now - last > delta:
            last = now
            yield count, item, True
        else:
            yield count, item, False
        delta *= growth

Data Loading for Distributed Training

The datasets we use for training are stored in the cloud. We use fake-imagenet, which is 1/10th the size of Imagenet and artificially generated, but it has the same number of shards and trains quickly.

Note that unlike the webdataset library, wids always needs a local cache directory (it will use /tmp if you don't give it anything explicitly).

# Parameters
epochs = 1
max_steps = int(1e12)
batch_size = 32
bucket = "https://storage.googleapis.com/webdataset/fake-imagenet/"
trainset_url = bucket+"imagenet-train.json"
valset_url = bucket+"imagenet-val.json"
cache_dir = "./_cache"
# This is a typical PyTorch dataset, except that we read from the cloud.


def make_dataset_train():
    transform_train = transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    def make_sample(sample):
        image = sample[".jpg"]
        label = sample[".cls"]
        return transform_train(image), label

    trainset = wids.ShardListDataset(trainset_url, cache_dir="./_cache", keep=True)
    trainset = trainset.add_transform(make_sample)

    return trainset

This is really the only thing that is ever so slightly special about the wids library: you should use the special DistributedChunkedSampler for sampling.

The regular DistributedSampler will technically work, but because of its poor locality of reference, will be significantly slower.

# To keep locality of reference in the dataloader, we use a special sampler
# for distributed training, DistributedChunkedSampler.


def make_dataloader_train():
    dataset = make_dataset_train()
    sampler = wids.DistributedChunkedSampler(dataset, chunksize=1000, shuffle=True)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, sampler=sampler, num_workers=4
    )
    return dataloader


def make_dataloader(split="train"):
    """Make a dataloader for training or validation."""
    if split == "train":
        return make_dataloader_train()
    elif split == "val":
        return make_dataloader_val()
    else:
        raise ValueError(f"unknown split {split}")


# Try it out.
sample = next(iter(make_dataloader()))
print(sample[0].shape, sample[1].shape)

PyTorch Distributed Training Code

Really, all that's needed for distributed training is the DistributedDataParallel wrapper around the model.

# For convenience, we collect all the configuration parameters into
# a dataclass.


@dataclasses.dataclass
class Config:
    rank: Optional[int] = None
    epochs: int = 1
    max_steps: int = int(1e18)
    lr: float = 0.001
    momentum: float = 0.9
    world_size: int = 8
    backend: str = "nccl"
    master_addr: str = "localhost"
    master_port: str = "12355"
    report_s: float = 15.0
    report_growth: float = 1.1


Config()
# A typical PyTorch training function.


def train(config):
    # Define the model, loss function, and optimizer
    model = resnet50(pretrained=False).cuda()
    if config.rank is not None:
        model = DistributedDataParallel(model)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)

    # Data loading code
    trainloader = make_dataloader(split="train")

    losses, accuracies, steps = deque(maxlen=100), deque(maxlen=100), 0

    # Training loop
    for epoch in range(config.epochs):
        for i, data, verbose in enumerate_report(trainloader, config.report_s):
            inputs, labels = data[0].cuda(), data[1].cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            # just bookkeping and progress report
            steps += len(labels)
            accuracy = (
                (outputs.argmax(1) == labels).float().mean()
            )  # calculate accuracy
            losses.append(loss.item())
            accuracies.append(accuracy.item())
            if verbose and len(losses) > 0:
                avgloss = sum(losses) / len(losses)
                avgaccuracy = sum(accuracies) / len(accuracies)
                print(
                    f"rank {config.rank} epoch {epoch:5d}/{i:9d} loss {avgloss:8.3f} acc {avgaccuracy:8.3f} {steps:9d}",
                    file=sys.stderr,
                )
            if steps > config.max_steps:
                print(
                    "finished training (max_steps)",
                    steps,
                    config.max_steps,
                    file=sys.stderr,
                )
                return

    print("finished Training", steps)
# A quick smoke test.

os.environ["GOPEN_VERBOSE"] = "1"
config = Config()
config.epochs = 1
config.max_steps = 1000
train(config)
os.environ["GOPEN_VERBOSE"] = "0"

Distributed Training in Ray

The code above can be used with any distributed computing framwork, including torch.distributed.launch.

Below is simply an example of how to launch the training jobs with the Ray framework. Ray is nice for distributed training because it makes the Python code independent of the runtime environment (Kubernetes, Slurm, ad-hoc networking, etc.). Meaning, the code below will work regardless of how you start up your Ray cluster.

# The distributed training function to be used with Ray.
# Since this is started via Ray remote, we set up the distributed
# training environment here.


@ray.remote(num_gpus=1)
def train_in_ray(rank, config):
    if rank is not None:
        # Set up distributed PyTorch.
        config.rank = rank
        os.environ["MASTER_ADDR"] = config.master_addr
        os.environ["MASTER_PORT"] = config.master_port
        dist.init_process_group(
            backend=config.backend, rank=rank, world_size=config.world_size
        )
        # Ray will automatically set CUDA_VISIBLE_DEVICES for each task.
    train(config)
if not ray.is_initialized():
    ray.init()
print("#gpus available in the cluster", ray.available_resources()["GPU"])


def distributed_training(config):
    num_gpus = ray.available_resources()["GPU"]
    config.world_size = int(min(config.world_size, num_gpus))
    pprint(config)
    results = ray.get(
        [train_in_ray.remote(i, config) for i in range(config.world_size)]
    )
    print(results)


config = Config()
config.epochs = epochs
config.max_steps = max_steps
config.batch_size = batch_size
print(config)
distributed_training(config)