未验证 提交 7860c6f0 编写于 作者: M Meiyim 提交者: GitHub

use pathlib, fix #546 (#560)

* use pathib, fix #546

* use encoding

* Update feature_column.py

fix #559
上级 32a463eb
......@@ -19,27 +19,44 @@ from __future__ import unicode_literals
import os
import logging
from tqdm import tqdm
from pathlib import Path
import six
if six.PY2:
from pathlib2 import Path
else:
from pathlib import Path
log = logging.getLogger(__name__)
def _fetch_from_remote(url, force_download=False):
def _fetch_from_remote(url, force_download=False, cached_dir='~/.paddle-ernie-cache'):
import hashlib, tempfile, requests, tarfile
sig = hashlib.md5(url.encode('utf8')).hexdigest()
cached_dir = os.path.join(tempfile.gettempdir(), sig)
if force_download or not os.path.exists(cached_dir):
with tempfile.NamedTemporaryFile() as f:
cached_dir = Path(cached_dir).expanduser()
try:
cached_dir.mkdir()
except OSError:
pass
cached_dir_model = cached_dir / sig
if force_download or not cached_dir_model.exists():
cached_dir_model.mkdir()
tmpfile = cached_dir_model / 'tmp'
with tmpfile.open('wb') as f:
#url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz'
r = requests.get(url, stream=True)
total_len = int(r.headers.get('content-length'))
for chunk in tqdm(r.iter_content(chunk_size=1024), total=total_len // 1024, desc='downloading %s' % url, unit='KB'):
for chunk in tqdm(r.iter_content(chunk_size=1024),
total=total_len // 1024,
desc='downloading %s' % url,
unit='KB'):
if chunk:
f.write(chunk)
f.flush()
log.debug('extacting... to %s' % f.name)
with tarfile.open(f.name) as tf:
tf.extractall(path=cached_dir)
log.debug('extacting... to %s' % tmpfile)
with tarfile.open(tmpfile.as_posix()) as tf:
tf.extractall(path=cached_dir_model.as_posix())
os.remove(tmpfile.as_posix())
log.debug('%s cached in %s' % (url, cached_dir))
return cached_dir
return cached_dir_model
def add_docstring(doc):
......
......@@ -24,6 +24,11 @@ import json
import logging
import logging
from functools import partial
import six
if six.PY2:
from pathlib2 import Path
else:
from pathlib import Path
import paddle.fluid.dygraph as D
import paddle.fluid as F
......@@ -191,7 +196,7 @@ class PretrainedModel(object):
}
@classmethod
def from_pretrained(cls, pretrain_dir_or_url, force_download=False, **kwargs):
if pretrain_dir_or_url in cls.resource_map:
if not Path(pretrain_dir_or_url).exists() and pretrain_dir_or_url in cls.resource_map:
url = cls.resource_map[pretrain_dir_or_url]
log.info('get pretrain dir from %s' % url)
pretrain_dir = _fetch_from_remote(url, force_download)
......@@ -199,16 +204,16 @@ class PretrainedModel(object):
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url
if not os.path.exists(pretrain_dir):
if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir)
param_path = os.path.join(pretrain_dir, 'params')
state_dict_path = os.path.join(pretrain_dir, 'saved_weights')
config_path = os.path.join(pretrain_dir, 'ernie_config.json')
param_path = pretrain_dir /'params'
state_dict_path = pretrain_dir / 'saved_weights'
config_path = pretrain_dir / 'ernie_config.json'
if not os.path.exists(config_path):
if not config_path.exists():
raise ValueError('config path not found: %s' % config_path)
name_prefix=kwargs.pop('name', None)
cfg_dict = dict(json.loads(open(config_path).read()), **kwargs)
cfg_dict = dict(json.loads(config_path.open().read()), **kwargs)
model = cls(cfg_dict, name=name_prefix)
log.info('loading pretrained model from %s' % pretrain_dir)
......@@ -217,8 +222,8 @@ class PretrainedModel(object):
# raise NotImplementedError()
# log.debug('load pretrained weight from program state')
# F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix
if os.path.exists(state_dict_path + '.pdparams'):
m, _ = D.load_dygraph(state_dict_path)
if state_dict_path.with_suffix('.pdparams').exists():
m, _ = D.load_dygraph(state_dict_path.as_posix())
for k, v in model.state_dict().items():
if k not in m:
log.warn('param:%s not set in pretrained model, skip' % k)
......
......@@ -91,12 +91,12 @@ class ErnieTokenizer(object):
else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url
if not os.path.exists(pretrain_dir):
if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir)
vocab_path = os.path.join(pretrain_dir, 'vocab.txt')
if not os.path.exists(vocab_path):
vocab_path = pretrain_dir / 'vocab.txt'
if not vocab_path.exists():
raise ValueError('no vocab file in pretrain dir: %s' % pretrain_dir)
vocab_dict = {j.strip().split('\t')[0]: i for i, j in enumerate(open(vocab_path).readlines())}
vocab_dict = {j.strip().split('\t')[0]: i for i, j in enumerate(vocab_path.open(encoding='utf8').readlines())}
t = cls(vocab_dict, **kwargs)
return t
......@@ -207,14 +207,14 @@ class ErnieTinyTokenizer(ErnieTokenizer):
else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url
if not os.path.exists(pretrain_dir):
if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir)
vocab_path = os.path.join(pretrain_dir, 'vocab.txt')
sp_model_path = os.path.join(pretrain_dir, 'subword/spm_cased_simp_sampled.model')
vocab_path = pretrain_dir / 'vocab.txt'
sp_model_path = pretrain_dir / 'subword/spm_cased_simp_sampled.model'
if not os.path.exists(vocab_path):
if not vocab_path.exists():
raise ValueError('no vocab file in pretrain dir: %s' % pretrain_dir)
vocab_dict = {j.strip().split('\t')[0]: i for i, j in enumerate(open(vocab_path).readlines())}
vocab_dict = {j.strip().split('\t')[0]: i for i, j in enumerate(vocab_path.open(encoding='utf8').readlines())}
t = cls(vocab_dict, sp_model_path, **kwargs)
return t
......
......@@ -125,7 +125,7 @@ class LabelColumn(Column):
ids = int(raw)
else:
ids = self.vocab[raw]
return ids
return np.array(ids, dtype=np.int64)
class TextColumn(Column):
......
......@@ -5,3 +5,4 @@ sklearn==0.0
sentencepiece==0.1.8
jieba==0.39
visualdl>=2.0.0b7
pathlib2>=2.3.2
......@@ -22,7 +22,7 @@ with open("README.md", "r", encoding='utf-8') as fh:
setuptools.setup(
name="paddle-ernie", # Replace with your own username
version="0.0.4dev1",
version="0.0.5dev1",
author="Baidu Ernie Team",
author_email="ernieernie.team@gmail.com",
description="A pretrained NLP model for every NLP tasks",
......@@ -33,6 +33,7 @@ setuptools.setup(
install_requires=[
'requests',
'tqdm',
'pathlib2',
],
classifiers=[
'Intended Audience :: Developers',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册