README.md 11.3 KB
Newer Older
S
Steffy-zxf 已提交
1
# 使用预训练模型Fine-tune完成中文文本匹配任务
2 3 4 5 6 7 8 9 10

随着深度学习的发展,模型参数的数量飞速增长。为了训练这些参数,需要更大的数据集来避免过拟合。然而,对于大部分NLP任务来说,构建大规模的标注数据集非常困难(成本过高),特别是对于句法和语义相关的任务。相比之下,大规模的未标注语料库的构建则相对容易。为了利用这些数据,我们可以先从其中学习到一个好的表示,再将这些表示应用到其他任务中。最近的研究表明,基于大规模未标注语料库的预训练模型(Pretrained Models, PTM) 在NLP任务上取得了很好的表现。

近年来,大量的研究表明基于大型语料库的预训练模型(Pretrained Models, PTM)可以学习通用的语言表示,有利于下游NLP任务,同时能够避免从零开始训练模型。随着计算能力的发展,深度模型的出现(即 Transformer)和训练技巧的增强使得 PTM 不断发展,由浅变深。

百度的预训练模型ERNIE经过海量的数据训练后,其特征抽取的工作已经做的非常好。借鉴迁移学习的思想,我们可以利用其在海量数据中学习的语义信息辅助小数据集(如本示例中的医疗文本数据集)上的任务。

<center> <img width="600px" src="https://ai-studio-static-online.cdn.bcebos.com/d96c602338044ee8bcd4171f38ea6d49506d1f3253f3496b802ec56cb654ecf5" /> </center>

S
Steffy-zxf 已提交
11
使用预训练模型ERNIE完成文本匹配任务,大家可能会想到将query和title文本拼接,之后输入ERNIE中,取`CLS`特征(pooled_output),之后输出全连接层,进行二分类。如下图ERNIE用于句对分类任务的用法:
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41

<p align="center">
<img src="https://ai-studio-static-online.cdn.bcebos.com/45440029c07240ad89d665c5b176e63297e9584e1da24e02b79dd54fb990f74a" width='30%'/> <br />
</p>

然而,以上用法的问题在于,**ERNIE的模型参数非常庞大,导致计算量非常大,预测的速度也不够理想**。从而达不到线上业务的要求。针对该问题,可以使用PaddleNLP工具搭建Sentence Transformer网络。

<p align="center">
<img src="https://ai-studio-static-online.cdn.bcebos.com/103998703e134a7184883511a538620e16fed045e2614dcc8afacec446600438" width='30%'/> <br />
</p>

Sentence Transformer采用了双塔(Siamese)的网络结构。Query和Title分别输入ERNIE,共享一个ERNIE参数,得到各自的token embedding特征。之后对token embedding进行pooling(此处教程使用mean pooling操作),之后输出分别记作u,v。之后将三个表征(u,v,|u-v|)拼接起来,进行二分类。网络结构如上图所示。

更多关于Sentence Transformer的信息可以参考论文:https://arxiv.org/abs/1908.10084

**同时,不仅可以使用ERNIR作为文本语义特征提取器,可以利用BERT/RoBerta/Electra等模型作为文本语义特征提取器**

**那么Sentence Transformer采用Siamese的网路结构,是如何提升预测速度呢?**

**Siamese的网络结构好处在于query和title分别输入同一套网络。如在信息搜索任务中,此时就可以将数据库中的title文本提前计算好对应sequence_output特征,保存在数据库中。当用户搜索query时,只需计算query的sequence_output特征与保存在数据库中的title sequence_output特征,通过一个简单的mean_pooling和全连接层进行二分类即可。从而大幅提升预测效率,同时也保障了模型性能。**

关于匹配任务常用的Siamese网络结构可以参考:https://blog.csdn.net/thriving_fcl/article/details/73730552

PaddleNLP提供了丰富的预训练模型,并且可以便捷地获取PaddlePaddle生态下的所有预训练模型。下面展示如何使用PaddleNLP一键加载ERNIE,优化文本匹配任务。

## 模型简介

本项目针对中文文本匹配问题,开源了一系列模型,供用户可配置地使用:

+ BERT([Bidirectional Encoder Representations from Transformers](https://arxiv.org/abs/1810.04805))中文模型,简写`bert-base-chinese`, 其由12层Transformer网络组成。
S
Steffy-zxf 已提交
42
+ ERNIE([Enhanced Representation through Knowledge Integration](https://arxiv.org/abs/1904.09223)),支持ERNIE 1.0中文模型(简写`ernie-1.0`)和ERNIE Tiny中文模型(简写`ernie-tiny`)。
43
   其中`ernie`由12层Transformer网络组成,`ernie-tiny`由3层Transformer网络组成。
44
+ RoBERTa([A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692)),支持12层Transformer网络的`roberta-wwm-ext`
45 46 47

| 模型  | dev acc | test acc |
| ---- | ------- | -------- |
48 49
| bert-base-chinese  | 0.86537 | 0.84440 |
| bert-wwm-chinese | 0.86333 | 0.84128 |
S
Steffy-zxf 已提交
50
| bert-wwm-ext-chinese | 0.86049 | 0.83848 |
S
Steffy-zxf 已提交
51
| ernie-1.0  | 0.87480  | 0.84760  |
52
| ernie-tiny  | 0.86071 | 0.83352 |
S
Steffy-zxf 已提交
53 54 55
| roberta-wwm-ext  | 0.87526 | 0.84904 |
| rbt3 | 0.85367 | 0.83464 |
| rbtl3 | 0.85174 | 0.83744 |
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

## 快速开始

### 安装说明

* PaddlePaddle 安装

   本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装

* PaddleNLP 安装

   ```shell
   pip install paddlenlp
   ```

* 环境依赖

   Python的版本要求 3.6+,其它环境请参考 PaddlePaddle [安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html) 部分的内容

### 代码结构说明

以下是本项目主要代码结构及说明:

```text
sentence_transformers/
├── model.py # Sentence Transfomer 组网文件
├── README.md # 文本说明
└── train.py # 模型训练评估
```

### 模型训练

我们以中文文本匹配公开数据集LCQMC为示例数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证
```shell
# 设置使用的GPU卡号
CUDA_VISIBLE_DEVICES=0
python train.py --model_type ernie --model_name ernie-1.0 --n_gpu 1 --save_dir ./checkpoints
```

可支持配置的参数:

* `model_type`:必选,模型类型,可以选择bert,ernie,roberta。
98 99 100
* `model_name`: 必选,具体的模型简称。
`model_type=ernie`,则model_name可以选择`ernie-1.0``ernie-tiny`
`model_type=bert`,则model_name可以选择`bert-base-chinese``bert-wwm-chinese``bert-wwm-ext-chinese`
S
Steffy-zxf 已提交
101
`model_type=roberta`,则model_name可以选择`roberta-wwm-ext``rbt3``rbtl3`
102 103 104 105 106
* `save_dir`:必选,保存训练模型的目录。
* `max_seq_length`:可选,ERNIE/BERT模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。
* `batch_size`:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
* `learning_rate`:可选,Fine-tune的最大学习率;默认为5e-5。
* `weight_decay`:可选,控制正则项力度的参数,用于防止过拟合,默认为0.00。
107
* `epochs`: 训练轮次,默认为3。
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
* `warmup_proption`:可选,学习率warmup策略的比例,如果0.1,则学习率会在前10%训练step的过程中从0慢慢增长到learning_rate, 而后再缓慢衰减,默认为0.1。
* `init_from_ckpt`:可选,模型参数路径,热启动模型训练;默认为None。
* `seed`:可选,随机种子,默认为1000.
* `n_gpu`:可选,训练过程中使用GPU卡数量,默认为1。若n_gpu=0,则使用CPU训练。


程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存模型在指定的`save_dir`中。
如:
```text
checkpoints/
├── model_100
│   ├── model_config.json
│   ├── model_state.pdparams
│   ├── tokenizer_config.json
│   └── vocab.txt
└── ...
```

**NOTE:**
* 如需恢复模型训练,则可以设置`init_from_ckpt`, 如`init_from_ckpt=checkpoints/model_100/model_state.pdparams`
128
* 如需使用ernie-tiny模型,则需要提前先安装sentencepiece依赖,如`pip install sentencepiece`
129 130 131 132 133 134

### 模型预测

启动预测:
```shell
CUDA_VISIBLE_DEVICES=0
135
python predict.py --model_type ernie --model_name ernie-tiny --params_path checkpoints/model_400/model_state.pdparams
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
```

将待预测数据如以下示例:

```text
世界上什么东西最小   世界上什么东西最小?
光眼睛大就好看吗  眼睛好看吗?
小蝌蚪找妈妈怎么样   小蝌蚪找妈妈是谁画的
```

可以直接调用`predict`函数即可输出预测结果。



```text
Data: ['世界上什么东西最小', '世界上什么东西最小?']      Label: similar
Data: ['光眼睛大就好看吗', '眼睛好看吗?']      Label: dissimilar
Data: ['小蝌蚪找妈妈怎么样', '小蝌蚪找妈妈是谁画的']      Label: dissimilar
```


## 引用

关于Sentence Transformer更多信息参考[www.SBERT.net](https://www.sbert.net)以及论文:
- [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084) (EMNLP 2019)
- [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813) (EMNLP 2020)
- [Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks](https://arxiv.org/abs/2010.08240) (arXiv 2020)

```
@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/1908.10084",
}
```

```
@inproceedings{reimers-2020-multilingual-sentence-bert,
    title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2020",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/2004.09813",
}
```

```
@article{thakur-2020-AugSBERT,
    title = "Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks",
    author = "Thakur, Nandan and Reimers, Nils and Daxenberger, Johannes and  Gurevych, Iryna",
    journal= "arXiv preprint arXiv:2010.08240",
    month = "10",
    year = "2020",
    url = "https://arxiv.org/abs/2010.08240",
}
```
S
Steffy-zxf 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216


## 线上体验教程

- [使用seq2vec模块进行句子情感分类](https://aistudio.baidu.com/aistudio/projectdetail/1283423)

- [如何将预训练模型Fine-tune下游任务](https://aistudio.baidu.com/aistudio/projectdetail/1294333)

- [使用Bi-GRU+CRF完成快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/1317771)

- [使用预训练模型ERNIE优化快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/1329361)

- [使用Seq2Seq模型完成自动对联模型](https://aistudio.baidu.com/aistudio/projectdetail/1321118)

- [使用预训练模型ERNIE-GEN实现智能写诗](https://aistudio.baidu.com/aistudio/projectdetail/1339888)

- [使用TCN网络完成新冠疫情病例数预测](https://aistudio.baidu.com/aistudio/projectdetail/1290873)

更多教程参见[PaddleNLP on AI Studio](https://aistudio.baidu.com/aistudio/personalcenter/thirdview/574995)