webdataset.utils

Miscellaneous utility functions.

View Source
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#

"""Miscellaneous utility functions."""

import sys
import re
import importlib
import itertools as itt


def identity(x):
    """Return the argument as is."""
    return x


def safe_eval(s, expr="{}"):
    """Evaluate the given expression more safely."""
    if re.sub("[^A-Za-z0-9_]", "", s) != s:
        raise ValueError(f"safe_eval: illegal characters in: '{s}'")
    return eval(expr.format(s))


def lookup_sym(sym, modules):
    """Look up a symbol in a list of modules."""
    for mname in modules:
        module = importlib.import_module(mname, package="webdataset")
        result = getattr(module, sym, None)
        if result is not None:
            return result
    return None


def repeatedly0(loader, nepochs=sys.maxsize, nbatches=sys.maxsize):
    """Repeatedly returns batches from a DataLoader."""
    for epoch in range(nepochs):
        for sample in itt.islice(loader, nbatches):
            yield sample


def guess_batchsize(batch):
    """Guess the batch size by looking at the length of the first element in a tuple."""
    return len(batch[0])


def repeatedly(source, nepochs=None, nbatches=None, nsamples=None, batchsize=guess_batchsize):
    """Repeatedly yield samples from an iterator."""
    epoch = 0
    batch = 0
    total = 0
    while True:
        for sample in source:
            yield sample
            batch += 1
            if nbatches is not None and batch >= nbatches:
                return
            if nsamples is not None:
                total += guess_batchsize(sample)
                if total >= nsamples:
                    return
        epoch += 1
        if nepochs is not None and epoch >= nepochs:
            return
#   def identity(x):
View Source
def identity(x):
    """Return the argument as is."""
    return x

Return the argument as is.

#   def safe_eval(s, expr='{}'):
View Source
def safe_eval(s, expr="{}"):
    """Evaluate the given expression more safely."""
    if re.sub("[^A-Za-z0-9_]", "", s) != s:
        raise ValueError(f"safe_eval: illegal characters in: '{s}'")
    return eval(expr.format(s))

Evaluate the given expression more safely.

#   def lookup_sym(sym, modules):
View Source
def lookup_sym(sym, modules):
    """Look up a symbol in a list of modules."""
    for mname in modules:
        module = importlib.import_module(mname, package="webdataset")
        result = getattr(module, sym, None)
        if result is not None:
            return result
    return None

Look up a symbol in a list of modules.

#   def repeatedly0(loader, nepochs=9223372036854775807, nbatches=9223372036854775807):
View Source
def repeatedly0(loader, nepochs=sys.maxsize, nbatches=sys.maxsize):
    """Repeatedly returns batches from a DataLoader."""
    for epoch in range(nepochs):
        for sample in itt.islice(loader, nbatches):
            yield sample

Repeatedly returns batches from a DataLoader.

#   def guess_batchsize(batch):
View Source
def guess_batchsize(batch):
    """Guess the batch size by looking at the length of the first element in a tuple."""
    return len(batch[0])

Guess the batch size by looking at the length of the first element in a tuple.

#   def repeatedly( source, nepochs=None, nbatches=None, nsamples=None, batchsize=<function guess_batchsize> ):
View Source
def repeatedly(source, nepochs=None, nbatches=None, nsamples=None, batchsize=guess_batchsize):
    """Repeatedly yield samples from an iterator."""
    epoch = 0
    batch = 0
    total = 0
    while True:
        for sample in source:
            yield sample
            batch += 1
            if nbatches is not None and batch >= nbatches:
                return
            if nsamples is not None:
                total += guess_batchsize(sample)
                if total >= nsamples:
                    return
        epoch += 1
        if nepochs is not None and epoch >= nepochs:
            return

Repeatedly yield samples from an iterator.