Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
84e9b247
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
5
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
84e9b247
编写于
12月 05, 2019
作者:
X
xixiaoyao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix infermodel save
上级
cb213b02
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
13 addition
and
7 deletion
+13
-7
README.md
README.md
+7
-2
paddlepalm/mtl_controller.py
paddlepalm/mtl_controller.py
+3
-2
setup.py
setup.py
+3
-3
未找到文件。
README.md
浏览文件 @
84e9b247
...
@@ -741,7 +741,7 @@ BERT包含了如下输入对象
...
@@ -741,7 +741,7 @@ BERT包含了如下输入对象
```
yaml
```
yaml
token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的单词id。
token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的单词id。
position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。
position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。
segment_ids: 一个shape为[batch_size, seq_len]的0/1矩阵,用于支持BERT、ERNIE等模型的输入,当元素为0时,代表当前token属于分类任务或匹配任务的text1,为1时代表当前token属于匹配任务的text2
.
segment_ids: 一个shape为[batch_size, seq_len]的0/1矩阵,用于支持BERT、ERNIE等模型的输入,当元素为0时,代表当前token属于分类任务或匹配任务的text1,为1时代表当前token属于匹配任务的text2
。
input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。
input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。
```
```
...
@@ -781,6 +781,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float
...
@@ -781,6 +781,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float
## 附录C:内置任务范式(paradigm)
## 附录C:内置任务范式(paradigm)
#### 分类范式:cls
#### 分类范式:cls
分类范式额外包含以下配置字段:
分类范式额外包含以下配置字段:
...
@@ -788,6 +789,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float
...
@@ -788,6 +789,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float
```
yaml
```
yaml
n_classes(REQUIRED): int类型。分类任务的类别数。
n_classes(REQUIRED): int类型。分类任务的类别数。
pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的
`save_path`
字段指定路径下的任务目录。
pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的
`save_path`
字段指定路径下的任务目录。
save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。
```
```
分类范式包含如下的输入对象:
分类范式包含如下的输入对象:
...
@@ -812,6 +814,7 @@ sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类
...
@@ -812,6 +814,7 @@ sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类
```
yaml
```
yaml
pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的
`save_path`
字段指定路径下的任务目录。
pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的
`save_path`
字段指定路径下的任务目录。
save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。
```
```
匹配范式包含如下的输入对象:
匹配范式包含如下的输入对象:
...
@@ -838,6 +841,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float
...
@@ -838,6 +841,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float
max_answer_len(REQUIRED): int类型。预测的最大答案长度
max_answer_len(REQUIRED): int类型。预测的最大答案长度
n_best_size (OPTIONAL) : int类型,默认为20。预测时保存的nbest回答文件中每条样本的n_best数量
n_best_size (OPTIONAL) : int类型,默认为20。预测时保存的nbest回答文件中每条样本的n_best数量
pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的
`save_path`
字段指定路径下的任务目录
pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的
`save_path`
字段指定路径下的任务目录
save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。
```
```
机器阅读理解范式包含如下的输入对象:
机器阅读理解范式包含如下的输入对象:
...
@@ -885,7 +889,8 @@ do_lower_case (OPTIONAL): bool类型。大小写标志位。默认为False,即
...
@@ -885,7 +889,8 @@ do_lower_case (OPTIONAL): bool类型。大小写标志位。默认为False,即
for_cn: bool类型。中文模式标志位。默认为False,即默认输入为英文,设置为True后,分词器、后处理等按照中文语言进行处理。
for_cn: bool类型。中文模式标志位。默认为False,即默认输入为英文,设置为True后,分词器、后处理等按照中文语言进行处理。
print_every_n_steps (OPTIONAL): int类型。默认为5。训练阶段打印日志的频率(step为单位)。
print_every_n_steps (OPTIONAL): int类型。默认为5。训练阶段打印日志的频率(step为单位)。
save_every_n_steps (OPTIONAL): int类型。默认为-1。训练过程中保存checkpoint模型的频率,默认不保存。
save_ckpt_every_n_steps (OPTIONAL): int类型。默认为-1。训练过程中保存完整计算图的检查点(checkpoint)的频率,默认-1,仅在最后一个step自动保存检查点。
save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。
optimizer(REQUIRED): str类型。优化器名称,目前框架只支持adam,未来会支持更多优化器。
optimizer(REQUIRED): str类型。优化器名称,目前框架只支持adam,未来会支持更多优化器。
learning_rate(REQUIRED): str类型。训练阶段的学习率。
learning_rate(REQUIRED): str类型。训练阶段的学习率。
...
...
paddlepalm/mtl_controller.py
浏览文件 @
84e9b247
...
@@ -592,8 +592,9 @@ class Controller(object):
...
@@ -592,8 +592,9 @@ class Controller(object):
global_step
+=
1
global_step
+=
1
cur_task
.
cur_train_step
+=
1
cur_task
.
cur_train_step
+=
1
if
cur_task
.
save_infermodel_every_n_steps
>
0
and
cur_task
.
cur_train_step
%
cur_task
.
save_infermodel_every_n_steps
==
0
:
cur_task_global_step
=
cur_task
.
cur_train_step
+
cur_task
.
cur_train_epoch
*
cur_task
.
steps_pur_epoch
cur_task
.
save
(
suffix
=
'.step'
+
str
(
cur_task
.
cur_train_step
))
if
cur_task
.
save_infermodel_every_n_steps
>
0
and
cur_task_global_step
%
cur_task
.
save_infermodel_every_n_steps
==
0
:
cur_task
.
save
(
suffix
=
'.step'
+
str
(
cur_task_global_step
))
if
global_step
%
main_conf
.
get
(
'print_every_n_steps'
,
5
)
==
0
:
if
global_step
%
main_conf
.
get
(
'print_every_n_steps'
,
5
)
==
0
:
loss
=
rt_outputs
[
cur_task
.
name
+
'/loss'
]
loss
=
rt_outputs
[
cur_task
.
name
+
'/loss'
]
...
...
setup.py
浏览文件 @
84e9b247
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
"""
"""
Setup script.
Setup script.
Authors: zhouxiangyang(zhouxiangyang@baidu.com)
Authors: zhouxiangyang(zhouxiangyang@baidu.com)
Date: 2019/
09/29 21:00
:01
Date: 2019/
12/05 13:24
:01
"""
"""
import
setuptools
import
setuptools
from
io
import
open
from
io
import
open
...
@@ -27,10 +27,10 @@ with open("README.md", "r", encoding='utf-8') as fh:
...
@@ -27,10 +27,10 @@ with open("README.md", "r", encoding='utf-8') as fh:
setuptools
.
setup
(
setuptools
.
setup
(
name
=
"paddlepalm"
,
name
=
"paddlepalm"
,
version
=
"0.2.
1
"
,
version
=
"0.2.
2
"
,
author
=
"PaddlePaddle"
,
author
=
"PaddlePaddle"
,
author_email
=
"zhangyiming04@baidu.com"
,
author_email
=
"zhangyiming04@baidu.com"
,
description
=
"A
Multi-task Learning
Lib for PaddlePaddle Users."
,
description
=
"A Lib for PaddlePaddle Users."
,
# long_description=long_description,
# long_description=long_description,
# long_description_content_type="text/markdown",
# long_description_content_type="text/markdown",
url
=
"https://github.com/PaddlePaddle/PALM"
,
url
=
"https://github.com/PaddlePaddle/PALM"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录