WebDataset API

Fluid Interfaces

The FluidInterface class provides a way to create fluent interfaces for chaining operations on datasets. Most operations are contained in the FluidInterface mixin class. with_epoch sets the epoch size (number of samples per epoch), effectively an itertools.islice over the dataset.

webdataset.WebDataset

Bases: DataPipeline, FluidInterface

Create a WebDataset pipeline for efficient data loading.

This class sets up a data pipeline for loading and processing WebDataset-format data. It handles URL generation, shard shuffling, caching, and sample grouping.

Parameters:
  • urls

    The source URLs or specifications for the dataset.

  • handler

    Function to handle exceptions. Defaults to reraise_exception.

  • mode

    The mode of operation. Defaults to None.

  • resampled

    Whether to use resampled mode. Defaults to False.

  • repeat

    Whether to repeat the dataset. Defaults to False.

  • shardshuffle

    The number of shards to shuffle, or None. Defaults to None.

  • cache_size

    The size of the cache in bytes. Defaults to -1 (unlimited).

  • cache_dir

    The directory to use for caching. Defaults to None.

  • url_to_name

    Function to convert URLs to cache names. Defaults to pipe_cleaner.

  • detshuffle

    Whether to use deterministic shuffling. Defaults to False.

  • nodesplitter

    Function to split data by node. Defaults to single_node_only.

  • workersplitter

    Function to split data by worker. Defaults to split_by_worker.

  • select_files

    Function to select files from tar archives. Defaults to None.

  • rename_files

    Function to rename files from tar archives. Defaults to None.

  • empty_check

    Whether to check for empty datasets. Defaults to True.

  • verbose

    Whether to print verbose output. Defaults to False.

  • seed

    Random seed for shuffling. Defaults to None.

Raises:
  • ValueError

    If the cache directory does not exist or if the URL type is not supported.

Source code in webdataset/compat.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
class WebDataset(DataPipeline, FluidInterface):
    """Create a WebDataset pipeline for efficient data loading.

    This class sets up a data pipeline for loading and processing WebDataset-format data.
    It handles URL generation, shard shuffling, caching, and sample grouping.

    Args:
        urls: The source URLs or specifications for the dataset.
        handler: Function to handle exceptions. Defaults to reraise_exception.
        mode: The mode of operation. Defaults to None.
        resampled: Whether to use resampled mode. Defaults to False.
        repeat: Whether to repeat the dataset. Defaults to False.
        shardshuffle: The number of shards to shuffle, or None. Defaults to None.
        cache_size: The size of the cache in bytes. Defaults to -1 (unlimited).
        cache_dir: The directory to use for caching. Defaults to None.
        url_to_name: Function to convert URLs to cache names. Defaults to pipe_cleaner.
        detshuffle: Whether to use deterministic shuffling. Defaults to False.
        nodesplitter: Function to split data by node. Defaults to single_node_only.
        workersplitter: Function to split data by worker. Defaults to split_by_worker.
        select_files: Function to select files from tar archives. Defaults to None.
        rename_files: Function to rename files from tar archives. Defaults to None.
        empty_check: Whether to check for empty datasets. Defaults to True.
        verbose: Whether to print verbose output. Defaults to False.
        seed: Random seed for shuffling. Defaults to None.

    Raises:
        ValueError: If the cache directory does not exist or if the URL type is not supported.
    """

    def __init__(
        self,
        urls,
        handler=reraise_exception,
        mode=None,
        resampled=False,
        repeat=False,
        shardshuffle=None,
        cache_size=-1,
        cache_dir=None,
        url_to_name=cache.pipe_cleaner,
        detshuffle=False,
        nodesplitter=shardlists.single_node_only,
        workersplitter=shardlists.split_by_worker,
        select_files=None,
        rename_files=None,
        empty_check=True,
        verbose=False,
        seed=None,
    ):
        super().__init__()
        if resampled:
            mode = "resampled"
        if mode == "resampled" and shardshuffle not in (False, None):
            warnings.warn(
                "WebDataset(shardshuffle=...) is ignored for resampled datasets"
            )
        elif shardshuffle is None:
            warnings.warn(
                "WebDataset(shardshuffle=...) is None; set explicitly to False or a number"
            )
        if shardshuffle is True:
            warnings.warn(
                "set WebDataset(shardshuffle=...) to a positive integer or 0 or False"
            )
            shardshuffle = 100
        args = SimpleNamespace(**locals())
        self.seed = (
            os.environ.get("WDS_SEED", random.randint(0, 1000000))
            if seed is None
            else seed
        )
        self.update_cache_info(args)

        # first, we add a generator for the urls to used
        # this generates a stream of dict(url=...)
        self.create_url_iterator(args)

        # split by node (for distributed processing)
        if nodesplitter is not None:
            self.append(nodesplitter)

        # split by worker (for DataLoader)
        if workersplitter:
            self.append(workersplitter)

        # add a shard shuffler
        if args.shardshuffle is not None:
            if args.detshuffle:
                self.append(filters.detshuffle(args.shardshuffle, seed=self.seed))
            else:
                self.append(filters.shuffle(args.shardshuffle, seed=self.seed))

        # next, we select a URL opener, either with or without caching
        # this generates a stream of dict(url=..., stream=...)
        if cache_dir is None or cache_size == 0:
            opener = cache.StreamingOpen(handler=handler)
        else:
            opener = cache.FileCache(
                cache_dir=cache_dir, cache_size=cache_size, handler=handler
            )
        self.append(opener)

        # now we need to open each stream and read the tar files contained in it
        # this generates a stream of dict(fname=..., data=...) objects
        expander = pipelinefilter(tar_file_expander)
        self.append(
            expander(
                handler=handler, select_files=select_files, rename_files=rename_files
            )
        )

        # finally, the files need to be groups into samples
        # this generates a stream of dict(__key__=..., ...=...) objects
        grouper = pipelinefilter(group_by_keys)
        self.append(grouper(handler=handler))

        # check for empty datasets
        if empty_check:
            self.append(check_empty)

    def update_cache_info(self, args):
        """Update cache information based on arguments and environment variables.

        Args:
            args: A SimpleNamespace object containing the arguments.

        Raises:
            ValueError: If the specified cache directory does not exist.
        """
        args.cache_size = int(os.environ.get("WDS_CACHE_SIZE", args.cache_size))
        args.cache_dir = os.environ.get("WDS_CACHE", args.cache_dir)
        if args.cache_dir is not None:
            args.cache_dir = os.path.expanduser(args.cache_dir)
            if not os.path.exists(args.cache_dir):
                raise ValueError(f"cache directory {args.cache_dir} does not exist")

    def create_url_iterator(self, args):
        """Create an appropriate URL iterator based on the input type.

        This method determines the type of URL input and creates the corresponding
        iterator for the dataset.

        Args:
            args: A SimpleNamespace object containing the arguments.

        Raises:
            ValueError: If the URL type is not supported or implemented.
        """
        urls = args.urls

        # .yaml specification files
        if isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
            with open(args.urls) as stream:
                spec = yaml.safe_load(stream)
            assert "datasets" in spec
            self.append(shardlists.MultiShardSample(spec))
            return

        # .yaml specifications already loaded as dictionaries
        if isinstance(args.urls, dict):
            assert "datasets" in args.urls
            self.append(shardlists.MultiShardSample(args.urls))
            return

        # .json specification files (from wids)
        if isinstance(urls, str) and urls.endswith(".json"):
            raise ValueError("unimplemented")

        # any URL ending in "/" is assumed to be a directory
        if isinstance(urls, str) and urlparse(urls).path.endswith("/"):
            self.append(shardlists.DirectoryShardList(urls, mode=args.mode))
            return

        # the rest is either a shard list or a resampled shard list
        if isinstance(args.urls, str) or utils.is_iterable(args.urls):
            if args.mode == "resampled":
                self.append(shardlists.ResampledShardList(args.urls))
            else:
                self.append(shardlists.SimpleShardList(args.urls))
            return

        raise ValueError(f"cannot handle urls of type {type(args.urls)}")

    def __enter__(self):
        """Enter the runtime context for the WebDataset.

        Returns:
            self: The WebDataset instance.
        """
        return self

    def __exit__(self, *args):
        """Exit the runtime context for the WebDataset.

        Args:
            *args: Exception type, value, and traceback if an exception occurred.
        """
        self.close()

__enter__()

Enter the runtime context for the WebDataset.

Returns:
  • self

    The WebDataset instance.

Source code in webdataset/compat.py
515
516
517
518
519
520
521
def __enter__(self):
    """Enter the runtime context for the WebDataset.

    Returns:
        self: The WebDataset instance.
    """
    return self

__exit__(*args)

Exit the runtime context for the WebDataset.

Parameters:
  • *args

    Exception type, value, and traceback if an exception occurred.

Source code in webdataset/compat.py
523
524
525
526
527
528
529
def __exit__(self, *args):
    """Exit the runtime context for the WebDataset.

    Args:
        *args: Exception type, value, and traceback if an exception occurred.
    """
    self.close()

create_url_iterator(args)

Create an appropriate URL iterator based on the input type.

This method determines the type of URL input and creates the corresponding iterator for the dataset.

Parameters:
  • args

    A SimpleNamespace object containing the arguments.

Raises:
  • ValueError

    If the URL type is not supported or implemented.

Source code in webdataset/compat.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
def create_url_iterator(self, args):
    """Create an appropriate URL iterator based on the input type.

    This method determines the type of URL input and creates the corresponding
    iterator for the dataset.

    Args:
        args: A SimpleNamespace object containing the arguments.

    Raises:
        ValueError: If the URL type is not supported or implemented.
    """
    urls = args.urls

    # .yaml specification files
    if isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
        with open(args.urls) as stream:
            spec = yaml.safe_load(stream)
        assert "datasets" in spec
        self.append(shardlists.MultiShardSample(spec))
        return

    # .yaml specifications already loaded as dictionaries
    if isinstance(args.urls, dict):
        assert "datasets" in args.urls
        self.append(shardlists.MultiShardSample(args.urls))
        return

    # .json specification files (from wids)
    if isinstance(urls, str) and urls.endswith(".json"):
        raise ValueError("unimplemented")

    # any URL ending in "/" is assumed to be a directory
    if isinstance(urls, str) and urlparse(urls).path.endswith("/"):
        self.append(shardlists.DirectoryShardList(urls, mode=args.mode))
        return

    # the rest is either a shard list or a resampled shard list
    if isinstance(args.urls, str) or utils.is_iterable(args.urls):
        if args.mode == "resampled":
            self.append(shardlists.ResampledShardList(args.urls))
        else:
            self.append(shardlists.SimpleShardList(args.urls))
        return

    raise ValueError(f"cannot handle urls of type {type(args.urls)}")

update_cache_info(args)

Update cache information based on arguments and environment variables.

Parameters:
  • args

    A SimpleNamespace object containing the arguments.

Raises:
  • ValueError

    If the specified cache directory does not exist.

Source code in webdataset/compat.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
def update_cache_info(self, args):
    """Update cache information based on arguments and environment variables.

    Args:
        args: A SimpleNamespace object containing the arguments.

    Raises:
        ValueError: If the specified cache directory does not exist.
    """
    args.cache_size = int(os.environ.get("WDS_CACHE_SIZE", args.cache_size))
    args.cache_dir = os.environ.get("WDS_CACHE", args.cache_dir)
    if args.cache_dir is not None:
        args.cache_dir = os.path.expanduser(args.cache_dir)
        if not os.path.exists(args.cache_dir):
            raise ValueError(f"cache directory {args.cache_dir} does not exist")

webdataset.WebLoader

Bases: DataPipeline, FluidInterface

A wrapper for DataLoader that adds a fluid interface.

Source code in webdataset/compat.py
540
541
542
543
544
class WebLoader(DataPipeline, FluidInterface):
    """A wrapper for DataLoader that adds a fluid interface."""

    def __init__(self, *args, **kw):
        super().__init__(DataLoader(*args, **kw))

webdataset.FluidInterface

Source code in webdataset/compat.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
class FluidInterface:
    def batched(
        self, batchsize, collation_fn=filters.default_collation_fn, partial=True
    ):
        """Create batches of the given size.

        This method forwards to the filters.batched function.

        Args:
            batchsize (int): Target batch size.
            collation_fn (callable, optional): Function to collate samples into a batch.
                Defaults to filters.default_collation_fn.
            partial (bool, optional): Whether to return partial batches. Defaults to True.

        Returns:
            FluidInterface: Updated pipeline with batched filter.
        """
        return self.compose(
            filters.batched(batchsize, collation_fn=collation_fn, partial=partial)
        )

    def unbatched(self):
        """Turn batched data back into unbatched data.

        This method forwards to the filters.unbatched function.

        Returns:
            FluidInterface: Updated pipeline with unbatched filter.
        """
        return self.compose(filters.unbatched())

    def listed(self, batchsize, partial=True):
        """Create lists of samples without collation.

        This method forwards to the filters.batched function with collation_fn set to None.

        Args:
            batchsize (int): Target list size.
            partial (bool, optional): Whether to return partial lists. Defaults to True.

        Returns:
            FluidInterface: Updated pipeline with listed filter.
        """
        return self.compose(filters.batched(batchsize=batchsize, collation_fn=None))

    def unlisted(self):
        """Turn listed data back into individual samples.

        This method forwards to the filters.unlisted function.

        Returns:
            FluidInterface: Updated pipeline with unlisted filter.
        """
        return self.compose(filters.unlisted())

    def log_keys(self, logfile=None):
        """Log keys of samples passing through the pipeline.

        This method forwards to the filters.log_keys function.

        Args:
            logfile (str, optional): Path to the log file. If None, logging is disabled.

        Returns:
            FluidInterface: Updated pipeline with log_keys filter.
        """
        return self.compose(filters.log_keys(logfile))

    def shuffle(self, size, **kw):
        """Shuffle the data in the stream.

        This method forwards to the filters.shuffle function if size > 0.

        Args:
            size (int): Buffer size for shuffling.
            **kw: Additional keyword arguments for filters.shuffle.

        Returns:
            FluidInterface: Updated pipeline with shuffle filter, or self if size < 1.
        """
        if size < 1:
            return self
        else:
            return self.compose(filters.shuffle(size, **kw))

    def map(self, f, handler=reraise_exception):
        """Apply a function to each sample in the stream.

        This method forwards to the filters.map function.

        Args:
            f (callable): Function to apply to each sample.
            handler (callable, optional): Exception handler. Defaults to reraise_exception.

        Returns:
            FluidInterface: Updated pipeline with map filter.
        """
        return self.compose(filters.map(f, handler=handler))

    def decode(
        self,
        *args,
        pre=None,
        post=None,
        only=None,
        partial=False,
        handler=reraise_exception,
    ):
        """Decode data based on the decoding functions given as arguments.

        This method creates a decoder using autodecode.Decoder and applies it using filters.map.

        Args:
            *args: Decoding functions or strings representing image handlers.
            pre (callable, optional): Pre-processing function.
            post (callable, optional): Post-processing function.
            only (list, optional): List of keys to decode.
            partial (bool, optional): Whether to allow partial decoding. Defaults to False.
            handler (callable, optional): Exception handler. Defaults to reraise_exception.

        Returns:
            FluidInterface: Updated pipeline with decode filter.
        """
        handlers = [
            autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args
        ]
        decoder = autodecode.Decoder(
            handlers, pre=pre, post=post, only=only, partial=partial
        )
        return self.map(decoder, handler=handler)

    def map_dict(self, handler=reraise_exception, **kw):
        """Map the entries in a dict sample with individual functions.

        This method forwards to the filters.map_dict function.

        Args:
            handler (callable, optional): Exception handler. Defaults to reraise_exception.
            **kw: Mapping of keys to functions to apply.

        Returns:
            FluidInterface: Updated pipeline with map_dict filter.
        """
        return self.compose(filters.map_dict(handler=handler, **kw))

    def select(self, predicate, **kw):
        """Select samples based on a predicate.

        This method forwards to the filters.select function.

        Args:
            predicate (callable): Function that returns True for samples to keep.
            **kw: Additional keyword arguments for filters.select.

        Returns:
            FluidInterface: Updated pipeline with select filter.
        """
        return self.compose(filters.select(predicate, **kw))

    def to_tuple(self, *args, **kw):
        """Convert dict samples to tuples.

        This method forwards to the filters.to_tuple function.

        Args:
            *args: Keys to extract from the dict.
            **kw: Additional keyword arguments for filters.to_tuple.

        Returns:
            FluidInterface: Updated pipeline with to_tuple filter.
        """
        return self.compose(filters.to_tuple(*args, **kw))

    def map_tuple(self, *args, handler=reraise_exception):
        """Map the entries of a tuple with individual functions.

        This method forwards to the filters.map_tuple function.

        Args:
            *args: Functions to apply to each element of the tuple.
            handler (callable, optional): Exception handler. Defaults to reraise_exception.

        Returns:
            FluidInterface: Updated pipeline with map_tuple filter.
        """
        return self.compose(filters.map_tuple(*args, handler=handler))

    def slice(self, *args):
        """Slice the data stream.

        This method forwards to the filters.slice function.

        Args:
            *args: Arguments for slicing (start, stop, step).

        Returns:
            FluidInterface: Updated pipeline with slice filter.
        """
        return self.compose(filters.slice(*args))

    def rename(self, **kw):
        """Rename samples based on keyword arguments.

        This method forwards to the filters.rename function.

        Args:
            **kw: Mapping of old names to new names.

        Returns:
            FluidInterface: Updated pipeline with rename filter.
        """
        return self.compose(filters.rename(**kw))

    def rsample(self, p=0.5):
        """Randomly subsample a stream of data.

        This method forwards to the filters.rsample function.

        Args:
            p (float, optional): Probability of keeping each sample. Defaults to 0.5.

        Returns:
            FluidInterface: Updated pipeline with rsample filter.
        """
        return self.compose(filters.rsample(p))

    def rename_keys(self, *args, **kw):
        """Rename keys in samples based on patterns.

        This method forwards to the filters.rename_keys function.

        Args:
            *args: Positional arguments for filters.rename_keys.
            **kw: Keyword arguments for filters.rename_keys.

        Returns:
            FluidInterface: Updated pipeline with rename_keys filter.
        """
        return self.compose(filters.rename_keys(*args, **kw))

    def extract_keys(self, *args, **kw):
        """Extract specific keys from samples.

        This method forwards to the filters.extract_keys function.

        Args:
            *args: Keys or patterns to extract.
            **kw: Additional keyword arguments for filters.extract_keys.

        Returns:
            FluidInterface: Updated pipeline with extract_keys filter.
        """
        return self.compose(filters.extract_keys(*args, **kw))

    def xdecode(self, *args, **kw):
        """Decode data based on file extensions.

        This method forwards to the filters.xdecode function.

        Args:
            *args: Positional arguments for filters.xdecode.
            **kw: Keyword arguments for filters.xdecode.

        Returns:
            FluidInterface: Updated pipeline with xdecode filter.
        """
        return self.compose(filters.xdecode(*args, **kw))

    def mcached(self):
        """Cache samples in memory.

        This method forwards to the filters.Cached class.

        Returns:
            FluidInterface: Updated pipeline with memory caching.
        """
        return self.compose(filters.Cached())

    def lmdb_cached(self, *args, **kw):
        """Cache samples using LMDB.

        This method forwards to the filters.LMDBCached class.

        Args:
            *args: Positional arguments for filters.LMDBCached.
            **kw: Keyword arguments for filters.LMDBCached.

        Returns:
            FluidInterface: Updated pipeline with LMDB caching.
        """
        return self.compose(filters.LMDBCached(*args, **kw))

batched(batchsize, collation_fn=filters.default_collation_fn, partial=True)

Create batches of the given size.

This method forwards to the filters.batched function.

Parameters:
  • batchsize (int) –

    Target batch size.

  • collation_fn (callable, default: default_collation_fn ) –

    Function to collate samples into a batch. Defaults to filters.default_collation_fn.

  • partial (bool, default: True ) –

    Whether to return partial batches. Defaults to True.

Returns:
  • FluidInterface

    Updated pipeline with batched filter.

Source code in webdataset/compat.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def batched(
    self, batchsize, collation_fn=filters.default_collation_fn, partial=True
):
    """Create batches of the given size.

    This method forwards to the filters.batched function.

    Args:
        batchsize (int): Target batch size.
        collation_fn (callable, optional): Function to collate samples into a batch.
            Defaults to filters.default_collation_fn.
        partial (bool, optional): Whether to return partial batches. Defaults to True.

    Returns:
        FluidInterface: Updated pipeline with batched filter.
    """
    return self.compose(
        filters.batched(batchsize, collation_fn=collation_fn, partial=partial)
    )

decode(*args, pre=None, post=None, only=None, partial=False, handler=reraise_exception)

Decode data based on the decoding functions given as arguments.

This method creates a decoder using autodecode.Decoder and applies it using filters.map.

Parameters:
  • *args

    Decoding functions or strings representing image handlers.

  • pre (callable, default: None ) –

    Pre-processing function.

  • post (callable, default: None ) –

    Post-processing function.

  • only (list, default: None ) –

    List of keys to decode.

  • partial (bool, default: False ) –

    Whether to allow partial decoding. Defaults to False.

  • handler (callable, default: reraise_exception ) –

    Exception handler. Defaults to reraise_exception.

Returns:
  • FluidInterface

    Updated pipeline with decode filter.

Source code in webdataset/compat.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def decode(
    self,
    *args,
    pre=None,
    post=None,
    only=None,
    partial=False,
    handler=reraise_exception,
):
    """Decode data based on the decoding functions given as arguments.

    This method creates a decoder using autodecode.Decoder and applies it using filters.map.

    Args:
        *args: Decoding functions or strings representing image handlers.
        pre (callable, optional): Pre-processing function.
        post (callable, optional): Post-processing function.
        only (list, optional): List of keys to decode.
        partial (bool, optional): Whether to allow partial decoding. Defaults to False.
        handler (callable, optional): Exception handler. Defaults to reraise_exception.

    Returns:
        FluidInterface: Updated pipeline with decode filter.
    """
    handlers = [
        autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args
    ]
    decoder = autodecode.Decoder(
        handlers, pre=pre, post=post, only=only, partial=partial
    )
    return self.map(decoder, handler=handler)

extract_keys(*args, **kw)

Extract specific keys from samples.

This method forwards to the filters.extract_keys function.

Parameters:
  • *args

    Keys or patterns to extract.

  • **kw

    Additional keyword arguments for filters.extract_keys.

Returns:
  • FluidInterface

    Updated pipeline with extract_keys filter.

Source code in webdataset/compat.py
256
257
258
259
260
261
262
263
264
265
266
267
268
def extract_keys(self, *args, **kw):
    """Extract specific keys from samples.

    This method forwards to the filters.extract_keys function.

    Args:
        *args: Keys or patterns to extract.
        **kw: Additional keyword arguments for filters.extract_keys.

    Returns:
        FluidInterface: Updated pipeline with extract_keys filter.
    """
    return self.compose(filters.extract_keys(*args, **kw))

listed(batchsize, partial=True)

Create lists of samples without collation.

This method forwards to the filters.batched function with collation_fn set to None.

Parameters:
  • batchsize (int) –

    Target list size.

  • partial (bool, default: True ) –

    Whether to return partial lists. Defaults to True.

Returns:
  • FluidInterface

    Updated pipeline with listed filter.

Source code in webdataset/compat.py
47
48
49
50
51
52
53
54
55
56
57
58
59
def listed(self, batchsize, partial=True):
    """Create lists of samples without collation.

    This method forwards to the filters.batched function with collation_fn set to None.

    Args:
        batchsize (int): Target list size.
        partial (bool, optional): Whether to return partial lists. Defaults to True.

    Returns:
        FluidInterface: Updated pipeline with listed filter.
    """
    return self.compose(filters.batched(batchsize=batchsize, collation_fn=None))

lmdb_cached(*args, **kw)

Cache samples using LMDB.

This method forwards to the filters.LMDBCached class.

Parameters:
  • *args

    Positional arguments for filters.LMDBCached.

  • **kw

    Keyword arguments for filters.LMDBCached.

Returns:
  • FluidInterface

    Updated pipeline with LMDB caching.

Source code in webdataset/compat.py
294
295
296
297
298
299
300
301
302
303
304
305
306
def lmdb_cached(self, *args, **kw):
    """Cache samples using LMDB.

    This method forwards to the filters.LMDBCached class.

    Args:
        *args: Positional arguments for filters.LMDBCached.
        **kw: Keyword arguments for filters.LMDBCached.

    Returns:
        FluidInterface: Updated pipeline with LMDB caching.
    """
    return self.compose(filters.LMDBCached(*args, **kw))

log_keys(logfile=None)

Log keys of samples passing through the pipeline.

This method forwards to the filters.log_keys function.

Parameters:
  • logfile (str, default: None ) –

    Path to the log file. If None, logging is disabled.

Returns:
  • FluidInterface

    Updated pipeline with log_keys filter.

Source code in webdataset/compat.py
71
72
73
74
75
76
77
78
79
80
81
82
def log_keys(self, logfile=None):
    """Log keys of samples passing through the pipeline.

    This method forwards to the filters.log_keys function.

    Args:
        logfile (str, optional): Path to the log file. If None, logging is disabled.

    Returns:
        FluidInterface: Updated pipeline with log_keys filter.
    """
    return self.compose(filters.log_keys(logfile))

map(f, handler=reraise_exception)

Apply a function to each sample in the stream.

This method forwards to the filters.map function.

Parameters:
  • f (callable) –

    Function to apply to each sample.

  • handler (callable, default: reraise_exception ) –

    Exception handler. Defaults to reraise_exception.

Returns:
  • FluidInterface

    Updated pipeline with map filter.

Source code in webdataset/compat.py
101
102
103
104
105
106
107
108
109
110
111
112
113
def map(self, f, handler=reraise_exception):
    """Apply a function to each sample in the stream.

    This method forwards to the filters.map function.

    Args:
        f (callable): Function to apply to each sample.
        handler (callable, optional): Exception handler. Defaults to reraise_exception.

    Returns:
        FluidInterface: Updated pipeline with map filter.
    """
    return self.compose(filters.map(f, handler=handler))

map_dict(handler=reraise_exception, **kw)

Map the entries in a dict sample with individual functions.

This method forwards to the filters.map_dict function.

Parameters:
  • handler (callable, default: reraise_exception ) –

    Exception handler. Defaults to reraise_exception.

  • **kw

    Mapping of keys to functions to apply.

Returns:
  • FluidInterface

    Updated pipeline with map_dict filter.

Source code in webdataset/compat.py
147
148
149
150
151
152
153
154
155
156
157
158
159
def map_dict(self, handler=reraise_exception, **kw):
    """Map the entries in a dict sample with individual functions.

    This method forwards to the filters.map_dict function.

    Args:
        handler (callable, optional): Exception handler. Defaults to reraise_exception.
        **kw: Mapping of keys to functions to apply.

    Returns:
        FluidInterface: Updated pipeline with map_dict filter.
    """
    return self.compose(filters.map_dict(handler=handler, **kw))

map_tuple(*args, handler=reraise_exception)

Map the entries of a tuple with individual functions.

This method forwards to the filters.map_tuple function.

Parameters:
  • *args

    Functions to apply to each element of the tuple.

  • handler (callable, default: reraise_exception ) –

    Exception handler. Defaults to reraise_exception.

Returns:
  • FluidInterface

    Updated pipeline with map_tuple filter.

Source code in webdataset/compat.py
189
190
191
192
193
194
195
196
197
198
199
200
201
def map_tuple(self, *args, handler=reraise_exception):
    """Map the entries of a tuple with individual functions.

    This method forwards to the filters.map_tuple function.

    Args:
        *args: Functions to apply to each element of the tuple.
        handler (callable, optional): Exception handler. Defaults to reraise_exception.

    Returns:
        FluidInterface: Updated pipeline with map_tuple filter.
    """
    return self.compose(filters.map_tuple(*args, handler=handler))

mcached()

Cache samples in memory.

This method forwards to the filters.Cached class.

Returns:
  • FluidInterface

    Updated pipeline with memory caching.

Source code in webdataset/compat.py
284
285
286
287
288
289
290
291
292
def mcached(self):
    """Cache samples in memory.

    This method forwards to the filters.Cached class.

    Returns:
        FluidInterface: Updated pipeline with memory caching.
    """
    return self.compose(filters.Cached())

rename(**kw)

Rename samples based on keyword arguments.

This method forwards to the filters.rename function.

Parameters:
  • **kw

    Mapping of old names to new names.

Returns:
  • FluidInterface

    Updated pipeline with rename filter.

Source code in webdataset/compat.py
216
217
218
219
220
221
222
223
224
225
226
227
def rename(self, **kw):
    """Rename samples based on keyword arguments.

    This method forwards to the filters.rename function.

    Args:
        **kw: Mapping of old names to new names.

    Returns:
        FluidInterface: Updated pipeline with rename filter.
    """
    return self.compose(filters.rename(**kw))

rename_keys(*args, **kw)

Rename keys in samples based on patterns.

This method forwards to the filters.rename_keys function.

Parameters:
  • *args

    Positional arguments for filters.rename_keys.

  • **kw

    Keyword arguments for filters.rename_keys.

Returns:
  • FluidInterface

    Updated pipeline with rename_keys filter.

Source code in webdataset/compat.py
242
243
244
245
246
247
248
249
250
251
252
253
254
def rename_keys(self, *args, **kw):
    """Rename keys in samples based on patterns.

    This method forwards to the filters.rename_keys function.

    Args:
        *args: Positional arguments for filters.rename_keys.
        **kw: Keyword arguments for filters.rename_keys.

    Returns:
        FluidInterface: Updated pipeline with rename_keys filter.
    """
    return self.compose(filters.rename_keys(*args, **kw))

rsample(p=0.5)

Randomly subsample a stream of data.

This method forwards to the filters.rsample function.

Parameters:
  • p (float, default: 0.5 ) –

    Probability of keeping each sample. Defaults to 0.5.

Returns:
  • FluidInterface

    Updated pipeline with rsample filter.

Source code in webdataset/compat.py
229
230
231
232
233
234
235
236
237
238
239
240
def rsample(self, p=0.5):
    """Randomly subsample a stream of data.

    This method forwards to the filters.rsample function.

    Args:
        p (float, optional): Probability of keeping each sample. Defaults to 0.5.

    Returns:
        FluidInterface: Updated pipeline with rsample filter.
    """
    return self.compose(filters.rsample(p))

select(predicate, **kw)

Select samples based on a predicate.

This method forwards to the filters.select function.

Parameters:
  • predicate (callable) –

    Function that returns True for samples to keep.

  • **kw

    Additional keyword arguments for filters.select.

Returns:
  • FluidInterface

    Updated pipeline with select filter.

Source code in webdataset/compat.py
161
162
163
164
165
166
167
168
169
170
171
172
173
def select(self, predicate, **kw):
    """Select samples based on a predicate.

    This method forwards to the filters.select function.

    Args:
        predicate (callable): Function that returns True for samples to keep.
        **kw: Additional keyword arguments for filters.select.

    Returns:
        FluidInterface: Updated pipeline with select filter.
    """
    return self.compose(filters.select(predicate, **kw))

shuffle(size, **kw)

Shuffle the data in the stream.

This method forwards to the filters.shuffle function if size > 0.

Parameters:
  • size (int) –

    Buffer size for shuffling.

  • **kw

    Additional keyword arguments for filters.shuffle.

Returns:
  • FluidInterface

    Updated pipeline with shuffle filter, or self if size < 1.

Source code in webdataset/compat.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def shuffle(self, size, **kw):
    """Shuffle the data in the stream.

    This method forwards to the filters.shuffle function if size > 0.

    Args:
        size (int): Buffer size for shuffling.
        **kw: Additional keyword arguments for filters.shuffle.

    Returns:
        FluidInterface: Updated pipeline with shuffle filter, or self if size < 1.
    """
    if size < 1:
        return self
    else:
        return self.compose(filters.shuffle(size, **kw))

slice(*args)

Slice the data stream.

This method forwards to the filters.slice function.

Parameters:
  • *args

    Arguments for slicing (start, stop, step).

Returns:
  • FluidInterface

    Updated pipeline with slice filter.

Source code in webdataset/compat.py
203
204
205
206
207
208
209
210
211
212
213
214
def slice(self, *args):
    """Slice the data stream.

    This method forwards to the filters.slice function.

    Args:
        *args: Arguments for slicing (start, stop, step).

    Returns:
        FluidInterface: Updated pipeline with slice filter.
    """
    return self.compose(filters.slice(*args))

to_tuple(*args, **kw)

Convert dict samples to tuples.

This method forwards to the filters.to_tuple function.

Parameters:
  • *args

    Keys to extract from the dict.

  • **kw

    Additional keyword arguments for filters.to_tuple.

Returns:
  • FluidInterface

    Updated pipeline with to_tuple filter.

Source code in webdataset/compat.py
175
176
177
178
179
180
181
182
183
184
185
186
187
def to_tuple(self, *args, **kw):
    """Convert dict samples to tuples.

    This method forwards to the filters.to_tuple function.

    Args:
        *args: Keys to extract from the dict.
        **kw: Additional keyword arguments for filters.to_tuple.

    Returns:
        FluidInterface: Updated pipeline with to_tuple filter.
    """
    return self.compose(filters.to_tuple(*args, **kw))

unbatched()

Turn batched data back into unbatched data.

This method forwards to the filters.unbatched function.

Returns:
  • FluidInterface

    Updated pipeline with unbatched filter.

Source code in webdataset/compat.py
37
38
39
40
41
42
43
44
45
def unbatched(self):
    """Turn batched data back into unbatched data.

    This method forwards to the filters.unbatched function.

    Returns:
        FluidInterface: Updated pipeline with unbatched filter.
    """
    return self.compose(filters.unbatched())

unlisted()

Turn listed data back into individual samples.

This method forwards to the filters.unlisted function.

Returns:
  • FluidInterface

    Updated pipeline with unlisted filter.

Source code in webdataset/compat.py
61
62
63
64
65
66
67
68
69
def unlisted(self):
    """Turn listed data back into individual samples.

    This method forwards to the filters.unlisted function.

    Returns:
        FluidInterface: Updated pipeline with unlisted filter.
    """
    return self.compose(filters.unlisted())

xdecode(*args, **kw)

Decode data based on file extensions.

This method forwards to the filters.xdecode function.

Parameters:
  • *args

    Positional arguments for filters.xdecode.

  • **kw

    Keyword arguments for filters.xdecode.

Returns:
  • FluidInterface

    Updated pipeline with xdecode filter.

Source code in webdataset/compat.py
270
271
272
273
274
275
276
277
278
279
280
281
282
def xdecode(self, *args, **kw):
    """Decode data based on file extensions.

    This method forwards to the filters.xdecode function.

    Args:
        *args: Positional arguments for filters.xdecode.
        **kw: Keyword arguments for filters.xdecode.

    Returns:
        FluidInterface: Updated pipeline with xdecode filter.
    """
    return self.compose(filters.xdecode(*args, **kw))

webdataset.with_epoch

Bases: IterableDataset

Change the actual and nominal length of an IterableDataset.

This will continuously iterate through the original dataset, but impose new epoch boundaries at the given length/nominal. This exists mainly as a workaround for the odd logic in DataLoader. It is also useful for choosing smaller nominal epoch sizes with very large datasets.

Parameters:
  • dataset

    The source IterableDataset.

  • length (int) –

    Declared length of the dataset.

Source code in webdataset/extradatasets.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class with_epoch(IterableDataset):
    """Change the actual and nominal length of an IterableDataset.

    This will continuously iterate through the original dataset, but
    impose new epoch boundaries at the given length/nominal.
    This exists mainly as a workaround for the odd logic in DataLoader.
    It is also useful for choosing smaller nominal epoch sizes with
    very large datasets.

    Args:
        dataset: The source IterableDataset.
        length (int): Declared length of the dataset.
    """

    def __init__(self, dataset, length):
        super().__init__()
        self.length = length
        self.source = None

    def __getstate__(self):
        """Return the pickled state of the dataset.

        This resets the dataset iterator, since that can't be pickled.

        Returns:
            dict: A dictionary representing the pickled state of the dataset.
        """
        result = dict(self.__dict__)
        result["source"] = None
        return result

    def invoke(self, dataset):
        """Return an iterator over the dataset.

        This iterator returns as many samples as given by the `length` parameter.

        Args:
            dataset: The source dataset to iterate over.

        Yields:
            Sample: The next sample from the dataset.
        """
        if self.source is None:
            self.source = iter(dataset)
        for _ in range(self.length):
            try:
                sample = next(self.source)
            except StopIteration:
                self.source = iter(dataset)
                try:
                    sample = next(self.source)
                except StopIteration:
                    return
            yield sample
        self.source = None

__getstate__()

Return the pickled state of the dataset.

This resets the dataset iterator, since that can't be pickled.

Returns:
  • dict

    A dictionary representing the pickled state of the dataset.

Source code in webdataset/extradatasets.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def __getstate__(self):
    """Return the pickled state of the dataset.

    This resets the dataset iterator, since that can't be pickled.

    Returns:
        dict: A dictionary representing the pickled state of the dataset.
    """
    result = dict(self.__dict__)
    result["source"] = None
    return result

invoke(dataset)

Return an iterator over the dataset.

This iterator returns as many samples as given by the length parameter.

Parameters:
  • dataset

    The source dataset to iterate over.

Yields:
  • Sample

    The next sample from the dataset.

Source code in webdataset/extradatasets.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def invoke(self, dataset):
    """Return an iterator over the dataset.

    This iterator returns as many samples as given by the `length` parameter.

    Args:
        dataset: The source dataset to iterate over.

    Yields:
        Sample: The next sample from the dataset.
    """
    if self.source is None:
        self.source = iter(dataset)
    for _ in range(self.length):
        try:
            sample = next(self.source)
        except StopIteration:
            self.source = iter(dataset)
            try:
                sample = next(self.source)
            except StopIteration:
                return
        yield sample
    self.source = None

Writing WebDatasets

webdataset.ShardWriter

Like TarWriter but splits into multiple shards.

Parameters:
  • pattern (str) –

    Output file pattern.

  • maxcount (int, default: 100000 ) –

    Maximum number of records per shard. Defaults to 100000.

  • maxsize (float, default: 3000000000.0 ) –

    Maximum size of each shard. Defaults to 3e9.

  • post (Optional[Callable], default: None ) –

    Optional callable to be executed after each shard is written. Defaults to None.

  • start_shard (int, default: 0 ) –

    Starting shard number. Defaults to 0.

  • verbose (int, default: 1 ) –

    Verbosity level. Defaults to 1.

  • opener (Optional[Callable], default: None ) –

    Optional callable to open output files. Defaults to None.

  • **kw

    Other options passed to TarWriter.

Source code in webdataset/writer.py
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
class ShardWriter:
    """Like TarWriter but splits into multiple shards.

    Args:
        pattern: Output file pattern.
        maxcount: Maximum number of records per shard. Defaults to 100000.
        maxsize: Maximum size of each shard. Defaults to 3e9.
        post: Optional callable to be executed after each shard is written. Defaults to None.
        start_shard: Starting shard number. Defaults to 0.
        verbose: Verbosity level. Defaults to 1.
        opener: Optional callable to open output files. Defaults to None.
        **kw: Other options passed to TarWriter.
    """

    def __init__(
        self,
        pattern: str,
        maxcount: int = 100000,
        maxsize: float = 3e9,
        post: Optional[Callable] = None,
        start_shard: int = 0,
        verbose: int = 1,
        opener: Optional[Callable] = None,
        **kw,
    ):
        """Create a ShardWriter.

        Args:
            pattern: Output file pattern.
            maxcount: Maximum number of records per shard.
            maxsize: Maximum size of each shard.
            post: Optional callable to be executed after each shard is written.
            start_shard: Starting shard number.
            verbose: Verbosity level.
            opener: Optional callable to open output files.
            **kw: Other options passed to TarWriter.
        """
        self.verbose = verbose
        self.kw = kw
        self.maxcount = maxcount
        self.maxsize = maxsize
        self.post = post

        self.tarstream = None
        self.shard = start_shard
        self.pattern = pattern
        self.total = 0
        self.count = 0
        self.size = 0
        self.fname = None
        self.opener = opener
        self.next_stream()

    def next_stream(self):
        """Close the current stream and move to the next."""
        self.finish()
        self.fname = self.pattern % self.shard
        if self.verbose:
            print(
                "# writing",
                self.fname,
                self.count,
                "%.1f GB" % (self.size / 1e9),
                self.total,
            )
        self.shard += 1
        if self.opener:
            self.tarstream = TarWriter(self.opener(self.fname), **self.kw)
        else:
            self.tarstream = TarWriter(self.fname, **self.kw)
        self.count = 0
        self.size = 0

    def write(self, obj):
        """Write a sample.

        Args:
            obj: Sample to be written.
        """
        if (
            self.tarstream is None
            or self.count >= self.maxcount
            or self.size >= self.maxsize
        ):
            self.next_stream()
        size = self.tarstream.write(obj)
        self.count += 1
        self.total += 1
        self.size += size

    def finish(self):
        """Finish all writing (use close instead)."""
        if self.tarstream is not None:
            self.tarstream.close()
            assert self.fname is not None
            if callable(self.post):
                self.post(self.fname)
            self.tarstream = None

    def close(self):
        """Close the stream."""
        self.finish()
        del self.tarstream
        del self.shard
        del self.count
        del self.size

    def __enter__(self):
        """Enter context.

        Returns:
            self: The ShardWriter object.
        """
        return self

    def __exit__(self, *args, **kw):
        """Exit context."""
        self.close()

__enter__()

Enter context.

Returns:
  • self

    The ShardWriter object.

Source code in webdataset/writer.py
609
610
611
612
613
614
615
def __enter__(self):
    """Enter context.

    Returns:
        self: The ShardWriter object.
    """
    return self

__exit__(*args, **kw)

Exit context.

Source code in webdataset/writer.py
617
618
619
def __exit__(self, *args, **kw):
    """Exit context."""
    self.close()

__init__(pattern, maxcount=100000, maxsize=3000000000.0, post=None, start_shard=0, verbose=1, opener=None, **kw)

Create a ShardWriter.

Parameters:
  • pattern (str) –

    Output file pattern.

  • maxcount (int, default: 100000 ) –

    Maximum number of records per shard.

  • maxsize (float, default: 3000000000.0 ) –

    Maximum size of each shard.

  • post (Optional[Callable], default: None ) –

    Optional callable to be executed after each shard is written.

  • start_shard (int, default: 0 ) –

    Starting shard number.

  • verbose (int, default: 1 ) –

    Verbosity level.

  • opener (Optional[Callable], default: None ) –

    Optional callable to open output files.

  • **kw

    Other options passed to TarWriter.

Source code in webdataset/writer.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
def __init__(
    self,
    pattern: str,
    maxcount: int = 100000,
    maxsize: float = 3e9,
    post: Optional[Callable] = None,
    start_shard: int = 0,
    verbose: int = 1,
    opener: Optional[Callable] = None,
    **kw,
):
    """Create a ShardWriter.

    Args:
        pattern: Output file pattern.
        maxcount: Maximum number of records per shard.
        maxsize: Maximum size of each shard.
        post: Optional callable to be executed after each shard is written.
        start_shard: Starting shard number.
        verbose: Verbosity level.
        opener: Optional callable to open output files.
        **kw: Other options passed to TarWriter.
    """
    self.verbose = verbose
    self.kw = kw
    self.maxcount = maxcount
    self.maxsize = maxsize
    self.post = post

    self.tarstream = None
    self.shard = start_shard
    self.pattern = pattern
    self.total = 0
    self.count = 0
    self.size = 0
    self.fname = None
    self.opener = opener
    self.next_stream()

close()

Close the stream.

Source code in webdataset/writer.py
601
602
603
604
605
606
607
def close(self):
    """Close the stream."""
    self.finish()
    del self.tarstream
    del self.shard
    del self.count
    del self.size

finish()

Finish all writing (use close instead).

Source code in webdataset/writer.py
592
593
594
595
596
597
598
599
def finish(self):
    """Finish all writing (use close instead)."""
    if self.tarstream is not None:
        self.tarstream.close()
        assert self.fname is not None
        if callable(self.post):
            self.post(self.fname)
        self.tarstream = None

next_stream()

Close the current stream and move to the next.

Source code in webdataset/writer.py
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
def next_stream(self):
    """Close the current stream and move to the next."""
    self.finish()
    self.fname = self.pattern % self.shard
    if self.verbose:
        print(
            "# writing",
            self.fname,
            self.count,
            "%.1f GB" % (self.size / 1e9),
            self.total,
        )
    self.shard += 1
    if self.opener:
        self.tarstream = TarWriter(self.opener(self.fname), **self.kw)
    else:
        self.tarstream = TarWriter(self.fname, **self.kw)
    self.count = 0
    self.size = 0

write(obj)

Write a sample.

Parameters:
  • obj

    Sample to be written.

Source code in webdataset/writer.py
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
def write(self, obj):
    """Write a sample.

    Args:
        obj: Sample to be written.
    """
    if (
        self.tarstream is None
        or self.count >= self.maxcount
        or self.size >= self.maxsize
    ):
        self.next_stream()
    size = self.tarstream.write(obj)
    self.count += 1
    self.total += 1
    self.size += size

webdataset.TarWriter

A class for writing dictionaries to tar files.

Parameters:
  • fileobj

    File name for tar file (.tgz/.tar) or open file descriptor.

  • encoder (Union[None, bool, Callable], default: True ) –

    Sample encoding. Defaults to True.

  • compress (Optional[Union[bool, str]], default: None ) –

    Compression flag. Defaults to None.

  • user (str, default: 'bigdata' ) –

    User for tar files. Defaults to "bigdata".

  • group (str, default: 'bigdata' ) –

    Group for tar files. Defaults to "bigdata".

  • mode (int, default: 292 ) –

    Mode for tar files. Defaults to 0o0444.

  • keep_meta (bool, default: False ) –

    Flag to keep metadata (entries starting with "_"). Defaults to False.

  • mtime (Optional[float], default: None ) –

    Modification time. Defaults to None.

  • format (Any, default: None ) –

    Tar format. Defaults to None.

Returns:
  • TarWriter object.

Raises:
  • ValueError

    If the encoder doesn't yield bytes for a key.

True will use an encoder that behaves similar to the automatic decoder for Dataset. False disables encoding and expects byte strings (except for metadata, which must be strings). The encoder argument can also be a callable, or a dictionary mapping extensions to encoders.

The following code will add two file to the tar archive: a/b.png and a/b.output.png.

tarwriter = TarWriter(stream)
image = imread("b.jpg")
image2 = imread("b.out.jpg")
sample = {"__key__": "a/b", "png": image, "output.png": image2}
tarwriter.write(sample)
Source code in webdataset/writer.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
class TarWriter:
    """A class for writing dictionaries to tar files.

    Args:
        fileobj: File name for tar file (.tgz/.tar) or open file descriptor.
        encoder: Sample encoding. Defaults to True.
        compress: Compression flag. Defaults to None.
        user: User for tar files. Defaults to "bigdata".
        group: Group for tar files. Defaults to "bigdata".
        mode: Mode for tar files. Defaults to 0o0444.
        keep_meta: Flag to keep metadata (entries starting with "_"). Defaults to False.
        mtime: Modification time. Defaults to None.
        format: Tar format. Defaults to None.

    Returns:
        TarWriter object.

    Raises:
        ValueError: If the encoder doesn't yield bytes for a key.

    `True` will use an encoder that behaves similar to the automatic
    decoder for `Dataset`. `False` disables encoding and expects byte strings
    (except for metadata, which must be strings). The `encoder` argument can
    also be a `callable`, or a dictionary mapping extensions to encoders.

    The following code will add two file to the tar archive: `a/b.png` and
    `a/b.output.png`.


        tarwriter = TarWriter(stream)
        image = imread("b.jpg")
        image2 = imread("b.out.jpg")
        sample = {"__key__": "a/b", "png": image, "output.png": image2}
        tarwriter.write(sample)

    """

    def __init__(
        self,
        fileobj,
        user: str = "bigdata",
        group: str = "bigdata",
        mode: int = 0o0444,
        compress: Optional[Union[bool, str]] = None,
        encoder: Union[None, bool, Callable] = True,
        keep_meta: bool = False,
        mtime: Optional[float] = None,
        format: Any = None,
    ):  # sourcery skip: avoid-builtin-shadow
        """Create a tar writer.

        Args:
            fileobj: Stream to write data to.
            user: User for tar files.
            group: Group for tar files.
            mode: Mode for tar files.
            compress: Desired compression.
            encoder: Encoder function.
            keep_meta: Keep metadata (entries starting with "_").
            mtime: Modification time (set this to some fixed value to get reproducible tar files).
            format: Tar format.
        """
        format = getattr(tarfile, format, format) if format else tarfile.USTAR_FORMAT
        self.mtime = mtime
        tarmode = self.tarmode(fileobj, compress)
        if isinstance(fileobj, str):
            fileobj = gopen(fileobj, "wb")
            self.own_fileobj = fileobj
        else:
            self.own_fileobj = None
        self.encoder = make_encoder(encoder)
        self.keep_meta = keep_meta
        self.stream = fileobj
        self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)

        self.user = user
        self.group = group
        self.mode = mode
        self.compress = compress

    def __enter__(self):
        """Enter context.

        Returns:
            self: The TarWriter object.
        """
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Exit context."""
        self.close()

    def close(self):
        """Close the tar file."""
        self.tarstream.close()
        if self.own_fileobj is not None:
            self.own_fileobj.close()
            self.own_fileobj = None

    def write(self, obj):
        """Write a dictionary to the tar file.

        Args:
            obj: Dictionary of objects to be stored.

        Returns:
            int: Size of the entry.

        Raises:
            ValueError: If the object doesn't contain a __key__ or if a key doesn't map to bytes after encoding.
        """
        total = 0
        obj = self.encoder(obj)
        if "__key__" not in obj:
            raise ValueError("object must contain a __key__")
        for k, v in list(obj.items()):
            if k[0] == "_":
                continue
            if not isinstance(v, (bytes, bytearray, memoryview)):
                raise ValueError(
                    f"{k} doesn't map to a bytes after encoding ({type(v)})"
                )
        key = obj["__key__"]
        for k in sorted(obj.keys()):
            if k == "__key__":
                continue
            if not self.keep_meta and k[0] == "_":
                continue
            v = obj[k]
            if isinstance(v, str):
                v = v.encode("utf-8")
            now = time.time()
            ti = tarfile.TarInfo(key + "." + k)
            ti.size = len(v)
            ti.mtime = self.mtime if self.mtime is not None else now
            ti.mode = self.mode
            ti.uname = self.user
            ti.gname = self.group
            if not isinstance(v, (bytes, bytearray, memoryview)):
                raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
            stream = io.BytesIO(v)
            self.tarstream.addfile(ti, stream)
            total += ti.size

        return total

    @staticmethod
    def tarmode(fileobj, compress: Optional[Union[bool, str]] = None):
        if compress is False:
            return "w|"
        elif (
            compress is True
            or compress == "gz"
            or (isinstance(fileobj, str) and fileobj.endswith("gz"))
        ):
            return "w|gz"
        elif compress == "bz2" or (
            isinstance(fileobj, str) and fileobj.endswith("bz2")
        ):
            return "w|bz2"
        elif compress == "xz" or (isinstance(fileobj, str) and fileobj.endswith("xz")):
            return "w|xz"
        else:
            return "w|"

__enter__()

Enter context.

Returns:
  • self

    The TarWriter object.

Source code in webdataset/writer.py
416
417
418
419
420
421
422
def __enter__(self):
    """Enter context.

    Returns:
        self: The TarWriter object.
    """
    return self

__exit__(exc_type, exc_val, exc_tb)

Exit context.

Source code in webdataset/writer.py
424
425
426
def __exit__(self, exc_type, exc_val, exc_tb):
    """Exit context."""
    self.close()

__init__(fileobj, user='bigdata', group='bigdata', mode=292, compress=None, encoder=True, keep_meta=False, mtime=None, format=None)

Create a tar writer.

Parameters:
  • fileobj

    Stream to write data to.

  • user (str, default: 'bigdata' ) –

    User for tar files.

  • group (str, default: 'bigdata' ) –

    Group for tar files.

  • mode (int, default: 292 ) –

    Mode for tar files.

  • compress (Optional[Union[bool, str]], default: None ) –

    Desired compression.

  • encoder (Union[None, bool, Callable], default: True ) –

    Encoder function.

  • keep_meta (bool, default: False ) –

    Keep metadata (entries starting with "_").

  • mtime (Optional[float], default: None ) –

    Modification time (set this to some fixed value to get reproducible tar files).

  • format (Any, default: None ) –

    Tar format.

Source code in webdataset/writer.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def __init__(
    self,
    fileobj,
    user: str = "bigdata",
    group: str = "bigdata",
    mode: int = 0o0444,
    compress: Optional[Union[bool, str]] = None,
    encoder: Union[None, bool, Callable] = True,
    keep_meta: bool = False,
    mtime: Optional[float] = None,
    format: Any = None,
):  # sourcery skip: avoid-builtin-shadow
    """Create a tar writer.

    Args:
        fileobj: Stream to write data to.
        user: User for tar files.
        group: Group for tar files.
        mode: Mode for tar files.
        compress: Desired compression.
        encoder: Encoder function.
        keep_meta: Keep metadata (entries starting with "_").
        mtime: Modification time (set this to some fixed value to get reproducible tar files).
        format: Tar format.
    """
    format = getattr(tarfile, format, format) if format else tarfile.USTAR_FORMAT
    self.mtime = mtime
    tarmode = self.tarmode(fileobj, compress)
    if isinstance(fileobj, str):
        fileobj = gopen(fileobj, "wb")
        self.own_fileobj = fileobj
    else:
        self.own_fileobj = None
    self.encoder = make_encoder(encoder)
    self.keep_meta = keep_meta
    self.stream = fileobj
    self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)

    self.user = user
    self.group = group
    self.mode = mode
    self.compress = compress

close()

Close the tar file.

Source code in webdataset/writer.py
428
429
430
431
432
433
def close(self):
    """Close the tar file."""
    self.tarstream.close()
    if self.own_fileobj is not None:
        self.own_fileobj.close()
        self.own_fileobj = None

write(obj)

Write a dictionary to the tar file.

Parameters:
  • obj

    Dictionary of objects to be stored.

Returns:
  • int

    Size of the entry.

Raises:
  • ValueError

    If the object doesn't contain a key or if a key doesn't map to bytes after encoding.

Source code in webdataset/writer.py
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def write(self, obj):
    """Write a dictionary to the tar file.

    Args:
        obj: Dictionary of objects to be stored.

    Returns:
        int: Size of the entry.

    Raises:
        ValueError: If the object doesn't contain a __key__ or if a key doesn't map to bytes after encoding.
    """
    total = 0
    obj = self.encoder(obj)
    if "__key__" not in obj:
        raise ValueError("object must contain a __key__")
    for k, v in list(obj.items()):
        if k[0] == "_":
            continue
        if not isinstance(v, (bytes, bytearray, memoryview)):
            raise ValueError(
                f"{k} doesn't map to a bytes after encoding ({type(v)})"
            )
    key = obj["__key__"]
    for k in sorted(obj.keys()):
        if k == "__key__":
            continue
        if not self.keep_meta and k[0] == "_":
            continue
        v = obj[k]
        if isinstance(v, str):
            v = v.encode("utf-8")
        now = time.time()
        ti = tarfile.TarInfo(key + "." + k)
        ti.size = len(v)
        ti.mtime = self.mtime if self.mtime is not None else now
        ti.mode = self.mode
        ti.uname = self.user
        ti.gname = self.group
        if not isinstance(v, (bytes, bytearray, memoryview)):
            raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
        stream = io.BytesIO(v)
        self.tarstream.addfile(ti, stream)
        total += ti.size

    return total

Low Level I/O

webdataset.gopen.gopen(url, mode='rb', bufsize=8192, **kw)

Open the URL using various schemes and protocols.

This function provides a unified interface for opening resources specified by URLs, supporting multiple schemes and protocols. It uses the gopen_schemes dispatch table to handle different URL schemes.

Built-in support is provided for the following schemes: - pipe: for opening named pipes - file: for local file system access - http, https: for web resources - sftp, ftps: for secure file transfer - scp: for secure copy protocol

When no scheme is specified in the URL, it is treated as a local file path.

Environment Variables: - GOPEN_VERBOSE: Set to a non-zero value to enable verbose logging of file operations. Format: GOPEN_VERBOSE=1 - USE_AIS_FOR: Specifies which cloud storage services should use AIS (and its cache) for access. Format: USE_AIS_FOR=aws:gs:s3 - GOPEN_BUFFER: Sets the buffer size for file operations (in bytes). Format: GOPEN_BUFFER=8192

Parameters:
  • url (str) –

    The source URL or file path to open.

  • mode (str, default: 'rb' ) –

    The mode for opening the resource. Only "rb" (read binary) and "wb" (write binary) are supported.

  • bufsize (int, default: 8192 ) –

    The buffer size for file operations. Default is 8192 bytes.

  • **kw

    Additional keyword arguments to pass to the underlying open function.

Returns:
  • file-like object: An opened file-like object for the specified resource.

Raises:
  • ValueError

    If an unsupported mode is specified.

Note: - For stdin/stdout operations, use "-" as the URL. - The function applies URL rewriting based on the GOPEN_REWRITE environment variable before processing.

Source code in webdataset/gopen.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
def gopen(url, mode="rb", bufsize=8192, **kw):
    """Open the URL using various schemes and protocols.

    This function provides a unified interface for opening resources specified by URLs,
    supporting multiple schemes and protocols. It uses the `gopen_schemes` dispatch table
    to handle different URL schemes.

    Built-in support is provided for the following schemes:
    - pipe: for opening named pipes
    - file: for local file system access
    - http, https: for web resources
    - sftp, ftps: for secure file transfer
    - scp: for secure copy protocol

    When no scheme is specified in the URL, it is treated as a local file path.

    Environment Variables:
    - GOPEN_VERBOSE: Set to a non-zero value to enable verbose logging of file operations.
      Format: GOPEN_VERBOSE=1
    - USE_AIS_FOR: Specifies which cloud storage services should use AIS (and its cache) for access.
      Format: USE_AIS_FOR=aws:gs:s3
    - GOPEN_BUFFER: Sets the buffer size for file operations (in bytes).
      Format: GOPEN_BUFFER=8192

    Args:
        url (str): The source URL or file path to open.
        mode (str): The mode for opening the resource. Only "rb" (read binary) and "wb" (write binary) are supported.
        bufsize (int): The buffer size for file operations. Default is 8192 bytes.
        **kw: Additional keyword arguments to pass to the underlying open function.

    Returns:
        file-like object: An opened file-like object for the specified resource.

    Raises:
        ValueError: If an unsupported mode is specified.
        Other exceptions may be raised depending on the specific handler used for the URL scheme.

    Note:
    - For stdin/stdout operations, use "-" as the URL.
    - The function applies URL rewriting based on the GOPEN_REWRITE environment variable before processing.
    """
    global fallback_gopen
    verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
    if verbose:
        print("GOPEN", url, info, file=sys.stderr)
    assert mode in ["rb", "wb"], mode
    if url == "-":
        if mode == "rb":
            return sys.stdin.buffer
        elif mode == "wb":
            return sys.stdout.buffer
        else:
            raise ValueError(f"unknown mode {mode}")
    url = rewrite_url(url)
    pr = urlparse(url)
    if pr.scheme == "":
        bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
        return open(url, mode, buffering=bufsize)
    if pr.scheme == "file":
        bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
        return open(url2pathname(pr.path), mode, buffering=bufsize)
    handler = gopen_schemes["__default__"]
    handler = gopen_schemes.get(pr.scheme, handler)
    return handler(url, mode, bufsize, **kw)

Error Handling

webdataset.ignore_and_continue(exn)

Ignore the exception and continue processing.

Parameters:
  • exn

    The exception to be ignored.

Returns:
  • bool

    Always returns True to indicate continuation.

Source code in webdataset/handlers.py
34
35
36
37
38
39
40
41
42
43
def ignore_and_continue(exn):
    """Ignore the exception and continue processing.

    Args:
        exn: The exception to be ignored.

    Returns:
        bool: Always returns True to indicate continuation.
    """
    return True

webdataset.ignore_and_stop(exn)

Ignore the exception and stop further processing.

Parameters:
  • exn

    The exception to be ignored.

Returns:
  • bool

    Always returns False to indicate stopping.

Source code in webdataset/handlers.py
60
61
62
63
64
65
66
67
68
69
def ignore_and_stop(exn):
    """Ignore the exception and stop further processing.

    Args:
        exn: The exception to be ignored.

    Returns:
        bool: Always returns False to indicate stopping.
    """
    return False

webdataset.reraise_exception(exn)

Re-raise the given exception.

Parameters:
  • exn

    The exception to be re-raised.

Source code in webdataset/handlers.py
22
23
24
25
26
27
28
29
30
31
def reraise_exception(exn):
    """Re-raise the given exception.

    Args:
        exn: The exception to be re-raised.

    Raises:
        The input exception.
    """
    raise exn

webdataset.warn_and_continue(exn)

Issue a warning for the exception and continue processing.

Parameters:
  • exn

    The exception to be warned about.

Returns:
  • bool

    Always returns True to indicate continuation.

Source code in webdataset/handlers.py
46
47
48
49
50
51
52
53
54
55
56
57
def warn_and_continue(exn):
    """Issue a warning for the exception and continue processing.

    Args:
        exn: The exception to be warned about.

    Returns:
        bool: Always returns True to indicate continuation.
    """
    warnings.warn(repr(exn))
    time.sleep(0.5)
    return True

webdataset.warn_and_stop(exn)

Issue a warning for the exception and stop further processing.

Parameters:
  • exn

    The exception to be warned about.

Returns:
  • bool

    Always returns False to indicate stopping.

Source code in webdataset/handlers.py
72
73
74
75
76
77
78
79
80
81
82
83
def warn_and_stop(exn):
    """Issue a warning for the exception and stop further processing.

    Args:
        exn: The exception to be warned about.

    Returns:
        bool: Always returns False to indicate stopping.
    """
    warnings.warn(repr(exn))
    time.sleep(0.5)
    return False