From de95bb4f96ea3f3d7bc06537d7d3fae0eccc7c74 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Thu, 14 Mar 2019 10:54:32 +0800 Subject: [PATCH] update downloader --- paddle_hub/module/module.py | 5 +- paddle_hub/tools/downloader.py | 141 ++++++++++++++++++--------------- 2 files changed, 81 insertions(+), 65 deletions(-) diff --git a/paddle_hub/module/module.py b/paddle_hub/module/module.py index 7f36bb29..81a02045 100644 --- a/paddle_hub/module/module.py +++ b/paddle_hub/module/module.py @@ -17,7 +17,7 @@ from __future__ import division from __future__ import print_function from paddle_hub.tools import utils from paddle_hub.tools.logger import logger -from paddle_hub.tools import downloader +from paddle_hub.tools.downloader import default_downloader from paddle_hub.tools import paddle_helper from paddle_hub.module import module_desc_pb2 from paddle_hub.module.signature import Signature, create_signature @@ -116,7 +116,8 @@ class Module: def _init_with_url(self, url): utils.check_url_valid(url) - module_dir = downloader.download_and_uncompress(module_url) + module_dir = default_downloader.download_file_and_uncompress( + module_url, save_path=".") self._init_with_module_file(module_dir) def _dump_processor(self): diff --git a/paddle_hub/tools/downloader.py b/paddle_hub/tools/downloader.py index 5dcd8be8..124c8ed2 100644 --- a/paddle_hub/tools/downloader.py +++ b/paddle_hub/tools/downloader.py @@ -23,11 +23,13 @@ import hashlib import requests import tempfile import tarfile +from paddle_hub.tools import utils +from paddle_hub.tools.logger import logger -__all__ = ['MODULE_HOME', 'download', 'md5file', 'download_and_uncompress'] +__all__ = ['MODULE_HOME', 'downloader', 'md5file', 'Downloader'] # TODO(ZeyuChen) add environment varialble to set MODULE_HOME -MODULE_HOME = os.path.expanduser('~/.cache/paddle/module') +MODULE_HOME = os.path.expanduser('~/.hub/module') # When running unit tests, there could be multiple processes that @@ -53,64 +55,77 @@ def md5file(fname): return hash_md5.hexdigest() -def download_and_uncompress(url, save_name=None): - module_name = url.split("/")[-2] - dirname = os.path.join(MODULE_HOME, module_name) - if not os.path.exists(dirname): - os.makedirs(dirname) - - #TODO(ZeyuChen) add download md5 file to verify file completeness - file_name = os.path.join( - dirname, - url.split('/')[-1] if save_name is None else save_name) - - retry = 0 - retry_limit = 3 - while not (os.path.exists(file_name)): - if os.path.exists(file_name): - print("file md5", md5file(file_name)) - if retry < retry_limit: - retry += 1 - else: - raise RuntimeError( - "Cannot download {0} within retry limit {1}".format( - url, retry_limit)) - print("Cache file %s not found, downloading %s" % (file_name, url)) - r = requests.get(url, stream=True) - total_length = r.headers.get('content-length') - - if total_length is None: - with open(file_name, 'wb') as f: - shutil.copyfileobj(r.raw, f) - else: - #TODO(ZeyuChen) upgrade to tqdm process - with open(file_name, 'wb') as f: - dl = 0 - total_length = int(total_length) - for data in r.iter_content(chunk_size=4096): - dl += len(data) - f.write(data) - done = int(50 * dl / total_length) - sys.stdout.write( - "\r[%s%s]" % ('=' * done, ' ' * (50 - done))) - sys.stdout.flush() - - print("file download completed!", file_name) - #TODO(ZeyuChen) add md5 check error and file incompleted error, then raise - # them and catch them - with tarfile.open(file_name, "r:gz") as tar: - file_names = tar.getnames() - print(file_names) - module_dir = os.path.join(dirname, file_names[0]) - for file_name in file_names: - tar.extract(file_name, dirname) - - return module_dir - - -if __name__ == "__main__": - # TODO(ZeyuChen) add unit test - link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz" - - module_path = download_and_uncompress(link) - print("module path", module_path) +class Downloader: + def __init__(self, module_home=None): + self.module_home = module_home if module_home else MODULE_HOME + + def download_file(self, url, save_path=None, save_name=None, retry_limit=3): + module_name = url.split("/")[-2] + save_path = self.module_home if save_path is None else save_path + if not os.path.exists(save_path): + utils.mkdir(save_path) + save_name = url.split('/')[-1] if save_name is None else save_name + file_name = os.path.join(save_path, save_name) + retry_times = 0 + while not (os.path.exists(file_name)): + if os.path.exists(file_name): + logger.info("file md5", md5file(file_name)) + if retry_times < retry_limit: + retry_times += 1 + else: + raise RuntimeError( + "Cannot download {0} within retry limit {1}".format( + url, retry_limit)) + logger.info( + "Cache file %s not found, downloading %s" % (file_name, url)) + r = requests.get(url, stream=True) + total_length = r.headers.get('content-length') + + if total_length is None: + with open(file_name, 'wb') as f: + shutil.copyfileobj(r.raw, f) + else: + #TODO(ZeyuChen) upgrade to tqdm process + with open(file_name, 'wb') as f: + dl = 0 + total_length = int(total_length) + for data in r.iter_content(chunk_size=4096): + dl += len(data) + f.write(data) + done = int(50 * dl / total_length) + sys.stdout.write( + "\r[%s%s]" % ('=' * done, ' ' * (50 - done))) + sys.stdout.flush() + + logger.info("file %s download completed!" % (file_name)) + return file_name + + def uncompress(self, file, dirname=None, delete_file=False): + dirname = os.path.dirname(file) if dirname is None else dirname + with tarfile.open(file, "r:gz") as tar: + file_names = tar.getnames() + logger.info(file_names) + module_dir = os.path.join(dirname, file_names[0]) + for file_name in file_names: + tar.extract(file_name, dirname) + + if delete_file: + os.remove(file) + + return module_dir + + def download_file_and_uncompress(self, + url, + save_path=None, + save_name=None, + retry_limit=3, + delete_file=True): + file = self.download_file( + url=url, + save_path=save_path, + save_name=save_name, + retry_limit=retry_limit) + return self.uncompress(file, delete_file=delete_file) + + +default_downloader = Downloader() -- GitLab