未验证 提交 a214a308 编写于 作者: H hong 提交者: GitHub

change download log format (#21290)

* change download log formate; test=develop

* add unittest for data download; test=develop

* remove cache before download; test=develop
上级 234060f8
...@@ -79,14 +79,15 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -79,14 +79,15 @@ def download(url, module_name, md5sum, save_name=None):
retry_limit = 3 retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum): while not (os.path.exists(filename) and md5file(filename) == md5sum):
if os.path.exists(filename): if os.path.exists(filename):
sys.stderr.write("file %s md5 %s" % (md5file(filename), md5sum)) sys.stderr.write("file %s md5 %s\n" % (md5file(filename), md5sum))
if retry < retry_limit: if retry < retry_limit:
retry += 1 retry += 1
else: else:
raise RuntimeError("Cannot download {0} within retry limit {1}". raise RuntimeError("Cannot download {0} within retry limit {1}".
format(url, retry_limit)) format(url, retry_limit))
sys.stderr.write("Cache file %s not found, downloading %s" % sys.stderr.write("Cache file %s not found, downloading %s \n" %
(filename, url)) (filename, url))
sys.stderr.write("Begin to download\n")
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
total_length = r.headers.get('content-length') total_length = r.headers.get('content-length')
...@@ -95,18 +96,20 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -95,18 +96,20 @@ def download(url, module_name, md5sum, save_name=None):
shutil.copyfileobj(r.raw, f) shutil.copyfileobj(r.raw, f)
else: else:
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
dl = 0 chunk_size = 4096
total_length = int(total_length) total_length = int(total_length)
for data in r.iter_content(chunk_size=4096): total_iter = total_length / chunk_size + 1
log_interval = total_iter / 20 if total_iter > 20 else 1
log_index = 0
for data in r.iter_content(chunk_size=chunk_size):
if six.PY2: if six.PY2:
data = six.b(data) data = six.b(data)
dl += len(data)
f.write(data) f.write(data)
done = int(50 * dl / total_length) log_index += 1
sys.stderr.write("\r[%s%s]" % ('=' * done, if log_index % log_interval == 0:
' ' * (50 - done))) sys.stderr.write(".")
sys.stdout.flush() sys.stdout.flush()
sys.stderr.write("\n") sys.stderr.write("\nDownload finished\n")
sys.stdout.flush() sys.stdout.flush()
return filename return filename
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from paddle.dataset.common import download, DATA_HOME, md5file
class TestDataSetDownload(unittest.TestCase):
def setUp(self):
flower_path = DATA_HOME + "/flowers/imagelabels.mat"
if os.path.exists(flower_path):
os.remove(flower_path)
def test_download_url(self):
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
catch_exp = False
try:
download(LABEL_URL, 'flowers', LABEL_MD5)
except Exception as e:
catch_exp = True
self.assertTrue(catch_exp == False)
file_path = DATA_HOME + "/flowers/imagelabels.mat"
self.assertTrue(os.path.exists(file_path))
self.assertTrue(md5file(file_path), LABEL_MD5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册