未验证 提交 7860c6f0 编写于 作者: M Meiyim 提交者: GitHub

use pathlib, fix #546 (#560)

* use pathib, fix #546

* use encoding

* Update feature_column.py

fix #559
上级 32a463eb
...@@ -19,27 +19,44 @@ from __future__ import unicode_literals ...@@ -19,27 +19,44 @@ from __future__ import unicode_literals
import os import os
import logging import logging
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path
import six
if six.PY2:
from pathlib2 import Path
else:
from pathlib import Path
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def _fetch_from_remote(url, force_download=False): def _fetch_from_remote(url, force_download=False, cached_dir='~/.paddle-ernie-cache'):
import hashlib, tempfile, requests, tarfile import hashlib, tempfile, requests, tarfile
sig = hashlib.md5(url.encode('utf8')).hexdigest() sig = hashlib.md5(url.encode('utf8')).hexdigest()
cached_dir = os.path.join(tempfile.gettempdir(), sig) cached_dir = Path(cached_dir).expanduser()
if force_download or not os.path.exists(cached_dir): try:
with tempfile.NamedTemporaryFile() as f: cached_dir.mkdir()
except OSError:
pass
cached_dir_model = cached_dir / sig
if force_download or not cached_dir_model.exists():
cached_dir_model.mkdir()
tmpfile = cached_dir_model / 'tmp'
with tmpfile.open('wb') as f:
#url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz' #url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz'
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
total_len = int(r.headers.get('content-length')) total_len = int(r.headers.get('content-length'))
for chunk in tqdm(r.iter_content(chunk_size=1024), total=total_len // 1024, desc='downloading %s' % url, unit='KB'): for chunk in tqdm(r.iter_content(chunk_size=1024),
total=total_len // 1024,
desc='downloading %s' % url,
unit='KB'):
if chunk: if chunk:
f.write(chunk) f.write(chunk)
f.flush() f.flush()
log.debug('extacting... to %s' % f.name) log.debug('extacting... to %s' % tmpfile)
with tarfile.open(f.name) as tf: with tarfile.open(tmpfile.as_posix()) as tf:
tf.extractall(path=cached_dir) tf.extractall(path=cached_dir_model.as_posix())
os.remove(tmpfile.as_posix())
log.debug('%s cached in %s' % (url, cached_dir)) log.debug('%s cached in %s' % (url, cached_dir))
return cached_dir return cached_dir_model
def add_docstring(doc): def add_docstring(doc):
......
...@@ -24,6 +24,11 @@ import json ...@@ -24,6 +24,11 @@ import json
import logging import logging
import logging import logging
from functools import partial from functools import partial
import six
if six.PY2:
from pathlib2 import Path
else:
from pathlib import Path
import paddle.fluid.dygraph as D import paddle.fluid.dygraph as D
import paddle.fluid as F import paddle.fluid as F
...@@ -191,7 +196,7 @@ class PretrainedModel(object): ...@@ -191,7 +196,7 @@ class PretrainedModel(object):
} }
@classmethod @classmethod
def from_pretrained(cls, pretrain_dir_or_url, force_download=False, **kwargs): def from_pretrained(cls, pretrain_dir_or_url, force_download=False, **kwargs):
if pretrain_dir_or_url in cls.resource_map: if not Path(pretrain_dir_or_url).exists() and pretrain_dir_or_url in cls.resource_map:
url = cls.resource_map[pretrain_dir_or_url] url = cls.resource_map[pretrain_dir_or_url]
log.info('get pretrain dir from %s' % url) log.info('get pretrain dir from %s' % url)
pretrain_dir = _fetch_from_remote(url, force_download) pretrain_dir = _fetch_from_remote(url, force_download)
...@@ -199,16 +204,16 @@ class PretrainedModel(object): ...@@ -199,16 +204,16 @@ class PretrainedModel(object):
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url pretrain_dir = pretrain_dir_or_url
if not os.path.exists(pretrain_dir): if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir) raise ValueError('pretrain dir not found: %s' % pretrain_dir)
param_path = os.path.join(pretrain_dir, 'params') param_path = pretrain_dir /'params'
state_dict_path = os.path.join(pretrain_dir, 'saved_weights') state_dict_path = pretrain_dir / 'saved_weights'
config_path = os.path.join(pretrain_dir, 'ernie_config.json') config_path = pretrain_dir / 'ernie_config.json'
if not os.path.exists(config_path): if not config_path.exists():
raise ValueError('config path not found: %s' % config_path) raise ValueError('config path not found: %s' % config_path)
name_prefix=kwargs.pop('name', None) name_prefix=kwargs.pop('name', None)
cfg_dict = dict(json.loads(open(config_path).read()), **kwargs) cfg_dict = dict(json.loads(config_path.open().read()), **kwargs)
model = cls(cfg_dict, name=name_prefix) model = cls(cfg_dict, name=name_prefix)
log.info('loading pretrained model from %s' % pretrain_dir) log.info('loading pretrained model from %s' % pretrain_dir)
...@@ -217,8 +222,8 @@ class PretrainedModel(object): ...@@ -217,8 +222,8 @@ class PretrainedModel(object):
# raise NotImplementedError() # raise NotImplementedError()
# log.debug('load pretrained weight from program state') # log.debug('load pretrained weight from program state')
# F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix # F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix
if os.path.exists(state_dict_path + '.pdparams'): if state_dict_path.with_suffix('.pdparams').exists():
m, _ = D.load_dygraph(state_dict_path) m, _ = D.load_dygraph(state_dict_path.as_posix())
for k, v in model.state_dict().items(): for k, v in model.state_dict().items():
if k not in m: if k not in m:
log.warn('param:%s not set in pretrained model, skip' % k) log.warn('param:%s not set in pretrained model, skip' % k)
......
...@@ -91,12 +91,12 @@ class ErnieTokenizer(object): ...@@ -91,12 +91,12 @@ class ErnieTokenizer(object):
else: else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url pretrain_dir = pretrain_dir_or_url
if not os.path.exists(pretrain_dir): if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir) raise ValueError('pretrain dir not found: %s' % pretrain_dir)
vocab_path = os.path.join(pretrain_dir, 'vocab.txt') vocab_path = pretrain_dir / 'vocab.txt'
if not os.path.exists(vocab_path): if not vocab_path.exists():
raise ValueError('no vocab file in pretrain dir: %s' % pretrain_dir) raise ValueError('no vocab file in pretrain dir: %s' % pretrain_dir)
vocab_dict = {j.strip().split('\t')[0]: i for i, j in enumerate(open(vocab_path).readlines())} vocab_dict = {j.strip().split('\t')[0]: i for i, j in enumerate(vocab_path.open(encoding='utf8').readlines())}
t = cls(vocab_dict, **kwargs) t = cls(vocab_dict, **kwargs)
return t return t
...@@ -207,14 +207,14 @@ class ErnieTinyTokenizer(ErnieTokenizer): ...@@ -207,14 +207,14 @@ class ErnieTinyTokenizer(ErnieTokenizer):
else: else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url pretrain_dir = pretrain_dir_or_url
if not os.path.exists(pretrain_dir): if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir) raise ValueError('pretrain dir not found: %s' % pretrain_dir)
vocab_path = os.path.join(pretrain_dir, 'vocab.txt') vocab_path = pretrain_dir / 'vocab.txt'
sp_model_path = os.path.join(pretrain_dir, 'subword/spm_cased_simp_sampled.model') sp_model_path = pretrain_dir / 'subword/spm_cased_simp_sampled.model'
if not os.path.exists(vocab_path): if not vocab_path.exists():
raise ValueError('no vocab file in pretrain dir: %s' % pretrain_dir) raise ValueError('no vocab file in pretrain dir: %s' % pretrain_dir)
vocab_dict = {j.strip().split('\t')[0]: i for i, j in enumerate(open(vocab_path).readlines())} vocab_dict = {j.strip().split('\t')[0]: i for i, j in enumerate(vocab_path.open(encoding='utf8').readlines())}
t = cls(vocab_dict, sp_model_path, **kwargs) t = cls(vocab_dict, sp_model_path, **kwargs)
return t return t
......
...@@ -125,7 +125,7 @@ class LabelColumn(Column): ...@@ -125,7 +125,7 @@ class LabelColumn(Column):
ids = int(raw) ids = int(raw)
else: else:
ids = self.vocab[raw] ids = self.vocab[raw]
return ids return np.array(ids, dtype=np.int64)
class TextColumn(Column): class TextColumn(Column):
......
...@@ -5,3 +5,4 @@ sklearn==0.0 ...@@ -5,3 +5,4 @@ sklearn==0.0
sentencepiece==0.1.8 sentencepiece==0.1.8
jieba==0.39 jieba==0.39
visualdl>=2.0.0b7 visualdl>=2.0.0b7
pathlib2>=2.3.2
...@@ -22,7 +22,7 @@ with open("README.md", "r", encoding='utf-8') as fh: ...@@ -22,7 +22,7 @@ with open("README.md", "r", encoding='utf-8') as fh:
setuptools.setup( setuptools.setup(
name="paddle-ernie", # Replace with your own username name="paddle-ernie", # Replace with your own username
version="0.0.4dev1", version="0.0.5dev1",
author="Baidu Ernie Team", author="Baidu Ernie Team",
author_email="ernieernie.team@gmail.com", author_email="ernieernie.team@gmail.com",
description="A pretrained NLP model for every NLP tasks", description="A pretrained NLP model for every NLP tasks",
...@@ -33,6 +33,7 @@ setuptools.setup( ...@@ -33,6 +33,7 @@ setuptools.setup(
install_requires=[ install_requires=[
'requests', 'requests',
'tqdm', 'tqdm',
'pathlib2',
], ],
classifiers=[ classifiers=[
'Intended Audience :: Developers', 'Intended Audience :: Developers',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册