提交 84e9b247 编写于 作者: X xixiaoyao

fix infermodel save

上级 cb213b02
...@@ -741,7 +741,7 @@ BERT包含了如下输入对象 ...@@ -741,7 +741,7 @@ BERT包含了如下输入对象
```yaml ```yaml
token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的单词id。 token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的单词id。
position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。 position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。
segment_ids: 一个shape为[batch_size, seq_len]的0/1矩阵,用于支持BERT、ERNIE等模型的输入,当元素为0时,代表当前token属于分类任务或匹配任务的text1,为1时代表当前token属于匹配任务的text2. segment_ids: 一个shape为[batch_size, seq_len]的0/1矩阵,用于支持BERT、ERNIE等模型的输入,当元素为0时,代表当前token属于分类任务或匹配任务的text1,为1时代表当前token属于匹配任务的text2
input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。 input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。
``` ```
...@@ -781,6 +781,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float ...@@ -781,6 +781,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float
## 附录C:内置任务范式(paradigm) ## 附录C:内置任务范式(paradigm)
#### 分类范式:cls #### 分类范式:cls
分类范式额外包含以下配置字段: 分类范式额外包含以下配置字段:
...@@ -788,6 +789,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float ...@@ -788,6 +789,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float
```yaml ```yaml
n_classes(REQUIRED): int类型。分类任务的类别数。 n_classes(REQUIRED): int类型。分类任务的类别数。
pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的`save_path`字段指定路径下的任务目录。 pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的`save_path`字段指定路径下的任务目录。
save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。
``` ```
分类范式包含如下的输入对象: 分类范式包含如下的输入对象:
...@@ -812,6 +814,7 @@ sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类 ...@@ -812,6 +814,7 @@ sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类
```yaml ```yaml
pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的`save_path`字段指定路径下的任务目录。 pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的`save_path`字段指定路径下的任务目录。
save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。
``` ```
匹配范式包含如下的输入对象: 匹配范式包含如下的输入对象:
...@@ -838,6 +841,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float ...@@ -838,6 +841,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float
max_answer_len(REQUIRED): int类型。预测的最大答案长度 max_answer_len(REQUIRED): int类型。预测的最大答案长度
n_best_size (OPTIONAL) : int类型,默认为20。预测时保存的nbest回答文件中每条样本的n_best数量 n_best_size (OPTIONAL) : int类型,默认为20。预测时保存的nbest回答文件中每条样本的n_best数量
pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的`save_path`字段指定路径下的任务目录 pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的`save_path`字段指定路径下的任务目录
save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。
``` ```
机器阅读理解范式包含如下的输入对象: 机器阅读理解范式包含如下的输入对象:
...@@ -885,7 +889,8 @@ do_lower_case (OPTIONAL): bool类型。大小写标志位。默认为False,即 ...@@ -885,7 +889,8 @@ do_lower_case (OPTIONAL): bool类型。大小写标志位。默认为False,即
for_cn: bool类型。中文模式标志位。默认为False,即默认输入为英文,设置为True后,分词器、后处理等按照中文语言进行处理。 for_cn: bool类型。中文模式标志位。默认为False,即默认输入为英文,设置为True后,分词器、后处理等按照中文语言进行处理。
print_every_n_steps (OPTIONAL): int类型。默认为5。训练阶段打印日志的频率(step为单位)。 print_every_n_steps (OPTIONAL): int类型。默认为5。训练阶段打印日志的频率(step为单位)。
save_every_n_steps (OPTIONAL): int类型。默认为-1。训练过程中保存checkpoint模型的频率,默认不保存。 save_ckpt_every_n_steps (OPTIONAL): int类型。默认为-1。训练过程中保存完整计算图的检查点(checkpoint)的频率,默认-1,仅在最后一个step自动保存检查点。
save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。
optimizer(REQUIRED): str类型。优化器名称,目前框架只支持adam,未来会支持更多优化器。 optimizer(REQUIRED): str类型。优化器名称,目前框架只支持adam,未来会支持更多优化器。
learning_rate(REQUIRED): str类型。训练阶段的学习率。 learning_rate(REQUIRED): str类型。训练阶段的学习率。
......
...@@ -592,8 +592,9 @@ class Controller(object): ...@@ -592,8 +592,9 @@ class Controller(object):
global_step += 1 global_step += 1
cur_task.cur_train_step += 1 cur_task.cur_train_step += 1
if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0: cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch
cur_task.save(suffix='.step'+str(cur_task.cur_train_step)) if cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0:
cur_task.save(suffix='.step'+str(cur_task_global_step))
if global_step % main_conf.get('print_every_n_steps', 5) == 0: if global_step % main_conf.get('print_every_n_steps', 5) == 0:
loss = rt_outputs[cur_task.name+'/loss'] loss = rt_outputs[cur_task.name+'/loss']
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
""" """
Setup script. Setup script.
Authors: zhouxiangyang(zhouxiangyang@baidu.com) Authors: zhouxiangyang(zhouxiangyang@baidu.com)
Date: 2019/09/29 21:00:01 Date: 2019/12/05 13:24:01
""" """
import setuptools import setuptools
from io import open from io import open
...@@ -27,10 +27,10 @@ with open("README.md", "r", encoding='utf-8') as fh: ...@@ -27,10 +27,10 @@ with open("README.md", "r", encoding='utf-8') as fh:
setuptools.setup( setuptools.setup(
name="paddlepalm", name="paddlepalm",
version="0.2.1", version="0.2.2",
author="PaddlePaddle", author="PaddlePaddle",
author_email="zhangyiming04@baidu.com", author_email="zhangyiming04@baidu.com",
description="A Multi-task Learning Lib for PaddlePaddle Users.", description="A Lib for PaddlePaddle Users.",
# long_description=long_description, # long_description=long_description,
# long_description_content_type="text/markdown", # long_description_content_type="text/markdown",
url="https://github.com/PaddlePaddle/PALM", url="https://github.com/PaddlePaddle/PALM",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册