提交 de95bb4f 编写于 作者: W wuzewu

update downloader

上级 cb09f5f4
...@@ -17,7 +17,7 @@ from __future__ import division ...@@ -17,7 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle_hub.tools import utils from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger from paddle_hub.tools.logger import logger
from paddle_hub.tools import downloader from paddle_hub.tools.downloader import default_downloader
from paddle_hub.tools import paddle_helper from paddle_hub.tools import paddle_helper
from paddle_hub.module import module_desc_pb2 from paddle_hub.module import module_desc_pb2
from paddle_hub.module.signature import Signature, create_signature from paddle_hub.module.signature import Signature, create_signature
...@@ -116,7 +116,8 @@ class Module: ...@@ -116,7 +116,8 @@ class Module:
def _init_with_url(self, url): def _init_with_url(self, url):
utils.check_url_valid(url) utils.check_url_valid(url)
module_dir = downloader.download_and_uncompress(module_url) module_dir = default_downloader.download_file_and_uncompress(
module_url, save_path=".")
self._init_with_module_file(module_dir) self._init_with_module_file(module_dir)
def _dump_processor(self): def _dump_processor(self):
......
...@@ -23,11 +23,13 @@ import hashlib ...@@ -23,11 +23,13 @@ import hashlib
import requests import requests
import tempfile import tempfile
import tarfile import tarfile
from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
__all__ = ['MODULE_HOME', 'download', 'md5file', 'download_and_uncompress'] __all__ = ['MODULE_HOME', 'downloader', 'md5file', 'Downloader']
# TODO(ZeyuChen) add environment varialble to set MODULE_HOME # TODO(ZeyuChen) add environment varialble to set MODULE_HOME
MODULE_HOME = os.path.expanduser('~/.cache/paddle/module') MODULE_HOME = os.path.expanduser('~/.hub/module')
# When running unit tests, there could be multiple processes that # When running unit tests, there could be multiple processes that
...@@ -53,64 +55,77 @@ def md5file(fname): ...@@ -53,64 +55,77 @@ def md5file(fname):
return hash_md5.hexdigest() return hash_md5.hexdigest()
def download_and_uncompress(url, save_name=None): class Downloader:
module_name = url.split("/")[-2] def __init__(self, module_home=None):
dirname = os.path.join(MODULE_HOME, module_name) self.module_home = module_home if module_home else MODULE_HOME
if not os.path.exists(dirname):
os.makedirs(dirname) def download_file(self, url, save_path=None, save_name=None, retry_limit=3):
module_name = url.split("/")[-2]
#TODO(ZeyuChen) add download md5 file to verify file completeness save_path = self.module_home if save_path is None else save_path
file_name = os.path.join( if not os.path.exists(save_path):
dirname, utils.mkdir(save_path)
url.split('/')[-1] if save_name is None else save_name) save_name = url.split('/')[-1] if save_name is None else save_name
file_name = os.path.join(save_path, save_name)
retry = 0 retry_times = 0
retry_limit = 3 while not (os.path.exists(file_name)):
while not (os.path.exists(file_name)): if os.path.exists(file_name):
if os.path.exists(file_name): logger.info("file md5", md5file(file_name))
print("file md5", md5file(file_name)) if retry_times < retry_limit:
if retry < retry_limit: retry_times += 1
retry += 1 else:
else: raise RuntimeError(
raise RuntimeError( "Cannot download {0} within retry limit {1}".format(
"Cannot download {0} within retry limit {1}".format( url, retry_limit))
url, retry_limit)) logger.info(
print("Cache file %s not found, downloading %s" % (file_name, url)) "Cache file %s not found, downloading %s" % (file_name, url))
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
total_length = r.headers.get('content-length') total_length = r.headers.get('content-length')
if total_length is None: if total_length is None:
with open(file_name, 'wb') as f: with open(file_name, 'wb') as f:
shutil.copyfileobj(r.raw, f) shutil.copyfileobj(r.raw, f)
else: else:
#TODO(ZeyuChen) upgrade to tqdm process #TODO(ZeyuChen) upgrade to tqdm process
with open(file_name, 'wb') as f: with open(file_name, 'wb') as f:
dl = 0 dl = 0
total_length = int(total_length) total_length = int(total_length)
for data in r.iter_content(chunk_size=4096): for data in r.iter_content(chunk_size=4096):
dl += len(data) dl += len(data)
f.write(data) f.write(data)
done = int(50 * dl / total_length) done = int(50 * dl / total_length)
sys.stdout.write( sys.stdout.write(
"\r[%s%s]" % ('=' * done, ' ' * (50 - done))) "\r[%s%s]" % ('=' * done, ' ' * (50 - done)))
sys.stdout.flush() sys.stdout.flush()
print("file download completed!", file_name) logger.info("file %s download completed!" % (file_name))
#TODO(ZeyuChen) add md5 check error and file incompleted error, then raise return file_name
# them and catch them
with tarfile.open(file_name, "r:gz") as tar: def uncompress(self, file, dirname=None, delete_file=False):
file_names = tar.getnames() dirname = os.path.dirname(file) if dirname is None else dirname
print(file_names) with tarfile.open(file, "r:gz") as tar:
module_dir = os.path.join(dirname, file_names[0]) file_names = tar.getnames()
for file_name in file_names: logger.info(file_names)
tar.extract(file_name, dirname) module_dir = os.path.join(dirname, file_names[0])
for file_name in file_names:
return module_dir tar.extract(file_name, dirname)
if delete_file:
if __name__ == "__main__": os.remove(file)
# TODO(ZeyuChen) add unit test
link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz" return module_dir
module_path = download_and_uncompress(link) def download_file_and_uncompress(self,
print("module path", module_path) url,
save_path=None,
save_name=None,
retry_limit=3,
delete_file=True):
file = self.download_file(
url=url,
save_path=save_path,
save_name=save_name,
retry_limit=retry_limit)
return self.uncompress(file, delete_file=delete_file)
default_downloader = Downloader()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册