WebDataset FAQ

This is a Frequently Asked Questions file for WebDataset. It is automatically generated from selected WebDataset issues using AI.

Since the entries are generated automatically, not all of them may be correct. When in doubt, check the original issue.


Issue #367

Q: How can I sample sequences of frames from large video datasets using WebDataset?

A: To sample sequences of frames from large video datasets with WebDataset, you can precompute sampled sequences of frames and treat each collection as a batch. Alternatively, you can split your videos into shorter clips with overlapping frames, generate multiple samples from each clip, and shuffle the resulting sequences. Here's a code snippet demonstrating how to generate and shuffle five- frame sequences from 50-frame clips:

from webdataset import WebDataset
import random

ds = WebDataset("video-clips-{000000..000999}.tar").decode()

def generate_clips(src):
    for sample in src:
        # assume that each video clip sample contains sample.000.jpg to sample.049.jpg images
        clip = [sample["%03d.jpg" % i] for i in range(50)]
        starts = random.sample(range(46), 10)  # Choose 10 starting points
        key = sample["__key__"]
        for i in starts:
            yield {
               "__key__": f"{key}-{i}",
               "sequence": clip[i:i+5],
            }

ds = ds.compose(generate_clips).shuffle(1000)

This approach allows you to work with large datasets by handling smaller, manageable sequences, which can be efficiently preprocessed and shuffled to create a diverse training set.


Issue #364

Q: How can I ensure that each validation sample is seen exactly once per epoch in a multi-node setup using WebDataset with FSDP?

A: When using WebDataset in a multi-node setup with Fully Sharded Data Parallel (FSDP), you can ensure that each validation sample is seen exactly once per epoch by assigning each shard to a specific GPU. Since you have an equal number of shards and GPUs, you can map each shard to a GPU. For the shard that is about half the size, you can either accept that the corresponding GPU will do less work, or you can split another shard to balance the load. To ensure that each sample is loaded exactly once, you can use the wds.ResampledShards function without resampling, and avoid using ddp_equalize since it is designed for training rather than validation. Here's an example of how you might set up your validation dataset:

val_dataset = wds.DataPipeline(
    wds.ResampledShards(
        os.path.join('path', 'to',  'val_samples_{0000...xxxx}.tar')
    ),
    wds.tarfile_to_samples(),
    wds.decode(),
    wds.to_tuple("input.npy", "target.npy"),
    wds.batched(1)
).with_length(num_val_samples)

To ensure that the validation loop stops after all samples have been loaded, you can use the length of the dataset to control the number of iterations in your validation loop. This way, you can manually iterate over the dataset and stop when you've reached the total number of samples.


Issue #331

Q: How can I handle gzipped tar files with WebDataset/WIDS?

A: When working with gzipped tar files in WebDataset or WIDS, it's important to understand that random access to compressed files is not straightforward due to the nature of compression. However, Python's tarfile library can handle gzip- compressed streams using tarfile.open("filename.tar.gz", "r:gz"). For WIDS, the best practice is to use uncompressed .tar files for the dataset, which allows for efficient random access. If storage is a concern, you can compress individual files within the tar archive (e.g., .json.gz instead of .json). This approach provides a balance between storage efficiency and compatibility with WIDS. Here's an example of how to compress individual files:

import tarfile
import gzip

# Compress individual files and add them to a tar archive
with tarfile.open('archive.tar', 'w') as tar:
    with open('file.json', 'rb') as f_in:
        with gzip.open('file.json.gz', 'wb') as f_out:
            f_out.writelines(f_in)
    tar.add('file.json.gz', arcname='file.json.gz')

Remember that for WebDataset, you can use .tar.gz files directly, as it supports on-the-fly decompression. If you encounter datasets that are not in order, you can repack them using GNU tar with sorting to ensure that corresponding files are adjacent, which is a requirement for WebDataset.


Issue #329

Q: How can I create a JSON metafile for random access in a WebDataset?

A: To create a JSON metafile for a WebDataset, you can use the widsindex command that comes with the webdataset package. This command generates an index file for a given list of WebDataset shards. The index file is in JSON format and allows for efficient random access to the dataset. Here's a simple example of how to use widsindex:

widsindex mydataset-0000.tar mydataset-0001.tar > mydataset-index.json

This command will create a JSON file named mydataset-index.json that contains the index for the shards mydataset-0000.tar and mydataset-0001.tar.


Issue #319

Q: How can I handle complex hierarchical data structures in WebDataset?

A: When dealing with complex hierarchical data structures in WebDataset, it's often more practical to use a flat file naming scheme and express the hierarchy within a JSON metadata file. This approach simplifies the file naming while allowing for detailed structuring of the data. You can sequentially number the files and reference them in the JSON, which contains the structure of your dataset, including frame order, timestamps, and other relevant information.

For example, instead of trying to express the hierarchy in the file names, you can name your files like this:

sample_0.000.jpg
sample_0.001.jpg
sample_0.002.jpg
sample_0.json

And then use a JSON file to define the structure:

{
    "frames": ["000.jpg", "001.jpg", "002.jpg"],
    "timestamps": [10001, 10002, 10003],
    "duration": 3
}

This method keeps the file naming simple and leverages the JSON file to maintain the hierarchical relationships within the dataset.


Issue #316

Q: Why am I getting a ValueError when trying to batch variable-length numpy arrays using webdataset?

A: The error you're encountering is due to the attempt to collate numpy arrays with different shapes into a single batch. Since the num_frames dimension varies, you cannot directly convert a list of such arrays into a single numpy array without padding or truncating them to a uniform size. To resolve this, you can specify a custom collation function that handles variable-length sequences appropriately. This function can either pad the sequences to the same length or store them in a data structure that accommodates variable lengths, such as a list or a padded tensor. Here's an example of how to specify a custom collation function:

def custom_collate_fn(batch):
    # Handle variable-length sequences here, e.g., by padding
    # Return the batch in the desired format
    return batch

pipeline.extend([
    # ... other pipeline steps ...
    wds.batched(args.batch_size, collation_fn=custom_collate_fn, partial=not is_train)
])

By providing a custom collation function, you can ensure that the data is prepared in a way that is compatible with your model's input requirements.


Issue #307

Q: Can I skip loading large files in a tar file when using WebDataset?

A: When working with WebDataset, it is not possible to skip the reading of files within a tar archive that you do not need. The library operates on a streaming basis, which means that all bytes are read sequentially. However, you can filter out unwanted data after it has been read into memory. If performance is a concern, consider creating a new dataset containing only the necessary files. For indexed access to WebDataset files, you can use the "wids" interface, which reads only the data you use from disk when working with local files.

Here's a short example of filtering out unwanted data after reading:

dataset = wds.WebDataset(["path/to/dataset.tar"])
keys_to_keep = ["__key__", "__url__", "txt"]

def filter_keys(sample):
    return {k: sample[k] for k in keys_to_keep if k in sample}

filtered_dataset = dataset.map(filter_keys)

Issue #303

Q: Why does the number of steps per epoch change when increasing num_workers in DDP training with Webdataset?

A: When using multiple workers in a distributed data parallel (DDP) training setup with Webdataset, the number of steps per epoch may change if the epoch size is not properly configured to account for the parallelism introduced by the workers. The with_epoch method should be applied to the WebLoader instead of the WebDataset to ensure that the dataset is correctly divided among the workers. Additionally, to maintain proper shuffling across workers, you may need to add cross-worker shuffling. Here's an example of how to configure the loader:

data = wds.WebDataset(self.url, resampled=True).shuffle(1000).map(preprocess_train)
loader = wds.WebLoader(data, pin_memory=True, shuffle=False, batch_size=20, num_workers=2).with_epoch(...)

For cross-worker shuffling, you can modify the loader like this:

loader = loader.unbatched().shuffle(2000).batched(20).with_epoch(200)

Issue #291

Q: How can I skip a corrupt image sample when using NVIDIA DALI for data loading?

A: When working with NVIDIA DALI for data loading, you can handle corrupt or missing data by using the handler parameter. This parameter allows you to specify a behavior when a decoding error occurs. For example, you can use warn_and_continue to issue a warning and skip the problematic sample, allowing the data pipeline to continue processing the remaining samples. This is particularly useful when dealing with large datasets where some samples may be corrupt or unreadable.

Here's a short code example demonstrating how to use the handler parameter:

from nvidia.dali.plugin import pytorch
import webdataset as wds

def warn_and_continue(e):
    print("Warning: skipping a corrupt sample.", e)

ds = (
    wds.WebDataset(url, handler=warn_and_continue, shardshuffle=True, verbose=verbose)
    .map(_mapper, handler=warn_and_continue)
    .to_tuple("jpg", "cls")
    .map_tuple(transform, identity, handler=warn_and_continue)
    .batched(batch_size)
)

By passing warn_and_continue to the .map, .map_tuple, or .decode methods, you instruct DALI to handle exceptions gracefully and continue with the next sample.


Issue #289

Q: Can WebDataset support interleaved datasets such as MMC4, where one example may include a list of texts with several images?

A: Yes, WebDataset can support interleaved datasets like MMC4. You can organize your dataset by creating a .json file that contains the hierarchical structure and references to the image files. This .json file acts as a manifest for each sample, detailing the associated text and images. The image files themselves are stored alongside the .json file. Here's a simple example of how you might structure a .json file for an interleaved dataset:

{
  "text": ["This is the first text", "This is the second text"],
  "images": ["image1.jpg", "image2.jpg", "image3.jpg"]
}

And in your dataset, you would have the .json file and the referenced images in the same sample directory or archive.


Issue #283

Q: How can I authenticate to read objects from a private bucket with WebDataset?

A: To authenticate and read objects from a private bucket using WebDataset, you need to provide the necessary credentials to the underlying command line programs that WebDataset uses for data access. If you are using a storage provider like NetApp, which is not directly supported by WebDataset's built-in protocols, you can use the pipe: protocol to specify a custom command that includes the necessary authentication steps. For example, you can create a shell script that uses your storage provider's CLI tools to authenticate with your access key id and secret access key, and then pass this script to WebDataset:

# auth_script.sh
# This script authenticates and downloads a shard from a private bucket
# Replace <ACCESS_KEY>, <SECRET_KEY>, <BUCKET_NAME>, and <SHARD_NAME> with your actual values
netappcli --access-key <ACCESS_KEY> --secret-key <SECRET_KEY> download <BUCKET_NAME>/<SHARD_NAME>

Then, use this script with WebDataset:

import webdataset as wds

# Use the 'pipe:' protocol with your authentication script
dataset = wds.WebDataset("pipe:./auth_script.sh")

Ensure that your script has the necessary permissions to be executed and that it correctly handles the authentication and data retrieval process.


Issue #278

Q: Why is with_epoch(N) needed for multinode training with WebDataset?

A: When using WebDataset for training models in PyTorch, the with_epoch(N) function is used to define the end of an epoch when working with an infinite stream of samples. This is particularly important in distributed training scenarios to ensure that all nodes process the same number of batches per epoch, which helps in synchronizing the training process across nodes. Without with_epoch(N), the training loop would not have a clear indication of when an epoch ends, potentially leading to inconsistent training states among different nodes. WebDataset operates with the IterableDataset interface, which does not support the set_epoch method used by DistributedSampler in PyTorch's DataLoader. Therefore, with_epoch(N) serves as a mechanism to delineate epochs in the absence of set_epoch.

# Example of using with_epoch in a training loop
for epoch in range(num_epochs):
    for sample in webdataset_reader.with_epoch(epoch_length):
        train(sample)

Issue #264

Q: How can I include the file name (only the stem, not the extension) in the metadata dictionary when using WebDataset?

A: When working with WebDataset, each sample in the dataset contains a special key __key__ that holds the file name without the extension. To include the file name in the metadata dictionary, you can create a custom mapping function that extracts the __key__ and adds it to the metadata. Here's a short code example on how to modify the pipeline to include the file name in the metadata:

def add_filename_to_metadata(sample):
    sample["metadata"]["filename"] = sample["__key__"]
    return sample

pipeline = [
    # ... (other pipeline steps)
    wds.map(add_filename_to_metadata),
    # ... (remaining pipeline steps)
]

This function should be added to the pipeline after the wds.decode step and before the wds.to_tuple step. This way, the metadata dictionary will contain the file name for each sample processed by the pipeline.


Issue #261

Q: Why is my WebDataset tar file unexpectedly large when saving individual tensors?

A: The large file size is due to the fact that each tensor is pointing to a large underlying byte array buffer, which is being saved in its entirety. This results in saving much more data than just the tensor's contents. To fix this, you should clone the tensor before saving it to ensure that only the relevant data is written to the file. Additionally, each file in a tar archive has a 512-byte header, which can add significant overhead when saving many small files. To reduce file size, consider compressing the tar file or batching tensors before saving.

Here's a code snippet showing how to clone the tensor before saving:

with wds.TarWriter(f"/tmp/dest.tar") as sink:
    for i, d in tqdm(enumerate(tensordict), total=N):
        obj = {"__key__": f"{i}"}
        for k, v in d.items():
            buffer = io.BytesIO()
            torch.save(v.clone(), buffer)  # Clone the tensor here
            obj[f"{k}.pth"] = buffer.getvalue()
        sink.write(obj)

To compress the tar file, simply save it with a .tar.gz extension and use a compression library:

with wds.TarWriter(f"/tmp/dest.tar.gz", compressor="gz") as sink:
    # ... rest of the code ...

Issue #260

Q: What is the purpose of the .with_epoch() method in WebDataset and could it be named more descriptively?

A: The .with_epoch() method in WebDataset is used to explicitly set the number of samples that constitute an epoch during distributed training. This is important for ensuring that each worker in a distributed system processes a full epoch's worth of data. The name .with_epoch() might not be immediately clear, but it is intended to indicate that the dataset is being configured with a specific epoch length. A more descriptive name like .set_epoch_size() could potentially convey the purpose more clearly. However, changing the method name would be a breaking change for existing codebases. Improving the documentation with examples can help clarify the usage:

# Original method name
dataset = dataset.with_epoch(10000)

# Hypothetical more descriptive method name
dataset = dataset.set_epoch_size(10000)

In the meantime, users should refer to the improved documentation for guidance on how to use the .with_epoch() method effectively.


Issue #257

Q: How can I efficiently load only the necessary auxiliary images for a sample in my training configuration to save on I/O and decoding time?

A: When working with datasets that include a main image and multiple auxiliary images, you can optimize the data loading process by selectively reading only the required files. This can be achieved by using the select_files option in WebDataset or similar tools, which allows you to specify which files to extract from the dataset. By pre-selecting the files during the dataset preparation phase, you ensure that your tar files contain exactly the files needed for training, minimizing unnecessary I/O operations and decoding time for unused images. Here's a short example of how you might use select_files:

import webdataset as wds

# Define your selection criteria based on the training configuration
def select_files(sample):
    return [sample['main.jpg']] + [sample[f'aux{i}.jpg'] for i in range(number_of_aux_images)]

# Create a dataset and apply the selection
dataset = wds.WebDataset("dataset.tar").select(select_files)

This approach is more efficient than reading all files and discarding the unneeded ones, as it avoids the overhead of reading and decoding data that will not be used in the training process.


Issue #256

Q: Why does my training program using WebDataset consume so much memory and crash?

A: The memory consumption issue you're experiencing with WebDataset during training is likely due to the shuffle buffer size. WebDataset uses in-memory buffering to shuffle data, and if the buffer size is too large, it can consume a significant amount of memory, especially when dealing with large datasets or when running on systems with limited memory. The parameters _SHARD_SHUFFLE_SIZE and _SAMPLE_SHUFFLE_SIZE control the number of shards and samples kept in memory for shuffling. Reducing these values can help mitigate memory usage issues. For example, you can try setting:

_SHARD_SHUFFLE_SIZE = 1000  # Reduced from 2000
_SAMPLE_SHUFFLE_SIZE = 2500  # Reduced from 5000

Adjust these values based on your system's memory capacity and the size of your dataset. Keep in mind that reducing the shuffle buffer size may affect the randomness of your data shuffling and potentially the training results. It's a trade-off between memory usage and shuffle effectiveness.


Issue #249

Q: Should I use WebDataset or TorchData for my data loading in PyTorch?

A: The choice between WebDataset and TorchData depends on your specific needs and the context of your project. WebDataset is still a good choice if you require backwards compatibility or if you need to work without PyTorch. It is also being integrated with other frameworks like Ray, which may be beneficial for certain use cases. However, it's important to note that as of July 2023, active development on TorchData has been paused to re-evaluate its technical design. This means that while TorchData is still usable, it may not receive updates or new features in the near future. If you are starting a new project or are able to adapt to changes, you might want to consider this factor. Here's a simple example of how you might use WebDataset:

import webdataset as wds

# Create a dataset
dataset = wds.WebDataset("path/to/data-{000000..000999}.tar")

# Iterate over the dataset
for sample in dataset:
    image, label = sample["image"], sample["label"]
    # process image and label

And here's how you might use TorchData:

from torchdata.datapipes.iter import FileOpener, TarArchiveReader

# Create a data pipeline
datapipes = FileOpener("path/to/data.tar") \
    .parse(TarArchiveReader())

# Iterate over the data pipeline
for file_name, file_stream in datapipes:
    # process file_stream

Given the pause in TorchData development, you should consider the stability and future support of the library when making your decision.


Issue #247

Q: How can I load images from nested tar files using webdataset?

A: To load images from nested tar files with webdataset, you can create a custom decoder that handles .tar files using Python's tarfile module. This decoder can be applied to your dataset with the .map() method, which allows you to modify each sample in the dataset. The custom decoder will read the nested tar file from the sample, extract its contents, and add them to the sample dictionary. Here's a short example of how you can implement this:

import io
import tarfile
from webdataset import WebDataset

def expand_tar_files(sample):
    stream = tarfile.open(fileobj=io.BytesIO(sample["tar"]))
    for tarinfo in stream:
        if tarinfo.isfile():
            name = tarinfo.name
            data = stream.extractfile(tarinfo).read()
            sample[name] = data
    return sample

ds = WebDataset("dataset.tar").map(expand_tar_files).decode("...")

In this example, expand_tar_files is a function that takes a sample from the dataset, opens the nested tar file contained within it, and adds each file from the nested tar to the sample. The WebDataset object is then created with the path to the dataset tar file, and the expand_tar_files function is applied to each sample in the dataset.


Issue #246

Q: What is the purpose of .to_tuple() in WebDataset and how does it handle missing files?

A: The .to_tuple() method in WebDataset is used to extract specific fields from a dataset where each sample is a dictionary with keys corresponding to file extensions. This method simplifies the process of preparing data for training by converting dictionaries into tuples, which are more convenient to work with in many machine learning frameworks. When you specify multiple file extensions separated by semicolons, .to_tuple() will return the first file that matches any of the given extensions. If a file with a specified extension is not present in a sample, .to_tuple() will raise an error. To handle optional files, you can use a custom function with .map() that uses the get method to return None if a key is missing, thus avoiding errors and allowing for flexible data structures.

Here's an example of using .to_tuple() with mandatory and optional files:

# Mandatory jpg and txt, optional npy
def make_tuple(sample):
    return sample["jpg"], sample.get("npy"), sample["txt"]

ds = WebDataset(...) ... .map(make_tuple)

And here's how you might use .to_tuple() directly for mandatory files:

ds = WebDataset(...) ... .to_tuple("jpg", "txt")

Issue #244

Q: How can I combine multiple data sources with a specified frequency for sampling from each?

A: To combine multiple data sources with non-integer sampling frequencies, you can use the RandomMix function from the WebDataset library. This function allows you to specify the relative sampling weights as floating-point numbers, which can represent the desired sampling frequency from each dataset. Here's an example of how to use RandomMix to combine two datasets with a specified sampling frequency:

from webdataset import WebDataset, RandomMix

ds1 = WebDataset('path_to_shards_A/{00..99}.tar')
ds2 = WebDataset('path_to_shards_B/{00..99}.tar')
mix = RandomMix([ds1, ds2], [1.45, 1.0])  # Sampling from ds1 1.45 times more frequently than ds2

This will create a mixed dataset where samples from ds1 are drawn approximately 1.45 times more often than samples from ds2.


Issue #239

Q: Can I filter a WebDataset to select only a subset of categories?

A: Yes, you can filter a WebDataset to select only a subset of categories by using a map function. This is efficient as long as the subset is not too small; otherwise, it can lead to inefficient I/O due to random disk accesses. For very small subsets, it's recommended to create a new WebDataset. Here's a simple example of how to filter categories:

def select(sample):
    if sample["cls"] in [0, 3, 9]:  # Replace with desired categories
        return sample
    else:
        return None

dataset = wds.WebDataset(...).decode().map(select)

This approach works well when the number of classes is much larger than the number of shards, and you're not discarding a significant portion of the data. If you find yourself discarding a large percentage of the data, consider creating a new WebDataset for efficiency.


Issue #237

Q: How does WebDataset handle filenames with multiple periods when extracting keys?

A: WebDataset uses periods to separate the base filename from the extension, which can lead to unexpected keys when multiple periods are present in the base filename. This is by design to support filenames with multiple extensions, such as .seg.jpg. It's important to follow this convention when creating datasets to avoid issues in downstream processing. If you have filenames with multiple periods, consider renaming them before creating the dataset. For matching files, you can use glob patterns like *.mp3 to ensure you're working with the correct file type.

# Example of using a glob pattern to match files with the .mp3 extension
dataset = wds.Dataset("dataset.tar").select(lambda x: fnmatch.fnmatch(x, "*.mp3"))

Issue #236

Q: How does webdataset handle the conversion of tensors to different file formats like .jpg and .npy?

A: In webdataset, the conversion of tensors to specific file formats is determined by the file extension you specify in the key when writing the data using ShardWriter. There is no automatic conversion; the tensor is simply saved in the format corresponding to the extension you provide. When reading the data, you can decode the files into tensors using the appropriate arguments. Here's a short example of how to write a tensor as different file formats:

from webdataset import ShardWriter

writer = ShardWriter(...)

sample = {}
sample["__key__"] = "dataset/sample00003"
sample["image.jpg"] = some_tensor  # Will be saved as a JPEG file
sample["image.npy"] = some_tensor  # Will be saved as a NPY file

writer.write(sample)

When you write a sample with {"__key__": "xyz", "image.jpg": some_tensor}, a JPEG file named xyz.image.jpg is created. Conversely, if you write {"__key__": "xyz", "image.npy": some_tensor}, an NPY file named xyz.image.npy is created.


Issue #233

Q: How do I ensure that WebDataset correctly splits shards across multiple nodes and workers?

A: When using WebDataset for distributed training across multiple nodes and workers, it's important to use the split_by_node and split_by_worker functions to ensure that each node and worker processes a unique subset of the data. The detshuffle function can be used for deterministic shuffling of shards before splitting. Here's a minimal example of how to set up the dataset pipeline for multi-node training:

import webdataset as wds

dataset = wds.DataPipeline(
    wds.SimpleShardList("source-{000000..000999}.tar"),
    wds.detshuffle(),
    wds.split_by_node,
    wds.split_by_worker,
)

for idx, item in enumerate(iter(dataset)):
    if idx < 2:  # Just for demonstration
        print(f"item: {item}")

Make sure you are using a recent version of WebDataset that supports these features. If you encounter any issues, check the version and consider updating to the latest release.


Issue #227

Q: How can I use Apache Beam to write data to a WebDataset tar file for large-scale machine learning datasets?

A: Apache Beam is a powerful tool for parallel data processing, which can be used to build large datasets for machine learning. When dealing with datasets larger than 10TB and requiring complex preprocessing, you can use Apache Beam to process and write the data into a WebDataset tar file format. Below is a simplified example of how you might set up your Beam pipeline to write to a WebDataset. This example assumes you have a function preprocess_sample that takes a sample and performs the necessary preprocessing:

import apache_beam as beam
from webdataset import ShardWriter

def write_to_webdataset(sample):
    # Assuming 'preprocess_sample' is a function that preprocesses your data
    processed_sample = preprocess_sample(sample)
    # Write the processed sample to a shard using ShardWriter
    # This is a simplified example; you'll need to manage shards and temp files
    with ShardWriter("output_shard.tar", maxcount=1000) as sink:
        sink.write(processed_sample)

# Set up your Apache Beam pipeline
with beam.Pipeline() as pipeline:
    (
        pipeline
        | 'Read Data' >> beam.io.ReadFromSomething(...)  # Replace with your data source
        | 'Process and Write' >> beam.Map(write_to_webdataset)
    )

Remember to manage the sharding and temporary files appropriately, as the ShardWriter will need to write to different shards based on your dataset's partitioning. The maxcount parameter controls how many items are in each shard. You will also need to handle the copying of the temporary shard files to your destination bucket as needed.


Issue #225

Q: How can I ensure that Distributed Data Parallel (DDP) training with WebDataset doesn't hang due to uneven data distribution across nodes?

A: When using WebDataset for DDP training, it's important to ensure that all nodes receive the same number of samples to prevent hanging during synchronization. One effective method is to create a number of shards that is divisible by the total number of workers and ensure each shard contains the same number of samples. Assign each worker the same number of shards to achieve exact epochs with no resampling, duplication, or missing samples. If the dataset cannot be evenly divided, you can use resampled=True to generate an infinite stream of samples, and set an epoch length using with_epoch. This approach allows for synchronization across workers even if the dataset size is not divisible by the number of workers. Here's an example of setting an epoch length:

from webdataset import WebDataset

dataset = WebDataset(urls, resampled=True).with_epoch(epoch_length)

For validation, where you want to avoid arbitrary epoch lengths, you can drop samples from the end of the validation set to make its size divisible by the world size. This can be done using TorchData as follows:

from torch.utils.data import DataLoader
import torch.distributed

dataset = dataset.batch(torch.distributed.get_world_size(), drop_last=True)
dataset = dataset.unbatch()
dataset = dataset.sharding_filter()

Remember to use the sharding_filter to ensure that each process only sees its own subset of the data.


Issue #219

Q: What should I use instead of ShardList in webdataset v2, and how do I specify a splitter?

A: In webdataset v2, the ShardList class has been renamed to SimpleShardList. If you encounter an AttributeError stating that the module webdataset has no attribute ShardList, you should replace it with SimpleShardList. Additionally, the splitter argument has been changed to nodesplitter. Here's how you can update your code to reflect these changes:

urls = list(braceexpand.braceexpand("dataset-{000000..000999}.tar"))
dataset = wds.SimpleShardList(urls, splitter=wds.split_by_worker, nodesplitter=wds.split_by_node, shuffle=False)
dataset = wds.Processor(dataset, wds.url_opener)
dataset = wds.Processor(dataset, wds.tar_file_expander)
dataset = wds.Processor(dataset, wds.group_by_keys)

If you are using WebDataset and encounter a TypeError regarding an unexpected keyword argument splitter, ensure that you are using the correct argument name nodesplitter instead.


Issue #216

Q: Can I use ShardWriter to write directly to a cloud storage URL like Google Cloud Storage?

A: The ShardWriter from the webdataset library is primarily designed to write shards to a local disk, and then these shards can be copied to cloud storage. Writing directly to cloud storage is not the default behavior because it can be less efficient and more error-prone due to network issues. However, if you have a large dataset that cannot be stored locally, you can modify the ShardWriter code to write directly to a cloud URL by changing the line where the TarWriter is instantiated. Here's a short example of the modification:

# Original line in ShardWriter
self.tarstream = TarWriter(open(self.fname, "wb"), **self.kw)

# Modified line to write directly to a cloud URL
self.tarstream = TarWriter(self.fname, **self.kw)

Please note that this is a workaround and may not be officially supported. It's recommended to test thoroughly to ensure data integrity and handle any potential exceptions related to network issues.


Issue #212

Q: Does WebDataset download all shards at once, and how does caching affect the download behavior?

A: WebDataset accesses shards individually and handles data in a streaming fashion by default, meaning that shards are not cached locally unless caching is explicitly enabled. When caching is enabled, each shard is downloaded completely before being used, which can block training until the download is finished. This behavior contrasts with the streaming mode, where training can start as soon as the first batch is ready. The caching mechanism does not currently download shards in parallel with training, which can lead to delays when starting the training process. To change the local cache name when using pipe:s3, you can override the url_to_name argument to map shard names to cache file names as desired.

Here's an example of how to override the url_to_name function:

import webdataset as wds

def custom_url_to_name(url):
    # Custom logic to convert URL to a cache filename
    return url.replace("http://url/dataset-", "").replace(".tar", ".cache")

dataset = wds.WebDataset("pipe:s3 http://url/dataset-{001..099}.tar", url_to_name=custom_url_to_name)

Issue #211

Q: How can I write to a remote location using ShardWriter?

A: ShardWriter is designed to write to local disk for simplicity and reliability, but it provides a hook for uploading data to a remote location. You can define a function that handles the upload process and then pass this function to the post parameter of ShardWriter. Here's a short example of how to use this feature:

def upload_shard(fname):
    os.system(f"gsutil cp {fname} gs://mybucket")
    os.unlink(fname)

with ShardWriter(..., post=upload_shard) as writer:
    # Your code to add data to the writer
    ...

This approach allows you to have control over the upload process and handle any errors that may occur during the transfer to the remote storage.


Issue #210

Q: How does the default_collation_fn work in WebDataset when it seems to expect a list or tuple, but the documentation suggests it should handle a collection of samples as dictionaries?

A: The confusion arises from the mismatch between the documentation and the actual implementation of default_collation_fn. The function is designed to take a batch of samples and collate them into a single batch for processing. However, the current implementation of default_collation_fn in WebDataset does not handle dictionaries directly. Instead, it expects each sample in the batch to be a list or tuple. If you have a batch of dictionaries, you would need to convert them into a list or tuple format before using default_collation_fn. Alternatively, you can use torch.utils.data.default_collate from PyTorch 1.11 or later, which can handle dictionaries, or you can provide a custom collate function that handles dictionaries. Here's an example of a custom collate function that could handle a list of dictionaries:

def custom_collate_fn(batch):
    # Assuming each element in batch is a dictionary
    collated_batch = {}
    for key in batch[0].keys():
        collated_batch[key] = [d[key] for d in batch]
    return collated_batch

You can then pass this custom_collate_fn to your data loader.


Issue #209

Q: How can I ensure each batch contains only one description per image when using webdatasets?

A: To ensure that each batch contains only one description per image in webdatasets, you can create a custom transformation function that acts as a filter or collate function. This function can be composed with your dataset to enforce the batching rule. You can use buffers or other conditional logic within your transformation to manage the batching process. Here's a simple example of how you might start implementing such a transformation:

def unique_image_collate(src):
    buffer = {}
    for sample in src:
        image_id = sample['image_id']
        if image_id not in buffer:
            buffer[image_id] = sample
            if len(buffer) == batch_size:
                yield list(buffer.values())
                buffer.clear()
        # Additional logic to handle leftovers, etc.
    if buffer:
        yield list(buffer.values())

dataset = dataset.compose(unique_image_collate)

This function collects samples in a buffer until it has a batch's worth of unique images, then yields that batch and clears the buffer for the next batch. You'll need to add additional logic to handle cases such as the end of an epoch where the buffer may not be full.


Issue #201

Q: How can I efficiently subsample a large dataset without slowing down iteration speed?

A: When dealing with large datasets, such as LAION 400M, and needing to subsample based on metadata, there are several strategies to maintain high I/O performance. If the subset is small and static, it's best to create a new dataset ahead of time. This can be done using a WebDataset/TarWriter pipeline or with tarp proc ... | tarp split ... commands, potentially parallelizing the process with tools like ray. If dynamic selection is necessary, consider splitting the dataset into shards by the categories of interest. This approach avoids random file accesses, which can significantly slow down data pipelines. Here's a simple example of creating a subset using tarp:

tarp proc mydataset.tar -c 'if sample["metadata"] in metadata_list: yield sample'
tarp split -o subset-%06d.tar --size=1e9

Remember to perform filtering before any heavy operations like decoding or augmentation to avoid unnecessary processing.


Issue #196

Q: How can I speed up subsampling from a tar file when using WebDataset?

A: When working with WebDataset, it's important to remember that it is optimized for streaming data and does not support efficient random access within tar files. To speed up subsampling, you should avoid using very small probabilities with rsample as it requires reading the entire stream. Instead, consider using more shards and applying rsample to the shards rather than individual samples. This approach avoids the overhead of sequential reading. Additionally, some storage servers like AIStore can perform server-side sampling, which can be more efficient as they can use random access.

# Example of using rsample with shards
dataset = WebDataset("dataset-{0000..9999}.tar").rsample(0.1)

Issue #194

Q: How should I balance dataset elements across DDP nodes when using WebDataset?

A: When using WebDataset with Distributed Data Parallel (DDP) in PyTorch, you may encounter situations where the dataset is not evenly distributed across the workers. To address this, you can use the .repeat() method in combination with .with_epoch() to ensure that each worker processes the same number of batches. The .repeat(2) method is used to repeat the dataset twice, which should be sufficient for most cases. If the dataset is highly unbalanced, you may need to adjust this number. The .with_epoch(n) method is used to limit the number of samples processed in an epoch to n, where n is typically set to the total number of samples divided by the batch size. This combination ensures that each epoch has a consistent size across workers, while also handling any imbalance in the number of shards or samples per worker.

Here's an example of how to use these methods:

batch_size = 64
epoch_size = 1281237  # Total number of samples in the dataset
loader = wds.WebLoader(dataset, num_workers=4)
loader = loader.repeat(2).with_epoch(epoch_size // batch_size)

This approach allows for a balanced distribution of data across DDP nodes, with the caveat that some batches may be missing or repeated. It's a trade-off between perfect balance and resource usage.


Issue #185

Q: How can I include the original file name in the metadata dictionary when iterating through a WebDataset?

A: When working with WebDataset, you can include the original file name in the metadata dictionary by defining a function that extracts the __key__ from the sample and adds it to the metadata. You then apply this function using the .map() method in your pipeline. Here's a short example of how to define and use such a function:

def add_filename_to_metadata(sample):
    sample["metadata"]["filename"] = sample["__key__"]
    return sample

# Add this to your pipeline after renaming the keys
pipeline.append(wds.map(add_filename_to_metadata))

This function should be added to the pipeline after the renaming step to ensure that the metadata key is already present in the sample dictionary.


Issue #177

Q: How can I resume training from a specific step without iterating over unused data when using WebDataset?

A: When using WebDataset for training with large datasets, it's common to want to resume training from a specific step without loading all the previous data into memory. WebDataset provides a feature for this scenario through shard resampling. By setting resampled=True or using the wds.resampled pipeline stage, you can ensure that you get the same training statistics when restarting your job without the need to skip samples manually. This approach is recommended over trying to implement "each sample exactly once per epoch," which can be complex and environment-dependent.

Here's a short example of how you might use the resampled option:

from webdataset import WebDataset

dataset = WebDataset(urls).resampled(rng=my_random_state)

And here's how you might use the wds.resampled pipeline stage:

import webdataset as wds

dataset = wds.WebDataset(urls).pipe(wds.resampled)

Issue #172

Q: Why does the detshuffle epoch count not increment across epochs when using WebDataset?

A: The issue with detshuffle not incrementing the epoch count across epochs is likely due to the interaction between the DataLoader's worker process management and the internal state of the detshuffle. When persistent_workers=False, the DataLoader creates new worker processes each epoch, which do not retain the state of the detshuffle instance. This results in the detshuffle epoch count resetting each time. To maintain the state across epochs, you can set persistent_workers=True in the DataLoader. Alternatively, you can manage the epoch count externally and pass it to detshuffle if needed. Here's a short example of how to set persistent_workers:

from torch.utils.data import DataLoader

# Assuming 'dataset' is your WebDataset instance
loader = DataLoader(dataset, persistent_workers=True)

If you need to manage the epoch count externally, you could use an environment variable or another mechanism to pass the epoch count to detshuffle. However, this approach is less clean and should be used with caution, as it may introduce complexity and potential bugs into your code.


Issue #171

Q: I'm getting an ImportError when trying to import PytorchShardList from webdataset. What should I do?

A: The PytorchShardList class has been removed in recent versions of the webdataset package. If you are using version 0.1 of webdataset, PytorchShardList was available, but in later versions, it has likely been replaced with SimpleShardList. To resolve the ImportError, you should update your import statement to use the new class name. Here's how you can import SimpleShardList:

from webdataset import SimpleShardList

If SimpleShardList does not meet your requirements, you may need to check the documentation for the version of webdataset you are using to find the appropriate replacement or consider downgrading to the version that contains PytorchShardList.


Issue #170

Q: How do I use glob patterns with WebDataset to read data from Google Cloud Storage (GCS)?

A: WebDataset does not natively support glob patterns due to the lack of a consistent API for globbing across different object stores. To use glob patterns with files stored in GCS, you need to manually resolve the glob pattern using gsutil and then pass the list of shards to WebDataset. Here's an example of how to do this in Python:

import os
import webdataset as wds

# Use gsutil to resolve the glob pattern and get the list of shard URLs
shard_list = [shard.strip() for shard in os.popen("gsutil ls gs://BUCKET/PATH/training_*.tar").readlines()]

# Create the WebDataset with the resolved list of shard URLs
train_data = wds.WebDataset(shard_list, shardshuffle=True, repeat=True)

This approach ensures that you get the expected behavior when reading data from shards that match a glob pattern in GCS. Remember to install gsutil and authenticate with GCS before running the code.