提交 de95bb4f 编写于 作者: W wuzewu

update downloader

上级 cb09f5f4
...@@ -17,7 +17,7 @@ from __future__ import division ...@@ -17,7 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle_hub.tools import utils from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger from paddle_hub.tools.logger import logger
from paddle_hub.tools import downloader from paddle_hub.tools.downloader import default_downloader
from paddle_hub.tools import paddle_helper from paddle_hub.tools import paddle_helper
from paddle_hub.module import module_desc_pb2 from paddle_hub.module import module_desc_pb2
from paddle_hub.module.signature import Signature, create_signature from paddle_hub.module.signature import Signature, create_signature
...@@ -116,7 +116,8 @@ class Module: ...@@ -116,7 +116,8 @@ class Module:
def _init_with_url(self, url): def _init_with_url(self, url):
utils.check_url_valid(url) utils.check_url_valid(url)
module_dir = downloader.download_and_uncompress(module_url) module_dir = default_downloader.download_file_and_uncompress(
module_url, save_path=".")
self._init_with_module_file(module_dir) self._init_with_module_file(module_dir)
def _dump_processor(self): def _dump_processor(self):
......
...@@ -23,11 +23,13 @@ import hashlib ...@@ -23,11 +23,13 @@ import hashlib
import requests import requests
import tempfile import tempfile
import tarfile import tarfile
from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
__all__ = ['MODULE_HOME', 'download', 'md5file', 'download_and_uncompress'] __all__ = ['MODULE_HOME', 'downloader', 'md5file', 'Downloader']
# TODO(ZeyuChen) add environment varialble to set MODULE_HOME # TODO(ZeyuChen) add environment varialble to set MODULE_HOME
MODULE_HOME = os.path.expanduser('~/.cache/paddle/module') MODULE_HOME = os.path.expanduser('~/.hub/module')
# When running unit tests, there could be multiple processes that # When running unit tests, there could be multiple processes that
...@@ -53,29 +55,29 @@ def md5file(fname): ...@@ -53,29 +55,29 @@ def md5file(fname):
return hash_md5.hexdigest() return hash_md5.hexdigest()
def download_and_uncompress(url, save_name=None): class Downloader:
module_name = url.split("/")[-2] def __init__(self, module_home=None):
dirname = os.path.join(MODULE_HOME, module_name) self.module_home = module_home if module_home else MODULE_HOME
if not os.path.exists(dirname):
os.makedirs(dirname)
#TODO(ZeyuChen) add download md5 file to verify file completeness
file_name = os.path.join(
dirname,
url.split('/')[-1] if save_name is None else save_name)
retry = 0 def download_file(self, url, save_path=None, save_name=None, retry_limit=3):
retry_limit = 3 module_name = url.split("/")[-2]
save_path = self.module_home if save_path is None else save_path
if not os.path.exists(save_path):
utils.mkdir(save_path)
save_name = url.split('/')[-1] if save_name is None else save_name
file_name = os.path.join(save_path, save_name)
retry_times = 0
while not (os.path.exists(file_name)): while not (os.path.exists(file_name)):
if os.path.exists(file_name): if os.path.exists(file_name):
print("file md5", md5file(file_name)) logger.info("file md5", md5file(file_name))
if retry < retry_limit: if retry_times < retry_limit:
retry += 1 retry_times += 1
else: else:
raise RuntimeError( raise RuntimeError(
"Cannot download {0} within retry limit {1}".format( "Cannot download {0} within retry limit {1}".format(
url, retry_limit)) url, retry_limit))
print("Cache file %s not found, downloading %s" % (file_name, url)) logger.info(
"Cache file %s not found, downloading %s" % (file_name, url))
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
total_length = r.headers.get('content-length') total_length = r.headers.get('content-length')
...@@ -95,22 +97,35 @@ def download_and_uncompress(url, save_name=None): ...@@ -95,22 +97,35 @@ def download_and_uncompress(url, save_name=None):
"\r[%s%s]" % ('=' * done, ' ' * (50 - done))) "\r[%s%s]" % ('=' * done, ' ' * (50 - done)))
sys.stdout.flush() sys.stdout.flush()
print("file download completed!", file_name) logger.info("file %s download completed!" % (file_name))
#TODO(ZeyuChen) add md5 check error and file incompleted error, then raise return file_name
# them and catch them
with tarfile.open(file_name, "r:gz") as tar: def uncompress(self, file, dirname=None, delete_file=False):
dirname = os.path.dirname(file) if dirname is None else dirname
with tarfile.open(file, "r:gz") as tar:
file_names = tar.getnames() file_names = tar.getnames()
print(file_names) logger.info(file_names)
module_dir = os.path.join(dirname, file_names[0]) module_dir = os.path.join(dirname, file_names[0])
for file_name in file_names: for file_name in file_names:
tar.extract(file_name, dirname) tar.extract(file_name, dirname)
return module_dir if delete_file:
os.remove(file)
if __name__ == "__main__": return module_dir
# TODO(ZeyuChen) add unit test
link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz"
module_path = download_and_uncompress(link) def download_file_and_uncompress(self,
print("module path", module_path) url,
save_path=None,
save_name=None,
retry_limit=3,
delete_file=True):
file = self.download_file(
url=url,
save_path=save_path,
save_name=save_name,
retry_limit=retry_limit)
return self.uncompress(file, delete_file=delete_file)
default_downloader = Downloader()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册