From f89a7b5582c377a2db33948d865f0c18a4d0781f Mon Sep 17 00:00:00 2001 From: Wenyu Date: Thu, 10 Jun 2021 17:22:47 +0800 Subject: [PATCH] add wget option in download (#33379) * add wget option in download --- python/paddle/hapi/hub.py | 6 +- python/paddle/tests/test_download.py | 25 ++++++ python/paddle/utils/download.py | 109 +++++++++++++++++++-------- 3 files changed, 106 insertions(+), 34 deletions(-) diff --git a/python/paddle/hapi/hub.py b/python/paddle/hapi/hub.py index 243bd79c191..b491bc0271b 100644 --- a/python/paddle/hapi/hub.py +++ b/python/paddle/hapi/hub.py @@ -110,7 +110,11 @@ def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'): url = _git_archive_link(repo_owner, repo_name, branch, source=source) fpath = get_path_from_url( - url, hub_dir, check_exist=not force_reload, decompress=False) + url, + hub_dir, + check_exist=not force_reload, + decompress=False, + method=('wget' if source == 'gitee' else 'get')) shutil.move(fpath, cached_file) with zipfile.ZipFile(cached_file) as cached_zipfile: diff --git a/python/paddle/tests/test_download.py b/python/paddle/tests/test_download.py index 4be2dde1bcc..986d84dd153 100644 --- a/python/paddle/tests/test_download.py +++ b/python/paddle/tests/test_download.py @@ -77,6 +77,31 @@ class TestDownload(unittest.TestCase): 'www.baidu.com', './test', ) + def test_wget_download_error(self, ): + with self.assertRaises(RuntimeError): + from paddle.utils.download import _download + _download('www.baidu', './test', method='wget') + + def test_download_methods(self, ): + urls = [ + "https://paddle-hapi.bj.bcebos.com/unittest/files.tar", + "https://paddle-hapi.bj.bcebos.com/unittest/files.zip", + ] + + import sys + from paddle.utils.download import _download + if sys.platform == 'linux': + methods = ['wget', 'get'] + else: + methods = ['get'] + + for url in urls: + for method in methods: + _download( + url, + path='./test', + method=method, ) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/utils/download.py b/python/paddle/utils/download.py index 3ad627ddea9..29baddff05a 100644 --- a/python/paddle/utils/download.py +++ b/python/paddle/utils/download.py @@ -21,6 +21,7 @@ import sys import os.path as osp import shutil import requests +import subprocess import hashlib import tarfile import zipfile @@ -121,7 +122,8 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True, - decompress=True): + decompress=True, + method='get'): """ Download from given url to root_dir. if file or directory specified by url is exists under root_dir, return the path directly, otherwise download @@ -132,7 +134,9 @@ def get_path_from_url(url, root_dir (str): root dir for downloading, it should be WEIGHTS_HOME or DATASET_HOME md5sum (str): md5 sum of download package - + decompress (bool): decompress zip or tar file. Default is `True` + method (str): which download method to use. Support `wget` and `get`. Default is `get`. + Returns: str: a local path to save downloaded models & weights & datasets. """ @@ -150,7 +154,7 @@ def get_path_from_url(url, logger.info("Found {}".format(fullpath)) else: if ParallelEnv().current_endpoint in unique_endpoints: - fullpath = _download(url, root_dir, md5sum) + fullpath = _download(url, root_dir, md5sum, method=method) else: while not os.path.exists(fullpath): time.sleep(1) @@ -163,13 +167,79 @@ def get_path_from_url(url, return fullpath -def _download(url, path, md5sum=None): +def _get_download(url, fullname): + # using requests.get method + fname = osp.basename(fullname) + try: + req = requests.get(url, stream=True) + except Exception as e: # requests.exceptions.ConnectionError + logger.info("Downloading {} from {} failed with exception {}".format( + fname, url, str(e))) + return False + + if req.status_code != 200: + raise RuntimeError("Downloading from {} failed with code " + "{}!".format(url, req.status_code)) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get('content-length') + with open(tmp_fullname, 'wb') as f: + if total_size: + with tqdm(total=(int(total_size) + 1023) // 1024) as pbar: + for chunk in req.iter_content(chunk_size=1024): + f.write(chunk) + pbar.update(1) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + shutil.move(tmp_fullname, fullname) + + return fullname + + +def _wget_download(url, fullname): + # using wget to download url + tmp_fullname = fullname + "_tmp" + # –user-agent + command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT, + url) + subprc = subprocess.Popen( + command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + _ = subprc.communicate() + + if subprc.returncode != 0: + raise RuntimeError( + '{} failed. Please make sure `wget` is installed or {} exists'. + format(command, url)) + + shutil.move(tmp_fullname, fullname) + + return fullname + + +_download_methods = { + 'get': _get_download, + 'wget': _wget_download, +} + + +def _download(url, path, md5sum=None, method='get'): """ Download from url, save to path. url (str): download url path (str): download to given path + md5sum (str): md5 sum of download package + method (str): which download method to use. Support `wget` and `get`. Default is `get`. + """ + assert method in _download_methods, 'make sure `{}` implemented'.format( + method) + if not osp.exists(path): os.makedirs(path) @@ -177,6 +247,7 @@ def _download(url, path, md5sum=None): fullname = osp.join(path, fname) retry_cnt = 0 + logger.info("Downloading {} from {}".format(fname, url)) while not (osp.exists(fullname) and _md5check(fullname, md5sum)): if retry_cnt < DOWNLOAD_RETRY_LIMIT: retry_cnt += 1 @@ -184,38 +255,10 @@ def _download(url, path, md5sum=None): raise RuntimeError("Download from {} failed. " "Retry limit reached".format(url)) - logger.info("Downloading {} from {}".format(fname, url)) - - try: - req = requests.get(url, stream=True) - except Exception as e: # requests.exceptions.ConnectionError - logger.info( - "Downloading {} from {} failed {} times with exception {}". - format(fname, url, retry_cnt + 1, str(e))) + if not _download_methods[method](url, fullname): time.sleep(1) continue - if req.status_code != 200: - raise RuntimeError("Downloading from {} failed with code " - "{}!".format(url, req.status_code)) - - # For protecting download interupted, download to - # tmp_fullname firstly, move tmp_fullname to fullname - # after download finished - tmp_fullname = fullname + "_tmp" - total_size = req.headers.get('content-length') - with open(tmp_fullname, 'wb') as f: - if total_size: - with tqdm(total=(int(total_size) + 1023) // 1024) as pbar: - for chunk in req.iter_content(chunk_size=1024): - f.write(chunk) - pbar.update(1) - else: - for chunk in req.iter_content(chunk_size=1024): - if chunk: - f.write(chunk) - shutil.move(tmp_fullname, fullname) - return fullname -- GitLab