README.md 2.8 KB
Newer Older
X
xixiaoyao 已提交
1

X
xixiaoyao 已提交
2
# 多任务学习框架PaddlePALM
X
xixiaoyao 已提交
3

X
xixiaoyao 已提交
4 5
# 安装
pip install paddlepalm
X
xixiaoyao 已提交
6

X
xixiaoyao 已提交
7
# 使用
X
xixiaoyao 已提交
8

X
xixiaoyao 已提交
9
### 1. 创建任务实例
10

X
xixiaoyao 已提交
11
使用yaml格式描述任务实例,每个任务实例中的必选字段包括
X
xixiaoyao 已提交
12

X
xixiaoyao 已提交
13 14 15 16
- train_file: 训练集文件路径
- reader: 数据集载入与处理工具名,框架预置reader列表见[这里](https://www.baidu.com/)
- backbone: 骨架模型名,框架预置reader列表见[这里](https://www.baidu.com/)
- paradigm: 任务范式(类型)名,框架预置paradigm列表见[这里](https://www.baidu.com/)
X
xixiaoyao 已提交
17

X
xixiaoyao 已提交
18
### 2. 完成训练配置
X
xixiaoyao 已提交
19

X
xixiaoyao 已提交
20
使用yaml格式完成配置多任务学习中的相关参数,如指定任务实例及其相关的主辅关系、参数复用关系、采样权重等
X
xixiaoyao 已提交
21

X
xixiaoyao 已提交
22
### 3. 开始训练
X
xixiaoyao 已提交
23

X
xixiaoyao 已提交
24
```python
X
xixiaoyao 已提交
25

X
xixiaoyao 已提交
26
import paddlepalm as palm
X
xixiaoyao 已提交
27

X
xixiaoyao 已提交
28 29 30 31 32
if __name__ == '__main__':
    controller = palm.Controller('config.yaml', task_dir='task_instance')
    controller.load_pretrain('pretrain_model/ernie/params')
    controller.train()
```
X
xixiaoyao 已提交
33

X
xixiaoyao 已提交
34
### 4. 预测
X
xixiaoyao 已提交
35

X
xixiaoyao 已提交
36
用户可在训练结束后直接调用pred接口对某个目标任务进行预测
X
xixiaoyao 已提交
37

X
xixiaoyao 已提交
38
示例:
X
xixiaoyao 已提交
39
```python
X
xixiaoyao 已提交
40
import paddlepalm as palm
X
xixiaoyao 已提交
41

X
xixiaoyao 已提交
42 43 44 45 46
if __name__ == '__main__':
    controller = palm.Controller(config_path='config.yaml', task_dir='task_instance')
    controller.load_pretrain('pretrain_model/ernie/params')
    controller.train()
    controller.pred('mrqa')
X
xixiaoyao 已提交
47 48
```

X
xixiaoyao 已提交
49
也可新建controller直接预测
X
Xiaoyao Xi 已提交
50

X
xixiaoyao 已提交
51
```python
X
xixiaoyao 已提交
52 53 54 55 56
import paddlepalm as palm

if __name__ == '__main__':
    controller = palm.Controller(config_path='config.yaml', task_dir='task_instance')
    controller.pred('mrqa', infermodel_path='output_model/firstrun2/infer_model')
X
xixiaoyao 已提交
57 58
```

X
xixiaoyao 已提交
59

X
xixiaoyao 已提交
60 61 62 63
# 运行机制

### 多任务学习机制
pass 
X
xixiaoyao 已提交
64

X
xixiaoyao 已提交
65
### 训练终止机制
X
Xiaoyao Xi 已提交
66

X
xixiaoyao 已提交
67 68 69 70 71
- 默认的设置:
  - **所有target任务达到目标训练步数后多任务学习停止**
  - 未设置成target任务的任务(即辅助任务)不会影响训练终止与否,只是担任”陪训“的角色
  - 注:默认所有的任务都是target任务,用户可以通过`target_tag`来标记目标/辅助任务
  - 每个目标任务的目标训练步数由num_epochs和mix_ratio计算得到
X
Xiaoyao Xi 已提交
72

X
xixiaoyao 已提交
73
### 保存机制
X
Xiaoyao Xi 已提交
74

X
xixiaoyao 已提交
75 76 77 78 79 80 81 82
- 默认的设置:
  - 训练过程中,保存下来的模型分为checkpoint (ckpt)和inference model (infermodel)两种:
    - ckpt保存的是包含所有任务的总计算图(即整个多任务学习计算图),用于训练中断恢复
    - infermodel保存的是某个目标任务的推理计算图和推理依赖的相关配置
  - 对于每个target任务,训练到预期的步数后自动保存inference model,之后不再保存。(注:保存inference model不影响ckpt的保存)
- 用户可改配置
  - 使用`save_ckpt_every_steps`来控制保存ckpt的频率,默认不保存
  - 每个task instance均可使用`save_infermodel_every_steps`来控制该task保存infermodel的频率,默认为-1,即只在达到目标训练步数时保存一下
X
xixiaoyao 已提交
83 84 85