webdataset.writer

Classes and functions for writing tar files and WebDataset files.

View Source
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#

"""Classes and functions for writing tar files and WebDataset files."""

import io
import pickle
import re
import tarfile
import time
import json

import numpy as np

from . import gopen


def imageencoder(image, format="PNG"):  # skipcq: PYL-W0622
    """Compress an image using PIL and return it as a string.

    Can handle float or uint8 images.

    :param image: ndarray representing an image
    :param format: compression format (PNG, JPEG, PPM)

    """
    import PIL

    if isinstance(image, np.ndarray):
        if image.dtype in [np.dtype("f"), np.dtype("d")]:
            if not (np.amin(image) > -0.001 and np.amax(image) < 1.001):
                raise ValueError(f"image values out of range {np.amin(image)} {np.amax(image)}")
            image = np.clip(image, 0.0, 1.0)
            image = np.array(image * 255.0, "uint8")
        assert image.ndim in [2, 3]
        if image.ndim == 3:
            assert image.shape[2] in [1, 3]
        image = PIL.Image.fromarray(image)
    if format.upper() == "JPG":
        format = "JPEG"
    elif format.upper() in ["IMG", "IMAGE"]:
        format = "PPM"
    if format == "JPEG":
        opts = dict(quality=100)
    else:
        opts = {}
    with io.BytesIO() as result:
        image.save(result, format=format, **opts)
        return result.getvalue()


def bytestr(data):
    """Convert data into a bytestring.

    Uses str and ASCII encoding for data that isn't already in string format.

    :param data: data
    """
    if isinstance(data, bytes):
        return data
    if isinstance(data, str):
        return data.encode("ascii")
    return str(data).encode("ascii")


def torch_dumps(data):
    """Dump data into a bytestring using torch.dumps.

    This delays importing torch until needed.

    :param data: data to be dumped
    """
    import io
    import torch

    stream = io.BytesIO()
    torch.save(data, stream)
    return stream.getvalue()


def numpy_dumps(data):
    """Dump data into a bytestring using numpy npy format

    :param data: data to be dumped
    """
    import io
    import numpy.lib.format

    stream = io.BytesIO()
    numpy.lib.format.write_array(stream, data)
    return stream.getvalue()


def make_handlers():
    """Create a list of handlers for encoding data."""
    handlers = {}
    for extension in ["cls", "cls2", "class", "count", "index", "inx", "id"]:
        handlers[extension] = lambda x: str(x).encode("ascii")
    for extension in ["txt", "text", "transcript"]:
        handlers[extension] = lambda x: x.encode("utf-8")
    for extension in ["png", "jpg", "jpeg", "img", "image", "pbm", "pgm", "ppm"]:

        def f(extension_):
            """f.

            :param extension_:
            """
            handlers[extension] = lambda data: imageencoder(data, extension_)

        f(extension)
    for extension in ["pyd", "pickle"]:
        handlers[extension] = pickle.dumps
    for extension in ["pth"]:
        handlers[extension] = torch_dumps
    for extension in ["npy"]:
        handlers[extension] = numpy_dumps
    for extension in ["json", "jsn"]:
        handlers[extension] = lambda x: json.dumps(x).encode("utf-8")
    for extension in ["ten", "tb"]:
        from . import tenbin

        def f(x):  # skipcq: PYL-E0102
            if isinstance(x, list):
                return memoryview(tenbin.encode_buffer(x))
            else:
                return memoryview(tenbin.encode_buffer([x]))

        handlers[extension] = f
    try:
        import msgpack

        for extension in ["mp", "msgpack", "msg"]:
            handlers[extension] = msgpack.packb
    except ImportError:
        pass
    return handlers


default_handlers = {"default": make_handlers()}


def encode_based_on_extension1(data, tname, handlers):
    """Encode data based on its extension and a dict of handlers.

    :param data: data
    :param tname: file extension
    :param handlers: handlers
    """
    if tname[0] == "_":
        if not isinstance(data, str):
            raise ValueError("the values of metadata must be of string type")
        return data
    extension = re.sub(r".*\.", "", tname).lower()
    if isinstance(data, bytes):
        return data
    if isinstance(data, str):
        return data.encode("utf-8")
    handler = handlers.get(extension)
    if handler is None:
        raise ValueError(f"no handler found for {extension}")
    return handler(data)


def encode_based_on_extension(sample, handlers):
    """Encode an entire sample with a collection of handlers.

    :param sample: data sample (a dict)
    :param handlers: handlers for encoding
    """
    return {k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items())}


def make_encoder(spec):
    """Make an encoder function from a specification.

    :param spec: specification
    """
    if spec is False or spec is None:

        def encoder(x):
            """Do not encode at all."""
            return x

    elif callable(spec):
        encoder = spec
    elif isinstance(spec, dict):

        def encoder(sample):
            """Encode based on extension."""
            return encode_based_on_extension(sample, spec)

    elif isinstance(spec, str) or spec is True:
        if spec is True:
            spec = "default"
        handlers = default_handlers.get(spec)
        if handlers is None:
            raise ValueError(f"no handler found for {spec}")

        def encoder(sample):
            """Encode based on extension."""
            return encode_based_on_extension(sample, handlers)

    else:
        raise ValueError(f"{spec}: unknown decoder spec")
    if not callable(encoder):
        raise ValueError(f"{spec} did not yield a callable encoder")
    return encoder


class TarWriter:
    """A class for writing dictionaries to tar files.

    :param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
    :param encoder: sample encoding (Default value = True)
    :param compress:  (Default value = None)

    `True` will use an encoder that behaves similar to the automatic
    decoder for `Dataset`. `False` disables encoding and expects byte strings
    (except for metadata, which must be strings). The `encoder` argument can
    also be a `callable`, or a dictionary mapping extensions to encoders.

    The following code will add two file to the tar archive: `a/b.png` and
    `a/b.output.png`.

    ```Python
        tarwriter = TarWriter(stream)
        image = imread("b.jpg")
        image2 = imread("b.out.jpg")
        sample = {"__key__": "a/b", "png": image, "output.png": image2}
        tarwriter.write(sample)
    ```
    """

    def __init__(
        self,
        fileobj,
        user="bigdata",
        group="bigdata",
        mode=0o0444,
        compress=None,
        encoder=True,
        keep_meta=False,
    ):
        """Create a tar writer.

        :param fileobj: stream to write data to
        :param user: user for tar files
        :param group: group for tar files
        :param mode: mode for tar files
        :param compress: desired compression
        :param encoder: encoder function
        :param keep_meta: keep metadata (entries starting with "_")
        """
        if isinstance(fileobj, str):
            if compress is False:
                tarmode = "w|"
            elif compress is True:
                tarmode = "w|gz"
            else:
                tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
            fileobj = gopen.gopen(fileobj, "wb")
            self.own_fileobj = fileobj
        else:
            tarmode = "w|gz" if compress is True else "w|"
            self.own_fileobj = None
        self.encoder = make_encoder(encoder)
        self.keep_meta = keep_meta
        self.stream = fileobj
        self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)

        self.user = user
        self.group = group
        self.mode = mode
        self.compress = compress

    def __enter__(self):
        """Enter context."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Exit context."""
        self.close()

    def close(self):
        """Close the tar file."""
        self.tarstream.close()
        if self.own_fileobj is not None:
            self.own_fileobj.close()
            self.own_fileobj = None

    def write(self, obj):
        """Write a dictionary to the tar file.

        :param obj: dictionary of objects to be stored
        :returns: size of the entry

        """
        total = 0
        obj = self.encoder(obj)
        if "__key__" not in obj:
            raise ValueError("object must contain a __key__")
        for k, v in list(obj.items()):
            if k[0] == "_":
                continue
            if not isinstance(v, (bytes, bytearray, memoryview)):
                raise ValueError(f"{k} doesn't map to a bytes after encoding ({type(v)})")
        key = obj["__key__"]
        for k in sorted(obj.keys()):
            if k == "__key__":
                continue
            if not self.keep_meta and k[0] == "_":
                continue
            v = obj[k]
            if isinstance(v, str):
                v = v.encode("utf-8")
            now = time.time()
            ti = tarfile.TarInfo(key + "." + k)
            ti.size = len(v)
            ti.mtime = now
            ti.mode = self.mode
            ti.uname = self.user
            ti.gname = self.group
            if not isinstance(v, (bytes, bytearray, memoryview)):
                raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
            stream = io.BytesIO(v)
            self.tarstream.addfile(ti, stream)
            total += ti.size
        return total


class ShardWriter:
    """Like TarWriter but splits into multiple shards."""

    def __init__(self, pattern, maxcount=100000, maxsize=3e9, post=None, start_shard=0, **kw):
        """Create a ShardWriter.

        :param pattern: output file pattern
        :param maxcount: maximum number of records per shard (Default value = 100000)
        :param maxsize: maximum size of each shard (Default value = 3e9)
        :param kw: other options passed to TarWriter
        """
        self.verbose = 1
        self.kw = kw
        self.maxcount = maxcount
        self.maxsize = maxsize
        self.post = post

        self.tarstream = None
        self.shard = start_shard
        self.pattern = pattern
        self.total = 0
        self.count = 0
        self.size = 0
        self.fname = None
        self.next_stream()

    def next_stream(self):
        """Close the current stream and move to the next."""
        self.finish()
        self.fname = self.pattern % self.shard
        if self.verbose:
            print(
                "# writing",
                self.fname,
                self.count,
                "%.1f GB" % (self.size / 1e9),
                self.total,
            )
        self.shard += 1
        stream = open(self.fname, "wb")
        self.tarstream = TarWriter(stream, **self.kw)
        self.count = 0
        self.size = 0

    def write(self, obj):
        """Write a sample.

        :param obj: sample to be written
        """
        if self.tarstream is None or self.count >= self.maxcount or self.size >= self.maxsize:
            self.next_stream()
        size = self.tarstream.write(obj)
        self.count += 1
        self.total += 1
        self.size += size

    def finish(self):
        """Finish all writing (use close instead)."""
        if self.tarstream is not None:
            self.tarstream.close()
            assert self.fname is not None
            if callable(self.post):
                self.post(self.fname)
            self.tarstream = None

    def close(self):
        """Close the stream."""
        self.finish()
        del self.tarstream
        del self.shard
        del self.count
        del self.size

    def __enter__(self):
        """Enter context."""
        return self

    def __exit__(self, *args, **kw):
        """Exit context."""
        self.close()
#   def imageencoder(image, format='PNG'):
View Source
def imageencoder(image, format="PNG"):  # skipcq: PYL-W0622
    """Compress an image using PIL and return it as a string.

    Can handle float or uint8 images.

    :param image: ndarray representing an image
    :param format: compression format (PNG, JPEG, PPM)

    """
    import PIL

    if isinstance(image, np.ndarray):
        if image.dtype in [np.dtype("f"), np.dtype("d")]:
            if not (np.amin(image) > -0.001 and np.amax(image) < 1.001):
                raise ValueError(f"image values out of range {np.amin(image)} {np.amax(image)}")
            image = np.clip(image, 0.0, 1.0)
            image = np.array(image * 255.0, "uint8")
        assert image.ndim in [2, 3]
        if image.ndim == 3:
            assert image.shape[2] in [1, 3]
        image = PIL.Image.fromarray(image)
    if format.upper() == "JPG":
        format = "JPEG"
    elif format.upper() in ["IMG", "IMAGE"]:
        format = "PPM"
    if format == "JPEG":
        opts = dict(quality=100)
    else:
        opts = {}
    with io.BytesIO() as result:
        image.save(result, format=format, **opts)
        return result.getvalue()

Compress an image using PIL and return it as a string.

Can handle float or uint8 images.

:param image: ndarray representing an image :param format: compression format (PNG, JPEG, PPM)

#   def bytestr(data):
View Source
def bytestr(data):
    """Convert data into a bytestring.

    Uses str and ASCII encoding for data that isn't already in string format.

    :param data: data
    """
    if isinstance(data, bytes):
        return data
    if isinstance(data, str):
        return data.encode("ascii")
    return str(data).encode("ascii")

Convert data into a bytestring.

Uses str and ASCII encoding for data that isn't already in string format.

:param data: data

#   def torch_dumps(data):
View Source
def torch_dumps(data):
    """Dump data into a bytestring using torch.dumps.

    This delays importing torch until needed.

    :param data: data to be dumped
    """
    import io
    import torch

    stream = io.BytesIO()
    torch.save(data, stream)
    return stream.getvalue()

Dump data into a bytestring using torch.dumps.

This delays importing torch until needed.

:param data: data to be dumped

#   def numpy_dumps(data):
View Source
def numpy_dumps(data):
    """Dump data into a bytestring using numpy npy format

    :param data: data to be dumped
    """
    import io
    import numpy.lib.format

    stream = io.BytesIO()
    numpy.lib.format.write_array(stream, data)
    return stream.getvalue()

Dump data into a bytestring using numpy npy format

:param data: data to be dumped

#   def make_handlers():
View Source
def make_handlers():
    """Create a list of handlers for encoding data."""
    handlers = {}
    for extension in ["cls", "cls2", "class", "count", "index", "inx", "id"]:
        handlers[extension] = lambda x: str(x).encode("ascii")
    for extension in ["txt", "text", "transcript"]:
        handlers[extension] = lambda x: x.encode("utf-8")
    for extension in ["png", "jpg", "jpeg", "img", "image", "pbm", "pgm", "ppm"]:

        def f(extension_):
            """f.

            :param extension_:
            """
            handlers[extension] = lambda data: imageencoder(data, extension_)

        f(extension)
    for extension in ["pyd", "pickle"]:
        handlers[extension] = pickle.dumps
    for extension in ["pth"]:
        handlers[extension] = torch_dumps
    for extension in ["npy"]:
        handlers[extension] = numpy_dumps
    for extension in ["json", "jsn"]:
        handlers[extension] = lambda x: json.dumps(x).encode("utf-8")
    for extension in ["ten", "tb"]:
        from . import tenbin

        def f(x):  # skipcq: PYL-E0102
            if isinstance(x, list):
                return memoryview(tenbin.encode_buffer(x))
            else:
                return memoryview(tenbin.encode_buffer([x]))

        handlers[extension] = f
    try:
        import msgpack

        for extension in ["mp", "msgpack", "msg"]:
            handlers[extension] = msgpack.packb
    except ImportError:
        pass
    return handlers

Create a list of handlers for encoding data.

#   def encode_based_on_extension1(data, tname, handlers):
View Source
def encode_based_on_extension1(data, tname, handlers):
    """Encode data based on its extension and a dict of handlers.

    :param data: data
    :param tname: file extension
    :param handlers: handlers
    """
    if tname[0] == "_":
        if not isinstance(data, str):
            raise ValueError("the values of metadata must be of string type")
        return data
    extension = re.sub(r".*\.", "", tname).lower()
    if isinstance(data, bytes):
        return data
    if isinstance(data, str):
        return data.encode("utf-8")
    handler = handlers.get(extension)
    if handler is None:
        raise ValueError(f"no handler found for {extension}")
    return handler(data)

Encode data based on its extension and a dict of handlers.

:param data: data :param tname: file extension :param handlers: handlers

#   def encode_based_on_extension(sample, handlers):
View Source
def encode_based_on_extension(sample, handlers):
    """Encode an entire sample with a collection of handlers.

    :param sample: data sample (a dict)
    :param handlers: handlers for encoding
    """
    return {k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items())}

Encode an entire sample with a collection of handlers.

:param sample: data sample (a dict) :param handlers: handlers for encoding

#   def make_encoder(spec):
View Source
def make_encoder(spec):
    """Make an encoder function from a specification.

    :param spec: specification
    """
    if spec is False or spec is None:

        def encoder(x):
            """Do not encode at all."""
            return x

    elif callable(spec):
        encoder = spec
    elif isinstance(spec, dict):

        def encoder(sample):
            """Encode based on extension."""
            return encode_based_on_extension(sample, spec)

    elif isinstance(spec, str) or spec is True:
        if spec is True:
            spec = "default"
        handlers = default_handlers.get(spec)
        if handlers is None:
            raise ValueError(f"no handler found for {spec}")

        def encoder(sample):
            """Encode based on extension."""
            return encode_based_on_extension(sample, handlers)

    else:
        raise ValueError(f"{spec}: unknown decoder spec")
    if not callable(encoder):
        raise ValueError(f"{spec} did not yield a callable encoder")
    return encoder

Make an encoder function from a specification.

:param spec: specification

#   class TarWriter:
View Source
class TarWriter:
    """A class for writing dictionaries to tar files.

    :param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
    :param encoder: sample encoding (Default value = True)
    :param compress:  (Default value = None)

    `True` will use an encoder that behaves similar to the automatic
    decoder for `Dataset`. `False` disables encoding and expects byte strings
    (except for metadata, which must be strings). The `encoder` argument can
    also be a `callable`, or a dictionary mapping extensions to encoders.

    The following code will add two file to the tar archive: `a/b.png` and
    `a/b.output.png`.

    ```Python
        tarwriter = TarWriter(stream)
        image = imread("b.jpg")
        image2 = imread("b.out.jpg")
        sample = {"__key__": "a/b", "png": image, "output.png": image2}
        tarwriter.write(sample)
    ```
    """

    def __init__(
        self,
        fileobj,
        user="bigdata",
        group="bigdata",
        mode=0o0444,
        compress=None,
        encoder=True,
        keep_meta=False,
    ):
        """Create a tar writer.

        :param fileobj: stream to write data to
        :param user: user for tar files
        :param group: group for tar files
        :param mode: mode for tar files
        :param compress: desired compression
        :param encoder: encoder function
        :param keep_meta: keep metadata (entries starting with "_")
        """
        if isinstance(fileobj, str):
            if compress is False:
                tarmode = "w|"
            elif compress is True:
                tarmode = "w|gz"
            else:
                tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
            fileobj = gopen.gopen(fileobj, "wb")
            self.own_fileobj = fileobj
        else:
            tarmode = "w|gz" if compress is True else "w|"
            self.own_fileobj = None
        self.encoder = make_encoder(encoder)
        self.keep_meta = keep_meta
        self.stream = fileobj
        self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)

        self.user = user
        self.group = group
        self.mode = mode
        self.compress = compress

    def __enter__(self):
        """Enter context."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Exit context."""
        self.close()

    def close(self):
        """Close the tar file."""
        self.tarstream.close()
        if self.own_fileobj is not None:
            self.own_fileobj.close()
            self.own_fileobj = None

    def write(self, obj):
        """Write a dictionary to the tar file.

        :param obj: dictionary of objects to be stored
        :returns: size of the entry

        """
        total = 0
        obj = self.encoder(obj)
        if "__key__" not in obj:
            raise ValueError("object must contain a __key__")
        for k, v in list(obj.items()):
            if k[0] == "_":
                continue
            if not isinstance(v, (bytes, bytearray, memoryview)):
                raise ValueError(f"{k} doesn't map to a bytes after encoding ({type(v)})")
        key = obj["__key__"]
        for k in sorted(obj.keys()):
            if k == "__key__":
                continue
            if not self.keep_meta and k[0] == "_":
                continue
            v = obj[k]
            if isinstance(v, str):
                v = v.encode("utf-8")
            now = time.time()
            ti = tarfile.TarInfo(key + "." + k)
            ti.size = len(v)
            ti.mtime = now
            ti.mode = self.mode
            ti.uname = self.user
            ti.gname = self.group
            if not isinstance(v, (bytes, bytearray, memoryview)):
                raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
            stream = io.BytesIO(v)
            self.tarstream.addfile(ti, stream)
            total += ti.size
        return total

A class for writing dictionaries to tar files.

:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor :param encoder: sample encoding (Default value = True) :param compress: (Default value = None)

True will use an encoder that behaves similar to the automatic decoder for Dataset. False disables encoding and expects byte strings (except for metadata, which must be strings). The encoder argument can also be a callable, or a dictionary mapping extensions to encoders.

The following code will add two file to the tar archive: a/b.png and a/b.output.png.

    tarwriter = TarWriter(stream)
    image = imread("b.jpg")
    image2 = imread("b.out.jpg")
    sample = {"__key__": "a/b", "png": image, "output.png": image2}
    tarwriter.write(sample)
#   TarWriter( fileobj, user='bigdata', group='bigdata', mode=292, compress=None, encoder=True, keep_meta=False )
View Source
    def __init__(
        self,
        fileobj,
        user="bigdata",
        group="bigdata",
        mode=0o0444,
        compress=None,
        encoder=True,
        keep_meta=False,
    ):
        """Create a tar writer.

        :param fileobj: stream to write data to
        :param user: user for tar files
        :param group: group for tar files
        :param mode: mode for tar files
        :param compress: desired compression
        :param encoder: encoder function
        :param keep_meta: keep metadata (entries starting with "_")
        """
        if isinstance(fileobj, str):
            if compress is False:
                tarmode = "w|"
            elif compress is True:
                tarmode = "w|gz"
            else:
                tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
            fileobj = gopen.gopen(fileobj, "wb")
            self.own_fileobj = fileobj
        else:
            tarmode = "w|gz" if compress is True else "w|"
            self.own_fileobj = None
        self.encoder = make_encoder(encoder)
        self.keep_meta = keep_meta
        self.stream = fileobj
        self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)

        self.user = user
        self.group = group
        self.mode = mode
        self.compress = compress

Create a tar writer.

:param fileobj: stream to write data to :param user: user for tar files :param group: group for tar files :param mode: mode for tar files :param compress: desired compression :param encoder: encoder function :param keep_meta: keep metadata (entries starting with "_")

#   def close(self):
View Source
    def close(self):
        """Close the tar file."""
        self.tarstream.close()
        if self.own_fileobj is not None:
            self.own_fileobj.close()
            self.own_fileobj = None

Close the tar file.

#   def write(self, obj):
View Source
    def write(self, obj):
        """Write a dictionary to the tar file.

        :param obj: dictionary of objects to be stored
        :returns: size of the entry

        """
        total = 0
        obj = self.encoder(obj)
        if "__key__" not in obj:
            raise ValueError("object must contain a __key__")
        for k, v in list(obj.items()):
            if k[0] == "_":
                continue
            if not isinstance(v, (bytes, bytearray, memoryview)):
                raise ValueError(f"{k} doesn't map to a bytes after encoding ({type(v)})")
        key = obj["__key__"]
        for k in sorted(obj.keys()):
            if k == "__key__":
                continue
            if not self.keep_meta and k[0] == "_":
                continue
            v = obj[k]
            if isinstance(v, str):
                v = v.encode("utf-8")
            now = time.time()
            ti = tarfile.TarInfo(key + "." + k)
            ti.size = len(v)
            ti.mtime = now
            ti.mode = self.mode
            ti.uname = self.user
            ti.gname = self.group
            if not isinstance(v, (bytes, bytearray, memoryview)):
                raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
            stream = io.BytesIO(v)
            self.tarstream.addfile(ti, stream)
            total += ti.size
        return total

Write a dictionary to the tar file.

:param obj: dictionary of objects to be stored :returns: size of the entry

#   class ShardWriter:
View Source
class ShardWriter:
    """Like TarWriter but splits into multiple shards."""

    def __init__(self, pattern, maxcount=100000, maxsize=3e9, post=None, start_shard=0, **kw):
        """Create a ShardWriter.

        :param pattern: output file pattern
        :param maxcount: maximum number of records per shard (Default value = 100000)
        :param maxsize: maximum size of each shard (Default value = 3e9)
        :param kw: other options passed to TarWriter
        """
        self.verbose = 1
        self.kw = kw
        self.maxcount = maxcount
        self.maxsize = maxsize
        self.post = post

        self.tarstream = None
        self.shard = start_shard
        self.pattern = pattern
        self.total = 0
        self.count = 0
        self.size = 0
        self.fname = None
        self.next_stream()

    def next_stream(self):
        """Close the current stream and move to the next."""
        self.finish()
        self.fname = self.pattern % self.shard
        if self.verbose:
            print(
                "# writing",
                self.fname,
                self.count,
                "%.1f GB" % (self.size / 1e9),
                self.total,
            )
        self.shard += 1
        stream = open(self.fname, "wb")
        self.tarstream = TarWriter(stream, **self.kw)
        self.count = 0
        self.size = 0

    def write(self, obj):
        """Write a sample.

        :param obj: sample to be written
        """
        if self.tarstream is None or self.count >= self.maxcount or self.size >= self.maxsize:
            self.next_stream()
        size = self.tarstream.write(obj)
        self.count += 1
        self.total += 1
        self.size += size

    def finish(self):
        """Finish all writing (use close instead)."""
        if self.tarstream is not None:
            self.tarstream.close()
            assert self.fname is not None
            if callable(self.post):
                self.post(self.fname)
            self.tarstream = None

    def close(self):
        """Close the stream."""
        self.finish()
        del self.tarstream
        del self.shard
        del self.count
        del self.size

    def __enter__(self):
        """Enter context."""
        return self

    def __exit__(self, *args, **kw):
        """Exit context."""
        self.close()

Like TarWriter but splits into multiple shards.

#   ShardWriter( pattern, maxcount=100000, maxsize=3000000000.0, post=None, start_shard=0, **kw )
View Source
    def __init__(self, pattern, maxcount=100000, maxsize=3e9, post=None, start_shard=0, **kw):
        """Create a ShardWriter.

        :param pattern: output file pattern
        :param maxcount: maximum number of records per shard (Default value = 100000)
        :param maxsize: maximum size of each shard (Default value = 3e9)
        :param kw: other options passed to TarWriter
        """
        self.verbose = 1
        self.kw = kw
        self.maxcount = maxcount
        self.maxsize = maxsize
        self.post = post

        self.tarstream = None
        self.shard = start_shard
        self.pattern = pattern
        self.total = 0
        self.count = 0
        self.size = 0
        self.fname = None
        self.next_stream()

Create a ShardWriter.

:param pattern: output file pattern :param maxcount: maximum number of records per shard (Default value = 100000) :param maxsize: maximum size of each shard (Default value = 3e9) :param kw: other options passed to TarWriter

#   def next_stream(self):
View Source
    def next_stream(self):
        """Close the current stream and move to the next."""
        self.finish()
        self.fname = self.pattern % self.shard
        if self.verbose:
            print(
                "# writing",
                self.fname,
                self.count,
                "%.1f GB" % (self.size / 1e9),
                self.total,
            )
        self.shard += 1
        stream = open(self.fname, "wb")
        self.tarstream = TarWriter(stream, **self.kw)
        self.count = 0
        self.size = 0

Close the current stream and move to the next.

#   def write(self, obj):
View Source
    def write(self, obj):
        """Write a sample.

        :param obj: sample to be written
        """
        if self.tarstream is None or self.count >= self.maxcount or self.size >= self.maxsize:
            self.next_stream()
        size = self.tarstream.write(obj)
        self.count += 1
        self.total += 1
        self.size += size

Write a sample.

:param obj: sample to be written

#   def finish(self):
View Source
    def finish(self):
        """Finish all writing (use close instead)."""
        if self.tarstream is not None:
            self.tarstream.close()
            assert self.fname is not None
            if callable(self.post):
                self.post(self.fname)
            self.tarstream = None

Finish all writing (use close instead).

#   def close(self):
View Source
    def close(self):
        """Close the stream."""
        self.finish()
        del self.tarstream
        del self.shard
        del self.count
        del self.size

Close the stream.