From 84e9b24729abf5a05e83901ebe81a0efb3a999d4 Mon Sep 17 00:00:00 2001 From: xixiaoyao Date: Thu, 5 Dec 2019 16:50:09 +0800 Subject: [PATCH] fix infermodel save --- README.md | 9 +++++++-- paddlepalm/mtl_controller.py | 5 +++-- setup.py | 6 +++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 84b2fbe..b2b3e3c 100644 --- a/README.md +++ b/README.md @@ -741,7 +741,7 @@ BERT包含了如下输入对象 ```yaml token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的单词id。 position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。 -segment_ids: 一个shape为[batch_size, seq_len]的0/1矩阵,用于支持BERT、ERNIE等模型的输入,当元素为0时,代表当前token属于分类任务或匹配任务的text1,为1时代表当前token属于匹配任务的text2. +segment_ids: 一个shape为[batch_size, seq_len]的0/1矩阵,用于支持BERT、ERNIE等模型的输入,当元素为0时,代表当前token属于分类任务或匹配任务的text1,为1时代表当前token属于匹配任务的text2。 input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。 ``` @@ -781,6 +781,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float ## 附录C:内置任务范式(paradigm) + #### 分类范式:cls 分类范式额外包含以下配置字段: @@ -788,6 +789,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float ```yaml n_classes(REQUIRED): int类型。分类任务的类别数。 pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的`save_path`字段指定路径下的任务目录。 +save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。 ``` 分类范式包含如下的输入对象: @@ -812,6 +814,7 @@ sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类 ```yaml pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的`save_path`字段指定路径下的任务目录。 +save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。 ``` 匹配范式包含如下的输入对象: @@ -838,6 +841,7 @@ sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float max_answer_len(REQUIRED): int类型。预测的最大答案长度 n_best_size (OPTIONAL) : int类型,默认为20。预测时保存的nbest回答文件中每条样本的n_best数量 pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的`save_path`字段指定路径下的任务目录 +save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。 ``` 机器阅读理解范式包含如下的输入对象: @@ -885,7 +889,8 @@ do_lower_case (OPTIONAL): bool类型。大小写标志位。默认为False,即 for_cn: bool类型。中文模式标志位。默认为False,即默认输入为英文,设置为True后,分词器、后处理等按照中文语言进行处理。 print_every_n_steps (OPTIONAL): int类型。默认为5。训练阶段打印日志的频率(step为单位)。 -save_every_n_steps (OPTIONAL): int类型。默认为-1。训练过程中保存checkpoint模型的频率,默认不保存。 +save_ckpt_every_n_steps (OPTIONAL): int类型。默认为-1。训练过程中保存完整计算图的检查点(checkpoint)的频率,默认-1,仅在最后一个step自动保存检查点。 +save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。 optimizer(REQUIRED): str类型。优化器名称,目前框架只支持adam,未来会支持更多优化器。 learning_rate(REQUIRED): str类型。训练阶段的学习率。 diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index 63af1a4..e265758 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -592,8 +592,9 @@ class Controller(object): global_step += 1 cur_task.cur_train_step += 1 - if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0: - cur_task.save(suffix='.step'+str(cur_task.cur_train_step)) + cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch + if cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0: + cur_task.save(suffix='.step'+str(cur_task_global_step)) if global_step % main_conf.get('print_every_n_steps', 5) == 0: loss = rt_outputs[cur_task.name+'/loss'] diff --git a/setup.py b/setup.py index bfeb6be..6c81d9e 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ """ Setup script. Authors: zhouxiangyang(zhouxiangyang@baidu.com) -Date: 2019/09/29 21:00:01 +Date: 2019/12/05 13:24:01 """ import setuptools from io import open @@ -27,10 +27,10 @@ with open("README.md", "r", encoding='utf-8') as fh: setuptools.setup( name="paddlepalm", - version="0.2.1", + version="0.2.2", author="PaddlePaddle", author_email="zhangyiming04@baidu.com", - description="A Multi-task Learning Lib for PaddlePaddle Users.", + description="A Lib for PaddlePaddle Users.", # long_description=long_description, # long_description_content_type="text/markdown", url="https://github.com/PaddlePaddle/PALM", -- GitLab