未验证 提交 f89a7b55 编写于 作者: W Wenyu 提交者: GitHub

add wget option in download (#33379)

* add wget option in download
上级 945e0847
...@@ -110,7 +110,11 @@ def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'): ...@@ -110,7 +110,11 @@ def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'):
url = _git_archive_link(repo_owner, repo_name, branch, source=source) url = _git_archive_link(repo_owner, repo_name, branch, source=source)
fpath = get_path_from_url( fpath = get_path_from_url(
url, hub_dir, check_exist=not force_reload, decompress=False) url,
hub_dir,
check_exist=not force_reload,
decompress=False,
method=('wget' if source == 'gitee' else 'get'))
shutil.move(fpath, cached_file) shutil.move(fpath, cached_file)
with zipfile.ZipFile(cached_file) as cached_zipfile: with zipfile.ZipFile(cached_file) as cached_zipfile:
......
...@@ -77,6 +77,31 @@ class TestDownload(unittest.TestCase): ...@@ -77,6 +77,31 @@ class TestDownload(unittest.TestCase):
'www.baidu.com', 'www.baidu.com',
'./test', ) './test', )
def test_wget_download_error(self, ):
with self.assertRaises(RuntimeError):
from paddle.utils.download import _download
_download('www.baidu', './test', method='wget')
def test_download_methods(self, ):
urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/files.zip",
]
import sys
from paddle.utils.download import _download
if sys.platform == 'linux':
methods = ['wget', 'get']
else:
methods = ['get']
for url in urls:
for method in methods:
_download(
url,
path='./test',
method=method, )
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -21,6 +21,7 @@ import sys ...@@ -21,6 +21,7 @@ import sys
import os.path as osp import os.path as osp
import shutil import shutil
import requests import requests
import subprocess
import hashlib import hashlib
import tarfile import tarfile
import zipfile import zipfile
...@@ -121,7 +122,8 @@ def get_path_from_url(url, ...@@ -121,7 +122,8 @@ def get_path_from_url(url,
root_dir, root_dir,
md5sum=None, md5sum=None,
check_exist=True, check_exist=True,
decompress=True): decompress=True,
method='get'):
""" Download from given url to root_dir. """ Download from given url to root_dir.
if file or directory specified by url is exists under if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download root_dir, return the path directly, otherwise download
...@@ -132,6 +134,8 @@ def get_path_from_url(url, ...@@ -132,6 +134,8 @@ def get_path_from_url(url,
root_dir (str): root dir for downloading, it should be root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
Returns: Returns:
str: a local path to save downloaded models & weights & datasets. str: a local path to save downloaded models & weights & datasets.
...@@ -150,7 +154,7 @@ def get_path_from_url(url, ...@@ -150,7 +154,7 @@ def get_path_from_url(url,
logger.info("Found {}".format(fullpath)) logger.info("Found {}".format(fullpath))
else: else:
if ParallelEnv().current_endpoint in unique_endpoints: if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum) fullpath = _download(url, root_dir, md5sum, method=method)
else: else:
while not os.path.exists(fullpath): while not os.path.exists(fullpath):
time.sleep(1) time.sleep(1)
...@@ -163,37 +167,15 @@ def get_path_from_url(url, ...@@ -163,37 +167,15 @@ def get_path_from_url(url,
return fullpath return fullpath
def _download(url, path, md5sum=None): def _get_download(url, fullname):
""" # using requests.get method
Download from url, save to path. fname = osp.basename(fullname)
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
logger.info("Downloading {} from {}".format(fname, url))
try: try:
req = requests.get(url, stream=True) req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError except Exception as e: # requests.exceptions.ConnectionError
logger.info( logger.info("Downloading {} from {} failed with exception {}".format(
"Downloading {} from {} failed {} times with exception {}". fname, url, str(e)))
format(fname, url, retry_cnt + 1, str(e))) return False
time.sleep(1)
continue
if req.status_code != 200: if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code " raise RuntimeError("Downloading from {} failed with code "
...@@ -219,6 +201,67 @@ def _download(url, path, md5sum=None): ...@@ -219,6 +201,67 @@ def _download(url, path, md5sum=None):
return fullname return fullname
def _wget_download(url, fullname):
# using wget to download url
tmp_fullname = fullname + "_tmp"
# –user-agent
command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT,
url)
subprc = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
_ = subprc.communicate()
if subprc.returncode != 0:
raise RuntimeError(
'{} failed. Please make sure `wget` is installed or {} exists'.
format(command, url))
shutil.move(tmp_fullname, fullname)
return fullname
_download_methods = {
'get': _get_download,
'wget': _wget_download,
}
def _download(url, path, md5sum=None, method='get'):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
"""
assert method in _download_methods, 'make sure `{}` implemented'.format(
method)
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _download_methods[method](url, fullname):
time.sleep(1)
continue
return fullname
def _md5check(fullname, md5sum=None): def _md5check(fullname, md5sum=None):
if md5sum is None: if md5sum is None:
return True return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册