%matplotlib inline

from functools import partial
from pprint import pprint
import random
from collections import deque
import numpy as np
from matplotlib import pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from torch import nn, optim
import wids
# parameters
epochs = 3
max_steps = 100000
batch_size = 32
bucket = "https://storage.googleapis.com/webdataset/fake-imagenet/"
num_workers = 4
cache_dir = "./_cache"
# helpers

import time


def enumerate_report(seq, delta, growth=1.0):
    last = 0
    count = 0
    for count, item in enumerate(seq):
        now = time.time()
        if now - last > delta:
            last = now
            yield count, item, True
        else:
            yield count, item, False
        delta *= growth
# The standard TorchVision transformations.

transform_train = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

transform_val = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# The dataset returns dictionaries. This is a small function we transform it
# with to get the augmented image and the label.


def make_sample(sample, val=False):
    image = sample[".jpg"]
    label = sample[".cls"]
    if val:
        return transform_val(image), label
    else:
        return transform_train(image), label
# These are standard PyTorch datasets. Download is incremental into the cache.

trainset = wids.ShardListDataset(
    bucket+"imagenet-train.json", cache_dir=cache_dir, keep=True
)
valset = wids.ShardListDataset(
    bucket+"imagenet-val.json", cache_dir=cache_dir, keep=True
)

trainset[0]
# Next, we add the transformation to the dataset. Transformations
# are executed in sequence. In fact, by default, there is a transformation
# that reads and decodes images.

trainset.add_transform(make_sample)
valset.add_transform(partial(make_sample, val=True))

print(trainset[0][0].shape, trainset[0][1])
# We also need a sampler for the training set. There are three
# special samplers in the `wids` package that work particularly
# well with sharded datasets:
# - `wids.ShardedSampler` shuffles shards and then samples in shards;
#   it guarantees that only one shard is used at a time
# - `wids.ChunkedSampler` samples by fixed sized chunks, shuffles
#   the chunks, and the the samples within each chunk
# - `wids.DistributedChunkedSampler` is like `ChunkedSampler` but
#   works with distributed training (it first divides the entire
#   dataset into per-node chunks, then the per-node chunks into
#   smaller chunks, then shuffles the smaller chunks)

# trainsampler = wids.ShardedSampler(trainset)
# trainsampler = wids.ChunkedSampler(trainset, chunksize=1000, shuffle=True)
trainsampler = wids.DistributedChunkedSampler(trainset, chunksize=1000, shuffle=True)

plt.plot(list(trainsampler)[:2500])

# Note that the sampler shuffles within each shard before moving on to
# the next shard. Furthermore, on the first epoch, the sampler
# uses the shards in order, but on subsequent epochs, it shuffles
# them. This makes testing and debugging easier. If you don't like
# this behavior, you can use shufflefirst=True

trainsampler.set_epoch(0)
# Create data loaders for the training and validation datasets

trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=4, sampler=trainsampler)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=4)

images, classes = next(iter(trainloader))
print(images.shape, classes.shape)
# The usual PyTorch model definition. We use an uninitialized ResNet50 model.

model = resnet50(pretrained=False)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# Move the model to the GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
losses, accuracies = deque(maxlen=100), deque(maxlen=100)

steps = 0

# Train the model
for epoch in range(epochs):
    for i, data, verbose in enumerate_report(trainloader, 5):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        pred = outputs.cpu().detach().argmax(dim=1, keepdim=True)
        correct = pred.eq(labels.cpu().view_as(pred)).sum().item()
        accuracy = correct / float(len(labels))

        losses.append(loss.item())
        accuracies.append(accuracy)
        steps += len(labels)

        if verbose and len(losses) > 5:
            print(
                "[%d, %5d] loss: %.5f correct: %.5f"
                % (epoch + 1, i + 1, np.mean(losses), np.mean(accuracies))
            )
            running_loss = 0.0

        if steps > max_steps:
            break
    if steps > max_steps:
        break

print("Finished Training")