未验证 提交 880321cc 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update README.md

上级 bc504176
...@@ -138,12 +138,14 @@ bash script/recover_params.sh pretrain_model/bert/params ...@@ -138,12 +138,14 @@ bash script/recover_params.sh pretrain_model/bert/params
### DEMO1:单任务训练 ### DEMO1:单任务训练
> 本demo路径位于`demo/demo1`
框架支持对任何一个内置任务进行传统的单任务训练。接下来我们启动一个复杂的机器阅读理解任务的训练,我们在`data/mrqa`文件夹中提供了[EMNLP2019 MRQA机器阅读理解评测](https://mrqa.github.io/shared)的部分比赛数据。下面我们利用该数据尝试完成一个基于BERT的机器阅读理解任务MRQA的单任务学习。 框架支持对任何一个内置任务进行传统的单任务训练。接下来我们启动一个复杂的机器阅读理解任务的训练,我们在`data/mrqa`文件夹中提供了[EMNLP2019 MRQA机器阅读理解评测](https://mrqa.github.io/shared)的部分比赛数据。下面我们利用该数据尝试完成一个基于BERT的机器阅读理解任务MRQA的单任务学习。
用户可通过运行如下脚本一键开始本节任务的训练 用户进入本demo目录后,可通过运行如下脚本一键开始本节任务的训练
```shell ```shell
bash run_demo1.sh bash run.sh
``` ```
下面以该任务为例,讲解如何基于paddlepalm框架轻松实现该任务。 下面以该任务为例,讲解如何基于paddlepalm框架轻松实现该任务。
...@@ -167,7 +169,7 @@ max_seq_len: 512 ...@@ -167,7 +169,7 @@ max_seq_len: 512
max_query_len: 64 max_query_len: 64
doc_stride: 128 # 在MRQA数据集中,存在较长的文档,因此我们这里使用滑动窗口处理样本,滑动步长设置为128 doc_stride: 128 # 在MRQA数据集中,存在较长的文档,因此我们这里使用滑动窗口处理样本,滑动步长设置为128
do_lower_case: True do_lower_case: True
vocab_path: "pretrain_model/bert/vocab.txt" vocab_path: "../../pretrain_model/bert/vocab.txt"
``` ```
更详细的任务实例配置方法(为任务实例选择合适的reader、paradigm和backbone)可参考[这里](#readerbackbone与paradigm的选择) 更详细的任务实例配置方法(为任务实例选择合适的reader、paradigm和backbone)可参考[这里](#readerbackbone与paradigm的选择)
...@@ -182,7 +184,7 @@ task_instance: "mrqa" ...@@ -182,7 +184,7 @@ task_instance: "mrqa"
save_path: "output_model/firstrun" save_path: "output_model/firstrun"
backbone: "bert" backbone: "bert"
backbone_config_path: "pretrain_model/bert/bert_config.json" backbone_config_path: "../../pretrain_model/bert/bert_config.json"
optimizer: "adam" optimizer: "adam"
learning_rate: 3e-5 learning_rate: 3e-5
...@@ -207,8 +209,8 @@ warmup_proportion: 0.1 ...@@ -207,8 +209,8 @@ warmup_proportion: 0.1
import paddlepalm as palm import paddlepalm as palm
if __name__ == '__main__': if __name__ == '__main__':
controller = palm.Controller('config_demo1.yaml', task_dir='demo1_tasks') controller = palm.Controller('config.yaml')
controller.load_pretrain('pretrain_model/bert/params') controller.load_pretrain('../../pretrain_model/bert/params')
controller.train() controller.train()
``` ```
...@@ -228,19 +230,21 @@ mrqa: inference model saved at output_model/firstrun/mrqa/infer_model ...@@ -228,19 +230,21 @@ mrqa: inference model saved at output_model/firstrun/mrqa/infer_model
### DEMO2:多任务辅助训练与目标任务预测 ### DEMO2:多任务辅助训练与目标任务预测
> 本demo路径位于`demo/demo2`
本节我们考虑更加复杂的学习目标,我们引入一个掩码语言模型(Mask Language Model,MLM)问答匹配(QA Match)任务来辅助上一节MRQA任务的训练,相关训练数据分别位于`data/mlm4mrqa``data/match4mrqa`。并且我们这里换用ERNIE模型作为主干网络,来获得更佳的效果。在多任务训练结束后,我们使用训练好的模型来对MRQA任务的测试集进行预测。 本节我们考虑更加复杂的学习目标,我们引入一个掩码语言模型(Mask Language Model,MLM)问答匹配(QA Match)任务来辅助上一节MRQA任务的训练,相关训练数据分别位于`data/mlm4mrqa``data/match4mrqa`。并且我们这里换用ERNIE模型作为主干网络,来获得更佳的效果。在多任务训练结束后,我们使用训练好的模型来对MRQA任务的测试集进行预测。
用户可通过运行如下脚本直接开始本节任务的训练 用户可通过运行如下脚本直接开始本节任务的训练
```shell ```shell
bash run_demo2.sh bash run.sh
``` ```
下面以该任务为例,讲解如何基于paddlepalm框架轻松实现这个复杂的多任务学习。 下面以该任务为例,讲解如何基于paddlepalm框架轻松实现这个复杂的多任务学习。
**1. 配置任务实例** **1. 配置任务实例**
首先,我们像上一节一样为MLM任务和Matching任务分别创建任务实例的配置文件`mlm4mrqa.yaml``match4mrqa.yaml` 首先,我们像上一节一样为MLM任务和Matching任务分别创建任务实例的配置文件`mlm4mrqa.yaml``match4mrqa.yaml`,并将两个文件放入`task`文件夹中
```yaml ```yaml
----- mlm4mrqa.yaml ----- ----- mlm4mrqa.yaml -----
...@@ -273,9 +277,9 @@ target_tag: 1,0,0 ...@@ -273,9 +277,9 @@ target_tag: 1,0,0
save_path: "output_model/secondrun" save_path: "output_model/secondrun"
backbone: "ernie" backbone: "ernie"
backbone_config_path: "pretrain_model/ernie/ernie_config.json" backbone_config_path: "../../pretrain_model/ernie/ernie_config.json"
vocab_path: "pretrain_model/ernie/vocab.txt" vocab_path: "../../pretrain_model/ernie/vocab.txt"
do_lower_case: True do_lower_case: True
max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例 max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例
...@@ -309,8 +313,8 @@ mix_ratio: 1.0, 0.5, 0.5 ...@@ -309,8 +313,8 @@ mix_ratio: 1.0, 0.5, 0.5
import paddlepalm as palm import paddlepalm as palm
if __name__ == '__main__': if __name__ == '__main__':
controller = palm.Controller('config_demo2.yaml', task_dir='demo2_tasks') controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('pretrain_model/ernie/params') controller.load_pretrain('../../pretrain_model/ernie/params')
controller.train() controller.train()
``` ```
...@@ -340,7 +344,7 @@ mrqa: inference model saved at output_model/secondrun/mrqa/infer_model ...@@ -340,7 +344,7 @@ mrqa: inference model saved at output_model/secondrun/mrqa/infer_model
例如,我们在上一节得到了mrqa任务的预测模型。首先创建一个新的*Controller***并且创建时要将`for_train`标志位置为*False***。而后调用*pred*接口,将要预测的任务实例名字和预测模型的路径传入,即可完成相关预测。预测的结果默认保存在任务实例配置文件的`pred_output_path`指定的路径中。代码段如下: 例如,我们在上一节得到了mrqa任务的预测模型。首先创建一个新的*Controller***并且创建时要将`for_train`标志位置为*False***。而后调用*pred*接口,将要预测的任务实例名字和预测模型的路径传入,即可完成相关预测。预测的结果默认保存在任务实例配置文件的`pred_output_path`指定的路径中。代码段如下:
```python ```python
controller = palm.Controller(config='config_demo2.yaml', task_dir='demo2_tasks', for_train=False) controller = palm.Controller(config='config.yaml', task_dir='tasks', for_train=False)
controller.pred('mrqa', inference_model_dir='output_model/secondrun/mrqa/infermodel') controller.pred('mrqa', inference_model_dir='output_model/secondrun/mrqa/infermodel')
``` ```
...@@ -359,6 +363,8 @@ mrqa: inference model saved at output_model/secondrun/mrqa/infer_model ...@@ -359,6 +363,8 @@ mrqa: inference model saved at output_model/secondrun/mrqa/infer_model
### DEMO3:多目标任务联合训练与任务层参数复用 ### DEMO3:多目标任务联合训练与任务层参数复用
> 本demo路径位于`demo/demo3`
本节我们考虑一个更加复杂的大规模多任务学习场景。假如手头有若干任务,其中每个任务都可能将来被用于预测(即均为目标任务),且鉴于这若干个任务之间存在一些相关性,我们希望将其中一部分任务的任务层参数也进行复用。分类数据集位于`data/cls4mrqa`内。 本节我们考虑一个更加复杂的大规模多任务学习场景。假如手头有若干任务,其中每个任务都可能将来被用于预测(即均为目标任务),且鉴于这若干个任务之间存在一些相关性,我们希望将其中一部分任务的任务层参数也进行复用。分类数据集位于`data/cls4mrqa`内。
具体来说,例如我们有6个分类任务(CLS1 ~ CLS6),均为目标任务(每个任务的模型都希望未来拿来做预测和部署),且我们希望任务1,2,5的任务输出层共享同一份参数,任务3、4共享同一份参数,任务6自己一份参数,即希望对6个任务实现如图所示的参数复用关系。 具体来说,例如我们有6个分类任务(CLS1 ~ CLS6),均为目标任务(每个任务的模型都希望未来拿来做预测和部署),且我们希望任务1,2,5的任务输出层共享同一份参数,任务3、4共享同一份参数,任务6自己一份参数,即希望对6个任务实现如图所示的参数复用关系。
...@@ -370,7 +376,7 @@ mrqa: inference model saved at output_model/secondrun/mrqa/infer_model ...@@ -370,7 +376,7 @@ mrqa: inference model saved at output_model/secondrun/mrqa/infer_model
用户可通过运行如下脚本一键开始学习本节任务目标: 用户可通过运行如下脚本一键开始学习本节任务目标:
```shell ```shell
bash run_demo3.sh bash run.sh
``` ```
**1. 配置任务实例** **1. 配置任务实例**
...@@ -400,9 +406,9 @@ task_reuse_tag: 0, 0, 1, 1, 0, 2 ...@@ -400,9 +406,9 @@ task_reuse_tag: 0, 0, 1, 1, 0, 2
save_path: "output_model/secondrun" save_path: "output_model/secondrun"
backbone: "ernie" backbone: "ernie"
backbone_config_path: "pretrain_model/ernie/ernie_config.json" backbone_config_path: "../../pretrain_model/ernie/ernie_config.json"
vocab_path: "pretrain_model/ernie/vocab.txt" vocab_path: "../../pretrain_model/ernie/vocab.txt"
do_lower_case: True do_lower_case: True
max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例 max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例
...@@ -421,8 +427,8 @@ weight_decay: 0.1 ...@@ -421,8 +427,8 @@ weight_decay: 0.1
import paddlepalm as palm import paddlepalm as palm
if __name__ == '__main__': if __name__ == '__main__':
controller = palm.Controller('config_demo3.yaml', task_dir='demo3_tasks') controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('pretrain_model/ernie/params') controller.load_pretrain('../../pretrain_model/ernie/params')
controller.train() controller.train()
``` ```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册