From ae12281d9b91b4d13bf0979d92cc1b3587c4fd1b Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 02:12:27 +0800 Subject: [PATCH] checkpoint notify --- paddle/fluid/operators/checkpoint_notify_op.cc | 9 +++++++-- paddle/fluid/operators/detail/grpc_server.cc | 5 ++++- .../operators/detail/request_handler_impl.cc | 7 +++++++ paddle/fluid/operators/save_op.cc | 12 ++++++++++-- python/paddle/fluid/io.py | 15 ++++++++++----- .../fluid/transpiler/distribute_transpiler.py | 4 +++- 6 files changed, 41 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index 026ad722c..3e5019dd4 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/string/printf.h" namespace paddle { namespace operators { @@ -36,12 +37,14 @@ class CheckpointNotifyOp : public framework::OperatorBase { const platform::Place& place) const override { std::vector epmap = Attr>("epmap"); std::string dir = Attr("dir"); + std::string lookup_table_name = Attr("lookup_table"); detail::RPCClient* rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < epmap.size(); i++) { - VLOG(3) << "sending to " << epmap[i] << " to checkpoint notify ... "; - rpc_client->AsyncCheckpointNotify(epmap[i], dir); + VLOG(3) << "sending " << dir <<" to " << epmap[i] << " to checkpoint notify ... "; + auto serial_looku_table = string::Sprintf("%s/%s.%d", dir, lookup_table_name, i); + rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table); } rpc_client->Wait(); } @@ -57,6 +60,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({"127.0.0.1:6164"}); AddAttr( "dir", "(string, default '') indicate the folder checkpoint will use"); + AddAttr( + "lookup_table", "(string, default '') the lookup table name"); AddComment(R"DOC( Prefetch operator diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index ed3e60ec4..9f4971dc1 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -208,11 +208,14 @@ class RequestCheckpointNotify final : public RequestBase { auto scope = request_->GetMutableLocalScope(); std::string checkpoint_notify = request_->Varname(); - std::string checkpoint_dir = request_->Varname(); + std::string checkpoint_dir = request_->OutVarname(); framework::Variable* invar = nullptr; framework::Variable* outvar = nullptr; + VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify + << ", dir: " << checkpoint_dir; + request_handler_->Handle(checkpoint_notify, scope, invar, &outvar, checkpoint_dir); Finish(reply_, &responder_); diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 41b22e214..487397312 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/rpc_server.h" +#include "paddle/fluid/string/printf.h" namespace paddle { namespace operators { @@ -124,6 +125,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Variable* invar, framework::Variable** outvar, const std::string& out_var_name) { + + auto lt_varname = string::Sprintf("%s.path", varname); + auto *lt_var = scope->FindVar(lt_varname)->GetMutable(); + lt_var->clear(); + lt_var->append(out_var_name); + VLOG(4) << "RequestCheckpointHandler update " << lt_varname << " to: " << out_var_name; executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; } diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index b54bd7db3..005e03e69 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -87,7 +87,7 @@ class SaveOp : public framework::OperatorBase { if (var->IsType()) { SaveLodTensor(filename, place, var); } else if (var->IsType()) { - SaveSelectedRows(filename, place, var); + SaveSelectedRows(scope, place, var); } else { PADDLE_ENFORCE( false, @@ -128,9 +128,17 @@ class SaveOp : public framework::OperatorBase { fout.close(); } - void SaveSelectedRows(const std::string &filename, + void SaveSelectedRows(const framework::Scope &scope, const platform::Place &place, framework::Variable *var) const { + + auto lt_varname = string::Sprintf("%s.path", Input("X")); + auto *lt_var = scope.FindVar(lt_varname)->GetMutable(); + PADDLE_ENFORCE(lt_var != nullptr, "Cannot find variable %s for SaveSelectedRows", + lt_varname); + std::string filename = lt_var->data(); + VLOG(4) << "SaveSelectedRows get File name: " << filename; + auto &selectedRows = var->Get(); // get device context from pool diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 253fd5651..ce82b6b90 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -471,7 +471,10 @@ def save_checkpoint(executor, trainer_id, trainer_args=None, main_program=None, - max_num_checkpoints=3): + max_num_checkpoints=3, + lookup_table=None, + ps_endpoint_list=None + ): """ Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy @@ -500,7 +503,7 @@ def save_checkpoint(executor, if trainer_id == 0: save_persist_vars_without_grad(executor, cur_dir, main_program) - save_pserver_vars_by_notify(executor, cur_dir, "") + save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table) _scroll_delete(checkpoint_dir, max_num_checkpoints) @@ -600,7 +603,7 @@ def save_persist_vars_without_grad(executor, dirname, program): _write_success(cur_dir) -def save_pserver_vars_by_notify(executor, dirname, epmap): +def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list): """ """ cur_dir = _get_lookuptable_dir(dirname) @@ -609,11 +612,12 @@ def save_pserver_vars_by_notify(executor, dirname, epmap): checkpoint_notify_block = checkpoint_notify_program.global_block() attrs = {} - attrs['epmap'] = None + attrs['epmap'] = ps_endpoint_list attrs['dir'] = cur_dir + attrs['lookup_table'] = lookup_table checkpoint_notify_block.append_op( - type='checkpoint_notify', inputs={}, output={}, attrs=attrs) + type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) executor.run(checkpoint_notify_program) @@ -783,3 +787,4 @@ def get_latest_checkpoint_serial(checkpoint_dir): if success_num > current_dir: current_dir = success_num return current_dir + diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 55a439660..d5ce6e270 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -838,13 +838,15 @@ class DistributeTranspiler: """ import os + pserver_program.global_block().create_var(name="%s.path"%self.table_name, persistable=True, type=core.VarDesc.VarType.RAW) + checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block.append_op( type='save', inputs={'X': [self.table_name]}, outputs={}, attrs={ - 'file_path': os.path.join("/tmp/pserver_ckpt/", self.table_name) + 'file_path': self.table_name) }) return checkpoint_save_block.idx -- GitLab