WebIndexedDataset + Distributed PyTorch Training
This notebook illustrates how to use the Web Indexed Dataset (wids
) library for distributed PyTorch training using DistributedDataParallel
.
Using wids
results in training code that is almost identical to plain PyTorch, with the only changes being the use of ShardListDataset
for the dataset construction, and the use of the DistributedChunkedSampler
for generating random samples from the dataset.
ShardListDataset
requires some local storage. By default, that local storage just grows as shards are downloaded, but if you have limited space, you can run create_cleanup_background_process
to clean up the cache; shards will be re-downloaded as necessary.
import os
import sys
from typing import (
List,
Tuple,
Dict,
Optional,
Any,
Union,
Callable,
Iterable,
Iterator,
NamedTuple,
Set,
Sequence,
)
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torchvision.models import resnet50
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import ray
import wids
import dataclasses
import time
from collections import deque
from pprint import pprint
def enumerate_report(seq, delta, growth=1.0):
last = 0
count = 0
for count, item in enumerate(seq):
now = time.time()
if now - last > delta:
last = now
yield count, item, True
else:
yield count, item, False
delta *= growth
Data Loading for Distributed Training
The datasets we use for training are stored in the cloud.
We use fake-imagenet
, which is 1/10th the size of Imagenet
and artificially generated, but it has the same number of
shards and trains quickly.
Note that unlike the webdataset
library, wids
always needs
a local cache directory (it will use /tmp
if you don't give it
anything explicitly).
# Parameters
epochs = 1
max_steps = int(1e12)
batch_size = 32
bucket = "https://storage.googleapis.com/webdataset/fake-imagenet/"
trainset_url = bucket+"imagenet-train.json"
valset_url = bucket+"imagenet-val.json"
cache_dir = "./_cache"
# This is a typical PyTorch dataset, except that we read from the cloud.
def make_dataset_train():
transform_train = transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def make_sample(sample):
image = sample[".jpg"]
label = sample[".cls"]
return transform_train(image), label
trainset = wids.ShardListDataset(trainset_url, cache_dir="./_cache", keep=True)
trainset = trainset.add_transform(make_sample)
return trainset
This is really the only thing that is ever so slightly special about the wids
library:
you should use the special DistributedChunkedSampler
for sampling.
The regular DistributedSampler
will technically work, but because of its poor locality
of reference, will be significantly slower.
# To keep locality of reference in the dataloader, we use a special sampler
# for distributed training, DistributedChunkedSampler.
def make_dataloader_train():
dataset = make_dataset_train()
sampler = wids.DistributedChunkedSampler(dataset, chunksize=1000, shuffle=True)
dataloader = DataLoader(
dataset, batch_size=batch_size, sampler=sampler, num_workers=4
)
return dataloader
def make_dataloader(split="train"):
"""Make a dataloader for training or validation."""
if split == "train":
return make_dataloader_train()
elif split == "val":
return make_dataloader_val()
else:
raise ValueError(f"unknown split {split}")
# Try it out.
sample = next(iter(make_dataloader()))
print(sample[0].shape, sample[1].shape)
PyTorch Distributed Training Code
Really, all that's needed for distributed training is the DistributedDataParallel
wrapper around the model.
# For convenience, we collect all the configuration parameters into
# a dataclass.
@dataclasses.dataclass
class Config:
rank: Optional[int] = None
epochs: int = 1
max_steps: int = int(1e18)
lr: float = 0.001
momentum: float = 0.9
world_size: int = 8
backend: str = "nccl"
master_addr: str = "localhost"
master_port: str = "12355"
report_s: float = 15.0
report_growth: float = 1.1
Config()
# A typical PyTorch training function.
def train(config):
# Define the model, loss function, and optimizer
model = resnet50(pretrained=False).cuda()
if config.rank is not None:
model = DistributedDataParallel(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)
# Data loading code
trainloader = make_dataloader(split="train")
losses, accuracies, steps = deque(maxlen=100), deque(maxlen=100), 0
# Training loop
for epoch in range(config.epochs):
for i, data, verbose in enumerate_report(trainloader, config.report_s):
inputs, labels = data[0].cuda(), data[1].cuda()
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
# just bookkeping and progress report
steps += len(labels)
accuracy = (
(outputs.argmax(1) == labels).float().mean()
) # calculate accuracy
losses.append(loss.item())
accuracies.append(accuracy.item())
if verbose and len(losses) > 0:
avgloss = sum(losses) / len(losses)
avgaccuracy = sum(accuracies) / len(accuracies)
print(
f"rank {config.rank} epoch {epoch:5d}/{i:9d} loss {avgloss:8.3f} acc {avgaccuracy:8.3f} {steps:9d}",
file=sys.stderr,
)
if steps > config.max_steps:
print(
"finished training (max_steps)",
steps,
config.max_steps,
file=sys.stderr,
)
return
print("finished Training", steps)
# A quick smoke test.
os.environ["GOPEN_VERBOSE"] = "1"
config = Config()
config.epochs = 1
config.max_steps = 1000
train(config)
os.environ["GOPEN_VERBOSE"] = "0"
Distributed Training in Ray
The code above can be used with any distributed computing framwork, including torch.distributed.launch
.
Below is simply an example of how to launch the training jobs with the Ray framework. Ray is nice for distributed training because it makes the Python code independent of the runtime environment (Kubernetes, Slurm, ad-hoc networking, etc.). Meaning, the code below will work regardless of how you start up your Ray cluster.
# The distributed training function to be used with Ray.
# Since this is started via Ray remote, we set up the distributed
# training environment here.
@ray.remote(num_gpus=1)
def train_in_ray(rank, config):
if rank is not None:
# Set up distributed PyTorch.
config.rank = rank
os.environ["MASTER_ADDR"] = config.master_addr
os.environ["MASTER_PORT"] = config.master_port
dist.init_process_group(
backend=config.backend, rank=rank, world_size=config.world_size
)
# Ray will automatically set CUDA_VISIBLE_DEVICES for each task.
train(config)
if not ray.is_initialized():
ray.init()
print("#gpus available in the cluster", ray.available_resources()["GPU"])
def distributed_training(config):
num_gpus = ray.available_resources()["GPU"]
config.world_size = int(min(config.world_size, num_gpus))
pprint(config)
results = ray.get(
[train_in_ray.remote(i, config) for i in range(config.world_size)]
)
print(results)
config = Config()
config.epochs = epochs
config.max_steps = max_steps
config.batch_size = batch_size
print(config)
distributed_training(config)