diff --git a/.metas/ernie_tiny.png b/.metas/ernie_tiny.png new file mode 100644 index 0000000000000000000000000000000000000000..580d9381c75232e09ef1f26a33b0ee5deba1f11f Binary files /dev/null and b/.metas/ernie_tiny.png differ diff --git a/README.md b/README.md index 58b0bd325878ff9c68927f0717ad5ed0688dab41..7a3854bb5f8a35e7d2acb7d8457b5b3549fefc9a 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,6 @@ English | [简体中文](./README.zh.md) * [Results](#results) * [Results on English Datasets](#results-on-english-datasets) * [Results on Chinese Datasets](#results-on-chinese-datasets) - * [Release Notes](#release-notes) * [Communication](#communication) * [Usage](#usage) @@ -615,14 +614,6 @@ LCQMC is a Chinese question semantic matching corpus published in COLING2018. [u BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equivalence identification. This dataset was published in EMNLP 2018. [url: https://www.aclweb.org/anthology/D18-1536] ``` -## Release Notes - -- Aug 21, 2019: featuers update: fp16 finetuning, multiprocess finetining. -- July 30, 2019: release ERNIE 2.0 -- Apr 10, 2019: update ERNIE_stable-1.0.1.tar.gz, update config and vocab -- Mar 18, 2019: update ERNIE_stable.tgz -- Mar 15, 2019: release ERNIE 1.0 - ## Communication @@ -657,6 +648,7 @@ BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equiv * [FAQ3: Is the argument batch_size for one GPU card or for all GPU cards?](#faq3-is-the--argument-batch_size-for-one-gpu-card-or-for-all-gpu-cards) * [FAQ4: Can not find library: libcudnn.so. Please try to add the lib path to LD_LIBRARY_PATH.](#faq4-can-not-find-library-libcudnnso-please-try-to-add-the-lib-path-to-ld_library_path) * [FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.](#faq5-can-not-find-library-libncclso-please-try-to-add-the-lib-path-to-ld_library_path) + * [FQA6: Runtime error: `ModuleNotFoundError No module named propeller`](#faq6) ### Install PaddlePaddle @@ -1009,3 +1001,9 @@ Export the path of cuda to LD_LIBRARY_PATH, e.g.: `export LD_LIBRARY_PATH=/home/ #### FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH. Download [NCCL2](https://developer.nvidia.com/nccl/nccl-download), and export the library path to LD_LIBRARY_PATH, e.g.:`export LD_LIBRARY_PATH=/home/work/nccl/lib` + +### FAQ6: Runtime error: `ModuleNotFoundError No module named propeller` + +you can import propeller to your PYTHONPATH by `export PYTHONPATH:./:$PYTHONPATH` +` + diff --git a/README.zh.md b/README.zh.md index 461042bb69a49213f56011a7aa0d85235c5db99b..9101f437776212530b214134aee5a42a567ebdc2 100644 --- a/README.zh.md +++ b/README.zh.md @@ -19,7 +19,7 @@ * [效果验证](#效果验证) * [中文效果验证](#中文效果验证) * [英文效果验证](#英文效果验证) - * [开源记录](#开源记录) + * [ERNIE tiny](#ernie-tiny) * [技术交流](#技术交流) * [使用](#使用) @@ -589,7 +589,6 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址 - #### GLUE - 验证集结果 | 数据集 | CoLA | SST-2 | MRPC | STS-B | QQP | MNLI-m | QNLI | RTE | @@ -617,11 +616,34 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址 由于 XLNet 暂未公布 GLUE 测试集上的单模型结果,所以我们只与 BERT 进行单模型比较。上表为ERNIE 2.0 单模型在 GLUE 测试集的表现结果。 -## 开源记录 -- 2019-07-30 发布 ERNIE 2.0 -- 2019-04-10 更新: update ERNIE_stable-1.0.1.tar.gz, 将模型参数、配置 ernie_config.json、vocab.txt 打包发布 -- 2019-03-18 更新: update ERNIE_stable.tgz -- 2019-03-15 发布 ERNIE 1.0 +### ERNIE tiny + +为了提升ERNIE模型在实际工业应用中的落地能力,我们推出ERNIE-tiny模型。 + +![ernie_tiny](.metas/ernie_tiny.png) + +ERNIE-tiny作为小型化ERNIE,采用了以下4点技术,保证了在实际真实数据中将近4.3倍的预测提速。 + +1. 浅:12层的ERNIE Base模型直接压缩为3层,线性提速4倍,但效果也会有较大幅度的下降; + +1. 胖:模型变浅带来的损失可通过hidden size的增大来弥补。由于fluid inference框架对于通用矩阵运算(gemm)的最后一维(hidden size)参数的不同取值会有深度的优化,因为将hidden size从768提升至1024并不会带来速度线性的增加; + +1. 短:ERNIE Tiny是首个开源的中文subword粒度的预训练模型。这里的短是指通过subword粒度替换字(char)粒度,能够明显地缩短输入文本的长度,而输入文本长度是和预测速度有线性相关。统计表明,在XNLI dev集上采用subword字典切分出来的序列长度比字表平均缩短40%; + +1. 萃:为了进一步提升模型的效果,ERNIE Tiny扮演学生角色,利用模型蒸馏的方式在Transformer层和Prediction层去学习教师模型ERNIE模型对应层的分布或输出,这种方式能够缩近ERNIE Tiny和ERNIE的效果差异。 + + +#### Benchmark + +ERNIE Tiny轻量级模型在公开数据集的效果如下所示,任务均值相对于ERNIE Base只下降了2.37%,但相对于“SOTA Before BERT”提升了8%。在延迟测试中,ERNIE Tiny能够带来4.3倍的速度提升 +(测试环境为:GPU P4,Paddle Inference C++ API,XNLI Dev集,最大maxlen=128,测试结果10次均值) + +|model|XNLI(acc)|LCQCM(acc)|CHNSENTICORP(acc)|NLPCC-DBQA(mrr/f1)|Average|Latency +|--|--|--|--|--|--|--| +|SOTA-before-ERNIE|68.3|83.4|92.2|72.01/-|78.98|-| +|ERNIE2.0-base|79.7|87.9|95.5|95.7/85.3|89.70|146ms(4.3x)| +|ERNIE-tiny-subword|75.1|86.1|95.2|92.9/78.6|87.33|633ms(1x)| + ## 技术交流 @@ -646,6 +668,7 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址 * [序列标注任务](#序列标注任务) * [实体识别](#实体识别) * [阅读理解任务](#阅读理解任务-1) + * [ERNIE tiny](#tune-ernie-tiny) * [利用Propeller进行二次开发](#利用propeller进行二次开发) * [预训练 (ERNIE 1.0)](#预训练-ernie-10) * [数据预处理](#数据预处理) @@ -695,6 +718,7 @@ pip install -r requirements.txt | [ERNIE 1.0 中文 Base 模型(max_len=512)](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz) | 包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json| | [ERNIE 2.0 英文 Base 模型](https://ernie.bj.bcebos.com/ERNIE_Base_en_stable-2.0.0.tar.gz) | 包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json| | [ERNIE 2.0 英文 Large 模型](https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz) | 包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json| +| [ERNIE tiny 中文模型](https://ernie.bj.bcebos.com/ernie_tiny.tar.gz)|包含预训练模型参数、词典 vocab.txt、模型配置 ernie_config.json 以及切词词表| @@ -894,6 +918,16 @@ text_a label [test evaluation] em: 88.061838, f1: 93.520152, avg: 90.790995, question_num: 3493 ``` + +### ERNIE tiny + +ERNIE tiny 模型采用了subword粒度输入,需要在数据前处理中加入切词(segmentation)并使用[sentence piece](https://github.com/google/sentencepiece)进行tokenization. +segmentation 以及 tokenization 需要使用的模型包含在了 ERNIE tiny 的[预训练模型文件](#预训练模型下载)中,分别是 `./subword/dict.wordseg.pickle` 和 `./subword/spm_cased_simp_sampled.model`. + +目前`./example/`下的代码针对 ERNIE tiny 的前处理进行了适配只需在脚本中通过 `--sentence_piece_model` 引入tokenization 模型,再通过 `--word_dict` 引入 segmentation 模型之后即可进行 ERNIE tiny 的 Fine-tune。 +对于命名实体识别类型的任务,为了跟输入标注对齐,ERNIE tiny 仍然采用中文单字粒度进行作为输入。因此使用 `./example/finetune_ner.py` 时只需要打开 `--use_sentence_piece_vocab` 即可。 +具体的使用方法可以参考[下节](#利用propeller进行二次开发). + ## 利用Propeller进行二次开发 [Propeller](./propeller/README.md) 是基于PaddlePaddle构建的一键式训练API,对于具备一定机器学习应用经验的开发者可以使用Propeller获得定制化开发体验。 @@ -1099,6 +1133,6 @@ python -u infer_classifyer.py \ 需要先下载 [NCCL](https://developer.nvidia.com/nccl/nccl-download),然后在 LD_LIBRARY_PATH 中添加 NCCL 库的路径,如`export LD_LIBRARY_PATH=/home/work/nccl/lib` -### FQA6: 运行报错`ModuleNotFoundError: No module named 'propeller'` +### FAQ6: 运行报错`ModuleNotFoundError: No module named 'propeller'` 您可以通过`export PYTHONPATH=./:$PYTHONPATH`的方式引入Propeller. diff --git a/ernie/utils/data.py b/ernie/utils/data.py index 42ff3d816d8e4fa77a539db61925f97a83281606..8f54826a12aced7ff683571548adeb8d8ce528d7 100644 --- a/ernie/utils/data.py +++ b/ernie/utils/data.py @@ -4,6 +4,7 @@ import re from propeller import log import itertools from propeller.paddle.data import Dataset +import pickle import six @@ -101,7 +102,7 @@ class SpaceTokenizer(object): class CharTokenizer(object): - def __init__(self, vocab, lower=True): + def __init__(self, vocab, lower=True, sentencepiece_style_vocab=False): """ char tokenizer (wordpiece english) normed txt(space seperated or not) => list of word-piece @@ -110,6 +111,7 @@ class CharTokenizer(object): #self.pat = re.compile(r'([,.!?\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]|[\u4e00-\u9fa5]|[a-zA-Z0-9]+)') self.pat = re.compile(r'([a-zA-Z0-9]+|\S)') self.lower = lower + self.sentencepiece_style_vocab = sentencepiece_style_vocab def __call__(self, sen): if len(sen) == 0: @@ -119,11 +121,51 @@ class CharTokenizer(object): sen = sen.lower() res = [] for match in self.pat.finditer(sen): - words, _ = wordpiece(match.group(0), vocab=self.vocab, unk_token='[UNK]') + words, _ = wordpiece(match.group(0), vocab=self.vocab, unk_token='[UNK]', sentencepiece_style_vocab=self.sentencepiece_style_vocab) res.extend(words) return res +class WSSPTokenizer(object): + def __init__(self, sp_model_dir, word_dict, ws=True, lower=True): + self.ws = ws + self.lower = lower + self.dict = pickle.load(open(word_dict, 'rb'), encoding='utf8') + import sentencepiece as spm + self.sp_model = spm.SentencePieceProcessor() + self.window_size = 5 + self.sp_model.Load(sp_model_dir) + + def cut(self, chars): + words = [] + idx = 0 + while idx < len(chars): + matched = False + for i in range(self.window_size, 0, -1): + cand = chars[idx: idx+i] + if cand in self.dict: + words.append(cand) + matched = True + break + if not matched: + i = 1 + words.append(chars[idx]) + idx += i + return words + + def __call__(self, sen): + sen = sen.decode('utf8') + if self.ws: + sen = [s for s in self.cut(sen) if s != ' '] + else: + sen = sen.split(' ') + if self.lower: + sen = [s.lower() for s in sen] + sen = ' '.join(sen) + ret = self.sp_model.EncodeAsPieces(sen) + return ret + + def build_2_pair(seg_a, seg_b, max_seqlen, cls_id, sep_id): token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0 token_type_b = np.ones_like(seg_b, dtype=np.int64) * 1 diff --git a/example/finetune_classifier.py b/example/finetune_classifier.py index 77a68ad989def8d69d723cea69989ddfa067f577..fb65ec6abf4ae91ec4f932259c135febf47fba8a 100644 --- a/example/finetune_classifier.py +++ b/example/finetune_classifier.py @@ -55,7 +55,7 @@ class ClassificationErnieModel(propeller.train.Model): pos_ids = L.cast(pos_ids, 'int64') pos_ids.stop_gradient = True input_mask.stop_gradient = True - task_ids = L.zeros_like(src_ids) + self.hparam.task_id #this shit wont use at the moment + task_ids = L.zeros_like(src_ids) + self.hparam.task_id task_ids.stop_gradient = True ernie = ErnieModel( @@ -128,6 +128,8 @@ if __name__ == '__main__': parser.add_argument('--vocab_file', type=str, required=True) parser.add_argument('--do_predict', action='store_true') parser.add_argument('--warm_start_from', type=str) + parser.add_argument('--sentence_piece_model', type=str, default=None) + parser.add_argument('--word_dict', type=str, default=None) args = parser.parse_args() run_config = propeller.parse_runconfig(args) hparams = propeller.parse_hparam(args) @@ -138,7 +140,12 @@ if __name__ == '__main__': cls_id = vocab['[CLS]'] unk_id = vocab['[UNK]'] - tokenizer = utils.data.CharTokenizer(vocab.keys()) + if args.sentence_piece_model is not None: + if args.word_dict is None: + raise ValueError('--word_dict no specified in subword Model') + tokenizer = utils.data.WSSPTokenizer(args.sentence_piece_model, args.word_dict, ws=True, lower=True) + else: + tokenizer = utils.data.CharTokenizer(vocab.keys()) def tokenizer_func(inputs): '''avoid pickle error''' @@ -179,7 +186,7 @@ if __name__ == '__main__': dev_ds.data_shapes = shapes dev_ds.data_types = types - varname_to_warmstart = re.compile('encoder.*|pooled.*|.*embedding|pre_encoder_.*') + varname_to_warmstart = re.compile(r'^encoder.*[wb]_0$|^.*embedding$|^.*bias$|^.*scale$|^pooled_fc.[wb]_0$') warm_start_dir = args.warm_start_from ws = propeller.WarmStartSetting( predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.path.exists(os.path.join(warm_start_dir, v.name)), diff --git a/example/finetune_ner.py b/example/finetune_ner.py index 89a9e22ffc16d1be3149333a395fcd3c7a8d4f55..954f2a7b44de714ed6e74b56fe96480e12b3cb89 100644 --- a/example/finetune_ner.py +++ b/example/finetune_ner.py @@ -32,7 +32,6 @@ import paddle.fluid.layers as L from model.ernie import ErnieModel from optimization import optimization -import tokenization import utils.data from propeller import log @@ -121,7 +120,7 @@ class SequenceLabelErnieModel(propeller.train.Model): def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_size, max_seqlen, is_train): label_map = {v: i for i, v in enumerate(label_list)} no_entity_id = label_map['O'] - delimiter = '' + delimiter = b'' def read_bio_data(filename): ds = propeller.data.Dataset.from_file(filename) @@ -132,10 +131,10 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ while 1: line = next(iterator) cols = line.rstrip(b'\n').split(b'\t') + tokens = cols[0].split(delimiter) + labels = cols[1].split(delimiter) if len(cols) != 2: continue - tokens = tokenization.convert_to_unicode(cols[0]).split(delimiter) - labels = tokenization.convert_to_unicode(cols[1]).split(delimiter) if len(tokens) != len(labels) or len(tokens) == 0: continue yield [tokens, labels] @@ -151,7 +150,8 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ ret_tokens = [] ret_labels = [] for token, label in zip(tokens, labels): - sub_token = tokenizer.tokenize(token) + sub_token = tokenizer(token) + label = label.decode('utf8') if len(sub_token) == 0: continue ret_tokens.extend(sub_token) @@ -179,7 +179,7 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ labels = labels[: max_seqlen - 2] tokens = ['[CLS]'] + tokens + ['[SEP]'] - token_ids = tokenizer.convert_tokens_to_ids(tokens) + token_ids = [vocab[t] for t in tokens] label_ids = [no_entity_id] + [label_map[x] for x in labels] + [no_entity_id] token_type_ids = [0] * len(token_ids) input_seqlen = len(token_ids) @@ -211,7 +211,7 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seqlen): - delimiter = '' + delimiter = b'' def stdin_gen(): if six.PY3: @@ -232,9 +232,9 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql while 1: line, = next(iterator) cols = line.rstrip(b'\n').split(b'\t') + tokens = cols[0].split(delimiter) if len(cols) != 1: continue - tokens = tokenization.convert_to_unicode(cols[0]).split(delimiter) if len(tokens) == 0: continue yield tokens, @@ -247,7 +247,7 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql tokens, = next(iterator) ret_tokens = [] for token in tokens: - sub_token = tokenizer.tokenize(token) + sub_token = tokenizer(token) if len(sub_token) == 0: continue ret_tokens.extend(sub_token) @@ -266,7 +266,7 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql tokens = tokens[: max_seqlen - 2] tokens = ['[CLS]'] + tokens + ['[SEP]'] - token_ids = tokenizer.convert_tokens_to_ids(tokens) + token_ids = [vocab[t] for t in tokens] token_type_ids = [0] * len(token_ids) input_seqlen = len(token_ids) @@ -296,13 +296,15 @@ if __name__ == '__main__': parser.add_argument('--data_dir', type=str, required=True) parser.add_argument('--vocab_file', type=str, required=True) parser.add_argument('--do_predict', action='store_true') + parser.add_argument('--use_sentence_piece_vocab', action='store_true') parser.add_argument('--warm_start_from', type=str) args = parser.parse_args() run_config = propeller.parse_runconfig(args) hparams = propeller.parse_hparam(args) - tokenizer = tokenization.FullTokenizer(args.vocab_file) - vocab = tokenizer.vocab + + vocab = {j.strip().split('\t')[0]: i for i, j in enumerate(open(args.vocab_file, 'r', encoding='utf8'))} + tokenizer = utils.data.CharTokenizer(vocab, sentencepiece_style_vocab=args.use_sentence_piece_vocab) sep_id = vocab['[SEP]'] cls_id = vocab['[CLS]'] unk_id = vocab['[UNK]'] @@ -358,7 +360,7 @@ if __name__ == '__main__': from_dir=warm_start_dir ) - best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1']) + best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1']) propeller.train.train_and_eval( model_class_or_model_fn=SequenceLabelErnieModel, params=hparams, @@ -387,7 +389,6 @@ if __name__ == '__main__': predict_ds.data_types = types rev_label_map = {i: v for i, v in enumerate(label_list)} - best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1']) learner = propeller.Learner(SequenceLabelErnieModel, run_config, hparams) for pred, _ in learner.predict(predict_ds, ckpt=-1): pred_str = ' '.join([rev_label_map[idx] for idx in np.argmax(pred, 1).tolist()]) diff --git a/example/finetune_ranker.py b/example/finetune_ranker.py index db40b26a56e150331f9fff35daf49e582d6b69f1..bb0661ece976cd1e094499bf13734ee3fb43b1fe 100644 --- a/example/finetune_ranker.py +++ b/example/finetune_ranker.py @@ -146,6 +146,7 @@ if __name__ == '__main__': parser.add_argument('--data_dir', type=str, required=True) parser.add_argument('--warm_start_from', type=str) parser.add_argument('--sentence_piece_model', type=str, default=None) + parser.add_argument('--word_dict', type=str, default=None) args = parser.parse_args() run_config = propeller.parse_runconfig(args) hparams = propeller.parse_hparam(args) @@ -157,7 +158,9 @@ if __name__ == '__main__': unk_id = vocab['[UNK]'] if args.sentence_piece_model is not None: - tokenizer = utils.data.JBSPTokenizer(args.sentence_piece_model, jb=True, lower=True) + if args.word_dict is None: + raise ValueError('--word_dict no specified in subword Model') + tokenizer = utils.data.WSSPTokenizer(args.sentence_piece_model, args.word_dict, ws=True, lower=True) else: tokenizer = utils.data.CharTokenizer(vocab.keys()) @@ -218,7 +221,7 @@ if __name__ == '__main__': from_dir=warm_start_dir ) - best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1']) + best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1']) propeller.train_and_eval( model_class_or_model_fn=RankingErnieModel, params=hparams, @@ -258,6 +261,7 @@ if __name__ == '__main__': est = propeller.Learner(RankingErnieModel, run_config, hparams) for qid, res in est.predict(predict_ds, ckpt=-1): print('%d\t%d\t%.5f\t%.5f' % (qid[0], np.argmax(res), res[0], res[1])) + #for i in predict_ds: # sen = i[0] # for ss in np.squeeze(sen): diff --git a/requirements.txt b/requirements.txt index 84aaf34f1604759b726911842791613ea3601ea4..2e08a150940fd1d877d0da9fd1a023ce31afd8f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ scikit-learn==0.20.3 scipy==1.2.1 six==1.11.0 sklearn==0.0 +sentencepiece==0.1.8 +paddlepaddle-gpu==1.5.2.post107