cache.py 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
import itertools, os, random, re, sys
from urllib.parse import urlparse

from . import filters
from webdataset import gopen
from webdataset.handlers import reraise_exception
from .tariterators import tar_file_and_group_expander

default_cache_dir = os.environ.get("WDS_CACHE", "./_cache")
default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18"))


def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
    """Performs cleanup of the file cache in cache_dir using an LRU strategy,
    keeping the total size of all remaining files below cache_size."""
    if not os.path.exists(cache_dir):
        return
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(cache_dir):
        for filename in filenames:
            total_size += os.path.getsize(os.path.join(dirpath, filename))
    if total_size <= cache_size:
        return
    # sort files by last access time
    files = []
    for dirpath, dirnames, filenames in os.walk(cache_dir):
        for filename in filenames:
            files.append(os.path.join(dirpath, filename))
    files.sort(key=keyfn, reverse=True)
    # delete files until we're under the cache size
    while len(files) > 0 and total_size > cache_size:
        fname = files.pop()
        total_size -= os.path.getsize(fname)
        if verbose:
            print("# deleting %s" % fname, file=sys.stderr)
        os.remove(fname)


def download(url, dest, chunk_size=1024 ** 2, verbose=False):
    """Download a file from `url` to `dest`."""
    temp = dest + f".temp{os.getpid()}"
    with gopen.gopen(url) as stream:
        with open(temp, "wb") as f:
            while True:
                data = stream.read(chunk_size)
                if not data:
                    break
                f.write(data)
    os.rename(temp, dest)


def pipe_cleaner(spec):
    """Guess the actual URL from a "pipe:" specification."""
    if spec.startswith("pipe:"):
        spec = spec[5:]
        words = spec.split(" ")
        for word in words:
            if re.match(r"^(https?|gs|ais|s3)", word):
                return word
    return spec


def get_file_cached(
    spec,
    cache_size=-1,
    cache_dir=None,
    url_to_name=pipe_cleaner,
    verbose=False,
):
    if cache_size == -1:
        cache_size = default_cache_size
    if cache_dir is None:
        cache_dir = default_cache_dir
    url = url_to_name(spec)
    parsed = urlparse(url)
    dirname, filename = os.path.split(parsed.path)
    dirname = dirname.lstrip("/")
    dirname = re.sub(r"[:/|;]", "_", dirname)
    destdir = os.path.join(cache_dir, dirname)
    os.makedirs(destdir, exist_ok=True)
    dest = os.path.join(cache_dir, dirname, filename)
    if not os.path.exists(dest):
        if verbose:
            print("# downloading %s to %s" % (url, dest), file=sys.stderr)
        lru_cleanup(cache_dir, cache_size, verbose=verbose)
        download(spec, dest, verbose=verbose)
    return dest


def get_filetype(fname):
    with os.popen("file '%s'" % fname) as f:
        ftype = f.read()
    return ftype


def check_tar_format(fname):
    """Check whether a file is a tar archive."""
    ftype = get_filetype(fname)
    return "tar archive" in ftype or "gzip compressed" in ftype


verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))


def cached_url_opener(
    data,
    handler=reraise_exception,
    cache_size=-1,
    cache_dir=None,
    url_to_name=pipe_cleaner,
    validator=check_tar_format,
    verbose=False,
    always=False,
):
    """Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
    verbose = verbose or verbose_cache
    for sample in data:
        assert isinstance(sample, dict), sample
        assert "url" in sample
        url = sample["url"]
        attempts = 5
        try:
            if not always and os.path.exists(url):
                dest = url
            else:
                dest = get_file_cached(
                    url,
                    cache_size=cache_size,
                    cache_dir=cache_dir,
                    url_to_name=url_to_name,
                    verbose=verbose,
                )
            if verbose:
                print("# opening %s" % dest, file=sys.stderr)
            assert os.path.exists(dest)
            if not validator(dest):
                ftype = get_filetype(dest)
                with open(dest, "rb") as f:
                    data = f.read(200)
                os.remove(dest)
                raise ValueError(
                    "%s (%s) is not a tar archive, but a %s, contains %s"
                    % (dest, url, ftype, repr(data))
                )
            try:
                stream = open(dest, "rb")
                sample.update(stream=stream)
                yield sample
            except FileNotFoundError as exn:
                # dealing with race conditions in lru_cleanup
                attempts -= 1
                if attempts > 0:
                    time.sleep(random.random() * 10)
                    continue
                raise exn
        except Exception as exn:
            exn.args = exn.args + (url,)
            if handler(exn):
                continue
            else:
                break


def cached_tarfile_samples(
    src,
    handler=reraise_exception,
    cache_size=-1,
    cache_dir=None,
    verbose=False,
    url_to_name=pipe_cleaner,
    always=False,
):
    streams = cached_url_opener(
        src,
        handler=handler,
        cache_size=cache_size,
        cache_dir=cache_dir,
        verbose=verbose,
        url_to_name=url_to_name,
        always=always,
    )
    samples = tar_file_and_group_expander(streams, handler=handler)
    return samples


cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples)