WIDS API

wids.ShardListDataset

Bases: Dataset[T]

An indexable dataset based on a list of shards.

The dataset is either given as a list of shards with optional options and name, or as a URL pointing to a JSON descriptor file.

Datasets can reference other datasets via source_url.

Shard references within a dataset are resolve relative to an explicitly given base property, or relative to the URL from which the dataset descriptor was loaded.

Source code in wids/wids.py
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
class ShardListDataset(Dataset[T]):
    """An indexable dataset based on a list of shards.

    The dataset is either given as a list of shards with optional options and name,
    or as a URL pointing to a JSON descriptor file.

    Datasets can reference other datasets via `source_url`.

    Shard references within a dataset are resolve relative to an explicitly
    given `base` property, or relative to the URL from which the dataset
    descriptor was loaded.
    """

    def __init__(
        self,
        shards,
        *,
        cache_size=int(1e12),
        cache_dir=None,
        lru_size=10,
        dataset_name=None,
        localname=None,
        transformations="PIL",
        keep=False,
        base=None,
        options=None,
    ):
        """Create a ShardListDataset.

        Args:
            shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file
            cache_size: the number of shards to keep in the cache
            lru_size: the number of shards to keep in the LRU cache
            localname: a function that maps URLs to local filenames

        Note that there are two caches: an on-disk directory, and an in-memory LRU cache.
        """
        if options is None:
            options = {}
        super(ShardListDataset, self).__init__()
        # shards is a list of (filename, length) pairs. We'll need to
        # keep track of the lengths and cumulative lengths to know how
        # to map indices to shards and indices within shards.
        if isinstance(shards, (str, io.IOBase)):
            if base is None and isinstance(shards, str):
                base = urldir(shards)
            self.base = base
            self.spec = load_dsdesc_and_resolve(shards, options=options, base=base)
            self.shards = self.spec.get("shardlist", [])
            self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards))
        else:
            self.base = None
            self.spec = options
            self.shards = shards
            self.dataset_name = dataset_name or hash_dataset_name(str(shards))

        self.lengths = [shard["nsamples"] for shard in self.shards]
        self.cum_lengths = np.cumsum(self.lengths)
        self.total_length = self.cum_lengths[-1]

        if cache_dir is not None:
            # when a cache dir is explicitly given, we download files into
            # that directory without any changes
            self.cache_dir = cache_dir
            self.localname = cache_localname(cache_dir)
        elif localname is not None:
            # when a localname function is given, we use that
            self.cache_dir = None
            self.localname = localname
        else:
            # when no cache dir or localname are given, use the cache from the environment
            self.cache_dir = os.environ.get("WIDS_CACHE", "/tmp/_wids_cache")
            self.localname = default_localname(self.cache_dir)

        if True or int(os.environ.get("WIDS_VERBOSE", 0)):
            nbytes = sum(shard.get("filesize", 0) for shard in self.shards)
            nsamples = sum(shard["nsamples"] for shard in self.shards)
            print(
                str(shards)[:50],
                "base:",
                self.base,
                "name:",
                self.spec.get("name"),
                "nfiles:",
                len(self.shards),
                "nbytes:",
                nbytes,
                "samples:",
                nsamples,
                "cache:",
                self.cache_dir,
                file=sys.stderr,
            )
        self.transformations = interpret_transformations(transformations)

        if lru_size > 200:
            warnings.warn(
                "LRU size is very large; consider reducing it to avoid running out of file descriptors"
            )
        self.cache = LRUShards(lru_size, localname=self.localname, keep=keep)

    def add_transform(self, transform):
        """Add a transformation to the dataset."""
        self.transformations.append(transform)
        return self

    def __len__(self):
        """Return the total number of samples in the dataset."""
        return self.total_length

    def get_stats(self):
        """Return the number of cache accesses and misses."""
        return self.cache.accesses, self.cache.misses

    def check_cache_misses(self):
        """Check if the cache miss rate is too high."""
        accesses, misses = self.get_stats()
        if accesses > 100 and misses / accesses > 0.3:
            # output a warning only once
            self.check_cache_misses = lambda: None
            print(
                "Warning: ShardListDataset has a cache miss rate of {:.1%}%".format(
                    misses * 100.0 / accesses
                )
            )

    def get_shard(self, index):
        """Get the shard and index within the shard corresponding to the given index."""
        # Find the shard corresponding to the given index.
        shard_idx = np.searchsorted(self.cum_lengths, index, side="right")

        # Figure out which index within the shard corresponds to the
        # given index.
        if shard_idx == 0:
            inner_idx = index
        else:
            inner_idx = index - self.cum_lengths[shard_idx - 1]

        # Get the shard and return the corresponding element.
        desc = self.shards[shard_idx]
        url = desc["url"]
        shard = self.cache.get_shard(url)
        return shard, inner_idx, desc

    def __getitem__(self, index):
        """Return the sample corresponding to the given index."""
        shard, inner_idx, desc = self.get_shard(index)
        sample = shard[inner_idx]

        # Check if we're missing the cache too often.
        self.check_cache_misses()

        sample["__dataset__"] = desc.get("dataset")
        sample["__index__"] = index
        sample["__shard__"] = desc["url"]
        sample["__shardindex__"] = inner_idx

        # Apply transformations
        for transform in self.transformations:
            sample = transform(sample)

        return sample

    def close(self):
        """Close the dataset."""
        self.cache.clear()

__getitem__(index)

Return the sample corresponding to the given index.

Source code in wids/wids.py
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def __getitem__(self, index):
    """Return the sample corresponding to the given index."""
    shard, inner_idx, desc = self.get_shard(index)
    sample = shard[inner_idx]

    # Check if we're missing the cache too often.
    self.check_cache_misses()

    sample["__dataset__"] = desc.get("dataset")
    sample["__index__"] = index
    sample["__shard__"] = desc["url"]
    sample["__shardindex__"] = inner_idx

    # Apply transformations
    for transform in self.transformations:
        sample = transform(sample)

    return sample

__init__(shards, *, cache_size=int(1000000000000.0), cache_dir=None, lru_size=10, dataset_name=None, localname=None, transformations='PIL', keep=False, base=None, options=None)

Create a ShardListDataset.

Parameters:
  • shards

    a list of (filename, length) pairs or a URL pointing to a JSON descriptor file

  • cache_size

    the number of shards to keep in the cache

  • lru_size

    the number of shards to keep in the LRU cache

  • localname

    a function that maps URLs to local filenames

Note that there are two caches: an on-disk directory, and an in-memory LRU cache.

Source code in wids/wids.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def __init__(
    self,
    shards,
    *,
    cache_size=int(1e12),
    cache_dir=None,
    lru_size=10,
    dataset_name=None,
    localname=None,
    transformations="PIL",
    keep=False,
    base=None,
    options=None,
):
    """Create a ShardListDataset.

    Args:
        shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file
        cache_size: the number of shards to keep in the cache
        lru_size: the number of shards to keep in the LRU cache
        localname: a function that maps URLs to local filenames

    Note that there are two caches: an on-disk directory, and an in-memory LRU cache.
    """
    if options is None:
        options = {}
    super(ShardListDataset, self).__init__()
    # shards is a list of (filename, length) pairs. We'll need to
    # keep track of the lengths and cumulative lengths to know how
    # to map indices to shards and indices within shards.
    if isinstance(shards, (str, io.IOBase)):
        if base is None and isinstance(shards, str):
            base = urldir(shards)
        self.base = base
        self.spec = load_dsdesc_and_resolve(shards, options=options, base=base)
        self.shards = self.spec.get("shardlist", [])
        self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards))
    else:
        self.base = None
        self.spec = options
        self.shards = shards
        self.dataset_name = dataset_name or hash_dataset_name(str(shards))

    self.lengths = [shard["nsamples"] for shard in self.shards]
    self.cum_lengths = np.cumsum(self.lengths)
    self.total_length = self.cum_lengths[-1]

    if cache_dir is not None:
        # when a cache dir is explicitly given, we download files into
        # that directory without any changes
        self.cache_dir = cache_dir
        self.localname = cache_localname(cache_dir)
    elif localname is not None:
        # when a localname function is given, we use that
        self.cache_dir = None
        self.localname = localname
    else:
        # when no cache dir or localname are given, use the cache from the environment
        self.cache_dir = os.environ.get("WIDS_CACHE", "/tmp/_wids_cache")
        self.localname = default_localname(self.cache_dir)

    if True or int(os.environ.get("WIDS_VERBOSE", 0)):
        nbytes = sum(shard.get("filesize", 0) for shard in self.shards)
        nsamples = sum(shard["nsamples"] for shard in self.shards)
        print(
            str(shards)[:50],
            "base:",
            self.base,
            "name:",
            self.spec.get("name"),
            "nfiles:",
            len(self.shards),
            "nbytes:",
            nbytes,
            "samples:",
            nsamples,
            "cache:",
            self.cache_dir,
            file=sys.stderr,
        )
    self.transformations = interpret_transformations(transformations)

    if lru_size > 200:
        warnings.warn(
            "LRU size is very large; consider reducing it to avoid running out of file descriptors"
        )
    self.cache = LRUShards(lru_size, localname=self.localname, keep=keep)

__len__()

Return the total number of samples in the dataset.

Source code in wids/wids.py
460
461
462
def __len__(self):
    """Return the total number of samples in the dataset."""
    return self.total_length

add_transform(transform)

Add a transformation to the dataset.

Source code in wids/wids.py
455
456
457
458
def add_transform(self, transform):
    """Add a transformation to the dataset."""
    self.transformations.append(transform)
    return self

check_cache_misses()

Check if the cache miss rate is too high.

Source code in wids/wids.py
468
469
470
471
472
473
474
475
476
477
478
def check_cache_misses(self):
    """Check if the cache miss rate is too high."""
    accesses, misses = self.get_stats()
    if accesses > 100 and misses / accesses > 0.3:
        # output a warning only once
        self.check_cache_misses = lambda: None
        print(
            "Warning: ShardListDataset has a cache miss rate of {:.1%}%".format(
                misses * 100.0 / accesses
            )
        )

close()

Close the dataset.

Source code in wids/wids.py
517
518
519
def close(self):
    """Close the dataset."""
    self.cache.clear()

get_shard(index)

Get the shard and index within the shard corresponding to the given index.

Source code in wids/wids.py
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
def get_shard(self, index):
    """Get the shard and index within the shard corresponding to the given index."""
    # Find the shard corresponding to the given index.
    shard_idx = np.searchsorted(self.cum_lengths, index, side="right")

    # Figure out which index within the shard corresponds to the
    # given index.
    if shard_idx == 0:
        inner_idx = index
    else:
        inner_idx = index - self.cum_lengths[shard_idx - 1]

    # Get the shard and return the corresponding element.
    desc = self.shards[shard_idx]
    url = desc["url"]
    shard = self.cache.get_shard(url)
    return shard, inner_idx, desc

get_stats()

Return the number of cache accesses and misses.

Source code in wids/wids.py
464
465
466
def get_stats(self):
    """Return the number of cache accesses and misses."""
    return self.cache.accesses, self.cache.misses

wids.ChunkedSampler

Bases: Sampler

A sampler that samples in chunks and then shuffles the samples within each chunk.

This preserves locality of reference while still shuffling the data.

Source code in wids/wids.py
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
class ChunkedSampler(Sampler):
    """A sampler that samples in chunks and then shuffles the samples within each chunk.

    This preserves locality of reference while still shuffling the data.
    """

    def __init__(
        self,
        dataset,
        *,
        dslength_per_replica=-1,
        num_samples=None,
        chunksize=2000,
        seed=0,
        shuffle=True,
        shufflefirst=False,
    ):
        if isinstance(num_samples, int):
            lo, hi = 0, num_samples
        elif num_samples is None:
            lo, hi = 0, len(dataset)
        else:
            lo, hi = num_samples

        self.dslength_per_replica = (
            dslength_per_replica if dslength_per_replica > 0 else (hi - lo)
        )

        self.ranges = [(i, min(i + chunksize, hi)) for i in range(lo, hi, chunksize)]
        self.seed = seed
        self.shuffle = shuffle
        self.shufflefirst = shufflefirst
        self.epoch = 0

    def set_epoch(self, epoch):
        self.epoch = epoch

    def __iter__(self):
        self.rng = random.Random(self.seed + 1289738273 * self.epoch)
        shardshuffle = self.shufflefirst or self.epoch > 0
        yield from iterate_ranges(
            self.ranges,
            self.rng,
            indexshuffle=self.shuffle,
            shardshuffle=(self.shuffle and shardshuffle),
        )
        self.epoch += 1

    def __len__(self) -> int:
        return self.dslength_per_replica

wids.DistributedChunkedSampler(dataset, *, num_replicas=None, num_samples=None, rank=None, shuffle=True, shufflefirst=False, seed=0, drop_last=None, chunksize=1000000)

Return a ChunkedSampler for the current worker in distributed training.

Reverts to a simple ChunkedSampler if not running in distributed mode.

Since the split among workers takes place before the chunk shuffle, workers end up with a fixed set of shards they need to download. The more workers, the fewer shards are used by each worker.

Source code in wids/wids.py
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
def DistributedChunkedSampler(
    dataset: Dataset,
    *,
    num_replicas: Optional[int] = None,
    num_samples: Optional[int] = None,
    rank: Optional[int] = None,
    shuffle: bool = True,
    shufflefirst: bool = False,
    seed: int = 0,
    drop_last: bool = None,
    chunksize: int = 1000000,
) -> ChunkedSampler:
    """Return a ChunkedSampler for the current worker in distributed training.

    Reverts to a simple ChunkedSampler if not running in distributed mode.

    Since the split among workers takes place before the chunk shuffle,
    workers end up with a fixed set of shards they need to download. The
    more workers, the fewer shards are used by each worker.
    """
    if drop_last is not None:
        warnings.warn(
            "DistributedChunkedSampler does not support drop_last, thus it will be ignored"
        )
    if not dist.is_initialized():
        warnings.warn(
            "DistributedChunkedSampler is called without distributed initialized; assuming single process"
        )
        num_replicas = 1
        rank = 0
    else:
        num_replicas = num_replicas or dist.get_world_size()
        rank = rank or dist.get_rank()
    assert rank >= 0 and rank < num_replicas

    # From https://github.com/pytorch/pytorch/blob/13fa59580e4dd695817ccf2f24922fd211667fc8/torch/utils/data/distributed.py#L93
    dslength_per_replica = (
        math.ceil(len(dataset) / num_replicas) if num_replicas > 1 else len(dataset)
    )

    num_samples = num_samples or len(dataset)
    worker_chunk = (num_samples + num_replicas - 1) // num_replicas
    worker_start = rank * worker_chunk
    worker_end = min(worker_start + worker_chunk, num_samples)
    return ChunkedSampler(
        dataset,
        dslength_per_replica=dslength_per_replica,
        num_samples=(worker_start, worker_end),
        chunksize=chunksize,
        seed=seed,
        shuffle=shuffle,
        shufflefirst=shufflefirst,
    )

wids.ShardedSampler = ShardListSampler module-attribute