未验证 提交 f89a7b55 编写于 作者: W Wenyu 提交者: GitHub

add wget option in download (#33379)

* add wget option in download
上级 945e0847
......@@ -110,7 +110,11 @@ def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'):
url = _git_archive_link(repo_owner, repo_name, branch, source=source)
fpath = get_path_from_url(
url, hub_dir, check_exist=not force_reload, decompress=False)
url,
hub_dir,
check_exist=not force_reload,
decompress=False,
method=('wget' if source == 'gitee' else 'get'))
shutil.move(fpath, cached_file)
with zipfile.ZipFile(cached_file) as cached_zipfile:
......
......@@ -77,6 +77,31 @@ class TestDownload(unittest.TestCase):
'www.baidu.com',
'./test', )
def test_wget_download_error(self, ):
with self.assertRaises(RuntimeError):
from paddle.utils.download import _download
_download('www.baidu', './test', method='wget')
def test_download_methods(self, ):
urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/files.zip",
]
import sys
from paddle.utils.download import _download
if sys.platform == 'linux':
methods = ['wget', 'get']
else:
methods = ['get']
for url in urls:
for method in methods:
_download(
url,
path='./test',
method=method, )
if __name__ == '__main__':
unittest.main()
......@@ -21,6 +21,7 @@ import sys
import os.path as osp
import shutil
import requests
import subprocess
import hashlib
import tarfile
import zipfile
......@@ -121,7 +122,8 @@ def get_path_from_url(url,
root_dir,
md5sum=None,
check_exist=True,
decompress=True):
decompress=True,
method='get'):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
......@@ -132,7 +134,9 @@ def get_path_from_url(url,
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
......@@ -150,7 +154,7 @@ def get_path_from_url(url,
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum)
fullpath = _download(url, root_dir, md5sum, method=method)
else:
while not os.path.exists(fullpath):
time.sleep(1)
......@@ -163,13 +167,79 @@ def get_path_from_url(url,
return fullpath
def _download(url, path, md5sum=None):
def _get_download(url, fullname):
# using requests.get method
fname = osp.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _wget_download(url, fullname):
# using wget to download url
tmp_fullname = fullname + "_tmp"
# –user-agent
command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT,
url)
subprc = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
_ = subprc.communicate()
if subprc.returncode != 0:
raise RuntimeError(
'{} failed. Please make sure `wget` is installed or {} exists'.
format(command, url))
shutil.move(tmp_fullname, fullname)
return fullname
_download_methods = {
'get': _get_download,
'wget': _wget_download,
}
def _download(url, path, md5sum=None, method='get'):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
"""
assert method in _download_methods, 'make sure `{}` implemented'.format(
method)
if not osp.exists(path):
os.makedirs(path)
......@@ -177,6 +247,7 @@ def _download(url, path, md5sum=None):
fullname = osp.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
......@@ -184,38 +255,10 @@ def _download(url, path, md5sum=None):
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
logger.info("Downloading {} from {}".format(fname, url))
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info(
"Downloading {} from {} failed {} times with exception {}".
format(fname, url, retry_cnt + 1, str(e)))
if not _download_methods[method](url, fullname):
time.sleep(1)
continue
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册