diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.cc b/paddle/fluid/operators/distributed/brpc/brpc_client.cc index cb93b8d910a2353b8c9a1e793338fa5d50a93165..b2a26089c868968734df35383f2ddf2ba512b137 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.cc +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.cc @@ -448,6 +448,7 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep, VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dirname, const std::string& varname, + const int mode, int64_t time_out) { sendrecv::VariableMessage req; req.set_varname(varname); diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.h b/paddle/fluid/operators/distributed/brpc/brpc_client.h index 2ea90d560f5685e19a8f16d15d07414c927001ba..91f94b4c9d5a3076e17b0b54cf153e2f3b9edc31 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.h +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.h @@ -103,7 +103,7 @@ class BRPCClient : public RPCClient { VarHandlePtr AsyncCheckpointNotify( const std::string& ep, const std::string& dirname, - const std::string& varname, + const std::string& varname, const int mode, int64_t time_out = FLAGS_rpc_deadline) override; bool Wait() override; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 0983b4a406e042f094965ad9a7de437684940fa9..f935e452f7147ad762f5b03709603a961fb8f17e 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -420,6 +420,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dirname, const std::string& varname, + const int mode, int64_t time_out) { const auto ch = GetChannel(ep); @@ -433,6 +434,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, sendrecv::VariableMessage req; req.set_varname(varname); + req.set_table_name(std::to_string(mode)); req.set_out_varname(dirname); platform::RecordRPCEvent record_event(method); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index 6b6249540c6d15954743c414a60472bf1f831151..c9448c37950c4d05052b3b73ea25b7323b847539 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -247,7 +247,7 @@ class GRPCClient : public RPCClient { VarHandlePtr AsyncCheckpointNotify( const std::string& ep, const std::string& dirname, - const std::string& varname, + const std::string& varname, const int mode, int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncDistributeNotify( diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 761a4edc523da52ffdbdd2039444c133e8da368c..52fc5fe8b4685b188364398b164952f8ba50fa83 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -266,13 +266,13 @@ bool RequestCheckpointHandler::Handle(const std::string &varname, const int trainer_id, const std::string &out_var_name, const std::string &table_name) { - VLOG(4) << "receive save var " << varname << " with path " << out_var_name; + int mode = std::stoi(out_var_name); + + VLOG(4) << "receive save var " << varname << " with path " << out_var_name + << " mode " << mode; auto *ins = distributed::LargeScaleKV::GetInstance(); - ins->Get(varname)->Save(out_var_name); - // auto checkpoint_op = BuildCheckpointOp(varname, out_var_name); - // paddle::platform::CPUPlace cpu_place; - // checkpoint_op->Run(*scope_, cpu_place); + ins->Get(varname)->Save(out_var_name, mode); return true; } diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 69a5e3274318337f5424afa6492da829e04daa69..b82589ac30e10a6965e307929085f4f28f2034f5 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -78,7 +78,8 @@ class RPCClient { virtual VarHandlePtr AsyncCheckpointNotify( const std::string& ep, const std::string& dirname, - const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0; + const std::string& varname, const int mode, + int64_t time_out = FLAGS_rpc_deadline) = 0; virtual VarHandlePtr AsyncDistributeNotify( const std::string& ep, const platform::DeviceContext& ctx, diff --git a/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc b/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc index 2ed2acb96dc842b6a60bf31701d39ac94dab9804..62fd72b335b9211586362a38c159408a7662eb0f 100644 --- a/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc @@ -36,8 +36,12 @@ class CheckpointNotifyOp : public framework::OperatorBase { Attr>("endpoints"); std::string dirname = Attr("dirname"); std::string varname = Attr("varname"); - auto is_slice = Attr("is_slice"); - VLOG(1) << "is_slice: " << is_slice; + auto mode = Attr("mode"); + + if (mode != 0 && mode != 1 && mode != 2) { + PADDLE_THROW(platform::errors::InvalidArgument( + "mode expected in [0/1/2], but got %d", mode)); + } std::vector slice_varnames = Attr>("slice_varnames"); @@ -52,11 +56,12 @@ class CheckpointNotifyOp : public framework::OperatorBase { auto save_path = string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]); - rpc_client->AsyncCheckpointNotify(epmap[i], save_path, - remote_varnames[i]); + rpc_client->AsyncCheckpointNotify(epmap[i], save_path, remote_varnames[i], + mode); VLOG(3) << "checkpoint notify sending with path: " << save_path - << " and var:" << slice_varnames[i] << " to " << epmap[i]; + << " and var:" << slice_varnames[i] << " to " << epmap[i] + << " with mode " << mode; } PADDLE_ENFORCE_EQ( rpc_client->Wait(), true, @@ -79,9 +84,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { "slice_varnames", "(string vector) the slice vars need to be saved"); AddAttr>( "remote_varnames", "(string vector) the slice vars need to be saved"); - AddAttr( - "is_slice", - "is_slice=True means the var has been slice by parameter server"); + AddAttr("mode", "mode=0/1/2 means nothing/save base/save delta") + .SetDefault(0); AddComment(R"DOC( CheckpointNotify operator This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at diff --git a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py index 6dd4661f00062f55bb834bbee50daf1924a0c87a..3f1fc518fe6384842727c719fcf3f1d7e9a044ea 100644 --- a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py +++ b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py @@ -436,8 +436,7 @@ class ParameterServerRuntime(RuntimeBase): executor.run(prog) return context.keys() - def _save_distributed_params(self, executor, dirname, context, - main_program): + def _save_distributed_params(self, executor, dirname, context, mode): prog = Program() block = prog.global_block() @@ -446,7 +445,7 @@ class ParameterServerRuntime(RuntimeBase): type='checkpoint_notify', attrs={ "varname": name, - "is_slice": True, + "mode": mode, "slice_varnames": var_ctx.split_varnames(), "remote_varnames": var_ctx.split_varnames(), "endpoints": var_ctx.split_endpoints(), @@ -456,7 +455,8 @@ class ParameterServerRuntime(RuntimeBase): executor.run(prog) return context.keys() - def _save_distributed_persistables(self, executor, dirname, main_program): + def _save_distributed_persistables(self, executor, dirname, main_program, + mode): dense_ctx = self.compiled_strategy.get_communicator_recv_context( recv_type=1) @@ -473,7 +473,7 @@ class ParameterServerRuntime(RuntimeBase): executor, dirname, sparse_ctx, main_program) recv_distributed_varnames = self._save_distributed_params( - executor, dirname, distributed_ctx, main_program) + executor, dirname, distributed_ctx, mode) saved_varnames = recv_dense_varnames + list( recv_sparse_varnames) + list(recv_distributed_varnames) @@ -493,6 +493,7 @@ class ParameterServerRuntime(RuntimeBase): executor, dirname, main_program=None, + mode=0, **kwargs): """ This function filters out all variables with `persistable==True` from the @@ -523,7 +524,8 @@ class ParameterServerRuntime(RuntimeBase): "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed" ) - self._save_distributed_persistables(executor, dirname, main_program) + self._save_distributed_persistables(executor, dirname, main_program, + mode) def _ps_inference_save_inference_model(self, executor, @@ -569,7 +571,8 @@ class ParameterServerRuntime(RuntimeBase): program = Program.parse_from_string(program_desc_str) program._copy_dist_param_info_from(fluid.default_main_program()) - self._ps_inference_save_persistables(executor, dirname, program) + self._ps_inference_save_persistables( + executor, dirname, program, mode=0) def _save_inference_model(self, *args, **kwargs): self._ps_inference_save_inference_model(*args, **kwargs)