From af0a6a149f7e77ffa3b3768f27dd4cc0615cab90 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 02:56:37 +0800 Subject: [PATCH] checkpoint notify --- .../operators/detail/request_handler_impl.cc | 5 ++-- paddle/fluid/operators/save_op.cc | 29 +++++++++++++++++-- .../fluid/transpiler/distribute_transpiler.py | 2 +- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 48739731221..87fa5842c4e 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -126,11 +126,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Variable** outvar, const std::string& out_var_name) { - auto lt_varname = string::Sprintf("%s.path", varname); - auto *lt_var = scope->FindVar(lt_varname)->GetMutable(); + auto *lt_var = scope->FindVar("loopup_table_path")->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); - VLOG(4) << "RequestCheckpointHandler update " << lt_varname << " to: " << out_var_name; + VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " << out_var_name; executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; } diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 005e03e69d2..13798c88b18 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -182,9 +182,32 @@ This operator will serialize and write a tensor/selected rows variable to file o } }; -} // namespace operators -} // namespace paddle +class SaveOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto out_var_name = op_desc.Output("loopup_table_path").front(); + auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class SaveOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override {} +}; +} +} + +// namespace operators +// namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker); +REGISTER_OPERATOR(save, ops::SaveOp, + paddle::framework::EmptyGradOpMaker, + ops::SaveOpProtoMaker, + ops::SaveOpVarTypeInference, + ops::SaveOpShapeInference); + diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index d5ce6e2704a..f9c39262ce3 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -838,7 +838,7 @@ class DistributeTranspiler: """ import os - pserver_program.global_block().create_var(name="%s.path"%self.table_name, persistable=True, type=core.VarDesc.VarType.RAW) + pserver_program.global_block().create_var(name="loopup_table_path", persistable=True, type=core.VarDesc.VarType.RAW) checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block.append_op( -- GitLab