common.py 7.3 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import requests
Y
Yu Yang 已提交
16
import hashlib
Y
Yu Yang 已提交
17
import os
18
import errno
Y
Yu Yang 已提交
19
import shutil
M
minqiyang 已提交
20
import six
H
Helin Wang 已提交
21
import sys
22
import importlib
23
import paddle.dataset
24
import six.moves.cPickle as pickle
25
import tempfile
26
import glob
27
import paddle
Y
Yu Yang 已提交
28

29 30
__all__ = []

S
Steffy-zxf 已提交
31
HOME = os.path.expanduser('~')
32 33 34 35 36 37 38 39 40 41 42 43

# If the default HOME dir does not support writing, we
# will create a temporary folder to store the cache files.
if not os.access(HOME, os.W_OK):
    """
    gettempdir() return the name of the directory used for temporary files.
    On Windows, the directories C:\TEMP, C:\TMP, \TEMP, and \TMP, in that order.
    On all other platforms, the directories /tmp, /var/tmp, and /usr/tmp, in that order.
    For more details, please refer to https://docs.python.org/3/library/tempfile.html
    """
    HOME = tempfile.gettempdir()

S
Steffy-zxf 已提交
44
DATA_HOME = os.path.join(HOME, '.cache', 'paddle', 'dataset')
Y
Yu Yang 已提交
45

46

47 48 49 50 51
# When running unit tests, there could be multiple processes that
# trying to create DATA_HOME directory simultaneously, so we cannot
# use a if condition to check for the existence of the directory;
# instead, we use the filesystem as the synchronization mechanism by
# catching returned errors.
52 53 54 55 56 57 58 59 60 61
def must_mkdirs(path):
    try:
        os.makedirs(DATA_HOME)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass


must_mkdirs(DATA_HOME)
Y
Yu Yang 已提交
62 63


64 65 66 67 68 69 70
def md5file(fname):
    hash_md5 = hashlib.md5()
    f = open(fname, "rb")
    for chunk in iter(lambda: f.read(4096), b""):
        hash_md5.update(chunk)
    f.close()
    return hash_md5.hexdigest()
71 72


Y
ying 已提交
73
def download(url, module_name, md5sum, save_name=None):
74 75 76 77
    dirname = os.path.join(DATA_HOME, module_name)
    if not os.path.exists(dirname):
        os.makedirs(dirname)

78 79 80
    filename = os.path.join(
        dirname,
        url.split('/')[-1] if save_name is None else save_name)
Y
ying 已提交
81

82 83 84
    if os.path.exists(filename) and md5file(filename) == md5sum:
        return filename

Y
Yu Yang 已提交
85 86 87
    retry = 0
    retry_limit = 3
    while not (os.path.exists(filename) and md5file(filename) == md5sum):
T
wip  
typhoonzero 已提交
88
        if os.path.exists(filename):
H
hong 已提交
89
            sys.stderr.write("file %s  md5 %s\n" % (md5file(filename), md5sum))
Y
Yu Yang 已提交
90 91 92
        if retry < retry_limit:
            retry += 1
        else:
93 94 95
            raise RuntimeError(
                "Cannot download {0} within retry limit {1}".format(
                    url, retry_limit))
H
hong 已提交
96
        sys.stderr.write("Cache file %s not found, downloading %s \n" %
97
                         (filename, url))
H
hong 已提交
98
        sys.stderr.write("Begin to download\n")
H
hong 已提交
99 100 101 102 103 104 105 106 107 108 109 110
        try:
            r = requests.get(url, stream=True)
            total_length = r.headers.get('content-length')

            if total_length is None:
                with open(filename, 'wb') as f:
                    shutil.copyfileobj(r.raw, f)
            else:
                with open(filename, 'wb') as f:
                    chunk_size = 4096
                    total_length = int(total_length)
                    total_iter = total_length / chunk_size + 1
111
                    log_interval = total_iter // 20 if total_iter > 20 else 1
H
hong 已提交
112
                    log_index = 0
113 114
                    bar = paddle.hapi.progressbar.ProgressBar(total_iter,
                                                              name='item')
H
hong 已提交
115 116 117
                    for data in r.iter_content(chunk_size=chunk_size):
                        f.write(data)
                        log_index += 1
118
                        bar.update(log_index, {})
H
hong 已提交
119
                        if log_index % log_interval == 0:
120 121
                            bar.update(log_index)

H
hong 已提交
122 123 124
        except Exception as e:
            # re-try
            continue
H
hong 已提交
125
    sys.stderr.write("\nDownload finished\n")
126
    sys.stdout.flush()
127
    return filename
Y
Yi Wang 已提交
128 129


130
def fetch_all():
131 132 133
    for module_name in [
            x for x in dir(paddle.dataset) if not x.startswith("__")
    ]:
134
        if "fetch" in dir(
135
                importlib.import_module("paddle.dataset.%s" % module_name)):
136 137
            getattr(importlib.import_module("paddle.dataset.%s" % module_name),
                    "fetch")()
138 139


140
def split(reader, line_count, suffix="%05d.pickle", dumper=pickle.dump):
141 142 143
    """
    you can call the function as:

144
    split(paddle.dataset.cifar.train10(), line_count=1000,
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
        suffix="imikolov-train-%05d.pickle")

    the output files as:

    |-imikolov-train-00000.pickle
    |-imikolov-train-00001.pickle
    |- ...
    |-imikolov-train-00480.pickle

    :param reader: is a reader creator
    :param line_count: line count for each file
    :param suffix: the suffix for the output files, should contain "%d"
                means the id for each file. Default is "%05d.pickle"
    :param dumper: is a callable function that dump object to file, this
                function will be called as dumper(obj, f) and obj is the object
                will be dumped, f is a file object. Default is cPickle.dump.
    """
    if not callable(dumper):
        raise TypeError("dumper should be callable.")
    lines = []
    indx_f = 0
    for i, d in enumerate(reader()):
        lines.append(d)
        if i >= line_count and i % line_count == 0:
            with open(suffix % indx_f, "w") as f:
                dumper(lines, f)
                lines = []
                indx_f += 1
    if lines:
        with open(suffix % indx_f, "w") as f:
            dumper(lines, f)


def cluster_files_reader(files_pattern,
                         trainer_count,
                         trainer_id,
181
                         loader=pickle.load):
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
    """
    Create a reader that yield element from the given files, select
    a file set according trainer count and trainer_id

    :param files_pattern: the files which generating by split(...)
    :param trainer_count: total trainer count
    :param trainer_id: the trainer rank id
    :param loader: is a callable function that load object from file, this
                function will be called as loader(f) and f is a file object.
                Default is cPickle.load
    """

    def reader():
        if not callable(loader):
            raise TypeError("loader should be callable.")
        file_list = glob.glob(files_pattern)
        file_list.sort()
        my_file_list = []
        for idx, fn in enumerate(file_list):
            if idx % trainer_count == trainer_id:
202
                print("append file: %s" % fn)
203 204 205 206 207 208 209 210
                my_file_list.append(fn)
        for fn in my_file_list:
            with open(fn, "r") as f:
                lines = loader(f)
                for line in lines:
                    yield line

    return reader
211 212 213 214 215 216 217 218 219


def _check_exists_and_download(path, url, md5, module_name, download=True):
    if path and os.path.exists(path):
        return path

    if download:
        return paddle.dataset.common.download(url, module_name, md5)
    else:
220 221
        raise ValueError(
            '{} not exists and auto download disabled'.format(path))