From f5f9dee4e7f01aae639cb4290476963b9aca462d Mon Sep 17 00:00:00 2001 From: Steffy-zxf <48793257+Steffy-zxf@users.noreply.github.com> Date: Fri, 11 Dec 2020 21:30:29 +0800 Subject: [PATCH] Add Sentence Transformer for text matching and Add readme (#5004) * update docs * add sbert * add readme * update readme * update codes --- .../pretrained_models/README.md | 28 +- .../pretrained_models/train.py | 4 +- .../text_classification/rnn/README.md | 83 ++-- PaddleNLP/examples/text_matching/README.md | 30 +- .../sentence_transformers/README.md | 207 ++++++++++ .../sentence_transformers/model.py | 71 ++++ .../sentence_transformers/predict.py | 227 +++++++++++ .../sentence_transformers/train.py | 370 ++++++++++++++++++ .../examples/text_matching/simnet/README.md | 168 ++++++++ .../examples/text_matching/simnet/predict.py | 2 +- .../examples/text_matching/simnet/train.py | 8 +- 11 files changed, 1130 insertions(+), 68 deletions(-) create mode 100644 PaddleNLP/examples/text_matching/sentence_transformers/README.md create mode 100644 PaddleNLP/examples/text_matching/sentence_transformers/model.py create mode 100644 PaddleNLP/examples/text_matching/sentence_transformers/predict.py create mode 100644 PaddleNLP/examples/text_matching/sentence_transformers/train.py create mode 100644 PaddleNLP/examples/text_matching/simnet/README.md diff --git a/PaddleNLP/examples/text_classification/pretrained_models/README.md b/PaddleNLP/examples/text_classification/pretrained_models/README.md index 3f75f6a2..efdb0f8f 100644 --- a/PaddleNLP/examples/text_classification/pretrained_models/README.md +++ b/PaddleNLP/examples/text_classification/pretrained_models/README.md @@ -11,9 +11,11 @@ 本项目针对中文文本分类问题,开源了一系列模型,供用户可配置地使用: + BERT([Bidirectional Encoder Representations from Transformers](https://arxiv.org/abs/1810.04805))中文模型,简写`bert-base-chinese`, 其由12层Transformer网络组成。 -+ ERNIE([Enhanced Representation through Knowledge Integration](https://arxiv.org/pdf/1904.09223)),支持ERNIE 1.0中文模型(简写`ernie`)和ERNIE Tiny中文模型(简写`ernie_tiny`)。 ++ ERNIE([Enhanced Representation through Knowledge Integration](https://arxiv.org/pdf/1904.09223)),支持ERNIE 1.0中文模型(简写`ernie-1.0`)和ERNIE Tiny中文模型(简写`ernie_tiny`)。 其中`ernie`由12层Transformer网络组成,`ernie_tiny`由3层Transformer网络组成。 + RoBERTa([A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692)),支持24层Transformer网络的`roberta-wwm-ext-large`和12层Transformer网络的`roberta-wwm-ext`。 ++ Electra([ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555)), 支持hidden_size=256的`chinese-electra-discriminator-small`和 + hidden_size=768的`chinese-electra-discriminator-base` | 模型 | dev acc | test acc | | ---- | ------- | -------- | @@ -63,24 +65,24 @@ pretrained_models/ 我们以中文情感分类公开数据集ChnSentiCorp为示例数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证 ```shell # 设置使用的GPU卡号 -export CUDA_VISIBLE_DEVICES=0 +CUDA_VISIBLE_DEVICES=0 python train.py --model_type ernie --model_name ernie_tiny --n_gpu 1 --save_dir ./checkpoints ``` 可支持配置的参数: -* model_type:必选,模型类型,可以选择bert,ernie,roberta。 -* model_name: 必选,具体的模型简称。如`model_type=ernie`,则model_name可以选择`ernie`和`ernie_tiny`。`model_type=bert`,则model_name可以选择`bert-base-chinese`。 +* `model_type`:必选,模型类型,可以选择bert,ernie,roberta。 +* `model_name`: 必选,具体的模型简称。如`model_type=ernie`,则model_name可以选择`ernie`和`ernie_tiny`。`model_type=bert`,则model_name可以选择`bert-base-chinese`。 `model_type=roberta`,则model_name可以选择`roberta-wwm-ext-large`和`roberta-wwm-ext`。 -* save_dir:必选,保存训练模型的目录。 -* max_seq_length:可选,ERNIE/BERT模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。 -* batch_size:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。 -* learning_rate:可选,Fine-tune的最大学习率;默认为5e-5。 -* weight_decay:可选,控制正则项力度的参数,用于防止过拟合,默认为0.00。 -* warmup_proption:可选,学习率warmup策略的比例,如果0.1,则学习率会在前10%训练step的过程中从0慢慢增长到learning_rate, 而后再缓慢衰减,默认为0.1。 -* init_from_ckpt:可选,模型参数路径,热启动模型训练;默认为None。 -* seed:可选,随机种子,默认为1000. -* n_gpu:可选,训练过程中使用GPU卡数量,默认为1。若n_gpu=0,则使用CPU训练。 +* `save_dir`:必选,保存训练模型的目录。 +* `max_seq_length`:可选,ERNIE/BERT模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。 +* `batch_size`:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。 +* `learning_rate`:可选,Fine-tune的最大学习率;默认为5e-5。 +* `weight_decay`:可选,控制正则项力度的参数,用于防止过拟合,默认为0.00。 +* `warmup_proption`:可选,学习率warmup策略的比例,如果0.1,则学习率会在前10%训练step的过程中从0慢慢增长到learning_rate, 而后再缓慢衰减,默认为0.1。 +* `init_from_ckpt`:可选,模型参数路径,热启动模型训练;默认为None。 +* `seed`:可选,随机种子,默认为1000. +* `n_gpu`:可选,训练过程中使用GPU卡数量,默认为1。若n_gpu=0,则使用CPU训练。 程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存模型在指定的`save_dir`中。 diff --git a/PaddleNLP/examples/text_classification/pretrained_models/train.py b/PaddleNLP/examples/text_classification/pretrained_models/train.py index 2a590b9a..33cf2211 100644 --- a/PaddleNLP/examples/text_classification/pretrained_models/train.py +++ b/PaddleNLP/examples/text_classification/pretrained_models/train.py @@ -94,7 +94,7 @@ def parse_args(): help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proption", - default=0.1, + default=0.0, type=float, help="Linear warmup proption over the training process.") parser.add_argument( @@ -304,7 +304,7 @@ def do_train(args): ]) criterion = paddle.nn.loss.CrossEntropyLoss() - metric = paddle.metric.Accuracy(name='acc_accumulation') + metric = paddle.metric.Accuracy() global_step = 0 tic_train = time.time() diff --git a/PaddleNLP/examples/text_classification/rnn/README.md b/PaddleNLP/examples/text_classification/rnn/README.md index b8fd860e..0eac7ee4 100644 --- a/PaddleNLP/examples/text_classification/rnn/README.md +++ b/PaddleNLP/examples/text_classification/rnn/README.md @@ -28,15 +28,6 @@ | Bi-LSTM Attention | 序列模型,在双向LSTM结构之上加入Attention机制,结合上下文更好地表征句子语义特征 | | TextCNN | 序列模型,使用多种卷积核大小,提取局部区域地特征 | -+ BOW(Bag Of Words)模型,是一个非序列模型,使用基本的全连接结构; -+ RNN (Recurrent Neural Network),序列模型,能够有效地处理序列信息; -+ GRU(Gated Recurrent Unit),序列模型,能够较好地解决序列文本中长距离依赖的问题; -+ LSTM(Long Short Term Memory),序列模型,能够较好地解决序列文本中长距离依赖的问题; -+ Bi-LSTM(Bidirectional Long Short Term Memory),序列模型,采用双向LSTM结构,更好地捕获句子中的语义特征; -+ Bi-GRU(Bidirectional Gated Recurrent Unit),序列模型,采用双向GRU结构,更好地捕获句子中的语义特征; -+ Bi-RNN(Bidirectional Recurrent Neural Network),序列模型,采用双向RNN结构,更好地捕获句子中的语义特征; -+ Bi-LSTM Attention, 序列模型,在双向LSTM结构之上加入Attention机制,结合上下文更好地表征句子语义特征; -+ TextCNN, 序列模型,使用多种卷积核大小,提取局部区域地特征; | 模型 | dev acc | test acc | | ---- | ------- | -------- | @@ -73,11 +64,10 @@ ```text . -├── config.py # 运行配置文件 ├── data.py # 数据读取 -├── train.py # 训练模型主程序入口,包括训练、评估 ├── predict.py # 模型预测 -├── model.py # 模型组网 +├── utils.py # 数据处理工具 +├── train.py # 训练模型主程序入口,包括训练、评估 └── README.md # 文档说明 ``` @@ -86,9 +76,9 @@ #### 使用PaddleNLP内置数据集 ```python -train_dataset = ppnlp.datasets.ChnSentiCorp('train') -dev_dataset = ppnlp.datasets.ChnSentiCorp('dev') -test_dataset = ppnlp.datasets.ChnSentiCorp('test') +from paddlenlp.datasets import ChnSentiCorp + +train_ds, dev_ds, test_ds = ChnSentiCorp.get_datasets(['train', 'dev', 'test']) ``` #### 自定义数据集 @@ -100,29 +90,33 @@ test_dataset = ppnlp.datasets.ChnSentiCorp('test') 在模型训练之前,需要先下载词汇表文件word_dict.txt,用于构造词-id映射关系。 ```shell -wget https://paddlenlp.bj.bcebos.com/data/word_dict.txt +wget https://paddlenlp.bj.bcebos.com/data/senta_word_dict.txt ``` 我们以中文情感分类公开数据集ChnSentiCorp为示例数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证 + +CPU 启动: + ```shell -# CPU启动 -python train.py --vocab_path='./word_dict.txt' --use_gpu=False --network_name=bilstm --lr=5e-4 --batch_size=64 --epochs=5 --save_dir='./checkpoints' +python train.py --vocab_path='./senta_word_dict.txt' --use_gpu=False --network_name=bilstm --lr=5e-4 --batch_size=64 --epochs=5 --save_dir='./checkpoints' +``` -# GPU启动 -# CUDA_VSIBLE_DEVICES指定想要利用的GPU卡号,可以是单卡,也可以多卡 -# CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch train.py --vocab_path='./word_dict.txt' --use_gpu=True --network_name=bilstm --lr=5e-4 --batch_size=64 --epochs=5 --save_dir='./checkpoints' +GPU 启动: + +```shell +# CUDA_VISIBLE_DEVICES=0 python train.py --vocab_path='./senta_word_dict.txt' --use_gpu=True --network_name=bilstm --lr=5e-4 --batch_size=64 --epochs=5 --save_dir='./checkpoints' ``` 以上参数表示: -* vocab_path: 词汇表文件路径。 -* use_gpu: 是否使用GPU进行训练, 默认为`False`。 -* network_name: 模型网络名称,默认为`bilstm_attn`, 可更换为bilstm, bigru, birnn,bow,lstm,rnn,gru,bilstm_attn,textcnn等。 -* lr: 学习率, 默认为5e-4。 -* batch_size: 运行一个batch大小,默认为64。 -* epochs: 训练轮次,默认为5。 -* save_dir: 训练保存模型的文件路径。 -* init_from_ckpt: 恢复模型训练的断点路径。 +* `vocab_path`: 词汇表文件路径。 +* `use_gpu`: 是否使用GPU进行训练, 默认为`False`。 +* `network_name`: 模型网络名称,默认为`bilstm_attn`, 可更换为bilstm, bigru, birnn,bow,lstm,rnn,gru,bilstm_attn,textcnn等。 +* `lr`: 学习率, 默认为5e-4。 +* `batch_size`: 运行一个batch大小,默认为64。 +* `epochs`: 训练轮次,默认为5。 +* `save_dir`: 训练保存模型的文件路径。 +* `init_from_ckpt`: 恢复模型训练的断点路径。 程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存模型在指定的`save_dir`中。 @@ -142,21 +136,25 @@ checkpoints/ ### 模型预测 启动预测: + +CPU启动: + ```shell -# CPU启动 -python predict.py --vocab_path='./word_dict.txt' --use_gpu=False --network_name=bilstm --params_path=checkpoints/final.pdparams +python predict.py --vocab_path='./senta_word_dict.txt' --use_gpu=False --network_name=bilstm --params_path=checkpoints/final.pdparams +``` + +GPU启动: -# GPU启动 -# CUDA_VSIBLE_DEVICES指定想要利用的GPU卡号,可以是单卡,也可以多卡 -# CUDA_VISIBLE_DEVICES=0 python predict.py --vocab_path='./word_dict.txt' --use_gpu=True --network_name=bilstm --params_path='./checkpoints/final.pdparams' +```shell +CUDA_VISIBLE_DEVICES=0 python predict.py --vocab_path='./senta_word_dict.txt' --use_gpu=True --network_name=bilstm --params_path='./checkpoints/final.pdparams' ``` 将待预测数据分词完毕后,如以下示例: ```text -这个 宾馆 比较 陈旧 了 , 特价 的 房间 也 很一般 。 总体来说 一般 -怀着 十分 激动 的 心情 放映 , 可是 看着 看着 发现 , 在 放映 完毕 后 , 出现 一集米 老鼠 的 动画片 ! -作为 老 的 四星酒店 , 房间 依然 很 整洁 , 相当 不错 。 机场 接机 服务 很好 , 可以 在 车上 办理 入住 手续 , 节省 时间 。 +这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般 +怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片 +作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。 ``` 处理成模型所需的`Tensor`,如可以直接调用`preprocess_prediction_data`函数既可处理完毕。之后传入`predict`函数即可输出预测结果。 @@ -164,12 +162,7 @@ python predict.py --vocab_path='./word_dict.txt' --use_gpu=False --network_name= 如 ```text -Data: 这个 宾馆 比较 陈旧 了 , 特价 的 房间 也 很一般 。 总体来说 一般 Lable: negative -Data: 怀着 十分 激动 的 心情 放映 , 可是 看着 看着 发现 , 在 放映 完毕 后 , 出现 一集米 老鼠 的 动画片 ! Lable: negative -Data: 作为 老 的 四星酒店 , 房间 依然 很 整洁 , 相当 不错 。 机场 接机 服务 很好 , 可以 在 车上 办理 入住 手续 , 节省 时间 。 Lable: positive +Data: 这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般 Lable: negative +Data: 怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片 Lable: negative +Data: 作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。 Lable: positive ``` - -## 其他 - -1、如何进行多分类? -本项目采用二分类数据集,如需进行多分类任务,修改类别数目及类别标签列表即可。 \ No newline at end of file diff --git a/PaddleNLP/examples/text_matching/README.md b/PaddleNLP/examples/text_matching/README.md index d83d18bc..a875dcbc 100644 --- a/PaddleNLP/examples/text_matching/README.md +++ b/PaddleNLP/examples/text_matching/README.md @@ -1,5 +1,29 @@ -# Text Matching +# Pointwise文本匹配 -## SimNet +**文本匹配一直是自然语言处理(NLP)领域一个基础且重要的方向,一般研究两段文本之间的关系。文本相似度计算、自然语言推理、问答系统、信息检索等,都可以看作针对不同数据和场景的文本匹配应用。这些自然语言处理任务在很大程度上都可以抽象成文本匹配问题,比如信息检索可以归结为搜索词和文档资源的匹配,问答系统可以归结为问题和候选答案的匹配,复述问题可以归结为两个同义句的匹配,对话系统可以归结为前一句对话和回复的匹配,机器翻译则可以归结为两种语言的匹配。** -## Sentence-BERT +

+
+

+ + +

+
+

+ + +文本匹配任务可以分为pointwise和pairwise类型。 + +pointwise,每一个样本通常由两个文本组成(query,title)。类别形式为0或1,0表示query与title不匹配; 1表示匹配。 +pairwise,每一个样本通常由三个文本组成(query,positive_title, negative_title)。positive_title比negative_title更加匹配query。 +根据本数据集示例,该匹配任务为pointwise类型。 + +该项目展示了使用传统的[SimNet](./simnet) 和 [SentenceBert](./sentence_bert)两种方法完成pointwise本匹配任务。 + +## Conventional Models + +[SimNet](./simnet) 展示了如何使用CNN、LSTM、GRU等网络完成pointwise文本匹配任务。 + +## Pretrained Model (PTMs) + +[Sentence Transformers](./sentence_transformers) 展示了如何使用以ERNIE为代表的模型Fine-tune完成pointwise文本匹配任务。 diff --git a/PaddleNLP/examples/text_matching/sentence_transformers/README.md b/PaddleNLP/examples/text_matching/sentence_transformers/README.md new file mode 100644 index 00000000..5670563c --- /dev/null +++ b/PaddleNLP/examples/text_matching/sentence_transformers/README.md @@ -0,0 +1,207 @@ +# 使用预训练模型Fine-tune完成pointwise中文文本匹配任务 + +随着深度学习的发展,模型参数的数量飞速增长。为了训练这些参数,需要更大的数据集来避免过拟合。然而,对于大部分NLP任务来说,构建大规模的标注数据集非常困难(成本过高),特别是对于句法和语义相关的任务。相比之下,大规模的未标注语料库的构建则相对容易。为了利用这些数据,我们可以先从其中学习到一个好的表示,再将这些表示应用到其他任务中。最近的研究表明,基于大规模未标注语料库的预训练模型(Pretrained Models, PTM) 在NLP任务上取得了很好的表现。 + +近年来,大量的研究表明基于大型语料库的预训练模型(Pretrained Models, PTM)可以学习通用的语言表示,有利于下游NLP任务,同时能够避免从零开始训练模型。随着计算能力的发展,深度模型的出现(即 Transformer)和训练技巧的增强使得 PTM 不断发展,由浅变深。 + +百度的预训练模型ERNIE经过海量的数据训练后,其特征抽取的工作已经做的非常好。借鉴迁移学习的思想,我们可以利用其在海量数据中学习的语义信息辅助小数据集(如本示例中的医疗文本数据集)上的任务。 + +
+ +使用预训练模型ERNIE完成pointwise文本匹配任务,大家可能会想到将query和title文本拼接,之后输入ERNIE中,取`CLS`特征(pooled_output),之后输出全连接层,进行二分类。如下图ERNIE用于句对分类任务的用法: + +

+
+

+ +然而,以上用法的问题在于,**ERNIE的模型参数非常庞大,导致计算量非常大,预测的速度也不够理想**。从而达不到线上业务的要求。针对该问题,可以使用PaddleNLP工具搭建Sentence Transformer网络。 + +

+
+

+ +Sentence Transformer采用了双塔(Siamese)的网络结构。Query和Title分别输入ERNIE,共享一个ERNIE参数,得到各自的token embedding特征。之后对token embedding进行pooling(此处教程使用mean pooling操作),之后输出分别记作u,v。之后将三个表征(u,v,|u-v|)拼接起来,进行二分类。网络结构如上图所示。 + +更多关于Sentence Transformer的信息可以参考论文:https://arxiv.org/abs/1908.10084 + +**同时,不仅可以使用ERNIR作为文本语义特征提取器,可以利用BERT/RoBerta/Electra等模型作为文本语义特征提取器** + +**那么Sentence Transformer采用Siamese的网路结构,是如何提升预测速度呢?** + +**Siamese的网络结构好处在于query和title分别输入同一套网络。如在信息搜索任务中,此时就可以将数据库中的title文本提前计算好对应sequence_output特征,保存在数据库中。当用户搜索query时,只需计算query的sequence_output特征与保存在数据库中的title sequence_output特征,通过一个简单的mean_pooling和全连接层进行二分类即可。从而大幅提升预测效率,同时也保障了模型性能。** + +关于匹配任务常用的Siamese网络结构可以参考:https://blog.csdn.net/thriving_fcl/article/details/73730552 + +PaddleNLP提供了丰富的预训练模型,并且可以便捷地获取PaddlePaddle生态下的所有预训练模型。下面展示如何使用PaddleNLP一键加载ERNIE,优化文本匹配任务。 + +## 模型简介 + +本项目针对中文文本匹配问题,开源了一系列模型,供用户可配置地使用: + ++ BERT([Bidirectional Encoder Representations from Transformers](https://arxiv.org/abs/1810.04805))中文模型,简写`bert-base-chinese`, 其由12层Transformer网络组成。 ++ ERNIE([Enhanced Representation through Knowledge Integration](https://arxiv.org/pdf/1904.09223)),支持ERNIE 1.0中文模型(简写`ernie-1.0`)和ERNIE Tiny中文模型(简写`ernie_tiny`)。 + 其中`ernie`由12层Transformer网络组成,`ernie_tiny`由3层Transformer网络组成。 ++ RoBERTa([A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692)),支持24层Transformer网络的`roberta-wwm-ext-large`和12层Transformer网络的`roberta-wwm-ext`。 ++ Electra([ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555)), 支持hidden_size=256的`chinese-electra-discriminator-small`和 + hidden_size=768的`chinese-electra-discriminator-base` + +## TODO 增加模型效果 +| 模型 | dev acc | test acc | +| ---- | ------- | -------- | +| bert-base-chinese | | | +| bert-wwm-chinese | | | +| bert-wwm-ext-chinese | | | +| ernie | | | +| ernie-tiny | | | +| roberta-wwm-ext | | | +| roberta-wwm-ext-large | | | +| rbt3 | | | +| rbtl3 | | | +| chinese-electra-discriminator-base | | | +| chinese-electra-discriminator-small | | | + +## 快速开始 + +### 安装说明 + +* PaddlePaddle 安装 + + 本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装 + +* PaddleNLP 安装 + + ```shell + pip install paddlenlp + ``` + +* 环境依赖 + + Python的版本要求 3.6+,其它环境请参考 PaddlePaddle [安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html) 部分的内容 + +### 代码结构说明 + +以下是本项目主要代码结构及说明: + +```text +sentence_transformers/ +├── checkpoint +│   ├── model_100 +│   │   ├── model_state.pdparams +│   │   ├── tokenizer_config.json +│   │   └── vocab.txt +│   ├── ... +│ +├── model.py # Sentence Transfomer 组网文件 +├── README.md # 文本说明 +└── train.py # 模型训练评估 +``` + +### 模型训练 + +我们以中文文本匹配公开数据集LCQMC为示例数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证 +```shell +# 设置使用的GPU卡号 +CUDA_VISIBLE_DEVICES=0 +python train.py --model_type ernie --model_name ernie-1.0 --n_gpu 1 --save_dir ./checkpoints +``` + +可支持配置的参数: + +* `model_type`:必选,模型类型,可以选择bert,ernie,roberta。 +* `model_name`: 必选,具体的模型简称。如`model_type=ernie`,则model_name可以选择`ernie`和`ernie_tiny`。`model_type=bert`,则model_name可以选择`bert-base-chinese`。 + `model_type=roberta`,则model_name可以选择`roberta-wwm-ext-large`和`roberta-wwm-ext`。 +* `save_dir`:必选,保存训练模型的目录。 +* `max_seq_length`:可选,ERNIE/BERT模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。 +* `batch_size`:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。 +* `learning_rate`:可选,Fine-tune的最大学习率;默认为5e-5。 +* `weight_decay`:可选,控制正则项力度的参数,用于防止过拟合,默认为0.00。 +* `warmup_proption`:可选,学习率warmup策略的比例,如果0.1,则学习率会在前10%训练step的过程中从0慢慢增长到learning_rate, 而后再缓慢衰减,默认为0.1。 +* `init_from_ckpt`:可选,模型参数路径,热启动模型训练;默认为None。 +* `seed`:可选,随机种子,默认为1000. +* `n_gpu`:可选,训练过程中使用GPU卡数量,默认为1。若n_gpu=0,则使用CPU训练。 + + +程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存模型在指定的`save_dir`中。 +如: +```text +checkpoints/ +├── model_100 +│   ├── model_config.json +│   ├── model_state.pdparams +│   ├── tokenizer_config.json +│   └── vocab.txt +└── ... +``` + +**NOTE:** +* 如需恢复模型训练,则可以设置`init_from_ckpt`, 如`init_from_ckpt=checkpoints/model_100/model_state.pdparams`。 +* 如需使用ernie_tiny模型,则需要提前先安装sentencepiece依赖,如`pip install sentencepiece` + +### 模型预测 + +启动预测: +```shell +CUDA_VISIBLE_DEVICES=0 +python predict.py --model_type ernie --model_name ernie_tiny --params_path checkpoints/model_400/model_state.pdparams +``` + +将待预测数据如以下示例: + +```text +世界上什么东西最小 世界上什么东西最小? +光眼睛大就好看吗 眼睛好看吗? +小蝌蚪找妈妈怎么样 小蝌蚪找妈妈是谁画的 +``` + +可以直接调用`predict`函数即可输出预测结果。 + +如 + +```text +Data: ['世界上什么东西最小', '世界上什么东西最小?'] Label: similar +Data: ['光眼睛大就好看吗', '眼睛好看吗?'] Label: dissimilar +Data: ['小蝌蚪找妈妈怎么样', '小蝌蚪找妈妈是谁画的'] Label: dissimilar +``` + + +## 引用 + +关于Sentence Transformer更多信息参考[www.SBERT.net](https://www.sbert.net)以及论文: +- [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084) (EMNLP 2019) +- [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813) (EMNLP 2020) +- [Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks](https://arxiv.org/abs/2010.08240) (arXiv 2020) + +``` +@inproceedings{reimers-2019-sentence-bert, + title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks", + author = "Reimers, Nils and Gurevych, Iryna", + booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing", + month = "11", + year = "2019", + publisher = "Association for Computational Linguistics", + url = "https://arxiv.org/abs/1908.10084", +} +``` + +``` +@inproceedings{reimers-2020-multilingual-sentence-bert, + title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation", + author = "Reimers, Nils and Gurevych, Iryna", + booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing", + month = "11", + year = "2020", + publisher = "Association for Computational Linguistics", + url = "https://arxiv.org/abs/2004.09813", +} +``` + +``` +@article{thakur-2020-AugSBERT, + title = "Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks", + author = "Thakur, Nandan and Reimers, Nils and Daxenberger, Johannes and Gurevych, Iryna", + journal= "arXiv preprint arXiv:2010.08240", + month = "10", + year = "2020", + url = "https://arxiv.org/abs/2010.08240", +} +``` diff --git a/PaddleNLP/examples/text_matching/sentence_transformers/model.py b/PaddleNLP/examples/text_matching/sentence_transformers/model.py new file mode 100644 index 00000000..44c178f6 --- /dev/null +++ b/PaddleNLP/examples/text_matching/sentence_transformers/model.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class SentenceTransformer(nn.Layer): + def __init__(self, pretrained_model, dropout=None): + super().__init__() + self.ptm = pretrained_model + self.dropout = nn.Dropout(dropout if dropout is not None else 0.1) + # num_labels = 2 (similar or dissimilar) + self.classifier = nn.Linear(self.ptm.config["hidden_size"] * 3, 2) + + def forward(self, + query_input_ids, + title_input_ids, + query_token_type_ids=None, + query_position_ids=None, + query_attention_mask=None, + title_token_type_ids=None, + title_position_ids=None, + title_attention_mask=None): + query_token_embedding, _ = self.ptm( + query_input_ids, query_token_type_ids, query_position_ids, + query_attention_mask) + query_token_embedding = self.dropout(query_token_embedding) + query_attention_mask = paddle.unsqueeze( + (query_input_ids != self.ptm.pad_token_id + ).astype(self.ptm.pooler.dense.weight.dtype), + axis=2) + # Set token embeddings to 0 for padding tokens + query_token_embedding = query_token_embedding * query_attention_mask + query_sum_embedding = paddle.sum(query_token_embedding, axis=1) + query_sum_mask = paddle.sum(query_attention_mask, axis=1) + query_mean = query_sum_embedding / query_sum_mask + + title_token_embedding, _ = self.ptm( + title_input_ids, title_token_type_ids, title_position_ids, + title_attention_mask) + title_token_embedding = self.dropout(title_token_embedding) + title_attention_mask = paddle.unsqueeze( + (title_input_ids != self.ptm.pad_token_id + ).astype(self.ptm.pooler.dense.weight.dtype), + axis=2) + # Set token embeddings to 0 for padding tokens + title_token_embedding = title_token_embedding * title_attention_mask + title_sum_embedding = paddle.sum(title_token_embedding, axis=1) + title_sum_mask = paddle.sum(title_attention_mask, axis=1) + title_mean = title_sum_embedding / title_sum_mask + + sub = paddle.abs(paddle.subtract(query_mean, title_mean)) + projection = paddle.concat([query_mean, title_mean, sub], axis=-1) + + logits = self.classifier(projection) + probs = F.softmax(logits) + + return probs diff --git a/PaddleNLP/examples/text_matching/sentence_transformers/predict.py b/PaddleNLP/examples/text_matching/sentence_transformers/predict.py new file mode 100644 index 00000000..6d90dc9a --- /dev/null +++ b/PaddleNLP/examples/text_matching/sentence_transformers/predict.py @@ -0,0 +1,227 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import argparse +import os +import random +import time + +import numpy as np +import paddle +import paddle.nn.functional as F +import paddlenlp as ppnlp +from paddlenlp.data import Stack, Tuple, Pad + +from model import SentenceTransformer + +MODEL_CLASSES = { + "bert": (ppnlp.transformers.BertModel, ppnlp.transformers.BertTokenizer), + 'ernie': (ppnlp.transformers.ErnieModel, ppnlp.transformers.ErnieTokenizer), + 'roberta': (ppnlp.transformers.RobertaModel, + ppnlp.transformers.RobertaTokenizer), + # 'electra': (ppnlp.transformers.Electra, + # ppnlp.transformers.ElectraTokenizer) +} + + +# yapf: disable +def parse_args(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--model_type", default='ernie', type=str, help="Model type selected in the list: " +", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--model_name", default='ernie-1.0', type=str, help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join(sum([list(classes[-1].pretrained_init_configuration.keys()) for classes in MODEL_CLASSES.values()], []))) + parser.add_argument("--params_path", type=str, default='./checkpoint/model_4900/model_state.pdparams', help="The path to model parameters to be loaded.") + + parser.add_argument("--max_seq_length", default=50, type=int, help="The maximum total input sequence length after tokenization. " + "Sequences longer than this will be truncated, sequences shorter will be padded.") + parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.") + parser.add_argument("--n_gpu", type=int, default=0, help="Number of GPUs to use, 0 for CPU.") + args = parser.parse_args() + return args +# yapf: enable + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False): + """ + Builds model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. And creates a mask from the two sequences passed + to be used in a sequence-pair classification task. + + A BERT sequence has the following format: + + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + + A BERT sequence pair mask has the following format: + :: + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + + If only one sequence, only returns the first portion of the mask (0's). + + + Args: + example(obj:`list[str]`): List of input data, containing query, title and label if it have label. + tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer` + which contains most of the methods. Users should refer to the superclass for more information regarding methods. + label_list(obj:`list[str]`): All the labels that the data has. + max_seq_len(obj:`int`): The maximum total input sequence length after tokenization. + Sequences longer than this will be truncated, sequences shorter will be padded. + is_test(obj:`False`, defaults to `False`): Whether the example contains label or not. + + Returns: + query_input_ids(obj:`list[int]`): The list of query token ids. + query_segment_ids(obj: `list[int]`): List of query sequence pair mask. + title_input_ids(obj:`list[int]`): The list of title token ids. + title_segment_ids(obj: `list[int]`): List of title sequence pair mask. + label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test. + """ + print(example) + query, title = example[0], example[1] + + query_encoded_inputs = tokenizer.encode( + text=query, max_seq_len=max_seq_length) + query_input_ids = query_encoded_inputs["input_ids"] + query_segment_ids = query_encoded_inputs["segment_ids"] + + title_encoded_inputs = tokenizer.encode( + text=title, max_seq_len=max_seq_length) + title_input_ids = title_encoded_inputs["input_ids"] + title_segment_ids = title_encoded_inputs["segment_ids"] + + if not is_test: + # create label maps if classification task + label = example[-1] + label_map = {} + for (i, l) in enumerate(label_list): + label_map[l] = i + label = label_map[label] + label = np.array([label], dtype="int64") + return query_input_ids, query_segment_ids, title_input_ids, title_segment_ids, label + else: + return query_input_ids, query_segment_ids, title_input_ids, title_segment_ids + + +def predict(model, data, tokenizer, label_map, batch_size=1): + """ + Predicts the data labels. + + Args: + model (obj:`paddle.nn.Layer`): A model to classify texts. + data (obj:`List(Example)`): The processed data whose each element is a Example (numedtuple) object. + A Example object contains `text`(word_ids) and `se_len`(sequence length). + tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer` + which contains most of the methods. Users should refer to the superclass for more information regarding methods. + label_map(obj:`dict`): The label id (key) to label str (value) map. + batch_size(obj:`int`, defaults to 1): The number of batch. + + Returns: + results(obj:`dict`): All the predictions labels. + """ + examples = [] + for text_pair in data: + query_input_ids, query_segment_ids, title_input_ids, title_segment_ids = convert_example( + text_pair, + tokenizer, + label_list=label_map.values(), + max_seq_length=args.max_seq_length, + is_test=True) + examples.append((query_input_ids, query_segment_ids, title_input_ids, + title_segment_ids)) + + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input + Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_segment + Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tilte_segment + ): [data for data in fn(samples)] + + # Seperates data into some batches. + batches = [] + one_batch = [] + for example in examples: + one_batch.append(example) + if len(one_batch) == batch_size: + batches.append(one_batch) + one_batch = [] + if one_batch: + # The last batch whose size is less than the config batch_size setting. + batches.append(one_batch) + + results = [] + model.eval() + for batch in batches: + query_input_ids, query_segment_ids, title_input_ids, title_segment_ids = batchify_fn( + batch) + + query_input_ids = paddle.to_tensor(query_input_ids) + query_segment_ids = paddle.to_tensor(query_segment_ids) + title_input_ids = paddle.to_tensor(title_input_ids) + title_segment_ids = paddle.to_tensor(title_segment_ids) + + print(query_input_ids) + print(query_segment_ids) + print(title_segment_ids) + + probs = model( + query_input_ids, + title_input_ids, + query_token_type_ids=query_segment_ids, + title_token_type_ids=title_segment_ids) + idx = paddle.argmax(probs, axis=1).numpy() + idx = idx.tolist() + labels = [label_map[i] for i in idx] + results.extend(labels) + return results + + +if __name__ == "__main__": + args = parse_args() + paddle.set_device("gpu" if args.n_gpu else "cpu") + + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + if args.model_name == 'ernie_tiny': + # ErnieTinyTokenizer is special for ernie_tiny pretained model. + tokenizer = ppnlp.transformers.ErnieTinyTokenizer.from_pretrained( + args.model_name) + else: + tokenizer = tokenizer_class.from_pretrained(args.model_name) + + data = [ + ['世界上什么东西最小', '世界上什么东西最小?'], + ['光眼睛大就好看吗', '眼睛好看吗?'], + ['小蝌蚪找妈妈怎么样', '小蝌蚪找妈妈是谁画的'], + ] + label_map = {0: 'dissimilar', 1: 'similar'} + + pretrained_model = model_class.from_pretrained(args.model_name) + model = SentenceTransformer(pretrained_model) + + if args.params_path and os.path.isfile(args.params_path): + state_dict = paddle.load(args.params_path) + model.set_dict(state_dict) + print("Loaded parameters from %s" % args.params_path) + + results = predict( + model, data, tokenizer, label_map, batch_size=args.batch_size) + for idx, text in enumerate(data): + print('Data: {} \t Lable: {}'.format(text, results[idx])) diff --git a/PaddleNLP/examples/text_matching/sentence_transformers/train.py b/PaddleNLP/examples/text_matching/sentence_transformers/train.py new file mode 100644 index 00000000..bb9673ab --- /dev/null +++ b/PaddleNLP/examples/text_matching/sentence_transformers/train.py @@ -0,0 +1,370 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import argparse +import os +import random +import time + +import numpy as np +import paddle +import paddle.nn.functional as F + +from paddlenlp.data import Stack, Tuple, Pad +import paddlenlp as ppnlp + +from model import SentenceTransformer + +MODEL_CLASSES = { + "bert": (ppnlp.transformers.BertModel, ppnlp.transformers.BertTokenizer), + 'ernie': (ppnlp.transformers.ErnieModel, ppnlp.transformers.ErnieTokenizer), + 'roberta': (ppnlp.transformers.RobertaModel, + ppnlp.transformers.RobertaTokenizer), + # 'electra': (ppnlp.transformers.Electra, + # ppnlp.transformers.ElectraTokenizer) +} + + +def parse_args(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_type", + default='ernie', + required=True, + type=str, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument( + "--model_name", + default='ernie-1.0', + required=True, + type=str, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], []))) + parser.add_argument( + "--save_dir", + default='./checkpoint', + required=True, + type=str, + help="The output directory where the model checkpoints will be written.") + + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. " + "Sequences longer than this will be truncated, sequences shorter will be padded." + ) + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument( + "--learning_rate", + default=5e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--epochs", + default=3, + type=int, + help="Total number of training epochs to perform.") + parser.add_argument( + "--warmup_proption", + default=0.0, + type=float, + help="Linear warmup proption over the training process.") + parser.add_argument( + "--init_from_ckpt", + type=str, + default=None, + help="The path of checkpoint to be loaded.") + parser.add_argument( + "--seed", type=int, default=1000, help="random seed for initialization") + parser.add_argument( + "--n_gpu", + type=int, + default=1, + help="Number of GPUs to use, 0 for CPU.") + args = parser.parse_args() + return args + + +def set_seed(args): + """sets random seed""" + random.seed(args.seed) + np.random.seed(args.seed) + paddle.seed(args.seed) + + +@paddle.no_grad() +def evaluate(model, criterion, metric, data_loader): + """ + Given a dataset, it evals model and computes the metric. + + Args: + model(obj:`paddle.nn.Layer`): A model to classify texts. + data_loader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches. + criterion(obj:`paddle.nn.Layer`): It can compute the loss. + metric(obj:`paddle.metric.Metric`): The evaluation metric. + """ + model.eval() + metric.reset() + losses = [] + for batch in data_loader: + query_input_ids, query_segment_ids, title_input_ids, title_segment_ids, labels = batch + probs = model( + query_input_ids, + title_input_ids, + query_token_type_ids=query_segment_ids, + title_token_type_ids=title_segment_ids) + loss = criterion(probs, labels) + losses.append(loss.numpy()) + correct = metric.compute(probs, labels) + metric.update(correct) + accu = metric.accumulate() + print("eval loss: %.5f, accu: %.5f" % (np.mean(losses), accu)) + model.train() + metric.reset() + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False): + """ + Builds model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. And creates a mask from the two sequences passed + to be used in a sequence-pair classification task. + + A BERT sequence has the following format: + + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + + A BERT sequence pair mask has the following format: + :: + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + + If only one sequence, only returns the first portion of the mask (0's). + + + Args: + example(obj:`list[str]`): List of input data, containing query, title and label if it have label. + tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer` + which contains most of the methods. Users should refer to the superclass for more information regarding methods. + label_list(obj:`list[str]`): All the labels that the data has. + max_seq_len(obj:`int`): The maximum total input sequence length after tokenization. + Sequences longer than this will be truncated, sequences shorter will be padded. + is_test(obj:`False`, defaults to `False`): Whether the example contains label or not. + + Returns: + query_input_ids(obj:`list[int]`): The list of query token ids. + query_segment_ids(obj: `list[int]`): List of query sequence pair mask. + title_input_ids(obj:`list[int]`): The list of title token ids. + title_segment_ids(obj: `list[int]`): List of title sequence pair mask. + label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test. + """ + query, title = example[0], example[1] + + query_encoded_inputs = tokenizer.encode( + text=query, max_seq_len=max_seq_length) + query_input_ids = query_encoded_inputs["input_ids"] + query_segment_ids = query_encoded_inputs["segment_ids"] + + title_encoded_inputs = tokenizer.encode( + text=title, max_seq_len=max_seq_length) + title_input_ids = title_encoded_inputs["input_ids"] + title_segment_ids = title_encoded_inputs["segment_ids"] + + if not is_test: + # create label maps if classification task + label = example[-1] + label_map = {} + for (i, l) in enumerate(label_list): + label_map[l] = i + label = label_map[label] + label = np.array([label], dtype="int64") + return query_input_ids, query_segment_ids, title_input_ids, title_segment_ids, label + else: + return query_input_ids, query_segment_ids, title_input_ids, title_segment_ids + + +def create_dataloader(dataset, + mode='train', + batch_size=1, + batchify_fn=None, + trans_fn=None): + if trans_fn: + dataset = dataset.apply(trans_fn, lazy=True) + + shuffle = True if mode == 'train' else False + if mode == 'train': + batch_sampler = paddle.io.DistributedBatchSampler( + dataset, batch_size=batch_size, shuffle=shuffle) + else: + batch_sampler = paddle.io.BatchSampler( + dataset, batch_size=batch_size, shuffle=shuffle) + + return paddle.io.DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=batchify_fn, + return_list=True) + + +def do_train(args): + set_seed(args) + paddle.set_device("gpu" if args.n_gpu else "cpu") + world_size = paddle.distributed.get_world_size() + if world_size > 1: + paddle.distributed.init_parallel_env() + + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + train_dataset, dev_dataset, test_dataset = ppnlp.datasets.LCQMC.get_datasets( + ['train', 'dev', 'test']) + if args.model_name == 'ernie_tiny': + # ErnieTinyTokenizer is special for ernie_tiny pretained model. + tokenizer = ppnlp.transformers.ErnieTinyTokenizer.from_pretrained( + args.model_name) + else: + tokenizer = tokenizer_class.from_pretrained(args.model_name) + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_dataset.get_labels(), + max_seq_length=args.max_seq_length) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input + Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_segment + Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tilte_segment + Stack(dtype="int64") # label + ): [data for data in fn(samples)] + train_data_loader = create_dataloader( + train_dataset, + mode='train', + batch_size=args.batch_size, + batchify_fn=batchify_fn, + trans_fn=trans_func) + dev_data_loader = create_dataloader( + dev_dataset, + mode='dev', + batch_size=args.batch_size, + batchify_fn=batchify_fn, + trans_fn=trans_func) + test_data_loader = create_dataloader( + test_dataset, + mode='test', + batch_size=args.batch_size, + batchify_fn=batchify_fn, + trans_fn=trans_func) + + pretrained_model = model_class.from_pretrained(args.model_name) + model = SentenceTransformer(pretrained_model) + + if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt): + state_dict = paddle.load(args.init_from_ckpt) + model.set_dict(state_dict) + model = paddle.DataParallel(model) + + num_training_steps = len(train_data_loader) * args.epochs + num_warmup_steps = int(args.warmup_proption * num_training_steps) + + def get_lr_factor(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + else: + return max(0.0, + float(num_training_steps - current_step) / + float(max(1, num_training_steps - num_warmup_steps))) + + lr_scheduler = paddle.optimizer.lr.LambdaDecay( + args.learning_rate, + lr_lambda=lambda current_step: get_lr_factor(current_step)) + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ]) + + criterion = paddle.nn.loss.CrossEntropyLoss() + metric = paddle.metric.Accuracy() + + global_step = 0 + tic_train = time.time() + for epoch in range(1, args.epochs + 1): + for step, batch in enumerate(train_data_loader, start=1): + query_input_ids, query_segment_ids, title_input_ids, title_segment_ids, labels = batch + probs = model( + query_input_ids, + title_input_ids, + query_token_type_ids=query_segment_ids, + title_token_type_ids=title_segment_ids) + loss = criterion(probs, labels) + correct = metric.compute(probs, labels) + metric.update(correct) + acc = metric.accumulate() + + global_step += 1 + if global_step % 10 == 0 and paddle.distributed.get_rank() == 0: + print( + "global step %d, epoch: %d, batch: %d, loss: %.5f, accu: %.5f, speed: %.2f step/s" + % (global_step, epoch, step, loss, acc, + 10 / (time.time() - tic_train))) + tic_train = time.time() + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_gradients() + if global_step % 100 == 0 and paddle.distributed.get_rank() == 0: + save_dir = os.path.join(args.save_dir, "model_%d" % global_step) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + evaluate(model, criterion, metric, dev_data_loader) + save_param_path = os.path.join(save_dir, 'model_state.pdparams') + paddle.save(model.state_dict(), save_param_path) + tokenizer.save_pretrained(save_dir) + + if paddle.distributed.get_rank() == 0: + print('Evaluating on test data.') + evaluate(model, criterion, metric, test_data_loader) + + +if __name__ == "__main__": + args = parse_args() + if args.n_gpu > 1: + paddle.distributed.spawn(do_train, args=(args, ), nprocs=args.n_gpu) + else: + do_train(args) diff --git a/PaddleNLP/examples/text_matching/simnet/README.md b/PaddleNLP/examples/text_matching/simnet/README.md new file mode 100644 index 00000000..111f836e --- /dev/null +++ b/PaddleNLP/examples/text_matching/simnet/README.md @@ -0,0 +1,168 @@ +# 使用SimNet完成pointwise文本匹配任务 + +短文本语义匹配(SimilarityNet, SimNet)是一个计算短文本相似度的框架,可以根据用户输入的两个文本,计算出相似度得分。 +SimNet框架在百度各产品上广泛应用,主要包括BOW、CNN、RNN、MMDNN等核心网络结构形式,提供语义相似度计算训练和预测框架, +适用于信息检索、新闻推荐、智能客服等多个应用场景,帮助企业解决语义匹配问题。 +可通过[AI开放平台-短文本相似度](https://ai.baidu.com/tech/nlp_basic/simnet)线上体验。 + +## 模型简介 + + +本项目通过调用[Seq2Vec](../../../paddlenlp/seq2vec/)中内置的模型进行序列建模,完成句子的向量表示。包含最简单的词袋模型和一系列经典的RNN类模型。 + +| 模型 | 模型介绍 | +| ------------------------------------------------ | ------------------------------------------------------------ | +| BOW(Bag Of Words) | 非序列模型,将句子表示为其所包含词的向量的加和 | +| RNN (Recurrent Neural Network) | 序列模型,能够有效地处理序列信息 | +| GRU(Gated Recurrent Unit) | 序列模型,能够较好地解决序列文本中长距离依赖的问题 | +| LSTM(Long Short Term Memory) | 序列模型,能够较好地解决序列文本中长距离依赖的问题 | + + +## TBD 增加模型效果 +| 模型 | dev acc | test acc | +| ---- | ------- | -------- | +| BoW | | | +| CNN | | | +| GRU | | | +| LSTM | | | + + + +## 快速开始 + +### 安装说明 + +* PaddlePaddle 安装 + + 本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装 + +* PaddleNLP 安装 + + ```shell + pip install paddlenlp + ``` + +* 环境依赖 + + 本项目依赖于jieba分词,请在运行本项目之前,安装jieba,如`pip install -U jieba` + + Python的版本要求 3.6+,其它环境请参考 PaddlePaddle [安装说明](https://www.paddlepaddle.org.cn/install/quick/zh/2.0rc-linux-docker) 部分的内容 + +### 代码结构说明 + +以下是本项目主要代码结构及说明: + +```text +. +├── data.py # 数据读取 +├── predict.py # 模型预测 +├── utils.py # 数据处理工具 +├── train.py # 训练模型主程序入口,包括训练、评估 +└── README.md # 文档说明 +``` + +### 数据准备 + +#### 使用PaddleNLP内置数据集 + +```python +from paddlenlp.datasets import LCQMC + +train_ds, dev_dataset, test_ds = LCQMC.get_datasets(['train', 'dev', 'test']) +``` + +部分样例数据如下: + +```text +query title label +最近有什么好看的电视剧,推荐一下 近期有什么好看的电视剧,求推荐? 1 +大学生验证仅针对在读学生,已毕业学生不能申请的哦。 通过了大学生验证的用户,可以在支付宝的合作商户,享受学生优惠 0 +如何在网上查户口 如何网上查户口 1 +关于故事的成语 来自故事的成语 1 + 湖北农村信用社手机银行客户端下载 湖北长阳农村商业银行手机银行客户端下载 0 +草泥马是什么动物 草泥马是一种什么动物 1 +``` + +### 模型训练 + +在模型训练之前,需要先下载词汇表文件term2id.dict,用于构造词-id映射关系。 + +```shell +wget https://paddlenlp.bj.bcebos.com/data/simnet_word_dict.txt +``` + +我们以中文pointwise文本匹配数据集LCQMC为示例数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证 + +CPU启动: + +```shell +CPU启动 +python train.py --vocab_path='./simnet_word_dict.txt' --use_gpu=False --network=lstm --lr=5e-4 --batch_size=64 --epochs=5 --save_dir='./checkpoints' +``` + +GPU启动: + +```shell +CUDA_VISIBLE_DEVICES=0 +python train.py --vocab_path='./simnet_word_dict.txt' --use_gpu=True --network=lstm --lr=5e-4 --batch_size=64 --epochs=5 --save_dir='./checkpoints' +``` + +以上参数表示: + +* `vocab_path`: 词汇表文件路径。 +* `use_gpu`: 是否使用GPU进行训练, 默认为`False`。 +* `network`: 模型网络名称,默认为`lstm`, 可更换为lstm, gru, rnn,bow,cnn等。 +* `lr`: 学习率, 默认为5e-4。 +* `batch_size`: 运行一个batch大小,默认为64。 +* `epochs`: 训练轮次,默认为5。 +* `save_dir`: 训练保存模型的文件路径。 +* `init_from_ckpt`: 恢复模型训练的断点路径。 + + +程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存模型在指定的`save_dir`中。 +如: +```text +checkpoints/ +├── 0.pdopt +├── 0.pdparams +├── 1.pdopt +├── 1.pdparams +├── ... +└── final.pdparams +``` + +**NOTE:** 如需恢复模型训练,则init_from_ckpt只需指定到文件名即可,不需要添加文件尾缀。如`--init_from_ckpt=checkpoints/0`即可,程序会自动加载模型参数`checkpoints/0.pdparams`,也会自动加载优化器状态`checkpoints/0.pdopt`。 + +### 模型预测 + +启动预测 + +CPU启动: + +```shell +python predict.py --vocab_path='./simnet_word_dict.txt' --use_gpu=False --network=lstm --params_path=checkpoints/final.pdparams +``` + +GPU启动: + +```shell +CUDA_VISIBLE_DEVICES=0 python predict.py --vocab_path='./simnet_word_dict.txt' --use_gpu=True --network=lstm --params_path='./checkpoints/final.pdparams' +``` + +将待预测数据分词完毕后,如以下示例: + +```text +世界上什么东西最小 世界上什么东西最小? +光眼睛大就好看吗 眼睛好看吗? +小蝌蚪找妈妈怎么样 小蝌蚪找妈妈是谁画的 +``` + +处理成模型所需的`Tensor`,如可以直接调用`preprocess_prediction_data`函数既可处理完毕。之后传入`predict`函数即可输出预测结果。 + +如 + +```text +Data: ['世界上什么东西最小', '世界上什么东西最小?'] Label: similar +Data: ['光眼睛大就好看吗', '眼睛好看吗?'] Label: dissimilar +Data: ['小蝌蚪找妈妈怎么样', '小蝌蚪找妈妈是谁画的'] Label: dissimilar +``` diff --git a/PaddleNLP/examples/text_matching/simnet/predict.py b/PaddleNLP/examples/text_matching/simnet/predict.py index 0a82fa68..07e252aa 100644 --- a/PaddleNLP/examples/text_matching/simnet/predict.py +++ b/PaddleNLP/examples/text_matching/simnet/predict.py @@ -25,7 +25,7 @@ parser = argparse.ArgumentParser(__doc__) parser.add_argument("--use_gpu", type=eval, default=False, help="Whether use GPU for training, input should be True or False") parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.") parser.add_argument("--vocab_path", type=str, default="./data/term2id.dict", help="The path to vocabulary.") -parser.add_argument('--network_name', type=str, default="lstm", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn, cnn and textcnn?") +parser.add_argument('--network_name', type=str, default="lstm", help="Which network you would like to choose bow, cnn, lstm or gru ?") parser.add_argument("--params_path", type=str, default='./chekpoints/final.pdparams', help="The path of model parameter to be loaded.") args = parser.parse_args() # yapf: enable diff --git a/PaddleNLP/examples/text_matching/simnet/train.py b/PaddleNLP/examples/text_matching/simnet/train.py index a17e1df0..df308ab8 100644 --- a/PaddleNLP/examples/text_matching/simnet/train.py +++ b/PaddleNLP/examples/text_matching/simnet/train.py @@ -26,13 +26,13 @@ from utils import load_vocab, generate_batch, convert_example # yapf: disable parser = argparse.ArgumentParser(__doc__) -parser.add_argument("--epochs", type=int, default=3, help="Number of epoches for training.") -parser.add_argument('--use_gpu', type=eval, default=True, help="Whether use GPU for training, input should be True or False") +parser.add_argument("--epochs", type=int, default=10, help="Number of epoches for training.") +parser.add_argument('--use_gpu', type=eval, default=False, help="Whether use GPU for training, input should be True or False") parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate used to train.") parser.add_argument("--save_dir", type=str, default='chekpoints/', help="Directory to save model checkpoint") parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.") -parser.add_argument("--vocab_path", type=str, default="./data/term2id.dict", help="The directory to dataset.") -parser.add_argument('--network', type=str, default="cnn", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn?") +parser.add_argument("--vocab_path", type=str, default="./simnet_word_dict.txt", help="The directory to dataset.") +parser.add_argument('--network', type=str, default="lstm", help="Which network you would like to choose bow, cnn, lstm or gru ?") parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.") args = parser.parse_args() # yapf: enable -- GitLab