From a214a3081b3da3064f8b7f79143cfe35db4a3ffc Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Tue, 26 Nov 2019 12:04:43 +0800 Subject: [PATCH] change download log format (#21290) * change download log formate; test=develop * add unittest for data download; test=develop * remove cache before download; test=develop --- python/paddle/dataset/common.py | 21 +++++---- .../tests/unittests/test_dataset_download.py | 46 +++++++++++++++++++ 2 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_dataset_download.py diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index e8c27180ee4..9060b8c0ddb 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -79,14 +79,15 @@ def download(url, module_name, md5sum, save_name=None): retry_limit = 3 while not (os.path.exists(filename) and md5file(filename) == md5sum): if os.path.exists(filename): - sys.stderr.write("file %s md5 %s" % (md5file(filename), md5sum)) + sys.stderr.write("file %s md5 %s\n" % (md5file(filename), md5sum)) if retry < retry_limit: retry += 1 else: raise RuntimeError("Cannot download {0} within retry limit {1}". format(url, retry_limit)) - sys.stderr.write("Cache file %s not found, downloading %s" % + sys.stderr.write("Cache file %s not found, downloading %s \n" % (filename, url)) + sys.stderr.write("Begin to download\n") r = requests.get(url, stream=True) total_length = r.headers.get('content-length') @@ -95,18 +96,20 @@ def download(url, module_name, md5sum, save_name=None): shutil.copyfileobj(r.raw, f) else: with open(filename, 'wb') as f: - dl = 0 + chunk_size = 4096 total_length = int(total_length) - for data in r.iter_content(chunk_size=4096): + total_iter = total_length / chunk_size + 1 + log_interval = total_iter / 20 if total_iter > 20 else 1 + log_index = 0 + for data in r.iter_content(chunk_size=chunk_size): if six.PY2: data = six.b(data) - dl += len(data) f.write(data) - done = int(50 * dl / total_length) - sys.stderr.write("\r[%s%s]" % ('=' * done, - ' ' * (50 - done))) + log_index += 1 + if log_index % log_interval == 0: + sys.stderr.write(".") sys.stdout.flush() - sys.stderr.write("\n") + sys.stderr.write("\nDownload finished\n") sys.stdout.flush() return filename diff --git a/python/paddle/fluid/tests/unittests/test_dataset_download.py b/python/paddle/fluid/tests/unittests/test_dataset_download.py new file mode 100644 index 00000000000..f1fba215b93 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dataset_download.py @@ -0,0 +1,46 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from paddle.dataset.common import download, DATA_HOME, md5file + + +class TestDataSetDownload(unittest.TestCase): + def setUp(self): + flower_path = DATA_HOME + "/flowers/imagelabels.mat" + + if os.path.exists(flower_path): + os.remove(flower_path) + + def test_download_url(self): + LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat' + LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' + + catch_exp = False + try: + download(LABEL_URL, 'flowers', LABEL_MD5) + except Exception as e: + catch_exp = True + + self.assertTrue(catch_exp == False) + + file_path = DATA_HOME + "/flowers/imagelabels.mat" + + self.assertTrue(os.path.exists(file_path)) + self.assertTrue(md5file(file_path), LABEL_MD5) + + +if __name__ == '__main__': + unittest.main() -- GitLab