From f70096a2a4fb8cd673b455873d52a266b5254dd3 Mon Sep 17 00:00:00 2001 From: seiriosPlus Date: Wed, 23 Sep 2020 17:05:52 +0800 Subject: [PATCH] add mode for save --- .../operators/distributed/brpc/brpc_client.cc | 1 + .../operators/distributed/brpc/brpc_client.h | 2 +- .../operators/distributed/grpc/grpc_client.cc | 2 ++ .../operators/distributed/grpc/grpc_client.h | 2 +- .../distributed/request_handler_impl.cc | 10 +++++----- .../fluid/operators/distributed/rpc_client.h | 3 ++- .../distributed_ops/checkpoint_notify_op.cc | 20 +++++++++++-------- .../fleet/runtime/parameter_server_runtime.py | 17 +++++++++------- 8 files changed, 34 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.cc b/paddle/fluid/operators/distributed/brpc/brpc_client.cc index cb93b8d910a..b2a26089c86 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.cc +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.cc @@ -448,6 +448,7 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep, VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dirname, const std::string& varname, + const int mode, int64_t time_out) { sendrecv::VariableMessage req; req.set_varname(varname); diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.h b/paddle/fluid/operators/distributed/brpc/brpc_client.h index 2ea90d560f5..91f94b4c9d5 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.h +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.h @@ -103,7 +103,7 @@ class BRPCClient : public RPCClient { VarHandlePtr AsyncCheckpointNotify( const std::string& ep, const std::string& dirname, - const std::string& varname, + const std::string& varname, const int mode, int64_t time_out = FLAGS_rpc_deadline) override; bool Wait() override; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 0983b4a406e..f935e452f71 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -420,6 +420,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dirname, const std::string& varname, + const int mode, int64_t time_out) { const auto ch = GetChannel(ep); @@ -433,6 +434,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, sendrecv::VariableMessage req; req.set_varname(varname); + req.set_table_name(std::to_string(mode)); req.set_out_varname(dirname); platform::RecordRPCEvent record_event(method); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index 6b6249540c6..c9448c37950 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -247,7 +247,7 @@ class GRPCClient : public RPCClient { VarHandlePtr AsyncCheckpointNotify( const std::string& ep, const std::string& dirname, - const std::string& varname, + const std::string& varname, const int mode, int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncDistributeNotify( diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 761a4edc523..52fc5fe8b46 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -266,13 +266,13 @@ bool RequestCheckpointHandler::Handle(const std::string &varname, const int trainer_id, const std::string &out_var_name, const std::string &table_name) { - VLOG(4) << "receive save var " << varname << " with path " << out_var_name; + int mode = std::stoi(out_var_name); + + VLOG(4) << "receive save var " << varname << " with path " << out_var_name + << " mode " << mode; auto *ins = distributed::LargeScaleKV::GetInstance(); - ins->Get(varname)->Save(out_var_name); - // auto checkpoint_op = BuildCheckpointOp(varname, out_var_name); - // paddle::platform::CPUPlace cpu_place; - // checkpoint_op->Run(*scope_, cpu_place); + ins->Get(varname)->Save(out_var_name, mode); return true; } diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 69a5e327431..b82589ac30e 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -78,7 +78,8 @@ class RPCClient { virtual VarHandlePtr AsyncCheckpointNotify( const std::string& ep, const std::string& dirname, - const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0; + const std::string& varname, const int mode, + int64_t time_out = FLAGS_rpc_deadline) = 0; virtual VarHandlePtr AsyncDistributeNotify( const std::string& ep, const platform::DeviceContext& ctx, diff --git a/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc b/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc index 2ed2acb96dc..62fd72b335b 100644 --- a/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc @@ -36,8 +36,12 @@ class CheckpointNotifyOp : public framework::OperatorBase { Attr>("endpoints"); std::string dirname = Attr("dirname"); std::string varname = Attr("varname"); - auto is_slice = Attr("is_slice"); - VLOG(1) << "is_slice: " << is_slice; + auto mode = Attr("mode"); + + if (mode != 0 && mode != 1 && mode != 2) { + PADDLE_THROW(platform::errors::InvalidArgument( + "mode expected in [0/1/2], but got %d", mode)); + } std::vector slice_varnames = Attr>("slice_varnames"); @@ -52,11 +56,12 @@ class CheckpointNotifyOp : public framework::OperatorBase { auto save_path = string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]); - rpc_client->AsyncCheckpointNotify(epmap[i], save_path, - remote_varnames[i]); + rpc_client->AsyncCheckpointNotify(epmap[i], save_path, remote_varnames[i], + mode); VLOG(3) << "checkpoint notify sending with path: " << save_path - << " and var:" << slice_varnames[i] << " to " << epmap[i]; + << " and var:" << slice_varnames[i] << " to " << epmap[i] + << " with mode " << mode; } PADDLE_ENFORCE_EQ( rpc_client->Wait(), true, @@ -79,9 +84,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { "slice_varnames", "(string vector) the slice vars need to be saved"); AddAttr>( "remote_varnames", "(string vector) the slice vars need to be saved"); - AddAttr( - "is_slice", - "is_slice=True means the var has been slice by parameter server"); + AddAttr("mode", "mode=0/1/2 means nothing/save base/save delta") + .SetDefault(0); AddComment(R"DOC( CheckpointNotify operator This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at diff --git a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py index 6dd4661f000..3f1fc518fe6 100644 --- a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py +++ b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py @@ -436,8 +436,7 @@ class ParameterServerRuntime(RuntimeBase): executor.run(prog) return context.keys() - def _save_distributed_params(self, executor, dirname, context, - main_program): + def _save_distributed_params(self, executor, dirname, context, mode): prog = Program() block = prog.global_block() @@ -446,7 +445,7 @@ class ParameterServerRuntime(RuntimeBase): type='checkpoint_notify', attrs={ "varname": name, - "is_slice": True, + "mode": mode, "slice_varnames": var_ctx.split_varnames(), "remote_varnames": var_ctx.split_varnames(), "endpoints": var_ctx.split_endpoints(), @@ -456,7 +455,8 @@ class ParameterServerRuntime(RuntimeBase): executor.run(prog) return context.keys() - def _save_distributed_persistables(self, executor, dirname, main_program): + def _save_distributed_persistables(self, executor, dirname, main_program, + mode): dense_ctx = self.compiled_strategy.get_communicator_recv_context( recv_type=1) @@ -473,7 +473,7 @@ class ParameterServerRuntime(RuntimeBase): executor, dirname, sparse_ctx, main_program) recv_distributed_varnames = self._save_distributed_params( - executor, dirname, distributed_ctx, main_program) + executor, dirname, distributed_ctx, mode) saved_varnames = recv_dense_varnames + list( recv_sparse_varnames) + list(recv_distributed_varnames) @@ -493,6 +493,7 @@ class ParameterServerRuntime(RuntimeBase): executor, dirname, main_program=None, + mode=0, **kwargs): """ This function filters out all variables with `persistable==True` from the @@ -523,7 +524,8 @@ class ParameterServerRuntime(RuntimeBase): "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed" ) - self._save_distributed_persistables(executor, dirname, main_program) + self._save_distributed_persistables(executor, dirname, main_program, + mode) def _ps_inference_save_inference_model(self, executor, @@ -569,7 +571,8 @@ class ParameterServerRuntime(RuntimeBase): program = Program.parse_from_string(program_desc_str) program._copy_dist_param_info_from(fluid.default_main_program()) - self._ps_inference_save_persistables(executor, dirname, program) + self._ps_inference_save_persistables( + executor, dirname, program, mode=0) def _save_inference_model(self, *args, **kwargs): self._ps_inference_save_inference_model(*args, **kwargs) -- GitLab