diff --git a/demo/sequence_labeling/README.md b/demo/sequence_labeling/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fda17c32c7d4edf234392659545a57998fb0dfd4 --- /dev/null +++ b/demo/sequence_labeling/README.md @@ -0,0 +1,156 @@ +# PaddleHub Transformer模型fine-tune序列标注(动态图) + +在2017年之前,工业界和学术界对NLP文本处理依赖于序列模型[Recurrent Neural Network (RNN)](https://baike.baidu.com/item/%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C/23199490?fromtitle=RNN&fromid=5707183&fr=aladdin). + +![](http://colah.github.io/posts/2015-09-NN-Types-FP/img/RNN-general.png) + +近年来随着深度学习的发展,模型参数数量飞速增长,为了训练这些参数,需要更大的数据集来避免过拟合。然而,对于大部分NLP任务来说,构建大规模的标注数据集成本过高,非常困难,特别是对于句法和语义相关的任务。相比之下,大规模的未标注语料库的构建则相对容易。最近的研究表明,基于大规模未标注语料库的预训练模型(Pretrained Models, PTM) 能够习得通用的语言表示,将预训练模型Fine-tune到下游任务,能够获得出色的表现。另外,预训练模型能够避免从零开始训练模型。 + +![](https://ai-studio-static-online.cdn.bcebos.com/327f44ff3ed24493adca5ddc4dc24bf61eebe67c84a6492f872406f464fde91e) + + +本示例将展示如何使用PaddleHub Transformer模型(如 ERNIE、BERT、RoBERTa等模型)Module 以动态图方式fine-tune并完成预测任务。 + +## 如何开始Fine-tune + + +我们以微软亚洲研究院发布的中文实体识别数据集MSRA-NER为示例数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证。通过如下命令,即可启动训练。 + +```shell +# 设置使用的GPU卡号 +export CUDA_VISIBLE_DEVICES=0 +python train.py +``` + + +## 代码步骤 + +使用PaddleHub Fine-tune API进行Fine-tune可以分为4个步骤。 + +### Step1: 选择模型 +```python +import paddlehub as hub + +model = hub.Module(name='ernie_tiny', version='2.0.1', task='token-cls') +``` + +其中,参数: + +* `name`:模型名称,可以选择`ernie`,`ernie_tiny`,`bert-base-cased`, `bert-base-chinese`, `roberta-wwm-ext`,`roberta-wwm-ext-large`等。 +* `version`:module版本号 +* `task`:fine-tune任务。此处为`token-cls`,表示序列标注任务。 + +通过以上的一行代码,`model`初始化为一个适用于序列标注任务的模型,为ERNIE Tiny的预训练模型后拼接上一个输出token共享的全连接网络(Full Connected)。 +![](https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=224484727,3049769188&fm=15&gp=0.jpg) + +以上图片来自于:https://arxiv.org/pdf/1810.04805.pdf + +### Step2: 下载并加载数据集 + +```python +train_dataset = hub.datasets.MSRA_NER( + tokenizer=model.get_tokenizer(tokenize_chinese_chars=True), max_seq_len=50, mode='train') +dev_dataset = hub.datasets.MSRA_NER( + tokenizer=model.get_tokenizer(tokenize_chinese_chars=True), max_seq_len=50, mode='dev') +``` + +* `tokenizer`:表示该module所需用到的tokenizer,其将对输入文本完成切词,并转化成module运行所需模型输入格式。 +* `mode`:选择数据模式,可选项有 `train`, `test`, `val`, 默认为`train`。 +* `max_seq_len`:ERNIE/BERT模型使用的最大序列长度,若出现显存不足,请适当调低这一参数。 + +预训练模型ERNIE对中文数据的处理是以字为单位,tokenizer作用为将原始输入文本转化成模型model可以接受的输入数据形式。 PaddleHub 2.0中的各种预训练模型已经内置了相应的tokenizer,可以通过`model.get_tokenizer`方法获取。 + +![](https://bj.bcebos.com/paddlehub/paddlehub-img/ernie_network_1.png) +![](https://bj.bcebos.com/paddlehub/paddlehub-img/ernie_network_2.png) + +### Step3: 选择优化策略和运行配置 + +```python +optimizer = paddle.optimizer.AdamW(learning_rate=5e-5, parameters=model.parameters()) +trainer = hub.Trainer(model, optimizer, checkpoint_dir='test_ernie_token_cls', use_gpu=False) + +trainer.train(train_dataset, epochs=3, batch_size=32, eval_dataset=dev_dataset) + +# 在测试集上评估当前训练模型 +trainer.evaluate(test_dataset, batch_size=32) +``` + +#### 优化策略 + +Paddle2.0-rc提供了多种优化器选择,如`SGD`, `Adam`, `Adamax`, `AdamW`等,详细参见[策略](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc/api/paddle/optimizer/optimizer/Optimizer_cn.html)。 + +其中`AdamW`: + +* `learning_rate`: 全局学习率。默认为1e-3; +* `parameters`: 待优化模型参数。 + +#### 运行配置 + +`Trainer` 主要控制Fine-tune的训练,包含以下可控制的参数: + +* `model`: 被优化模型; +* `optimizer`: 优化器选择; +* `use_gpu`: 是否使用GPU训练,默认为False; +* `use_vdl`: 是否使用vdl可视化训练过程; +* `checkpoint_dir`: 保存模型参数的地址; +* `compare_metrics`: 保存最优模型的衡量指标; + +`trainer.train` 主要控制具体的训练过程,包含以下可控制的参数: + +* `train_dataset`: 训练时所用的数据集; +* `epochs`: 训练轮数; +* `batch_size`: 训练的批大小,如果使用GPU,请根据实际情况调整batch_size; +* `num_workers`: workers的数量,默认为0; +* `eval_dataset`: 验证集; +* `log_interval`: 打印日志的间隔, 单位为执行批训练的次数。 +* `save_interval`: 保存模型的间隔频次,单位为执行训练的轮数。 + +## 模型预测 + +当完成Fine-tune后,Fine-tune过程在验证集上表现最优的模型会被保存在`${CHECKPOINT_DIR}/best_model`目录下,其中`${CHECKPOINT_DIR}`目录为Fine-tune时所选择的保存checkpoint的目录。 + +我们以以下数据为待预测数据,使用该模型来进行预测 + +```text +去年十二月二十四日,市委书记张敬涛召集县市主要负责同志研究信访工作时,提出三问:『假如上访群众是我们的父母姐妹,你会用什么样的感情对待他们? +新华社北京5月7日电国务院副总理李岚清今天在中南海会见了美国前商务部长芭芭拉·弗兰克林。 +根据测算,海卫1表面温度已经从“旅行者”号探测器1989年造访时的零下236摄氏度上升到零下234摄氏度。 +华裔作家韩素音女士曾三次到大足,称“大足石窟是一座未被开发的金矿”。 +``` + +```python +import paddlehub as hub + +split_char = "\002" +label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] +text_a = [ + '去年十二月二十四日,市委书记张敬涛召集县市主要负责同志研究信访工作时,提出三问:『假如上访群众是我们的父母姐妹,你会用什么样的感情对待他们?', + '新华社北京5月7日电国务院副总理李岚清今天在中南海会见了美国前商务部长芭芭拉·弗兰克林。', + '根据测算,海卫1表面温度已经从“旅行者”号探测器1989年造访时的零下236摄氏度上升到零下234摄氏度。', + '华裔作家韩素音女士曾三次到大足,称“大足石窟是一座未被开发的金矿”。', +] +data = [[split_char.join(text)] for text in text_a] +label_map = { + idx: label for idx, label in enumerate(label_list) +} + +model = hub.Module( + name='ernie_tiny', + version='2.0.1', + task='token_cls', + load_checkpoint='./token_cls_save_dir/best_model/model.pdparams', + label_map=label_map, +) + +results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False) +for idx, text in enumerate(text_a): + print(f'Data: {text} \t Lable: {", ".join(results[idx][1:len(text)+1])}') +``` + +参数配置正确后,请执行脚本`python predict.py`, 加载模型具体可参见[加载](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc/api/paddle/framework/io/load_cn.html#load)。 + +### 依赖 + +paddlepaddle >= 2.0.0rc + +paddlehub >= 2.0.0 diff --git a/demo/sequence_labeling/predict.py b/demo/sequence_labeling/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..13a5e5059449c29d3371c9ef52ca55fbc6901f6d --- /dev/null +++ b/demo/sequence_labeling/predict.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddlehub as hub + +if __name__ == '__main__': + split_char = "\002" + label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] + text_a = [ + '去年十二月二十四日,市委书记张敬涛召集县市主要负责同志研究信访工作时,提出三问:『假如上访群众是我们的父母姐妹,你会用什么样的感情对待他们?', + '新华社北京5月7日电国务院副总理李岚清今天在中南海会见了美国前商务部长芭芭拉·弗兰克林。', + '根据测算,海卫1表面温度已经从“旅行者”号探测器1989年造访时的零下236摄氏度上升到零下234摄氏度。', + '华裔作家韩素音女士曾三次到大足,称“大足石窟是一座未被开发的金矿”。', + ] + data = [[split_char.join(text)] for text in text_a] + label_map = { + idx: label for idx, label in enumerate(label_list) + } + + model = hub.Module( + name='ernie_tiny', + version='2.0.1', + task='token-cls', + load_checkpoint='./token_cls_save_dir/best/model.pdparams', + label_map=label_map, + ) + + results = model.predict(data=data, max_seq_len=128, batch_size=1, use_gpu=True) + for idx, text in enumerate(text_a): + print(f'Text:\n{text} \nLable: \n{", ".join(results[idx][1:len(text)+1])} \n') + diff --git a/demo/sequence_labeling/train.py b/demo/sequence_labeling/train.py new file mode 100644 index 0000000000000000000000000000000000000000..43a81fb48c6f595cf43102a26010e96ee20fcfff --- /dev/null +++ b/demo/sequence_labeling/train.py @@ -0,0 +1,45 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddlehub as hub + +if __name__ == '__main__': + label_list = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] + label_map = { + idx: label for idx, label in enumerate(label_list) + } + model = hub.Module( + name='ernie_tiny', + version='2.0.1', + task='token-cls', + label_map=label_map, + ) + + train_dataset = hub.datasets.MSRA_NER( + tokenizer=model.get_tokenizer(), + max_seq_len=128, + mode='train' + ) + + dev_dataset = hub.datasets.MSRA_NER( + tokenizer=model.get_tokenizer(), + max_seq_len=50, + mode='dev' + ) + + optimizer = paddle.optimizer.AdamW(learning_rate=5e-5, parameters=model.parameters()) + trainer = hub.Trainer(model, optimizer, checkpoint_dir='token_cls_save_dir', use_gpu=True) + + trainer.train(train_dataset, epochs=3, batch_size=32, eval_dataset=dev_dataset, save_interval=1) diff --git a/demo/text_classification/README.md b/demo/text_classification/README.md index 9f52bb97d137e42881d71b8f090e853ffef0d6b9..6d75757d42c752f79640925648a2ce189c4c1454 100644 --- a/demo/text_classification/README.md +++ b/demo/text_classification/README.md @@ -1,6 +1,15 @@ # PaddleHub Transformer模型fine-tune文本分类(动态图) -本示例将展示如何使用PaddleHub Transformer模型(如 ERNIE、BERT、RoBERTa等模型)module 以动态图方式fine-tune并完成预测任务。 +在2017年之前,工业界和学术界对NLP文本处理依赖于序列模型[Recurrent Neural Network (RNN)](https://baike.baidu.com/item/%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C/23199490?fromtitle=RNN&fromid=5707183&fr=aladdin). + +![](http://colah.github.io/posts/2015-09-NN-Types-FP/img/RNN-general.png) + +近年来随着深度学习的发展,模型参数数量飞速增长,为了训练这些参数,需要更大的数据集来避免过拟合。然而,对于大部分NLP任务来说,构建大规模的标注数据集成本过高,非常困难,特别是对于句法和语义相关的任务。相比之下,大规模的未标注语料库的构建则相对容易。最近的研究表明,基于大规模未标注语料库的预训练模型(Pretrained Models, PTM) 能够习得通用的语言表示,将预训练模型Fine-tune到下游任务,能够获得出色的表现。另外,预训练模型能够避免从零开始训练模型。 + +![](https://ai-studio-static-online.cdn.bcebos.com/327f44ff3ed24493adca5ddc4dc24bf61eebe67c84a6492f872406f464fde91e) + + +本示例将展示如何使用PaddleHub Transformer模型(如 ERNIE、BERT、RoBERTa等模型)Module 以动态图方式fine-tune并完成预测任务。 ## 如何开始Fine-tune @@ -22,14 +31,19 @@ python train.py ```python import paddlehub as hub -model = hub.Module(name='ernie_tiny', version='2.0.0', task='sequence_classification') +model = hub.Module(name='ernie_tiny', version='2.0.1', task='seq-cls') ``` 其中,参数: -* `name`:模型名称,可以选择`ernie`,`ernie-tiny`,`bert_chinese_L-12_H-768_A-12`,`chinese-roberta-wwm-ext`,`chinese-roberta-wwm-ext-large`等。 +* `name`:模型名称,可以选择`ernie`,`ernie_tiny`,`bert-base-cased`, `bert-base-chinese`, `roberta-wwm-ext`,`roberta-wwm-ext-large`等。 * `version`:module版本号 -* `task`:fine-tune任务。此处为`sequence_classification`,表示文本分类任务。 +* `task`:fine-tune任务。此处为`seq-cls`,表示文本分类任务。 + +通过以上的一行代码,`model`初始化为一个适用于文本分类任务的模型,为ERNIE Tiny的预训练模型后拼接上一个全连接网络(Full Connected)。 +![](https://ai-studio-static-online.cdn.bcebos.com/f9e1bf9d56c6412d939960f2e3767c2f13b93eab30554d738b137ab2b98e328c) + +以上图片来自于:https://arxiv.org/pdf/1810.04805.pdf ### Step2: 下载并加载数据集 @@ -44,6 +58,11 @@ dev_dataset = hub.datasets.ChnSentiCorp( * `mode`:选择数据模式,可选项有 `train`, `test`, `val`, 默认为`train`。 * `max_seq_len`:ERNIE/BERT模型使用的最大序列长度,若出现显存不足,请适当调低这一参数。 +预训练模型ERNIE对中文数据的处理是以字为单位,tokenizer作用为将原始输入文本转化成模型model可以接受的输入数据形式。 PaddleHub 2.0中的各种预训练模型已经内置了相应的tokenizer,可以通过`model.get_tokenizer`方法获取。 + +![](https://bj.bcebos.com/paddlehub/paddlehub-img/ernie_network_1.png) +![](https://bj.bcebos.com/paddlehub/paddlehub-img/ernie_network_2.png) + ### Step3: 选择优化策略和运行配置 ```python @@ -110,7 +129,7 @@ label_map = {0: 'negative', 1: 'positive'} model = hub.Module( directory='/mnt/zhangxuefei/program-paddle/PaddleHub/modules/text/language_model/ernie_tiny', version='2.0.0', - task='sequence_classification', + task='seq-cls', load_checkpoint='./test_ernie_text_cls/best_model/model.pdparams', label_map=label_map) results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False) diff --git a/demo/text_classification/predict.py b/demo/text_classification/predict.py index d70e9b37b7ed2eb7db294e390fcab42f748450d6..ad7721cafcc5292e02b3329eaa5510ec829b9db0 100644 --- a/demo/text_classification/predict.py +++ b/demo/text_classification/predict.py @@ -24,8 +24,8 @@ if __name__ == '__main__': model = hub.Module( name='ernie_tiny', - version='2.0.0', - task='sequence_classification', + version='2.0.1', + task='seq-cls', load_checkpoint='./test_ernie_text_cls/best_model/model.pdparams', label_map=label_map) results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False) diff --git a/demo/text_classification/train.py b/demo/text_classification/train.py index 1f2084f362cd6be86e2f0d7de1efe9bae6ed9bd0..3f1ec858b2adaca50238bec810cc3e75dedad51b 100644 --- a/demo/text_classification/train.py +++ b/demo/text_classification/train.py @@ -15,7 +15,7 @@ import paddle import paddlehub as hub if __name__ == '__main__': - model = hub.Module(name='ernie_tiny', version='2.0.0', task='sequence_classification') + model = hub.Module(name='ernie_tiny', version='2.0.1', task='seq-cls') train_dataset = hub.datasets.ChnSentiCorp( tokenizer=model.get_tokenizer(tokenize_chinese_chars=True), max_seq_len=128, mode='train') diff --git a/modules/text/language_model/ernie_tiny/README.md b/modules/text/language_model/ernie_tiny/README.md index 4dd0ef322388cd870cea245ef17ca97945f72953..899c500e9a826ee892f05af7a973a6f3e659e286 100644 --- a/modules/text/language_model/ernie_tiny/README.md +++ b/modules/text/language_model/ernie_tiny/README.md @@ -1,5 +1,5 @@ ```shell -$ hub install ernie_tiny==2.0.0 +$ hub install ernie_tiny==2.0.1 ``` ## 在线体验 AI Studio 快速体验 @@ -30,7 +30,7 @@ def __init__( **参数** -* `task`: 任务名称,可为`sequence_classification`。 +* `task`: 任务名称,可为`seq-cls`(文本分类任务,原来的`sequence_classification`在未来会被弃用)或`token-cls`(序列标注任务)。 * `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。 * `label_map`:预测时的类别映射表。 @@ -77,15 +77,15 @@ def get_embedding( import paddlehub as hub data = [ - '这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般', - '怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片', - '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。', + ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般'], + ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片'], + ['作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'], ] label_map = {0: 'negative', 1: 'positive'} model = hub.Module( name='ernie_tiny', - version='2.0.0', + version='2.0.1', task='sequence_classification', load_checkpoint='/path/to/parameters', label_map=label_map) @@ -161,3 +161,7 @@ paddlehub >= 2.0.0 * 2.0.0 全面升级动态图版本,接口有所变化 + +* 2.0.1 + + 任务名称调整,增加序列标注任务`token-cls` diff --git a/modules/text/language_model/ernie_tiny/module.py b/modules/text/language_model/ernie_tiny/module.py index a05f764a1c899c0bf7b2bebc868e03895c162719..77b954c8ace43a247427d147be64e760576b0929 100644 --- a/modules/text/language_model/ernie_tiny/module.py +++ b/modules/text/language_model/ernie_tiny/module.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union, Tuple import os from paddle.dataset.common import DATA_HOME @@ -20,19 +19,21 @@ import paddle.nn as nn import paddle.nn.functional as F from paddlehub import ErnieTinyTokenizer -from paddlehub.module.modeling_ernie import ErnieModel, ErnieForSequenceClassification -from paddlehub.module.module import moduleinfo, serving +from paddlenlp.transformers.ernie.modeling import ErnieModel, ErnieForSequenceClassification, ErnieForTokenClassification +from paddlehub.module.module import moduleinfo +from paddlehub.module.nlp_module import TransformerModule from paddlehub.utils.log import logger from paddlehub.utils.utils import download @moduleinfo( name="ernie_tiny", - version="2.0.0", + version="2.0.1", summary="Baidu's ERNIE-tiny, Enhanced Representation through kNowledge IntEgration, tiny version, max_seq_len=512", author="paddlepaddle", author_email="", - type="nlp/semantic_model") + type="nlp/semantic_model", + meta=TransformerModule) class ErnieTiny(nn.Layer): """ Ernie model @@ -43,17 +44,34 @@ class ErnieTiny(nn.Layer): task=None, load_checkpoint=None, label_map=None, + num_classes=2, + **kwargs, ): super(ErnieTiny, self).__init__() - # TODO(zhangxuefei): add token_classification task + if label_map: + self.num_classes = len(label_map) + else: + self.num_classes = num_classes + if task == 'sequence_classification': - self.model = ErnieForSequenceClassification.from_pretrained(pretrained_model_name_or_path='ernie_tiny') + task = 'seq-cls' + logger.warning( + "current task name 'sequence_classification' was renamed to 'seq-cls', " + "'sequence_classification' has been deprecated and will be removed the future.", + ) + if task == 'seq-cls': + self.model = ErnieForSequenceClassification.from_pretrained(pretrained_model_name_or_path='ernie-tiny', num_classes=self.num_classes, **kwargs) + self.criterion = paddle.nn.loss.CrossEntropyLoss() + self.metric = paddle.metric.Accuracy() + elif task == 'token-cls': + self.model = ErnieForTokenClassification.from_pretrained(pretrained_model_name_or_path='ernie-tiny', num_classes=self.num_classes, **kwargs) self.criterion = paddle.nn.loss.CrossEntropyLoss() - self.metric = paddle.metric.Accuracy(name='acc_accumulation') + self.metric = paddle.metric.Accuracy() elif task is None: - self.model = ErnieModel.from_pretrained(pretrained_model_name_or_path='ernie_tiny') + self.model = ErnieModel.from_pretrained(pretrained_model_name_or_path='ernie-tiny', **kwargs) else: - raise RuntimeError("Unknown task %s, task should be sequence_classification" % task) + raise RuntimeError("Unknown task {}, task should be one in {}".format( + task, self._tasks_supported)) self.task = task self.label_map = label_map @@ -65,7 +83,7 @@ class ErnieTiny(nn.Layer): def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, labels=None): result = self.model(input_ids, token_type_ids, position_ids, attention_mask) - if self.task is not None: + if self.task == 'seq-cls': logits = result probs = F.softmax(logits, axis=1) if labels is not None: @@ -74,6 +92,16 @@ class ErnieTiny(nn.Layer): acc = self.metric.update(correct) return probs, loss, acc return probs + elif self.task == 'token-cls': + logits = result + token_level_probs = F.softmax(logits, axis=2) + if labels is not None: + labels = paddle.to_tensor(labels).unsqueeze(-1) + loss = self.criterion(logits, labels) + correct = self.metric.compute(token_level_probs, labels) + acc = self.metric.update(correct) + return token_level_probs, loss, acc + return token_level_probs else: sequence_output, pooled_output = result return sequence_output, pooled_output @@ -108,122 +136,3 @@ class ErnieTiny(nn.Layer): download(url, os.path.join(DATA_HOME, 'ernie_tiny')) return ErnieTinyTokenizer(self.get_vocab_path(), spm_path, word_dict_path) - - def training_step(self, batch: List[paddle.Tensor], batch_idx: int): - """ - One step for training, which should be called as forward computation. - Args: - batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed, - such as input_ids, sent_ids, pos_ids, input_mask and labels. - batch_idx(int): The index of batch. - Returns: - results(:obj: Dict) : The model outputs, such as loss and metrics. - """ - predictions, avg_loss, acc = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2]) - return {'loss': avg_loss, 'metrics': {'acc': acc}} - - def validation_step(self, batch: List[paddle.Tensor], batch_idx: int): - """ - One step for validation, which should be called as forward computation. - Args: - batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed, - such as input_ids, sent_ids, pos_ids, input_mask and labels. - batch_idx(int): The index of batch. - Returns: - results(:obj: Dict) : The model outputs, such as metrics. - """ - predictions, avg_loss, acc = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2]) - return {'metrics': {'acc': acc}} - - def predict(self, data, max_seq_len=128, batch_size=1, use_gpu=False): - """ - Predicts the data labels. - - Args: - data (obj:`List(str)`): The processed data whose each element is the raw text. - max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`): - If set to a number, will limit the total sequence returned so that it has a maximum length. - batch_size(obj:`int`, defaults to 1): The number of batch. - use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not. - - Returns: - results(obj:`list`): All the predictions labels. - """ - # TODO(zhangxuefei): add task token_classification task predict. - if self.task not in ['sequence_classification']: - raise RuntimeError("The predict method is for sequence_classification task, but got task %s." % self.task) - - paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') - tokenizer = self.get_tokenizer() - - examples = [] - for text in data: - if len(text) == 1: - encoded_inputs = tokenizer.encode(text[0], text_pair=None, max_seq_len=max_seq_len) - elif len(text) == 2: - encoded_inputs = tokenizer.encode(text[0], text_pair=text[1], max_seq_len=max_seq_len) - else: - raise RuntimeError( - 'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text)) - examples.append((encoded_inputs['input_ids'], encoded_inputs['segment_ids'])) - - def _batchify_fn(batch): - input_ids = [entry[0] for entry in batch] - segment_ids = [entry[1] for entry in batch] - return input_ids, segment_ids - - # Seperates data into some batches. - batches = [] - one_batch = [] - for example in examples: - one_batch.append(example) - if len(one_batch) == batch_size: - batches.append(one_batch) - one_batch = [] - if one_batch: - # The last batch whose size is less than the config batch_size setting. - batches.append(one_batch) - - results = [] - self.eval() - for batch in batches: - input_ids, segment_ids = _batchify_fn(batch) - input_ids = paddle.to_tensor(input_ids) - segment_ids = paddle.to_tensor(segment_ids) - - # TODO(zhangxuefei): add task token_classification postprocess after prediction. - if self.task == 'sequence_classification': - probs = self(input_ids, segment_ids) - idx = paddle.argmax(probs, axis=1).numpy() - idx = idx.tolist() - labels = [self.label_map[i] for i in idx] - results.extend(labels) - - return results - - @serving - def get_embedding(self, texts, use_gpu=False): - if self.task is not None: - raise RuntimeError("The get_embedding method is only valid when task is None, but got task %s" % self.task) - - paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') - - tokenizer = self.get_tokenizer() - results = [] - for text in texts: - if len(text) == 1: - encoded_inputs = tokenizer.encode(text[0], text_pair=None, pad_to_max_seq_len=False) - elif len(text) == 2: - encoded_inputs = tokenizer.encode(text[0], text_pair=text[1], pad_to_max_seq_len=False) - else: - raise RuntimeError( - 'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text)) - - input_ids = paddle.to_tensor(encoded_inputs['input_ids']).unsqueeze(0) - segment_ids = paddle.to_tensor(encoded_inputs['segment_ids']).unsqueeze(0) - sequence_output, pooled_output = self(input_ids, segment_ids) - - sequence_output = sequence_output.squeeze(0) - pooled_output = pooled_output.squeeze(0) - results.append((sequence_output.numpy().tolist(), pooled_output.numpy().tolist())) - return results diff --git a/paddlehub/datasets/__init__.py b/paddlehub/datasets/__init__.py index d88b816f1729916aac86b88fbacfdfa8ec39ac14..12447233a967204773d59bb19e6088d0bb47035a 100644 --- a/paddlehub/datasets/__init__.py +++ b/paddlehub/datasets/__init__.py @@ -16,3 +16,4 @@ from paddlehub.datasets.canvas import Canvas from paddlehub.datasets.flowers import Flowers from paddlehub.datasets.minicoco import MiniCOCO from paddlehub.datasets.chnsenticorp import ChnSentiCorp +from paddlehub.datasets.msra_ner import MSRA_NER diff --git a/paddlehub/datasets/base_nlp_dataset.py b/paddlehub/datasets/base_nlp_dataset.py index 1259cabec5def11869604459156f083331fe3238..acca7b8c5e8393f6fe03214a3d7a557d4670d0df 100644 --- a/paddlehub/datasets/base_nlp_dataset.py +++ b/paddlehub/datasets/base_nlp_dataset.py @@ -17,7 +17,7 @@ import io import os import numpy as np -import paddle.fluid as fluid +import paddle from paddlehub.env import DATA_HOME from paddlehub.text.bert_tokenizer import BertTokenizer @@ -152,7 +152,7 @@ class BaseNLPDataset(object): return self.label_list -class TextClassificationDataset(BaseNLPDataset, fluid.io.Dataset): +class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset): """ The dataset class which is fit for all datatset of text classification. """ @@ -258,3 +258,138 @@ class TextClassificationDataset(BaseNLPDataset, fluid.io.Dataset): def __len__(self): return len(self.records) + + +class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): + def __init__(self, + base_path: str, + tokenizer: Union[BertTokenizer, CustomTokenizer], + max_seq_len: int = 128, + mode: str = "train", + data_file: str = None, + label_file: str = None, + label_list: list = None, + split_char="\002", + no_entity_label="O", + is_file_with_header: bool = False): + super(SeqLabelingDataset, self).__init__( + base_path=base_path, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + mode=mode, + data_file=data_file, + label_file=label_file, + label_list=label_list) + + self.no_entity_label = no_entity_label + self.split_char = split_char + + self.examples = self._read_file(self.data_file, is_file_with_header) + self.records = self._convert_examples_to_records(self.examples) + + def _read_file(self, input_file, is_file_with_header: bool = False) -> List[InputExample]: + """Reads a tab separated value file.""" + if not os.path.exists(input_file): + raise RuntimeError("The file {} is not found.".format(input_file)) + else: + with io.open(input_file, "r", encoding="UTF-8") as f: + reader = csv.reader(f, delimiter="\t", quotechar=None) + examples = [] + seq_id = 0 + header = next(reader) if is_file_with_header else None + for line in reader: + example = InputExample(guid=seq_id, label=line[1], text_a=line[0]) + seq_id += 1 + examples.append(example) + return examples + + def _convert_examples_to_records(self, examples: List[InputExample]) -> List[dict]: + """ + Returns a list[dict] including all the input information what the model need. + Args: + examples (list): the data examples, returned by _read_file. + Returns: + a list with all the examples record. + """ + records = [] + for example in examples: + tokens, labels = self._reseg_token_label( + tokens=example.text_a.split(self.split_char), + labels=example.label.split(self.split_char)) + record = self.tokenizer.encode( + text=tokens, max_seq_len=self.max_seq_len) + # CustomTokenizer will tokenize the text firstly and then lookup words in the vocab + # When all words are not found in the vocab, the text will be dropped. + if not record: + logger.info( + "The text %s has been dropped as it has no words in the vocab after tokenization." + % example.text_a) + continue + if labels: + record["label"] = [] + tokens_with_specical_token = self.tokenizer.decode( + record, only_convert_to_tokens=True) + tokens_index = 0 + for token in tokens_with_specical_token: + if tokens_index < len( + tokens) and token == tokens[tokens_index]: + record["label"].append( + self.label_list.index(labels[tokens_index])) + tokens_index += 1 + else: + record["label"].append( + self.label_list.index(self.no_entity_label)) + records.append(record) + return records + + def _reseg_token_label( + self, tokens: List[str], labels: List[str] = None) -> Tuple[List[str], List[str]] or List[str]: + if labels: + if len(tokens) != len(labels): + raise ValueError( + "The length of tokens must be same with labels") + ret_tokens = [] + ret_labels = [] + for token, label in zip(tokens, labels): + sub_token = self.tokenizer.tokenize(token) + if len(sub_token) == 0: + continue + ret_tokens.extend(sub_token) + ret_labels.append(label) + if len(sub_token) < 2: + continue + sub_label = label + if label.startswith("B-"): + sub_label = "I-" + label[2:] + ret_labels.extend([sub_label] * (len(sub_token) - 1)) + + if len(ret_tokens) != len(ret_labels): + raise ValueError( + "The length of ret_tokens can't match with labels") + return ret_tokens, ret_labels + else: + ret_tokens = [] + for token in tokens: + sub_token = self.tokenizer.tokenize(token) + if len(sub_token) == 0: + continue + ret_tokens.extend(sub_token) + if len(sub_token) < 2: + continue + return ret_tokens, None + + def __getitem__(self, idx): + record = self.records[idx] + if 'label' in record.keys(): + if isinstance(self.tokenizer, BertTokenizer): + return np.array(record['input_ids']), np.array(record['segment_ids']), np.array(record['label']) + else: # TODO(chenxiaojie): add CustomTokenizer supported + raise NotImplementedError + else: + if isinstance(self.tokenizer, BertTokenizer): + return np.array(record['input_ids']), np.array(record['segment_ids']) + else: # TODO(chenxiaojie): add CustomTokenizer supported + raise NotImplementedError + + def __len__(self): + return len(self.records) diff --git a/paddlehub/datasets/msra_ner.py b/paddlehub/datasets/msra_ner.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f4b2e324a75b28058e2bd9af08d1559cbaade0 --- /dev/null +++ b/paddlehub/datasets/msra_ner.py @@ -0,0 +1,51 @@ +#coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union +import os + +from paddlehub.env import DATA_HOME +from paddlehub.utils.download import download_data +from paddlehub.datasets.base_nlp_dataset import SeqLabelingDataset +from paddlehub.text.bert_tokenizer import BertTokenizer +from paddlehub.text.tokenizer import CustomTokenizer + + +@download_data(url="https://bj.bcebos.com/paddlehub-dataset/msra_ner.tar.gz") +class MSRA_NER(SeqLabelingDataset): + """ + A set of manually annotated Chinese word-segmentation data and + specifications for training and testing a Chinese word-segmentation system + for research purposes. For more information please refer to + https://www.microsoft.com/en-us/download/details.aspx?id=52531 + """ + def __init__(self, tokenizer: Union[BertTokenizer, CustomTokenizer], max_seq_len: int = 128, mode: str = 'train'): + base_path = os.path.join(DATA_HOME, "msra_ner") + if mode == 'train': + data_file = 'train.tsv' + elif mode == 'test': + data_file = 'test.tsv' + else: + data_file = 'dev.tsv' + super().__init__( + base_path=base_path, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + mode=mode, + data_file=data_file, + label_file=None, + label_list=["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"], + is_file_with_header=True, + ) diff --git a/paddlehub/module/nlp_module.py b/paddlehub/module/nlp_module.py index c58e83fd8ae7dff3470acb9336eb4440b2a80d72..638be412de603f3758c4e067d1f5d63a7a794b35 100644 --- a/paddlehub/module/nlp_module.py +++ b/paddlehub/module/nlp_module.py @@ -21,17 +21,20 @@ import io import json import os import six +from typing import List import paddle import paddle.nn as nn from paddle.dataset.common import DATA_HOME from paddle.utils.download import get_path_from_url +from paddlehub.module.module import serving, RunModule, runnable from paddlehub.utils.log import logger __all__ = [ 'PretrainedModel', 'register_base_model', + 'TransformerModule', ] @@ -342,3 +345,136 @@ class PretrainedModel(nn.Layer): # save model file_name = os.path.join(save_directory, list(self.resource_files_names.values())[0]) paddle.save(self.state_dict(), file_name) + + +class EmbeddingServing(object): + @serving + def get_embedding(self, texts, use_gpu=False): + if self.task is not None: + raise RuntimeError("The get_embedding method is only valid when task is None, but got task %s" % self.task) + + paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') + + tokenizer = self.get_tokenizer() + results = [] + for text in texts: + if len(text) == 1: + encoded_inputs = tokenizer.encode(text[0], text_pair=None, pad_to_max_seq_len=False) + elif len(text) == 2: + encoded_inputs = tokenizer.encode(text[0], text_pair=text[1], pad_to_max_seq_len=False) + else: + raise RuntimeError( + 'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text)) + + input_ids = paddle.to_tensor(encoded_inputs['input_ids']).unsqueeze(0) + segment_ids = paddle.to_tensor(encoded_inputs['segment_ids']).unsqueeze(0) + sequence_output, pooled_output = self(input_ids, segment_ids) + + sequence_output = sequence_output.squeeze(0) + pooled_output = pooled_output.squeeze(0) + results.append((sequence_output.numpy().tolist(), pooled_output.numpy().tolist())) + return results + + +class TransformerModule(RunModule, EmbeddingServing): + _tasks_supported = [ + 'seq-cls', + 'token-cls', + ] + + def training_step(self, batch: List[paddle.Tensor], batch_idx: int): + """ + One step for training, which should be called as forward computation. + Args: + batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed, + such as input_ids, sent_ids, pos_ids, input_mask and labels. + batch_idx(int): The index of batch. + Returns: + results(:obj: Dict) : The model outputs, such as loss and metrics. + """ + predictions, avg_loss, acc = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2]) + return {'loss': avg_loss, 'metrics': {'acc': acc}} + + def validation_step(self, batch: List[paddle.Tensor], batch_idx: int): + """ + One step for validation, which should be called as forward computation. + Args: + batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed, + such as input_ids, sent_ids, pos_ids, input_mask and labels. + batch_idx(int): The index of batch. + Returns: + results(:obj: Dict) : The model outputs, such as metrics. + """ + predictions, avg_loss, acc = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2]) + return {'metrics': {'acc': acc}} + + def predict(self, data, max_seq_len=128, batch_size=1, use_gpu=False): + """ + Predicts the data labels. + + Args: + data (obj:`List(str)`): The processed data whose each element is the raw text. + max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`): + If set to a number, will limit the total sequence returned so that it has a maximum length. + batch_size(obj:`int`, defaults to 1): The number of batch. + use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not. + + Returns: + results(obj:`list`): All the predictions labels. + """ + if self.task not in self._tasks_supported: + raise RuntimeError("The predict method supports task in {}, but got task {}.".format( + self._tasks_supported, self.task)) + + paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') + tokenizer = self.get_tokenizer() + + examples = [] + for text in data: + if len(text) == 1: + encoded_inputs = tokenizer.encode(text[0], text_pair=None, max_seq_len=max_seq_len) + elif len(text) == 2: + encoded_inputs = tokenizer.encode(text[0], text_pair=text[1], max_seq_len=max_seq_len) + else: + raise RuntimeError( + 'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text)) + examples.append((encoded_inputs['input_ids'], encoded_inputs['segment_ids'])) + + def _batchify_fn(batch): + input_ids = [entry[0] for entry in batch] + segment_ids = [entry[1] for entry in batch] + return input_ids, segment_ids + + # Seperates data into some batches. + batches = [] + one_batch = [] + for example in examples: + one_batch.append(example) + if len(one_batch) == batch_size: + batches.append(one_batch) + one_batch = [] + if one_batch: + # The last batch whose size is less than the config batch_size setting. + batches.append(one_batch) + + results = [] + self.eval() + for batch in batches: + input_ids, segment_ids = _batchify_fn(batch) + input_ids = paddle.to_tensor(input_ids) + segment_ids = paddle.to_tensor(segment_ids) + + if self.task == 'seq-cls': + probs = self(input_ids, segment_ids) + idx = paddle.argmax(probs, axis=1).numpy() + idx = idx.tolist() + labels = [self.label_map[i] for i in idx] + results.extend(labels) + elif self.task == 'token-cls': + probs = self(input_ids, segment_ids) + batch_ids = paddle.argmax(probs, axis=2).numpy() # (batch_size, max_seq_len) + batch_ids = batch_ids.tolist() + token_labels = [[self.label_map[i] for i in token_ids] for token_ids in batch_ids] + results.extend(token_labels) + + return results diff --git a/requirements.txt b/requirements.txt index 3b39befcdc39698f0039ca38d4ee47c06185d8bb..fe998426235dc0995ccde15e72d7dbf54fadd444 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ tqdm visualdl >= 2.0.0 # gunicorn not support windows gunicorn >= 19.10.0; sys_platform != "win32" +paddlenlp >= 2.0.0b \ No newline at end of file