提交 985026ce 编写于 作者: T tangwei12

add checkpoint_notify in python

上级 1c2e9bdd
...@@ -76,7 +76,7 @@ class CheckpointNotifyOpShapeInference : public framework::InferShapeBase { ...@@ -76,7 +76,7 @@ class CheckpointNotifyOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(checkpointnotify, ops::CheckpointNotifyOp, REGISTER_OPERATOR(checkpoint_notify, ops::CheckpointNotifyOp,
paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker,
ops::CheckpointNotifyOpMaker, ops::CheckpointNotifyOpMaker,
ops::CheckpointNotifyOpShapeInference); ops::CheckpointNotifyOpShapeInference);
...@@ -382,7 +382,7 @@ class Operator(object): ...@@ -382,7 +382,7 @@ class Operator(object):
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv',
'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine',
'ncclInit', 'channel_create', 'channel_close', 'channel_send', 'ncclInit', 'channel_create', 'channel_close', 'channel_send',
'channel_recv', 'select' 'channel_recv', 'select', 'checkpoint_notify'
} }
def __init__(self, def __init__(self,
......
...@@ -613,7 +613,7 @@ def save_pserver_vars_by_notify(executor, dirname, epmap): ...@@ -613,7 +613,7 @@ def save_pserver_vars_by_notify(executor, dirname, epmap):
attrs['dir'] = cur_dir attrs['dir'] = cur_dir
checkpoint_notify_block.append_op( checkpoint_notify_block.append_op(
type='checkpointnotify', inputs={}, output={}, attrs=attrs) type='checkpoint_notify', inputs={}, output={}, attrs=attrs)
executor.run(checkpoint_notify_program) executor.run(checkpoint_notify_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册