未验证 提交 f89a7b55 编写于 作者: W Wenyu 提交者: GitHub

add wget option in download (#33379)

* add wget option in download
上级 945e0847
...@@ -110,7 +110,11 @@ def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'): ...@@ -110,7 +110,11 @@ def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'):
url = _git_archive_link(repo_owner, repo_name, branch, source=source) url = _git_archive_link(repo_owner, repo_name, branch, source=source)
fpath = get_path_from_url( fpath = get_path_from_url(
url, hub_dir, check_exist=not force_reload, decompress=False) url,
hub_dir,
check_exist=not force_reload,
decompress=False,
method=('wget' if source == 'gitee' else 'get'))
shutil.move(fpath, cached_file) shutil.move(fpath, cached_file)
with zipfile.ZipFile(cached_file) as cached_zipfile: with zipfile.ZipFile(cached_file) as cached_zipfile:
......
...@@ -77,6 +77,31 @@ class TestDownload(unittest.TestCase): ...@@ -77,6 +77,31 @@ class TestDownload(unittest.TestCase):
'www.baidu.com', 'www.baidu.com',
'./test', ) './test', )
def test_wget_download_error(self, ):
with self.assertRaises(RuntimeError):
from paddle.utils.download import _download
_download('www.baidu', './test', method='wget')
def test_download_methods(self, ):
urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/files.zip",
]
import sys
from paddle.utils.download import _download
if sys.platform == 'linux':
methods = ['wget', 'get']
else:
methods = ['get']
for url in urls:
for method in methods:
_download(
url,
path='./test',
method=method, )
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -21,6 +21,7 @@ import sys ...@@ -21,6 +21,7 @@ import sys
import os.path as osp import os.path as osp
import shutil import shutil
import requests import requests
import subprocess
import hashlib import hashlib
import tarfile import tarfile
import zipfile import zipfile
...@@ -121,7 +122,8 @@ def get_path_from_url(url, ...@@ -121,7 +122,8 @@ def get_path_from_url(url,
root_dir, root_dir,
md5sum=None, md5sum=None,
check_exist=True, check_exist=True,
decompress=True): decompress=True,
method='get'):
""" Download from given url to root_dir. """ Download from given url to root_dir.
if file or directory specified by url is exists under if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download root_dir, return the path directly, otherwise download
...@@ -132,7 +134,9 @@ def get_path_from_url(url, ...@@ -132,7 +134,9 @@ def get_path_from_url(url,
root_dir (str): root dir for downloading, it should be root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
Returns: Returns:
str: a local path to save downloaded models & weights & datasets. str: a local path to save downloaded models & weights & datasets.
""" """
...@@ -150,7 +154,7 @@ def get_path_from_url(url, ...@@ -150,7 +154,7 @@ def get_path_from_url(url,
logger.info("Found {}".format(fullpath)) logger.info("Found {}".format(fullpath))
else: else:
if ParallelEnv().current_endpoint in unique_endpoints: if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum) fullpath = _download(url, root_dir, md5sum, method=method)
else: else:
while not os.path.exists(fullpath): while not os.path.exists(fullpath):
time.sleep(1) time.sleep(1)
...@@ -163,13 +167,79 @@ def get_path_from_url(url, ...@@ -163,13 +167,79 @@ def get_path_from_url(url,
return fullpath return fullpath
def _download(url, path, md5sum=None): def _get_download(url, fullname):
# using requests.get method
fname = osp.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _wget_download(url, fullname):
# using wget to download url
tmp_fullname = fullname + "_tmp"
# –user-agent
command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT,
url)
subprc = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
_ = subprc.communicate()
if subprc.returncode != 0:
raise RuntimeError(
'{} failed. Please make sure `wget` is installed or {} exists'.
format(command, url))
shutil.move(tmp_fullname, fullname)
return fullname
_download_methods = {
'get': _get_download,
'wget': _wget_download,
}
def _download(url, path, md5sum=None, method='get'):
""" """
Download from url, save to path. Download from url, save to path.
url (str): download url url (str): download url
path (str): download to given path path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
""" """
assert method in _download_methods, 'make sure `{}` implemented'.format(
method)
if not osp.exists(path): if not osp.exists(path):
os.makedirs(path) os.makedirs(path)
...@@ -177,6 +247,7 @@ def _download(url, path, md5sum=None): ...@@ -177,6 +247,7 @@ def _download(url, path, md5sum=None):
fullname = osp.join(path, fname) fullname = osp.join(path, fname)
retry_cnt = 0 retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)): while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT: if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1 retry_cnt += 1
...@@ -184,38 +255,10 @@ def _download(url, path, md5sum=None): ...@@ -184,38 +255,10 @@ def _download(url, path, md5sum=None):
raise RuntimeError("Download from {} failed. " raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url)) "Retry limit reached".format(url))
logger.info("Downloading {} from {}".format(fname, url)) if not _download_methods[method](url, fullname):
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info(
"Downloading {} from {} failed {} times with exception {}".
format(fname, url, retry_cnt + 1, str(e)))
time.sleep(1) time.sleep(1)
continue continue
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname return fullname
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册