提交 de95bb4f 编写于 作者: W wuzewu

update downloader

上级 cb09f5f4
......@@ -17,7 +17,7 @@ from __future__ import division
from __future__ import print_function
from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
from paddle_hub.tools import downloader
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.tools import paddle_helper
from paddle_hub.module import module_desc_pb2
from paddle_hub.module.signature import Signature, create_signature
......@@ -116,7 +116,8 @@ class Module:
def _init_with_url(self, url):
utils.check_url_valid(url)
module_dir = downloader.download_and_uncompress(module_url)
module_dir = default_downloader.download_file_and_uncompress(
module_url, save_path=".")
self._init_with_module_file(module_dir)
def _dump_processor(self):
......
......@@ -23,11 +23,13 @@ import hashlib
import requests
import tempfile
import tarfile
from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
__all__ = ['MODULE_HOME', 'download', 'md5file', 'download_and_uncompress']
__all__ = ['MODULE_HOME', 'downloader', 'md5file', 'Downloader']
# TODO(ZeyuChen) add environment varialble to set MODULE_HOME
MODULE_HOME = os.path.expanduser('~/.cache/paddle/module')
MODULE_HOME = os.path.expanduser('~/.hub/module')
# When running unit tests, there could be multiple processes that
......@@ -53,64 +55,77 @@ def md5file(fname):
return hash_md5.hexdigest()
def download_and_uncompress(url, save_name=None):
module_name = url.split("/")[-2]
dirname = os.path.join(MODULE_HOME, module_name)
if not os.path.exists(dirname):
os.makedirs(dirname)
#TODO(ZeyuChen) add download md5 file to verify file completeness
file_name = os.path.join(
dirname,
url.split('/')[-1] if save_name is None else save_name)
retry = 0
retry_limit = 3
while not (os.path.exists(file_name)):
if os.path.exists(file_name):
print("file md5", md5file(file_name))
if retry < retry_limit:
retry += 1
else:
raise RuntimeError(
"Cannot download {0} within retry limit {1}".format(
url, retry_limit))
print("Cache file %s not found, downloading %s" % (file_name, url))
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(file_name, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
#TODO(ZeyuChen) upgrade to tqdm process
with open(file_name, 'wb') as f:
dl = 0
total_length = int(total_length)
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
done = int(50 * dl / total_length)
sys.stdout.write(
"\r[%s%s]" % ('=' * done, ' ' * (50 - done)))
sys.stdout.flush()
print("file download completed!", file_name)
#TODO(ZeyuChen) add md5 check error and file incompleted error, then raise
# them and catch them
with tarfile.open(file_name, "r:gz") as tar:
file_names = tar.getnames()
print(file_names)
module_dir = os.path.join(dirname, file_names[0])
for file_name in file_names:
tar.extract(file_name, dirname)
return module_dir
if __name__ == "__main__":
# TODO(ZeyuChen) add unit test
link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz"
module_path = download_and_uncompress(link)
print("module path", module_path)
class Downloader:
def __init__(self, module_home=None):
self.module_home = module_home if module_home else MODULE_HOME
def download_file(self, url, save_path=None, save_name=None, retry_limit=3):
module_name = url.split("/")[-2]
save_path = self.module_home if save_path is None else save_path
if not os.path.exists(save_path):
utils.mkdir(save_path)
save_name = url.split('/')[-1] if save_name is None else save_name
file_name = os.path.join(save_path, save_name)
retry_times = 0
while not (os.path.exists(file_name)):
if os.path.exists(file_name):
logger.info("file md5", md5file(file_name))
if retry_times < retry_limit:
retry_times += 1
else:
raise RuntimeError(
"Cannot download {0} within retry limit {1}".format(
url, retry_limit))
logger.info(
"Cache file %s not found, downloading %s" % (file_name, url))
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(file_name, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
#TODO(ZeyuChen) upgrade to tqdm process
with open(file_name, 'wb') as f:
dl = 0
total_length = int(total_length)
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
done = int(50 * dl / total_length)
sys.stdout.write(
"\r[%s%s]" % ('=' * done, ' ' * (50 - done)))
sys.stdout.flush()
logger.info("file %s download completed!" % (file_name))
return file_name
def uncompress(self, file, dirname=None, delete_file=False):
dirname = os.path.dirname(file) if dirname is None else dirname
with tarfile.open(file, "r:gz") as tar:
file_names = tar.getnames()
logger.info(file_names)
module_dir = os.path.join(dirname, file_names[0])
for file_name in file_names:
tar.extract(file_name, dirname)
if delete_file:
os.remove(file)
return module_dir
def download_file_and_uncompress(self,
url,
save_path=None,
save_name=None,
retry_limit=3,
delete_file=True):
file = self.download_file(
url=url,
save_path=save_path,
save_name=save_name,
retry_limit=retry_limit)
return self.uncompress(file, delete_file=delete_file)
default_downloader = Downloader()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册