README.md 17.8 KB
Newer Older
X
Xiaoyao Xi 已提交
1
# PaddlePALM
X
xixiaoyao 已提交
2

X
Xiaoyao Xi 已提交
3
PaddlePALM (PAddLe for Multi-task) 是一个强大快速、灵活易用的NLP多任务学习框架,用户仅需书写极少量代码即可完成复杂的多任务训练与推理。同时框架提供了定制化接口,若内置工具、主干网络和任务无法满足需求,开发者可以轻松完成相关组件的自定义。
X
xixiaoyao 已提交
4

X
Xiaoyao Xi 已提交
5
框架中内置了丰富的主干网络及其预训练模型(BERT、ERNIE等)、常见的任务范式(分类、匹配、机器阅读理解等)和相应的数据集读取与处理工具。相关列表见这里。
X
Xiaoyao Xi 已提交
6 7 8 9 10 11

## 安装

推荐使用pip安装paddlepalm框架:

```shell
X
xixiaoyao 已提交
12
pip install paddlepalm
X
Xiaoyao Xi 已提交
13 14 15 16 17 18
```

对于离线机器,可以使用基于源码的安装方式:

```shell
git clone https://github.com/PaddlePaddle/PALM.git
X
Xiaoyao Xi 已提交
19
cd PALM && python setup.py install
X
Xiaoyao Xi 已提交
20 21 22 23 24 25 26 27 28
```



**环境依赖**

- Python >= 2.7
- cuda >= 9.0
- cudnn >= 7.0
X
Xiaoyao Xi 已提交
29
- PaddlePaddle >= 1.6.1 (请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装)
X
Xiaoyao Xi 已提交
30 31 32 33 34 35 36 37 38 39



## 前期准备

#### 理论准备

框架默认采用一对多(One-to-Many)的参数共享方式,如图


X
Xiaoyao Xi 已提交
40
![image-20191022194400259](https://tva1.sinaimg.cn/large/006y8mN6ly1g88ajvpqmgj31hu07wn5s.jpg)
X
xixiaoyao 已提交
41 42


X
Xiaoyao Xi 已提交
43
例如我们有一个目标任务MRC和两个辅助任务MLM和MATCH,我们希望通过MLM和MATCH来提高目标任务MRC的测试集表现(即提升模型泛化能力),那么我们可以令三个任务共享相同的文本特征抽取模型(例如BERT、ERNIE等),然后分别为每个任务定义输出层,计算各自的loss值。
44

X
Xiaoyao Xi 已提交
45
框架默认采用任务采样+mini-batch采样的方式(alternating mini-batches optimization)进行模型训练,即对于每个训练step,首先对候选任务进行采样(采样权重取决于用户设置的`mix_ratio`),而后从该任务的训练集中采样出一个mini-batch(采样出的样本数取决于用户设置的`batch_size`)。
X
xixiaoyao 已提交
46

X
Xiaoyao Xi 已提交
47
#### 模型准备
X
xixiaoyao 已提交
48

X
Xiaoyao Xi 已提交
49
我们提供了BERT、ERNIE等主干模型及其相关预训练模型。为了加速模型收敛,获得更佳的测试集表现,我们强烈建议用户在预训练模型的基础上进行多任务学习(而不是从参数随机初始化开始)。用户可通过运行脚本`script/download_pretrain_models`下载需要的预训练模型,例如,下载预训练BERT模型的命令如下
X
xixiaoyao 已提交
50

X
Xiaoyao Xi 已提交
51
```shell
X
Xiaoyao Xi 已提交
52
bash script/download_pretrain_backbone.sh bert
X
Xiaoyao Xi 已提交
53
```
X
xixiaoyao 已提交
54

X
Xiaoyao Xi 已提交
55
脚本会自动在当前文件夹中创建一个pretrain_model目录,并在其中创建bert子目录,里面存放预训练模型(`params`文件夹内)、相关的网络参数(`bert_config.json`)和字典(`vocab.txt`)。
X
xixiaoyao 已提交
56

X
Xiaoyao Xi 已提交
57 58 59 60
然后通过运行`script/convert_params.sh`将预训练模型转换成框架可用的预训练backbone

```shell
bash script/convert_params.sh pretrain_model/bert/params
X
Xiaoyao Xi 已提交
61
```
X
xixiaoyao 已提交
62

X
Xiaoyao Xi 已提交
63
*注:未来框架将支持更多的预置主干网络,如XLNet、多层LSTM等。此外,用户可以自定义添加新的主干网络,详情见[这里]()*
X
xixiaoyao 已提交
64

X
Xiaoyao Xi 已提交
65
## DEMO1:单任务训练
X
Xiaoyao Xi 已提交
66 67 68 69 70 71 72 73 74 75 76 77

接下来我们启动一个复杂的机器阅读理解任务的训练,我们在`data/mrqa`文件夹中提供了EMNLP2019 MRQA机器阅读理解评测的部分比赛数据。

用户可通过运行如下脚本一键开始本节任务的训练

```shell
bash run_demo1.sh
```


下面以该任务为例,讲解如何基于paddlepalm框架轻松实现该任务。

X
Xiaoyao Xi 已提交
78 79 80
**1. 配置任务实例**

首先,我们编写该任务实例的配置文件`mrqa.yaml`,若该任务实例参与训练或预测,则框架将自动解析该配置文件并创建相应的任务实例。配置文件需符合yaml格式的要求。一个任务实例的配置文件最少应包含`train_file``reader``paradigm`这三个字段,分别代表训练集的文件路径`train_file`、使用的数据集载入与处理工具`reader`、任务范式`paradigm`
X
Xiaoyao Xi 已提交
81 82

```yaml
X
Xiaoyao Xi 已提交
83 84
train_file: data/mrqa/train.json
reader: mrc
X
Xiaoyao Xi 已提交
85 86 87
paradigm: mrc
```

X
Xiaoyao Xi 已提交
88
*注:框架内置的其他数据集载入与处理工具和任务范式列表见[这里]()*
X
Xiaoyao Xi 已提交
89 90 91 92 93 94 95 96

此外,我们还需要配置reader的预处理规则,各个预置reader支持的预处理配置和规则请参考【这里】。预处理规则同样直接写入`mrqa.yaml`中。

```yaml
max_seq_len: 512
max_query_len: 64
doc_stride: 128 # 在MRQA数据集中,存在较长的文档,因此我们这里使用滑动窗口处理样本,滑动步长设置为128
do_lower_case: True
X
Xiaoyao Xi 已提交
97
vocab_path: "pretrain_model/bert/vocab.txt"
X
Xiaoyao Xi 已提交
98 99
```

X
Xiaoyao Xi 已提交
100 101 102 103
更详细的任务实例配置方法可参考这里

**2.配置全局参数**

X
Xiaoyao Xi 已提交
104 105 106 107 108 109 110
然后我们配置全局的学习规则,同样使用yaml格式描述,我们新建`mtl_conf.yaml`。在这里我们配置一下需要学习的任务、模型的保存路径`save_path`和规则、使用的模型骨架`backbone`、学习器`optimizer`等。

```yaml
task_instance: "mrqa"

save_path: "output_model/firstrun"

X
Xiaoyao Xi 已提交
111 112
backbone: "bert"
backbone_config_path: "pretrain_model/bert/bert_config.json"
X
Xiaoyao Xi 已提交
113 114 115

optimizer: "adam"
learning_rate: 3e-5
X
Xiaoyao Xi 已提交
116
batch_size: 4
X
Xiaoyao Xi 已提交
117 118 119

num_epochs: 2                                                                                    
warmup_proportion: 0.1 
X
xixiaoyao 已提交
120
```
X
xixiaoyao 已提交
121

X
Xiaoyao Xi 已提交
122 123 124
*注:框架支持的其他backbone参数如日志打印控制等见[这里]()*

**3.开始训练**
X
xixiaoyao 已提交
125

X
Xiaoyao Xi 已提交
126
下面我们开始尝试启动MRQA任务的训练(该代码位于`demo1.py`中)。框架的核心组件是`Controller`
X
xixiaoyao 已提交
127 128

```python
X
Xiaoyao Xi 已提交
129
# Demo 1: single task training of MRQA 
X
xixiaoyao 已提交
130
import paddlepalm as palm
X
xixiaoyao 已提交
131

X
xixiaoyao 已提交
132
if __name__ == '__main__':
X
Xiaoyao Xi 已提交
133 134
    controller = palm.Controller('config_demo1.yaml', task_dir='demo1_tasks')
    controller.load_pretrain('pretrain_model/bert/params')
X
xixiaoyao 已提交
135
    controller.train()
X
xixiaoyao 已提交
136 137
```

X
Xiaoyao Xi 已提交
138
默认情况下每5个训练step打印一次训练日志,如下(该日志在8卡P40机器上运行得到),可以看到loss值随着训练收敛。在训练结束后,`Controller`自动为mrqa任务保存预测模型。
X
Xiaoyao Xi 已提交
139

X
Xiaoyao Xi 已提交
140 141 142 143 144 145 146 147 148
```
Global step: 5. Task: mrqa, step 5/13 (epoch 0), loss: 4.976, speed: 0.11 steps/s
Global step: 10. Task: mrqa, step 10/13 (epoch 0), loss: 2.938, speed: 0.48 steps/s
Global step: 15. Task: mrqa, step 2/13 (epoch 1), loss: 2.422, speed: 0.47 steps/s
Global step: 20. Task: mrqa, step 7/13 (epoch 1), loss: 2.809, speed: 0.53 steps/s
Global step: 25. Task: mrqa, step 12/13 (epoch 1), loss: 1.744, speed: 0.50 steps/s
mrqa: train finished!
mrqa: inference model saved at output_model/firstrun/infer_model
```
X
xixiaoyao 已提交
149

X
Xiaoyao Xi 已提交
150
## DEMO2:多任务训练与推理
X
Xiaoyao Xi 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168

本节我们考虑更加复杂的学习目标,我们引入一个问答匹配(QA Matching)任务来辅助MRQA任务的学习。在多任务训练结束后,我们希望使用训练好的模型来对MRQA任务的测试集进行预测。

用户可通过运行如下脚本直接开始本节任务的训练

```shell
bash run_demo2.sh
```



下面以该任务为例,讲解如何基于paddlepalm框架轻松实现这个复杂的多任务学习。

**1. 配置任务实例**

首先,我们像上一节一样为Matching任务分别配置任务实例`match4mrqa.yaml`

```yaml
X
Xiaoyao Xi 已提交
169 170
train_file: "data/match/train.tsv"
reader: match
X
Xiaoyao Xi 已提交
171 172 173
paradigm: match
```

X
Xiaoyao Xi 已提交
174 175
*注:更详细的任务实例配置方法可参考[这里]()*

X
Xiaoyao Xi 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
**2.配置全局参数**

由于MRQA和Matching任务有相同的字典、大小写配置、截断长度等,因此我们可以将这些各个任务中相同的参数写入到全局配置文件`mtl_config.yaml`中,**框架会自动将该文件中的配置广播(broadcast)到各个任务实例。**

```yaml
task_instance: "mrqa, match4mrqa"
target_tag: 1,0 

save_path: "output_model/secondrun"

backbone: "ernie"
backbone_config_path: "pretrain_model/ernie/ernie_config.json"

vocab_path: "pretrain_model/ernie/vocab.txt"
do_lower_case: True
max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例

X
Xiaoyao Xi 已提交
193
batch_size: 4
X
Xiaoyao Xi 已提交
194 195 196 197 198 199 200
num_epochs: 2
optimizer: "adam"
learning_rate: 3e-5
warmup_proportion: 0.1 
weight_decay: 0.1 
```

X
Xiaoyao Xi 已提交
201 202
这里我们可以使用`target_tag`来标记目标任务和辅助任务,各个任务的tag使用逗号`,`隔开。target_tag与task_instance中的元素一一对应,当某任务的tag设置为1时,表示对应的任务被设置为目标任务;设置为0时,表示对应的任务被设置为辅助任务,默认情况下所以任务均被设置为目标任务(即默认`target_tag`为全1)。

X
Xiaoyao Xi 已提交
203
辅助任务不会保存预测模型,且不会影响训练的终止。当所有的目标任务达到预期的训练步数后多任务学习终止,框架自动为每个目标任务保存预测模型(inference model)到设置的`save_path`位置。
X
Xiaoyao Xi 已提交
204 205 206 207 208

在训练过程中,默认每个训练step会从各个任务等概率采样,来决定当前step训练哪个任务。若用户希望改变采样比率,可以通过`mix_ratio`字段来进行设置,例如

```yaml
mix_ratio: 1.0, 0.5
X
xixiaoyao 已提交
209 210
```

X
Xiaoyao Xi 已提交
211
若将如上设置加入到全局配置文件中,则辅助任务`match4mrqa`的采样概率仅为`mrqa`任务的一半。
X
xixiaoyao 已提交
212

X
Xiaoyao Xi 已提交
213
这里`num_epochs`指代目标任务`mrqa`的训练epoch数量(训练集遍历次数),**当目标任务有多个时,该参数将作用于第一个出现的目标任务(称为主任务,main task)**
X
xixiaoyao 已提交
214

X
Xiaoyao Xi 已提交
215
**3.开始多任务训练**
X
xixiaoyao 已提交
216

X
Xiaoyao Xi 已提交
217 218 219 220
```python
import paddlepalm as palm

if __name__ == '__main__':
X
Xiaoyao Xi 已提交
221
    controller = palm.Controller('config_demo2.yaml', task_dir='demo2_tasks')
X
Xiaoyao Xi 已提交
222 223
    controller.load_pretrain('pretrain_model/ernie/params')
    controller.train()
X
Xiaoyao Xi 已提交
224

X
Xiaoyao Xi 已提交
225
```
X
Xiaoyao Xi 已提交
226

X
Xiaoyao Xi 已提交
227
**4.预测**
X
Xiaoyao Xi 已提交
228

X
Xiaoyao Xi 已提交
229 230 231 232
在得到目标任务的预测模型(inference_model)后,我们可以加载预测模型对该任务的测试集进行预测。在多任务训练阶段,在全局配置文件的`save_path`指定的路径下会为每个目标任务创建同名子目录,子目录中都有预测模型文件夹`infermodel`。我们可以将该路径传给框架的`controller`来完成对该目标任务的预测。

例如,我们在上一节得到了mrqa任务的预测模型。首先创建一个新的*Controller***并且创建时要将`for_train`标志位置为*False***。而后调用*pred*接口,将要预测的任务实例名字和预测模型的路径传入,即可完成相关预测。预测的结果默认保存在任务实例配置文件的`pred_output_path`指定的路径中。代码段如下:

X
Xiaoyao Xi 已提交
233 234
```python
    controller = palm.Controller(config='demo2_config.yaml', task_dir='demo2_tasks', for_train=False)
X
Xiaoyao Xi 已提交
235
    controller.pred('mrqa', inference_model_dir='output_model/secondrun/mrqa/infermodel') 
X
Xiaoyao Xi 已提交
236
```
X
xixiaoyao 已提交
237

X
Xiaoyao Xi 已提交
238

X
Xiaoyao Xi 已提交
239 240
## 进阶篇

X
Xiaoyao Xi 已提交
241 242 243 244 245 246 247 248 249 250 251
### 设置多个目标任务

框架内支持设定多个目标任务,当全局配置文件的`task_instance`字段指定超过一个任务实例时,这多个任务实例默认均为目标任务(即`target_tag`字段被自动填充为全1)。对于被设置成目标任务的任务实例,框架会为其计算预期的训练步数(详情见下一节)并在达到预期训练步数后为其保存预测模型。

当框架存在多个目标任务时,全局配置文件中的`num_epochs`(训练集遍历次数)仅会作用于第一个出现的目标任务,称为主任务(main task)。框架会根据主任务的训练步数来推理其他目标任务的预期训练步数(可通过`mix_ratio`控制)。**注意,除了用来标记`num_epochs`的作用对象外,主任务与其他目标任务没有任何不同。**例如

```yaml
task_instance: domain_cls, mrqa, senti_cls, mlm, qq_match
target_tag: 0, 1, 1, 0, 1
```
在上述的设置中,mrqa,senti_cls和qq_match这三个任务被标记成了目标任务(其中mrqa为主任务),domain_cls和mlm被标记为了辅助任务。辅助任务仅仅“陪同”目标任务训练,框架不会为其保存预测模型(inference_model),也不会计算预期训练步数。但包括辅助任务在内,各个任务的采样概率是可以被控制的,详情见下一小节。
X
Xiaoyao Xi 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269

### 更改任务采样概率(期望的训练步数)

在默认情况下,每个训练step的各个任务被采样到的概率均等,若用户希望更改其中某些任务的采样概率(比如某些任务的训练集较小,希望减少对其采样的次数;或某些任务较难,希望被更多的训练),可以在全局配置文件中通过`mix_ratio`字段控制各个任务的采样概率。例如

```yaml
task_instance: mrqa, match4mrqa, mlm4mrqa
mix_ratio: 1.0, 0.5, 0.5
```

上述设置表示`match4mrqa``mlm4mrqa`任务的期望被采样次数均为`mrqa`任务的一半。此时,在mrqa任务被设置为主任务的情况下(第一个目标任务即为主任务),若mrqa任务训练一个epoch要经历5000 steps,且全局配置文件中设置了num_epochs为2,则根据上述`mix_ratio`的设置,mrqa任务将被训练5000\*2\*1.0=10000个steps,而`match4mrqa`任务和`mlm4mrqa`任务都会被训练5000个steps**左右**

> 注意:若match4mrqa, mlm4mrqa被设置为辅助任务,则实际训练步数可能略多或略少于5000个steps。对于目标任务,则是精确的5000 steps。

### 共享任务层参数

### 分布式训练

X
Xiaoyao Xi 已提交
270
## 内置数据集载入与处理工具(reader)
X
Xiaoyao Xi 已提交
271

X
Xiaoyao Xi 已提交
272 273 274 275 276 277 278 279
所有的内置reader均同时支持中英文输入数据,**默认读取的数据为英文数据**,希望读入中文数据时,需在配置文件中设置

```yaml
for_cn: True
```

所有的内置reader,均支持以下字段

X
Xiaoyao Xi 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292
```yaml
- vocab_path(REQUIRED): str类型。字典文件路径。
- max_seq_len(REQUIRED): int类型。切词后的序列最大长度(即token ids的最大长度)。注意经过分词后,token ids的数量往往多于原始的单词数(e.g., 使用wordpiece tokenizer时)。
- batch_size(REQUIRED): int类型。训练或预测时的批大小(每个step喂入神经网络的样本数)。
- train_file(REQUIRED): str类型。训练集文件所在路径。仅进行预测时,该字段可不设置。
- pred_file(REQUIRED): str类型。测试集文件所在路径。仅进行训练时,该字段可不设置。

- do_lower_case(OPTIONAL): bool类型,默认为False。是否将大写英文字母转换成小写。
- shuffle(OPTIONAL): bool类型,默认为True。训练阶段打乱数据集样本的标志位,当置为True时,对数据集的样本进行全局打乱。注意,该标志位的设置不会影响预测阶段(预测阶段不会shuffle数据集)。
- seed(OPTIONAL): int类型,默认为。
- pred_batch_size(OPTIONAL): int类型。预测阶段的批大小,当该参数未设置时,预测阶段的批大小取决于`batch_size`字段的值。
- print_first_n(OPTIONAL): int类型。打印数据集的前n条样本和对应的reader输出,默认为0。
```
X
Xiaoyao Xi 已提交
293

X
Xiaoyao Xi 已提交
294
### 文本分类数据集reader工具:cls
X
Xiaoyao Xi 已提交
295

X
Xiaoyao Xi 已提交
296 297 298 299 300 301 302 303 304 305 306 307
该reader完成文本分类数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含两列,一列为样本标签`label`,一列为原始文本`text_a`。形如

```
label   text_a                                                                                   
1   when was the last time the san antonio spurs missed the playoffshave only missed the playoffs four times since entering the NBA
0   the creation of the federal reserve system was an attempt toReserve System ( also known as the Federal Reserve or simply the Fed ) is the central banking system of the United States of America . 
2   group f / 64 was a major backlash against the earlier photographic movement off / 64 was formed , Edward Weston went to a meeting of the John Reed Club , which was founded to support Marxist artists and writers . 
0   Bessarabia eventually became under the control of which country?
```
***注意:数据集的第一列必须为header,即标注每一列的列名***

该reader额外包含以下配置字段
X
Xiaoyao Xi 已提交
308

X
Xiaoyao Xi 已提交
309 310 311
```yaml
- n_classes(REQUIRED): int类型。分类任务的类别数。
```
X
Xiaoyao Xi 已提交
312

X
Xiaoyao Xi 已提交
313
### 文本匹配数据集reader工具:match
X
Xiaoyao Xi 已提交
314 315 316

该reader完成文本匹配数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含三列,一列为样本标签`label`,其余两列分别为待匹配的文本`text_a`和文本`text_b`,形如

X
Xiaoyao Xi 已提交
317 318 319 320 321 322 323
```yaml
label   text_a  text_b                                                                           
1   From what work of Durkheim's was interaction ritual theory derived? **[TAB]** Subsequent to these developments, Randall Collins (2004) formulated his interaction ritual theory by drawing on Durkheim's work on totemic rituals that was extended by Goffman (1964/2013; 1967) into everyday focused encounters. Based on interaction ritual theory, we experience different levels
0   where is port au prince located in haiti **[TAB]** Its population is difficult to ascertain due to the rapid growth of slums in the hillsides
0   What is the world’s first-ever pilsner type blond lager, the company also awarded the Master Homebrewer Competition held in San Francisco to an award-winning brewer who won the prestigious American Homebrewers Associations' Homebrewer of the Year award in 2013? **[TAB]** of the Year award in 2013, becoming the first woman in thirty years, and the first African American person ever to ever win the award.
1   What has Pakistan told phone companies? **[TAB]** Islamabad, Pakistan (CNN) -- Under heavy criticism for a telling cell phone carriers to ban certain words in text messages, the Pakistan Telecommunication Authority went into damage control mode Wednesday.
```
X
Xiaoyao Xi 已提交
324 325

***注意:数据集的第一列必须为header,即标注每一列的列名***
X
Xiaoyao Xi 已提交
326

X
Xiaoyao Xi 已提交
327 328 329
### 机器阅读理解数据集reader工具:mrc

### 掩码语言模型数据集reader工具:mlm
X
Xiaoyao Xi 已提交
330 331 332

## 内置主干网络(backbone)

X
Xiaoyao Xi 已提交
333 334 335 336
### BERT

### ERNIE

X
Xiaoyao Xi 已提交
337
## 内置任务范式(paradigm)
X
Xiaoyao Xi 已提交
338

X
Xiaoyao Xi 已提交
339 340 341 342 343 344 345 346
### 分类任务

### 匹配任务

### 机器阅读理解任务

### 掩码语言模型任务

X
Xiaoyao Xi 已提交
347 348 349 350 351 352 353
## License

This tutorial is contributed by [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) and licensed under the [Apache-2.0 license](https://github.com/PaddlePaddle/models/blob/develop/LICENSE).

## 许可证书

此向导由[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)贡献,受[Apache-2.0 license](https://github.com/PaddlePaddle/models/blob/develop/LICENSE)许可认证。
X
xixiaoyao 已提交
354 355