common.py 7.3 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import requests
Y
Yu Yang 已提交
16
import hashlib
Y
Yu Yang 已提交
17
import os
18
import errno
Y
Yu Yang 已提交
19
import shutil
H
Helin Wang 已提交
20
import sys
21
import importlib
22
import paddle.dataset
23
import pickle
24
import tempfile
25
import glob
26
import paddle
Y
Yu Yang 已提交
27

28 29
__all__ = []

S
Steffy-zxf 已提交
30
HOME = os.path.expanduser('~')
31 32 33 34 35 36 37 38 39 40 41 42

# If the default HOME dir does not support writing, we
# will create a temporary folder to store the cache files.
if not os.access(HOME, os.W_OK):
    """
    gettempdir() return the name of the directory used for temporary files.
    On Windows, the directories C:\TEMP, C:\TMP, \TEMP, and \TMP, in that order.
    On all other platforms, the directories /tmp, /var/tmp, and /usr/tmp, in that order.
    For more details, please refer to https://docs.python.org/3/library/tempfile.html
    """
    HOME = tempfile.gettempdir()

S
Steffy-zxf 已提交
43
DATA_HOME = os.path.join(HOME, '.cache', 'paddle', 'dataset')
Y
Yu Yang 已提交
44

45

46 47 48 49 50
# When running unit tests, there could be multiple processes that
# trying to create DATA_HOME directory simultaneously, so we cannot
# use a if condition to check for the existence of the directory;
# instead, we use the filesystem as the synchronization mechanism by
# catching returned errors.
51 52 53 54 55 56 57 58 59
def must_mkdirs(path):
    try:
        os.makedirs(DATA_HOME)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise


must_mkdirs(DATA_HOME)
Y
Yu Yang 已提交
60 61


62 63 64 65 66 67 68
def md5file(fname):
    hash_md5 = hashlib.md5()
    f = open(fname, "rb")
    for chunk in iter(lambda: f.read(4096), b""):
        hash_md5.update(chunk)
    f.close()
    return hash_md5.hexdigest()
69 70


Y
ying 已提交
71
def download(url, module_name, md5sum, save_name=None):
72 73 74 75
    dirname = os.path.join(DATA_HOME, module_name)
    if not os.path.exists(dirname):
        os.makedirs(dirname)

76 77 78
    filename = os.path.join(
        dirname,
        url.split('/')[-1] if save_name is None else save_name)
Y
ying 已提交
79

80 81 82
    if os.path.exists(filename) and md5file(filename) == md5sum:
        return filename

Y
Yu Yang 已提交
83 84 85
    retry = 0
    retry_limit = 3
    while not (os.path.exists(filename) and md5file(filename) == md5sum):
T
wip  
typhoonzero 已提交
86
        if os.path.exists(filename):
H
hong 已提交
87
            sys.stderr.write("file %s  md5 %s\n" % (md5file(filename), md5sum))
Y
Yu Yang 已提交
88 89 90
        if retry < retry_limit:
            retry += 1
        else:
91 92 93
            raise RuntimeError(
                "Cannot download {0} within retry limit {1}".format(
                    url, retry_limit))
H
hong 已提交
94
        sys.stderr.write("Cache file %s not found, downloading %s \n" %
95
                         (filename, url))
H
hong 已提交
96
        sys.stderr.write("Begin to download\n")
H
hong 已提交
97 98 99 100 101 102 103 104 105 106 107 108
        try:
            r = requests.get(url, stream=True)
            total_length = r.headers.get('content-length')

            if total_length is None:
                with open(filename, 'wb') as f:
                    shutil.copyfileobj(r.raw, f)
            else:
                with open(filename, 'wb') as f:
                    chunk_size = 4096
                    total_length = int(total_length)
                    total_iter = total_length / chunk_size + 1
109
                    log_interval = total_iter // 20 if total_iter > 20 else 1
H
hong 已提交
110
                    log_index = 0
111 112
                    bar = paddle.hapi.progressbar.ProgressBar(total_iter,
                                                              name='item')
H
hong 已提交
113 114 115
                    for data in r.iter_content(chunk_size=chunk_size):
                        f.write(data)
                        log_index += 1
116
                        bar.update(log_index, {})
H
hong 已提交
117
                        if log_index % log_interval == 0:
118 119
                            bar.update(log_index)

H
hong 已提交
120 121 122
        except Exception as e:
            # re-try
            continue
H
hong 已提交
123
    sys.stderr.write("\nDownload finished\n")
124
    sys.stdout.flush()
125
    return filename
Y
Yi Wang 已提交
126 127


128
def fetch_all():
129 130 131
    for module_name in [
            x for x in dir(paddle.dataset) if not x.startswith("__")
    ]:
132
        if "fetch" in dir(
133
                importlib.import_module("paddle.dataset.%s" % module_name)):
134 135
            getattr(importlib.import_module("paddle.dataset.%s" % module_name),
                    "fetch")()
136 137


138
def split(reader, line_count, suffix="%05d.pickle", dumper=pickle.dump):
139 140 141
    """
    you can call the function as:

142
    split(paddle.dataset.cifar.train10(), line_count=1000,
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        suffix="imikolov-train-%05d.pickle")

    the output files as:

    |-imikolov-train-00000.pickle
    |-imikolov-train-00001.pickle
    |- ...
    |-imikolov-train-00480.pickle

    :param reader: is a reader creator
    :param line_count: line count for each file
    :param suffix: the suffix for the output files, should contain "%d"
                means the id for each file. Default is "%05d.pickle"
    :param dumper: is a callable function that dump object to file, this
                function will be called as dumper(obj, f) and obj is the object
                will be dumped, f is a file object. Default is cPickle.dump.
    """
    if not callable(dumper):
        raise TypeError("dumper should be callable.")
    lines = []
    indx_f = 0
    for i, d in enumerate(reader()):
        lines.append(d)
        if i >= line_count and i % line_count == 0:
            with open(suffix % indx_f, "w") as f:
                dumper(lines, f)
                lines = []
                indx_f += 1
    if lines:
        with open(suffix % indx_f, "w") as f:
            dumper(lines, f)


def cluster_files_reader(files_pattern,
                         trainer_count,
                         trainer_id,
179
                         loader=pickle.load):
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
    """
    Create a reader that yield element from the given files, select
    a file set according trainer count and trainer_id

    :param files_pattern: the files which generating by split(...)
    :param trainer_count: total trainer count
    :param trainer_id: the trainer rank id
    :param loader: is a callable function that load object from file, this
                function will be called as loader(f) and f is a file object.
                Default is cPickle.load
    """

    def reader():
        if not callable(loader):
            raise TypeError("loader should be callable.")
        file_list = glob.glob(files_pattern)
        file_list.sort()
        my_file_list = []
        for idx, fn in enumerate(file_list):
            if idx % trainer_count == trainer_id:
200
                print("append file: %s" % fn)
201 202 203 204 205 206 207 208
                my_file_list.append(fn)
        for fn in my_file_list:
            with open(fn, "r") as f:
                lines = loader(f)
                for line in lines:
                    yield line

    return reader
209 210 211 212 213 214 215 216 217


def _check_exists_and_download(path, url, md5, module_name, download=True):
    if path and os.path.exists(path):
        return path

    if download:
        return paddle.dataset.common.download(url, module_name, md5)
    else:
218 219
        raise ValueError(
            '{} not exists and auto download disabled'.format(path))