%matplotlib inline
from functools import partial
from pprint import pprint
import random
from collections import deque
import numpy as np
from matplotlib import pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from torch import nn, optim
import wids
# parameters
epochs = 3
max_steps = 100000
batch_size = 32
bucket = "https://storage.googleapis.com/webdataset/fake-imagenet/"
num_workers = 4
cache_dir = "./_cache"
# helpers
import time
def enumerate_report(seq, delta, growth=1.0):
last = 0
count = 0
for count, item in enumerate(seq):
now = time.time()
if now - last > delta:
last = now
yield count, item, True
else:
yield count, item, False
delta *= growth
# The standard TorchVision transformations.
transform_train = transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
transform_val = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# The dataset returns dictionaries. This is a small function we transform it
# with to get the augmented image and the label.
def make_sample(sample, val=False):
image = sample[".jpg"]
label = sample[".cls"]
if val:
return transform_val(image), label
else:
return transform_train(image), label
# These are standard PyTorch datasets. Download is incremental into the cache.
trainset = wids.ShardListDataset(
bucket+"imagenet-train.json", cache_dir=cache_dir, keep=True
)
valset = wids.ShardListDataset(
bucket+"imagenet-val.json", cache_dir=cache_dir, keep=True
)
trainset[0]
# Next, we add the transformation to the dataset. Transformations
# are executed in sequence. In fact, by default, there is a transformation
# that reads and decodes images.
trainset.add_transform(make_sample)
valset.add_transform(partial(make_sample, val=True))
print(trainset[0][0].shape, trainset[0][1])
# We also need a sampler for the training set. There are three
# special samplers in the `wids` package that work particularly
# well with sharded datasets:
# - `wids.ShardedSampler` shuffles shards and then samples in shards;
# it guarantees that only one shard is used at a time
# - `wids.ChunkedSampler` samples by fixed sized chunks, shuffles
# the chunks, and the the samples within each chunk
# - `wids.DistributedChunkedSampler` is like `ChunkedSampler` but
# works with distributed training (it first divides the entire
# dataset into per-node chunks, then the per-node chunks into
# smaller chunks, then shuffles the smaller chunks)
# trainsampler = wids.ShardedSampler(trainset)
# trainsampler = wids.ChunkedSampler(trainset, chunksize=1000, shuffle=True)
trainsampler = wids.DistributedChunkedSampler(trainset, chunksize=1000, shuffle=True)
plt.plot(list(trainsampler)[:2500])
# Note that the sampler shuffles within each shard before moving on to
# the next shard. Furthermore, on the first epoch, the sampler
# uses the shards in order, but on subsequent epochs, it shuffles
# them. This makes testing and debugging easier. If you don't like
# this behavior, you can use shufflefirst=True
trainsampler.set_epoch(0)
# Create data loaders for the training and validation datasets
trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=4, sampler=trainsampler)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=4)
images, classes = next(iter(trainloader))
print(images.shape, classes.shape)
# The usual PyTorch model definition. We use an uninitialized ResNet50 model.
model = resnet50(pretrained=False)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
# Move the model to the GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
losses, accuracies = deque(maxlen=100), deque(maxlen=100)
steps = 0
# Train the model
for epoch in range(epochs):
for i, data, verbose in enumerate_report(trainloader, 5):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
pred = outputs.cpu().detach().argmax(dim=1, keepdim=True)
correct = pred.eq(labels.cpu().view_as(pred)).sum().item()
accuracy = correct / float(len(labels))
losses.append(loss.item())
accuracies.append(accuracy)
steps += len(labels)
if verbose and len(losses) > 5:
print(
"[%d, %5d] loss: %.5f correct: %.5f"
% (epoch + 1, i + 1, np.mean(losses), np.mean(accuracies))
)
running_loss = 0.0
if steps > max_steps:
break
if steps > max_steps:
break
print("Finished Training")