未验证 提交 85e6e6a4 编写于 作者: M Meiyim 提交者: GitHub

Merge pull request #371 from Meiyim/ernie_tiny

Ernie tiny
...@@ -19,7 +19,6 @@ English | [简体中文](./README.zh.md) ...@@ -19,7 +19,6 @@ English | [简体中文](./README.zh.md)
* [Results](#results) * [Results](#results)
* [Results on English Datasets](#results-on-english-datasets) * [Results on English Datasets](#results-on-english-datasets)
* [Results on Chinese Datasets](#results-on-chinese-datasets) * [Results on Chinese Datasets](#results-on-chinese-datasets)
* [Release Notes](#release-notes)
* [Communication](#communication) * [Communication](#communication)
* [Usage](#usage) * [Usage](#usage)
...@@ -615,14 +614,6 @@ LCQMC is a Chinese question semantic matching corpus published in COLING2018. [u ...@@ -615,14 +614,6 @@ LCQMC is a Chinese question semantic matching corpus published in COLING2018. [u
BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equivalence identification. This dataset was published in EMNLP 2018. [url: https://www.aclweb.org/anthology/D18-1536] BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equivalence identification. This dataset was published in EMNLP 2018. [url: https://www.aclweb.org/anthology/D18-1536]
``` ```
## Release Notes
- Aug 21, 2019: featuers update: fp16 finetuning, multiprocess finetining.
- July 30, 2019: release ERNIE 2.0
- Apr 10, 2019: update ERNIE_stable-1.0.1.tar.gz, update config and vocab
- Mar 18, 2019: update ERNIE_stable.tgz
- Mar 15, 2019: release ERNIE 1.0
## Communication ## Communication
...@@ -657,6 +648,7 @@ BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equiv ...@@ -657,6 +648,7 @@ BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equiv
* [FAQ3: Is the argument batch_size for one GPU card or for all GPU cards?](#faq3-is-the--argument-batch_size-for-one-gpu-card-or-for-all-gpu-cards) * [FAQ3: Is the argument batch_size for one GPU card or for all GPU cards?](#faq3-is-the--argument-batch_size-for-one-gpu-card-or-for-all-gpu-cards)
* [FAQ4: Can not find library: libcudnn.so. Please try to add the lib path to LD_LIBRARY_PATH.](#faq4-can-not-find-library-libcudnnso-please-try-to-add-the-lib-path-to-ld_library_path) * [FAQ4: Can not find library: libcudnn.so. Please try to add the lib path to LD_LIBRARY_PATH.](#faq4-can-not-find-library-libcudnnso-please-try-to-add-the-lib-path-to-ld_library_path)
* [FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.](#faq5-can-not-find-library-libncclso-please-try-to-add-the-lib-path-to-ld_library_path) * [FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.](#faq5-can-not-find-library-libncclso-please-try-to-add-the-lib-path-to-ld_library_path)
* [FQA6: Runtime error: `ModuleNotFoundError No module named propeller`](#faq6)
### Install PaddlePaddle ### Install PaddlePaddle
...@@ -1009,3 +1001,9 @@ Export the path of cuda to LD_LIBRARY_PATH, e.g.: `export LD_LIBRARY_PATH=/home/ ...@@ -1009,3 +1001,9 @@ Export the path of cuda to LD_LIBRARY_PATH, e.g.: `export LD_LIBRARY_PATH=/home/
#### FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH. #### FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.
Download [NCCL2](https://developer.nvidia.com/nccl/nccl-download), and export the library path to LD_LIBRARY_PATH, e.g.:`export LD_LIBRARY_PATH=/home/work/nccl/lib` Download [NCCL2](https://developer.nvidia.com/nccl/nccl-download), and export the library path to LD_LIBRARY_PATH, e.g.:`export LD_LIBRARY_PATH=/home/work/nccl/lib`
### FAQ6: Runtime error: `ModuleNotFoundError No module named propeller`<a name="faq6"></a>
you can import propeller to your PYTHONPATH by `export PYTHONPATH:./:$PYTHONPATH`
`
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
* [效果验证](#效果验证) * [效果验证](#效果验证)
* [中文效果验证](#中文效果验证) * [中文效果验证](#中文效果验证)
* [英文效果验证](#英文效果验证) * [英文效果验证](#英文效果验证)
* [开源记录](#开源记录) * [ERNIE tiny](#ernie-tiny)
* [技术交流](#技术交流) * [技术交流](#技术交流)
* [使用](#使用) * [使用](#使用)
...@@ -589,7 +589,6 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址 ...@@ -589,7 +589,6 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
#### GLUE - 验证集结果 #### GLUE - 验证集结果
| <strong>数据集</strong> | <strong>CoLA</strong> | <strong>SST-2</strong> | <strong>MRPC</strong> | <strong>STS-B</strong> | <strong>QQP</strong> | <strong>MNLI-m</strong> | <strong>QNLI</strong> | <strong>RTE</strong> | | <strong>数据集</strong> | <strong>CoLA</strong> | <strong>SST-2</strong> | <strong>MRPC</strong> | <strong>STS-B</strong> | <strong>QQP</strong> | <strong>MNLI-m</strong> | <strong>QNLI</strong> | <strong>RTE</strong> |
...@@ -617,11 +616,34 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址 ...@@ -617,11 +616,34 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
由于 XLNet 暂未公布 GLUE 测试集上的单模型结果,所以我们只与 BERT 进行单模型比较。上表为ERNIE 2.0 单模型在 GLUE 测试集的表现结果。 由于 XLNet 暂未公布 GLUE 测试集上的单模型结果,所以我们只与 BERT 进行单模型比较。上表为ERNIE 2.0 单模型在 GLUE 测试集的表现结果。
## 开源记录 ### ERNIE tiny
- 2019-07-30 发布 ERNIE 2.0
- 2019-04-10 更新: update ERNIE_stable-1.0.1.tar.gz, 将模型参数、配置 ernie_config.json、vocab.txt 打包发布 为了提升ERNIE模型在实际工业应用中的落地能力,我们推出ERNIE-tiny模型。
- 2019-03-18 更新: update ERNIE_stable.tgz
- 2019-03-15 发布 ERNIE 1.0 ![ernie_tiny](.metas/ernie_tiny.png)
ERNIE-tiny作为小型化ERNIE,采用了以下4点技术,保证了在实际真实数据中将近4.3倍的预测提速。
1. 浅:12层的ERNIE Base模型直接压缩为3层,线性提速4倍,但效果也会有较大幅度的下降;
1. 胖:模型变浅带来的损失可通过hidden size的增大来弥补。由于fluid inference框架对于通用矩阵运算(gemm)的最后一维(hidden size)参数的不同取值会有深度的优化,因为将hidden size从768提升至1024并不会带来速度线性的增加;
1. 短:ERNIE Tiny是首个开源的中文subword粒度的预训练模型。这里的短是指通过subword粒度替换字(char)粒度,能够明显地缩短输入文本的长度,而输入文本长度是和预测速度有线性相关。统计表明,在XNLI dev集上采用subword字典切分出来的序列长度比字表平均缩短40%;
1. 萃:为了进一步提升模型的效果,ERNIE Tiny扮演学生角色,利用模型蒸馏的方式在Transformer层和Prediction层去学习教师模型ERNIE模型对应层的分布或输出,这种方式能够缩近ERNIE Tiny和ERNIE的效果差异。
#### Benchmark
ERNIE Tiny轻量级模型在公开数据集的效果如下所示,任务均值相对于ERNIE Base只下降了2.37%,但相对于“SOTA Before BERT”提升了8%。在延迟测试中,ERNIE Tiny能够带来4.3倍的速度提升
(测试环境为:GPU P4,Paddle Inference C++ API,XNLI Dev集,最大maxlen=128,测试结果10次均值)
|model|XNLI(acc)|LCQCM(acc)|CHNSENTICORP(acc)|NLPCC-DBQA(mrr/f1)|Average|Latency
|--|--|--|--|--|--|--|
|SOTA-before-ERNIE|68.3|83.4|92.2|72.01/-|78.98|-|
|ERNIE2.0-base|79.7|87.9|95.5|95.7/85.3|89.70|146ms(4.3x)|
|ERNIE-tiny-subword|75.1|86.1|95.2|92.9/78.6|87.33|633ms(1x)|
## 技术交流 ## 技术交流
...@@ -646,6 +668,7 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址 ...@@ -646,6 +668,7 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
* [序列标注任务](#序列标注任务) * [序列标注任务](#序列标注任务)
* [实体识别](#实体识别) * [实体识别](#实体识别)
* [阅读理解任务](#阅读理解任务-1) * [阅读理解任务](#阅读理解任务-1)
* [ERNIE tiny](#tune-ernie-tiny)
* [利用Propeller进行二次开发](#利用propeller进行二次开发) * [利用Propeller进行二次开发](#利用propeller进行二次开发)
* [预训练 (ERNIE 1.0)](#预训练-ernie-10) * [预训练 (ERNIE 1.0)](#预训练-ernie-10)
* [数据预处理](#数据预处理) * [数据预处理](#数据预处理)
...@@ -695,6 +718,7 @@ pip install -r requirements.txt ...@@ -695,6 +718,7 @@ pip install -r requirements.txt
| [ERNIE 1.0 中文 Base 模型(max_len=512)](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz) | 包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json| | [ERNIE 1.0 中文 Base 模型(max_len=512)](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz) | 包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json|
| [ERNIE 2.0 英文 Base 模型](https://ernie.bj.bcebos.com/ERNIE_Base_en_stable-2.0.0.tar.gz) | 包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json| | [ERNIE 2.0 英文 Base 模型](https://ernie.bj.bcebos.com/ERNIE_Base_en_stable-2.0.0.tar.gz) | 包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json|
| [ERNIE 2.0 英文 Large 模型](https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz) | 包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json| | [ERNIE 2.0 英文 Large 模型](https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz) | 包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json|
| [ERNIE tiny 中文模型](https://ernie.bj.bcebos.com/ernie_tiny.tar.gz)|包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json 以及切词词表|
...@@ -894,6 +918,16 @@ text_a label ...@@ -894,6 +918,16 @@ text_a label
[test evaluation] em: 88.061838, f1: 93.520152, avg: 90.790995, question_num: 3493 [test evaluation] em: 88.061838, f1: 93.520152, avg: 90.790995, question_num: 3493
``` ```
### ERNIE tiny <a name="tune-ernie-tiny"></a>
ERNIE tiny 模型采用了subword粒度输入,需要在数据前处理中加入切词(segmentation)并使用[sentence piece](https://github.com/google/sentencepiece)进行tokenization.
segmentation 以及 tokenization 需要使用的模型包含在了 ERNIE tiny 的[预训练模型文件](#预训练模型下载)中,分别是 `./subword/dict.wordseg.pickle` 和 `./subword/spm_cased_simp_sampled.model`.
目前`./example/`下的代码针对 ERNIE tiny 的前处理进行了适配只需在脚本中通过 `--sentence_piece_model` 引入tokenization 模型,再通过 `--word_dict` 引入 segmentation 模型之后即可进行 ERNIE tiny 的 Fine-tune。
对于命名实体识别类型的任务,为了跟输入标注对齐,ERNIE tiny 仍然采用中文单字粒度进行作为输入。因此使用 `./example/finetune_ner.py` 时只需要打开 `--use_sentence_piece_vocab` 即可。
具体的使用方法可以参考[下节](#利用propeller进行二次开发).
## 利用Propeller进行二次开发 ## 利用Propeller进行二次开发
[Propeller](./propeller/README.md) 是基于PaddlePaddle构建的一键式训练API,对于具备一定机器学习应用经验的开发者可以使用Propeller获得定制化开发体验。 [Propeller](./propeller/README.md) 是基于PaddlePaddle构建的一键式训练API,对于具备一定机器学习应用经验的开发者可以使用Propeller获得定制化开发体验。
...@@ -1099,6 +1133,6 @@ python -u infer_classifyer.py \ ...@@ -1099,6 +1133,6 @@ python -u infer_classifyer.py \
需要先下载 [NCCL](https://developer.nvidia.com/nccl/nccl-download),然后在 LD_LIBRARY_PATH 中添加 NCCL 库的路径,如`export LD_LIBRARY_PATH=/home/work/nccl/lib` 需要先下载 [NCCL](https://developer.nvidia.com/nccl/nccl-download),然后在 LD_LIBRARY_PATH 中添加 NCCL 库的路径,如`export LD_LIBRARY_PATH=/home/work/nccl/lib`
### FQA6: 运行报错`ModuleNotFoundError: No module named 'propeller'`<a name="faq6"></a> ### FAQ6: 运行报错`ModuleNotFoundError: No module named 'propeller'`<a name="faq6"></a>
您可以通过`export PYTHONPATH=./:$PYTHONPATH`的方式引入Propeller. 您可以通过`export PYTHONPATH=./:$PYTHONPATH`的方式引入Propeller.
...@@ -4,6 +4,7 @@ import re ...@@ -4,6 +4,7 @@ import re
from propeller import log from propeller import log
import itertools import itertools
from propeller.paddle.data import Dataset from propeller.paddle.data import Dataset
import pickle
import six import six
...@@ -101,7 +102,7 @@ class SpaceTokenizer(object): ...@@ -101,7 +102,7 @@ class SpaceTokenizer(object):
class CharTokenizer(object): class CharTokenizer(object):
def __init__(self, vocab, lower=True): def __init__(self, vocab, lower=True, sentencepiece_style_vocab=False):
""" """
char tokenizer (wordpiece english) char tokenizer (wordpiece english)
normed txt(space seperated or not) => list of word-piece normed txt(space seperated or not) => list of word-piece
...@@ -110,6 +111,7 @@ class CharTokenizer(object): ...@@ -110,6 +111,7 @@ class CharTokenizer(object):
#self.pat = re.compile(r'([,.!?\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]|[\u4e00-\u9fa5]|[a-zA-Z0-9]+)') #self.pat = re.compile(r'([,.!?\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]|[\u4e00-\u9fa5]|[a-zA-Z0-9]+)')
self.pat = re.compile(r'([a-zA-Z0-9]+|\S)') self.pat = re.compile(r'([a-zA-Z0-9]+|\S)')
self.lower = lower self.lower = lower
self.sentencepiece_style_vocab = sentencepiece_style_vocab
def __call__(self, sen): def __call__(self, sen):
if len(sen) == 0: if len(sen) == 0:
...@@ -119,11 +121,51 @@ class CharTokenizer(object): ...@@ -119,11 +121,51 @@ class CharTokenizer(object):
sen = sen.lower() sen = sen.lower()
res = [] res = []
for match in self.pat.finditer(sen): for match in self.pat.finditer(sen):
words, _ = wordpiece(match.group(0), vocab=self.vocab, unk_token='[UNK]') words, _ = wordpiece(match.group(0), vocab=self.vocab, unk_token='[UNK]', sentencepiece_style_vocab=self.sentencepiece_style_vocab)
res.extend(words) res.extend(words)
return res return res
class WSSPTokenizer(object):
def __init__(self, sp_model_dir, word_dict, ws=True, lower=True):
self.ws = ws
self.lower = lower
self.dict = pickle.load(open(word_dict, 'rb'), encoding='utf8')
import sentencepiece as spm
self.sp_model = spm.SentencePieceProcessor()
self.window_size = 5
self.sp_model.Load(sp_model_dir)
def cut(self, chars):
words = []
idx = 0
while idx < len(chars):
matched = False
for i in range(self.window_size, 0, -1):
cand = chars[idx: idx+i]
if cand in self.dict:
words.append(cand)
matched = True
break
if not matched:
i = 1
words.append(chars[idx])
idx += i
return words
def __call__(self, sen):
sen = sen.decode('utf8')
if self.ws:
sen = [s for s in self.cut(sen) if s != ' ']
else:
sen = sen.split(' ')
if self.lower:
sen = [s.lower() for s in sen]
sen = ' '.join(sen)
ret = self.sp_model.EncodeAsPieces(sen)
return ret
def build_2_pair(seg_a, seg_b, max_seqlen, cls_id, sep_id): def build_2_pair(seg_a, seg_b, max_seqlen, cls_id, sep_id):
token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0 token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0
token_type_b = np.ones_like(seg_b, dtype=np.int64) * 1 token_type_b = np.ones_like(seg_b, dtype=np.int64) * 1
......
...@@ -55,7 +55,7 @@ class ClassificationErnieModel(propeller.train.Model): ...@@ -55,7 +55,7 @@ class ClassificationErnieModel(propeller.train.Model):
pos_ids = L.cast(pos_ids, 'int64') pos_ids = L.cast(pos_ids, 'int64')
pos_ids.stop_gradient = True pos_ids.stop_gradient = True
input_mask.stop_gradient = True input_mask.stop_gradient = True
task_ids = L.zeros_like(src_ids) + self.hparam.task_id #this shit wont use at the moment task_ids = L.zeros_like(src_ids) + self.hparam.task_id
task_ids.stop_gradient = True task_ids.stop_gradient = True
ernie = ErnieModel( ernie = ErnieModel(
...@@ -128,6 +128,8 @@ if __name__ == '__main__': ...@@ -128,6 +128,8 @@ if __name__ == '__main__':
parser.add_argument('--vocab_file', type=str, required=True) parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--do_predict', action='store_true') parser.add_argument('--do_predict', action='store_true')
parser.add_argument('--warm_start_from', type=str) parser.add_argument('--warm_start_from', type=str)
parser.add_argument('--sentence_piece_model', type=str, default=None)
parser.add_argument('--word_dict', type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
run_config = propeller.parse_runconfig(args) run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args) hparams = propeller.parse_hparam(args)
...@@ -138,7 +140,12 @@ if __name__ == '__main__': ...@@ -138,7 +140,12 @@ if __name__ == '__main__':
cls_id = vocab['[CLS]'] cls_id = vocab['[CLS]']
unk_id = vocab['[UNK]'] unk_id = vocab['[UNK]']
tokenizer = utils.data.CharTokenizer(vocab.keys()) if args.sentence_piece_model is not None:
if args.word_dict is None:
raise ValueError('--word_dict no specified in subword Model')
tokenizer = utils.data.WSSPTokenizer(args.sentence_piece_model, args.word_dict, ws=True, lower=True)
else:
tokenizer = utils.data.CharTokenizer(vocab.keys())
def tokenizer_func(inputs): def tokenizer_func(inputs):
'''avoid pickle error''' '''avoid pickle error'''
...@@ -179,7 +186,7 @@ if __name__ == '__main__': ...@@ -179,7 +186,7 @@ if __name__ == '__main__':
dev_ds.data_shapes = shapes dev_ds.data_shapes = shapes
dev_ds.data_types = types dev_ds.data_types = types
varname_to_warmstart = re.compile('encoder.*|pooled.*|.*embedding|pre_encoder_.*') varname_to_warmstart = re.compile(r'^encoder.*[wb]_0$|^.*embedding$|^.*bias$|^.*scale$|^pooled_fc.[wb]_0$')
warm_start_dir = args.warm_start_from warm_start_dir = args.warm_start_from
ws = propeller.WarmStartSetting( ws = propeller.WarmStartSetting(
predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.path.exists(os.path.join(warm_start_dir, v.name)), predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.path.exists(os.path.join(warm_start_dir, v.name)),
......
...@@ -32,7 +32,6 @@ import paddle.fluid.layers as L ...@@ -32,7 +32,6 @@ import paddle.fluid.layers as L
from model.ernie import ErnieModel from model.ernie import ErnieModel
from optimization import optimization from optimization import optimization
import tokenization
import utils.data import utils.data
from propeller import log from propeller import log
...@@ -121,7 +120,7 @@ class SequenceLabelErnieModel(propeller.train.Model): ...@@ -121,7 +120,7 @@ class SequenceLabelErnieModel(propeller.train.Model):
def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_size, max_seqlen, is_train): def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_size, max_seqlen, is_train):
label_map = {v: i for i, v in enumerate(label_list)} label_map = {v: i for i, v in enumerate(label_list)}
no_entity_id = label_map['O'] no_entity_id = label_map['O']
delimiter = '' delimiter = b''
def read_bio_data(filename): def read_bio_data(filename):
ds = propeller.data.Dataset.from_file(filename) ds = propeller.data.Dataset.from_file(filename)
...@@ -132,10 +131,10 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ ...@@ -132,10 +131,10 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_
while 1: while 1:
line = next(iterator) line = next(iterator)
cols = line.rstrip(b'\n').split(b'\t') cols = line.rstrip(b'\n').split(b'\t')
tokens = cols[0].split(delimiter)
labels = cols[1].split(delimiter)
if len(cols) != 2: if len(cols) != 2:
continue continue
tokens = tokenization.convert_to_unicode(cols[0]).split(delimiter)
labels = tokenization.convert_to_unicode(cols[1]).split(delimiter)
if len(tokens) != len(labels) or len(tokens) == 0: if len(tokens) != len(labels) or len(tokens) == 0:
continue continue
yield [tokens, labels] yield [tokens, labels]
...@@ -151,7 +150,8 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ ...@@ -151,7 +150,8 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_
ret_tokens = [] ret_tokens = []
ret_labels = [] ret_labels = []
for token, label in zip(tokens, labels): for token, label in zip(tokens, labels):
sub_token = tokenizer.tokenize(token) sub_token = tokenizer(token)
label = label.decode('utf8')
if len(sub_token) == 0: if len(sub_token) == 0:
continue continue
ret_tokens.extend(sub_token) ret_tokens.extend(sub_token)
...@@ -179,7 +179,7 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ ...@@ -179,7 +179,7 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_
labels = labels[: max_seqlen - 2] labels = labels[: max_seqlen - 2]
tokens = ['[CLS]'] + tokens + ['[SEP]'] tokens = ['[CLS]'] + tokens + ['[SEP]']
token_ids = tokenizer.convert_tokens_to_ids(tokens) token_ids = [vocab[t] for t in tokens]
label_ids = [no_entity_id] + [label_map[x] for x in labels] + [no_entity_id] label_ids = [no_entity_id] + [label_map[x] for x in labels] + [no_entity_id]
token_type_ids = [0] * len(token_ids) token_type_ids = [0] * len(token_ids)
input_seqlen = len(token_ids) input_seqlen = len(token_ids)
...@@ -211,7 +211,7 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ ...@@ -211,7 +211,7 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_
def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seqlen): def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seqlen):
delimiter = '' delimiter = b''
def stdin_gen(): def stdin_gen():
if six.PY3: if six.PY3:
...@@ -232,9 +232,9 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql ...@@ -232,9 +232,9 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql
while 1: while 1:
line, = next(iterator) line, = next(iterator)
cols = line.rstrip(b'\n').split(b'\t') cols = line.rstrip(b'\n').split(b'\t')
tokens = cols[0].split(delimiter)
if len(cols) != 1: if len(cols) != 1:
continue continue
tokens = tokenization.convert_to_unicode(cols[0]).split(delimiter)
if len(tokens) == 0: if len(tokens) == 0:
continue continue
yield tokens, yield tokens,
...@@ -247,7 +247,7 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql ...@@ -247,7 +247,7 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql
tokens, = next(iterator) tokens, = next(iterator)
ret_tokens = [] ret_tokens = []
for token in tokens: for token in tokens:
sub_token = tokenizer.tokenize(token) sub_token = tokenizer(token)
if len(sub_token) == 0: if len(sub_token) == 0:
continue continue
ret_tokens.extend(sub_token) ret_tokens.extend(sub_token)
...@@ -266,7 +266,7 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql ...@@ -266,7 +266,7 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql
tokens = tokens[: max_seqlen - 2] tokens = tokens[: max_seqlen - 2]
tokens = ['[CLS]'] + tokens + ['[SEP]'] tokens = ['[CLS]'] + tokens + ['[SEP]']
token_ids = tokenizer.convert_tokens_to_ids(tokens) token_ids = [vocab[t] for t in tokens]
token_type_ids = [0] * len(token_ids) token_type_ids = [0] * len(token_ids)
input_seqlen = len(token_ids) input_seqlen = len(token_ids)
...@@ -296,13 +296,15 @@ if __name__ == '__main__': ...@@ -296,13 +296,15 @@ if __name__ == '__main__':
parser.add_argument('--data_dir', type=str, required=True) parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--vocab_file', type=str, required=True) parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--do_predict', action='store_true') parser.add_argument('--do_predict', action='store_true')
parser.add_argument('--use_sentence_piece_vocab', action='store_true')
parser.add_argument('--warm_start_from', type=str) parser.add_argument('--warm_start_from', type=str)
args = parser.parse_args() args = parser.parse_args()
run_config = propeller.parse_runconfig(args) run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args) hparams = propeller.parse_hparam(args)
tokenizer = tokenization.FullTokenizer(args.vocab_file)
vocab = tokenizer.vocab vocab = {j.strip().split('\t')[0]: i for i, j in enumerate(open(args.vocab_file, 'r', encoding='utf8'))}
tokenizer = utils.data.CharTokenizer(vocab, sentencepiece_style_vocab=args.use_sentence_piece_vocab)
sep_id = vocab['[SEP]'] sep_id = vocab['[SEP]']
cls_id = vocab['[CLS]'] cls_id = vocab['[CLS]']
unk_id = vocab['[UNK]'] unk_id = vocab['[UNK]']
...@@ -358,7 +360,7 @@ if __name__ == '__main__': ...@@ -358,7 +360,7 @@ if __name__ == '__main__':
from_dir=warm_start_dir from_dir=warm_start_dir
) )
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1']) best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
propeller.train.train_and_eval( propeller.train.train_and_eval(
model_class_or_model_fn=SequenceLabelErnieModel, model_class_or_model_fn=SequenceLabelErnieModel,
params=hparams, params=hparams,
...@@ -387,7 +389,6 @@ if __name__ == '__main__': ...@@ -387,7 +389,6 @@ if __name__ == '__main__':
predict_ds.data_types = types predict_ds.data_types = types
rev_label_map = {i: v for i, v in enumerate(label_list)} rev_label_map = {i: v for i, v in enumerate(label_list)}
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
learner = propeller.Learner(SequenceLabelErnieModel, run_config, hparams) learner = propeller.Learner(SequenceLabelErnieModel, run_config, hparams)
for pred, _ in learner.predict(predict_ds, ckpt=-1): for pred, _ in learner.predict(predict_ds, ckpt=-1):
pred_str = ' '.join([rev_label_map[idx] for idx in np.argmax(pred, 1).tolist()]) pred_str = ' '.join([rev_label_map[idx] for idx in np.argmax(pred, 1).tolist()])
......
...@@ -146,6 +146,7 @@ if __name__ == '__main__': ...@@ -146,6 +146,7 @@ if __name__ == '__main__':
parser.add_argument('--data_dir', type=str, required=True) parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--warm_start_from', type=str) parser.add_argument('--warm_start_from', type=str)
parser.add_argument('--sentence_piece_model', type=str, default=None) parser.add_argument('--sentence_piece_model', type=str, default=None)
parser.add_argument('--word_dict', type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
run_config = propeller.parse_runconfig(args) run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args) hparams = propeller.parse_hparam(args)
...@@ -157,7 +158,9 @@ if __name__ == '__main__': ...@@ -157,7 +158,9 @@ if __name__ == '__main__':
unk_id = vocab['[UNK]'] unk_id = vocab['[UNK]']
if args.sentence_piece_model is not None: if args.sentence_piece_model is not None:
tokenizer = utils.data.JBSPTokenizer(args.sentence_piece_model, jb=True, lower=True) if args.word_dict is None:
raise ValueError('--word_dict no specified in subword Model')
tokenizer = utils.data.WSSPTokenizer(args.sentence_piece_model, args.word_dict, ws=True, lower=True)
else: else:
tokenizer = utils.data.CharTokenizer(vocab.keys()) tokenizer = utils.data.CharTokenizer(vocab.keys())
...@@ -218,7 +221,7 @@ if __name__ == '__main__': ...@@ -218,7 +221,7 @@ if __name__ == '__main__':
from_dir=warm_start_dir from_dir=warm_start_dir
) )
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1']) best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
propeller.train_and_eval( propeller.train_and_eval(
model_class_or_model_fn=RankingErnieModel, model_class_or_model_fn=RankingErnieModel,
params=hparams, params=hparams,
...@@ -258,6 +261,7 @@ if __name__ == '__main__': ...@@ -258,6 +261,7 @@ if __name__ == '__main__':
est = propeller.Learner(RankingErnieModel, run_config, hparams) est = propeller.Learner(RankingErnieModel, run_config, hparams)
for qid, res in est.predict(predict_ds, ckpt=-1): for qid, res in est.predict(predict_ds, ckpt=-1):
print('%d\t%d\t%.5f\t%.5f' % (qid[0], np.argmax(res), res[0], res[1])) print('%d\t%d\t%.5f\t%.5f' % (qid[0], np.argmax(res), res[0], res[1]))
#for i in predict_ds: #for i in predict_ds:
# sen = i[0] # sen = i[0]
# for ss in np.squeeze(sen): # for ss in np.squeeze(sen):
......
...@@ -5,3 +5,5 @@ scikit-learn==0.20.3 ...@@ -5,3 +5,5 @@ scikit-learn==0.20.3
scipy==1.2.1 scipy==1.2.1
six==1.11.0 six==1.11.0
sklearn==0.0 sklearn==0.0
sentencepiece==0.1.8
paddlepaddle-gpu==1.5.2.post107
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册