提交 f70096a2 编写于 作者: S seiriosPlus

add mode for save

上级 11d17938
...@@ -448,6 +448,7 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep, ...@@ -448,6 +448,7 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dirname, const std::string& dirname,
const std::string& varname, const std::string& varname,
const int mode,
int64_t time_out) { int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(varname); req.set_varname(varname);
......
...@@ -103,7 +103,7 @@ class BRPCClient : public RPCClient { ...@@ -103,7 +103,7 @@ class BRPCClient : public RPCClient {
VarHandlePtr AsyncCheckpointNotify( VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname, const std::string& ep, const std::string& dirname,
const std::string& varname, const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override; bool Wait() override;
......
...@@ -420,6 +420,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, ...@@ -420,6 +420,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dirname, const std::string& dirname,
const std::string& varname, const std::string& varname,
const int mode,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
...@@ -433,6 +434,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -433,6 +434,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(varname); req.set_varname(varname);
req.set_table_name(std::to_string(mode));
req.set_out_varname(dirname); req.set_out_varname(dirname);
platform::RecordRPCEvent record_event(method); platform::RecordRPCEvent record_event(method);
......
...@@ -247,7 +247,7 @@ class GRPCClient : public RPCClient { ...@@ -247,7 +247,7 @@ class GRPCClient : public RPCClient {
VarHandlePtr AsyncCheckpointNotify( VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname, const std::string& ep, const std::string& dirname,
const std::string& varname, const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify( VarHandlePtr AsyncDistributeNotify(
......
...@@ -266,13 +266,13 @@ bool RequestCheckpointHandler::Handle(const std::string &varname, ...@@ -266,13 +266,13 @@ bool RequestCheckpointHandler::Handle(const std::string &varname,
const int trainer_id, const int trainer_id,
const std::string &out_var_name, const std::string &out_var_name,
const std::string &table_name) { const std::string &table_name) {
VLOG(4) << "receive save var " << varname << " with path " << out_var_name; int mode = std::stoi(out_var_name);
VLOG(4) << "receive save var " << varname << " with path " << out_var_name
<< " mode " << mode;
auto *ins = distributed::LargeScaleKV::GetInstance(); auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Save(out_var_name); ins->Get(varname)->Save(out_var_name, mode);
// auto checkpoint_op = BuildCheckpointOp(varname, out_var_name);
// paddle::platform::CPUPlace cpu_place;
// checkpoint_op->Run(*scope_, cpu_place);
return true; return true;
} }
......
...@@ -78,7 +78,8 @@ class RPCClient { ...@@ -78,7 +78,8 @@ class RPCClient {
virtual VarHandlePtr AsyncCheckpointNotify( virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname, const std::string& ep, const std::string& dirname,
const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify( virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
......
...@@ -36,8 +36,12 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -36,8 +36,12 @@ class CheckpointNotifyOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("endpoints"); Attr<std::vector<std::string>>("endpoints");
std::string dirname = Attr<std::string>("dirname"); std::string dirname = Attr<std::string>("dirname");
std::string varname = Attr<std::string>("varname"); std::string varname = Attr<std::string>("varname");
auto is_slice = Attr<bool>("is_slice"); auto mode = Attr<int>("mode");
VLOG(1) << "is_slice: " << is_slice;
if (mode != 0 && mode != 1 && mode != 2) {
PADDLE_THROW(platform::errors::InvalidArgument(
"mode expected in [0/1/2], but got %d", mode));
}
std::vector<std::string> slice_varnames = std::vector<std::string> slice_varnames =
Attr<std::vector<std::string>>("slice_varnames"); Attr<std::vector<std::string>>("slice_varnames");
...@@ -52,11 +56,12 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -52,11 +56,12 @@ class CheckpointNotifyOp : public framework::OperatorBase {
auto save_path = auto save_path =
string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]); string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]);
rpc_client->AsyncCheckpointNotify(epmap[i], save_path, rpc_client->AsyncCheckpointNotify(epmap[i], save_path, remote_varnames[i],
remote_varnames[i]); mode);
VLOG(3) << "checkpoint notify sending with path: " << save_path VLOG(3) << "checkpoint notify sending with path: " << save_path
<< " and var:" << slice_varnames[i] << " to " << epmap[i]; << " and var:" << slice_varnames[i] << " to " << epmap[i]
<< " with mode " << mode;
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rpc_client->Wait(), true, rpc_client->Wait(), true,
...@@ -79,9 +84,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -79,9 +84,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
"slice_varnames", "(string vector) the slice vars need to be saved"); "slice_varnames", "(string vector) the slice vars need to be saved");
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"remote_varnames", "(string vector) the slice vars need to be saved"); "remote_varnames", "(string vector) the slice vars need to be saved");
AddAttr<bool>( AddAttr<int>("mode", "mode=0/1/2 means nothing/save base/save delta")
"is_slice", .SetDefault(0);
"is_slice=True means the var has been slice by parameter server");
AddComment(R"DOC( AddComment(R"DOC(
CheckpointNotify operator CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
......
...@@ -436,8 +436,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -436,8 +436,7 @@ class ParameterServerRuntime(RuntimeBase):
executor.run(prog) executor.run(prog)
return context.keys() return context.keys()
def _save_distributed_params(self, executor, dirname, context, def _save_distributed_params(self, executor, dirname, context, mode):
main_program):
prog = Program() prog = Program()
block = prog.global_block() block = prog.global_block()
...@@ -446,7 +445,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -446,7 +445,7 @@ class ParameterServerRuntime(RuntimeBase):
type='checkpoint_notify', type='checkpoint_notify',
attrs={ attrs={
"varname": name, "varname": name,
"is_slice": True, "mode": mode,
"slice_varnames": var_ctx.split_varnames(), "slice_varnames": var_ctx.split_varnames(),
"remote_varnames": var_ctx.split_varnames(), "remote_varnames": var_ctx.split_varnames(),
"endpoints": var_ctx.split_endpoints(), "endpoints": var_ctx.split_endpoints(),
...@@ -456,7 +455,8 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -456,7 +455,8 @@ class ParameterServerRuntime(RuntimeBase):
executor.run(prog) executor.run(prog)
return context.keys() return context.keys()
def _save_distributed_persistables(self, executor, dirname, main_program): def _save_distributed_persistables(self, executor, dirname, main_program,
mode):
dense_ctx = self.compiled_strategy.get_communicator_recv_context( dense_ctx = self.compiled_strategy.get_communicator_recv_context(
recv_type=1) recv_type=1)
...@@ -473,7 +473,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -473,7 +473,7 @@ class ParameterServerRuntime(RuntimeBase):
executor, dirname, sparse_ctx, main_program) executor, dirname, sparse_ctx, main_program)
recv_distributed_varnames = self._save_distributed_params( recv_distributed_varnames = self._save_distributed_params(
executor, dirname, distributed_ctx, main_program) executor, dirname, distributed_ctx, mode)
saved_varnames = recv_dense_varnames + list( saved_varnames = recv_dense_varnames + list(
recv_sparse_varnames) + list(recv_distributed_varnames) recv_sparse_varnames) + list(recv_distributed_varnames)
...@@ -493,6 +493,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -493,6 +493,7 @@ class ParameterServerRuntime(RuntimeBase):
executor, executor,
dirname, dirname,
main_program=None, main_program=None,
mode=0,
**kwargs): **kwargs):
""" """
This function filters out all variables with `persistable==True` from the This function filters out all variables with `persistable==True` from the
...@@ -523,7 +524,8 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -523,7 +524,8 @@ class ParameterServerRuntime(RuntimeBase):
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed" "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
) )
self._save_distributed_persistables(executor, dirname, main_program) self._save_distributed_persistables(executor, dirname, main_program,
mode)
def _ps_inference_save_inference_model(self, def _ps_inference_save_inference_model(self,
executor, executor,
...@@ -569,7 +571,8 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -569,7 +571,8 @@ class ParameterServerRuntime(RuntimeBase):
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
program._copy_dist_param_info_from(fluid.default_main_program()) program._copy_dist_param_info_from(fluid.default_main_program())
self._ps_inference_save_persistables(executor, dirname, program) self._ps_inference_save_persistables(
executor, dirname, program, mode=0)
def _save_inference_model(self, *args, **kwargs): def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs) self._ps_inference_save_inference_model(*args, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册