未验证 提交 fdb59529 编写于 作者: N nbcc 提交者: GitHub

Merge pull request #673 from Meiyim/multihead-download

Multihead download
...@@ -17,6 +17,9 @@ from __future__ import absolute_import ...@@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import sys
import logging
import paddle import paddle
if paddle.__version__ != '0.0.0' and paddle.__version__ < '2.0.0': if paddle.__version__ != '0.0.0' and paddle.__version__ < '2.0.0':
raise RuntimeError('propeller 0.2 requires paddle 2.0+, got %s' % raise RuntimeError('propeller 0.2 requires paddle 2.0+, got %s' %
...@@ -28,3 +31,10 @@ from ernie.modeling_ernie import ( ...@@ -28,3 +31,10 @@ from ernie.modeling_ernie import (
ErnieModelForQuestionAnswering, ErnieModelForPretraining) ErnieModelForQuestionAnswering, ErnieModelForPretraining)
from ernie.tokenizing_ernie import ErnieTokenizer, ErnieTinyTokenizer from ernie.tokenizing_ernie import ErnieTokenizer, ErnieTinyTokenizer
log = logging.getLogger(__name__)
formatter = logging.Formatter(fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]: %(message)s')
stream_hdl = logging.StreamHandler(stream=sys.stderr)
stream_hdl.setFormatter(formatter)
log.addHandler(stream_hdl)
log.propagate = False
...@@ -21,7 +21,6 @@ import logging ...@@ -21,7 +21,6 @@ import logging
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
import six import six
import paddle as P
import time import time
if six.PY2: if six.PY2:
from pathlib2 import Path from pathlib2 import Path
...@@ -35,8 +34,6 @@ def _fetch_from_remote(url, ...@@ -35,8 +34,6 @@ def _fetch_from_remote(url,
force_download=False, force_download=False,
cached_dir='~/.paddle-ernie-cache'): cached_dir='~/.paddle-ernie-cache'):
import hashlib, tempfile, requests, tarfile import hashlib, tempfile, requests, tarfile
env = P.distributed.ParallelEnv()
sig = hashlib.md5(url.encode('utf8')).hexdigest() sig = hashlib.md5(url.encode('utf8')).hexdigest()
cached_dir = Path(cached_dir).expanduser() cached_dir = Path(cached_dir).expanduser()
try: try:
...@@ -44,13 +41,15 @@ def _fetch_from_remote(url, ...@@ -44,13 +41,15 @@ def _fetch_from_remote(url,
except OSError: except OSError:
pass pass
cached_dir_model = cached_dir / sig cached_dir_model = cached_dir / sig
done_file = cached_dir_model / 'fetch_done' from filelock import FileLock
if force_download or not done_file.exists(): with FileLock(str(cached_dir_model) + '.lock'):
if env.dev_id == 0: donefile = cached_dir_model / 'done'
cached_dir_model.mkdir() if (not force_download) and donefile.exists():
log.debug('%s cached in %s' % (url, cached_dir_model))
return cached_dir_model
cached_dir_model.mkdir(exist_ok=True)
tmpfile = cached_dir_model / 'tmp' tmpfile = cached_dir_model / 'tmp'
with tmpfile.open('wb') as f: with tmpfile.open('wb') as f:
#url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz'
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
total_len = int(r.headers.get('content-length')) total_len = int(r.headers.get('content-length'))
for chunk in tqdm( for chunk in tqdm(
...@@ -63,15 +62,10 @@ def _fetch_from_remote(url, ...@@ -63,15 +62,10 @@ def _fetch_from_remote(url,
f.flush() f.flush()
log.debug('extacting... to %s' % tmpfile) log.debug('extacting... to %s' % tmpfile)
with tarfile.open(tmpfile.as_posix()) as tf: with tarfile.open(tmpfile.as_posix()) as tf:
tf.extractall(path=cached_dir_model.as_posix()) tf.extractall(path=str(cached_dir_model))
donefile.touch()
os.remove(tmpfile.as_posix()) os.remove(tmpfile.as_posix())
f = done_file.open('wb')
f.close()
else:
while not done_file.exists():
time.sleep(1)
log.debug('%s cached in %s' % (url, cached_dir))
return cached_dir_model return cached_dir_model
......
...@@ -272,7 +272,7 @@ class PretrainedModel(object): ...@@ -272,7 +272,7 @@ class PretrainedModel(object):
pretrain_dir = Path(pretrain_dir_or_url) pretrain_dir = Path(pretrain_dir_or_url)
if not pretrain_dir.exists(): if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir) raise ValueError('pretrain dir not found: %s, optional: %s' % (pretrain_dir, cls.resource_map.keys()))
state_dict_path = pretrain_dir / 'saved_weights.pdparams' state_dict_path = pretrain_dir / 'saved_weights.pdparams'
config_path = pretrain_dir / 'ernie_config.json' config_path = pretrain_dir / 'ernie_config.json'
......
...@@ -107,7 +107,7 @@ class ErnieTokenizer(object): ...@@ -107,7 +107,7 @@ class ErnieTokenizer(object):
(pretrain_dir_or_url, repr(cls.resource_map))) (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = Path(pretrain_dir_or_url) pretrain_dir = Path(pretrain_dir_or_url)
if not pretrain_dir.exists(): if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir) raise ValueError('pretrain dir not found: %s, optional: %s' % (pretrain_dir, cls.resource_map.keys()))
vocab_path = pretrain_dir / 'vocab.txt' vocab_path = pretrain_dir / 'vocab.txt'
if not vocab_path.exists(): if not vocab_path.exists():
raise ValueError('no vocab file in pretrain dir: %s' % raise ValueError('no vocab file in pretrain dir: %s' %
......
numpy numpy
pyzmq==18.0.2 pyzmq==18.0.2
six==1.11.0 six>=1.11.0
sklearn==0.0 sklearn==0.0
sentencepiece==0.1.8 sentencepiece==0.1.8
jieba==0.39 jieba==0.39
visualdl>=2.0.0b7 visualdl>=2.0.0b7
pathlib2>=2.3.2 pathlib2>=2.3.2
filelock>=3.0.0
tqdm>=4.32.2 tqdm>=4.32.2
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册