diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 5eba18776c9643077d79e2b6b3c9a239bebec637..372249e01f66bc39045db3f36238da1500c7738e 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -88,27 +88,31 @@ def download(url, module_name, md5sum, save_name=None): sys.stderr.write("Cache file %s not found, downloading %s \n" % (filename, url)) sys.stderr.write("Begin to download\n") - r = requests.get(url, stream=True) - total_length = r.headers.get('content-length') - - if total_length is None: - with open(filename, 'wb') as f: - shutil.copyfileobj(r.raw, f) - else: - with open(filename, 'wb') as f: - chunk_size = 4096 - total_length = int(total_length) - total_iter = total_length / chunk_size + 1 - log_interval = total_iter / 20 if total_iter > 20 else 1 - log_index = 0 - for data in r.iter_content(chunk_size=chunk_size): - if six.PY2: - data = six.b(data) - f.write(data) - log_index += 1 - if log_index % log_interval == 0: - sys.stderr.write(".") - sys.stdout.flush() + try: + r = requests.get(url, stream=True) + total_length = r.headers.get('content-length') + + if total_length is None: + with open(filename, 'wb') as f: + shutil.copyfileobj(r.raw, f) + else: + with open(filename, 'wb') as f: + chunk_size = 4096 + total_length = int(total_length) + total_iter = total_length / chunk_size + 1 + log_interval = total_iter / 20 if total_iter > 20 else 1 + log_index = 0 + for data in r.iter_content(chunk_size=chunk_size): + if six.PY2: + data = six.b(data) + f.write(data) + log_index += 1 + if log_index % log_interval == 0: + sys.stderr.write(".") + sys.stdout.flush() + except Exception as e: + # re-try + continue sys.stderr.write("\nDownload finished\n") sys.stdout.flush() return filename