diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 487397312217b7bc40446411a130599751779cb4..87fa5842c4e9aebb47c8e48023f44a5154ef7c16 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -126,11 +126,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Variable** outvar, const std::string& out_var_name) { - auto lt_varname = string::Sprintf("%s.path", varname); - auto *lt_var = scope->FindVar(lt_varname)->GetMutable(); + auto *lt_var = scope->FindVar("loopup_table_path")->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); - VLOG(4) << "RequestCheckpointHandler update " << lt_varname << " to: " << out_var_name; + VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " << out_var_name; executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; } diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 005e03e69d2b84734f5345ae22b0d529b8426be8..13798c88b1856a6ca0ed3cdb1c4a7bf8d82f4a09 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -182,9 +182,32 @@ This operator will serialize and write a tensor/selected rows variable to file o } }; -} // namespace operators -} // namespace paddle +class SaveOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto out_var_name = op_desc.Output("loopup_table_path").front(); + auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class SaveOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override {} +}; +} +} + +// namespace operators +// namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker); +REGISTER_OPERATOR(save, ops::SaveOp, + paddle::framework::EmptyGradOpMaker, + ops::SaveOpProtoMaker, + ops::SaveOpVarTypeInference, + ops::SaveOpShapeInference); + diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index d5ce6e2704a92cf500327c8b8a0a455b979bd1a9..f9c39262ce32ee97ea1ce5e14420f5d67ad5b120 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -838,7 +838,7 @@ class DistributeTranspiler: """ import os - pserver_program.global_block().create_var(name="%s.path"%self.table_name, persistable=True, type=core.VarDesc.VarType.RAW) + pserver_program.global_block().create_var(name="loopup_table_path", persistable=True, type=core.VarDesc.VarType.RAW) checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block.append_op(