提交 f70096a2 编写于 作者: S seiriosPlus

add mode for save

上级 11d17938
......@@ -448,6 +448,7 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dirname,
const std::string& varname,
const int mode,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(varname);
......
......@@ -103,7 +103,7 @@ class BRPCClient : public RPCClient {
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname,
const std::string& varname,
const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
......
......@@ -420,6 +420,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dirname,
const std::string& varname,
const int mode,
int64_t time_out) {
const auto ch = GetChannel(ep);
......@@ -433,6 +434,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(varname);
req.set_table_name(std::to_string(mode));
req.set_out_varname(dirname);
platform::RecordRPCEvent record_event(method);
......
......@@ -247,7 +247,7 @@ class GRPCClient : public RPCClient {
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname,
const std::string& varname,
const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify(
......
......@@ -266,13 +266,13 @@ bool RequestCheckpointHandler::Handle(const std::string &varname,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "receive save var " << varname << " with path " << out_var_name;
int mode = std::stoi(out_var_name);
VLOG(4) << "receive save var " << varname << " with path " << out_var_name
<< " mode " << mode;
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Save(out_var_name);
// auto checkpoint_op = BuildCheckpointOp(varname, out_var_name);
// paddle::platform::CPUPlace cpu_place;
// checkpoint_op->Run(*scope_, cpu_place);
ins->Get(varname)->Save(out_var_name, mode);
return true;
}
......
......@@ -78,7 +78,8 @@ class RPCClient {
virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dirname,
const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0;
const std::string& varname, const int mode,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx,
......
......@@ -36,8 +36,12 @@ class CheckpointNotifyOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("endpoints");
std::string dirname = Attr<std::string>("dirname");
std::string varname = Attr<std::string>("varname");
auto is_slice = Attr<bool>("is_slice");
VLOG(1) << "is_slice: " << is_slice;
auto mode = Attr<int>("mode");
if (mode != 0 && mode != 1 && mode != 2) {
PADDLE_THROW(platform::errors::InvalidArgument(
"mode expected in [0/1/2], but got %d", mode));
}
std::vector<std::string> slice_varnames =
Attr<std::vector<std::string>>("slice_varnames");
......@@ -52,11 +56,12 @@ class CheckpointNotifyOp : public framework::OperatorBase {
auto save_path =
string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]);
rpc_client->AsyncCheckpointNotify(epmap[i], save_path,
remote_varnames[i]);
rpc_client->AsyncCheckpointNotify(epmap[i], save_path, remote_varnames[i],
mode);
VLOG(3) << "checkpoint notify sending with path: " << save_path
<< " and var:" << slice_varnames[i] << " to " << epmap[i];
<< " and var:" << slice_varnames[i] << " to " << epmap[i]
<< " with mode " << mode;
}
PADDLE_ENFORCE_EQ(
rpc_client->Wait(), true,
......@@ -79,9 +84,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
"slice_varnames", "(string vector) the slice vars need to be saved");
AddAttr<std::vector<std::string>>(
"remote_varnames", "(string vector) the slice vars need to be saved");
AddAttr<bool>(
"is_slice",
"is_slice=True means the var has been slice by parameter server");
AddAttr<int>("mode", "mode=0/1/2 means nothing/save base/save delta")
.SetDefault(0);
AddComment(R"DOC(
CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
......
......@@ -436,8 +436,7 @@ class ParameterServerRuntime(RuntimeBase):
executor.run(prog)
return context.keys()
def _save_distributed_params(self, executor, dirname, context,
main_program):
def _save_distributed_params(self, executor, dirname, context, mode):
prog = Program()
block = prog.global_block()
......@@ -446,7 +445,7 @@ class ParameterServerRuntime(RuntimeBase):
type='checkpoint_notify',
attrs={
"varname": name,
"is_slice": True,
"mode": mode,
"slice_varnames": var_ctx.split_varnames(),
"remote_varnames": var_ctx.split_varnames(),
"endpoints": var_ctx.split_endpoints(),
......@@ -456,7 +455,8 @@ class ParameterServerRuntime(RuntimeBase):
executor.run(prog)
return context.keys()
def _save_distributed_persistables(self, executor, dirname, main_program):
def _save_distributed_persistables(self, executor, dirname, main_program,
mode):
dense_ctx = self.compiled_strategy.get_communicator_recv_context(
recv_type=1)
......@@ -473,7 +473,7 @@ class ParameterServerRuntime(RuntimeBase):
executor, dirname, sparse_ctx, main_program)
recv_distributed_varnames = self._save_distributed_params(
executor, dirname, distributed_ctx, main_program)
executor, dirname, distributed_ctx, mode)
saved_varnames = recv_dense_varnames + list(
recv_sparse_varnames) + list(recv_distributed_varnames)
......@@ -493,6 +493,7 @@ class ParameterServerRuntime(RuntimeBase):
executor,
dirname,
main_program=None,
mode=0,
**kwargs):
"""
This function filters out all variables with `persistable==True` from the
......@@ -523,7 +524,8 @@ class ParameterServerRuntime(RuntimeBase):
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
)
self._save_distributed_persistables(executor, dirname, main_program)
self._save_distributed_persistables(executor, dirname, main_program,
mode)
def _ps_inference_save_inference_model(self,
executor,
......@@ -569,7 +571,8 @@ class ParameterServerRuntime(RuntimeBase):
program = Program.parse_from_string(program_desc_str)
program._copy_dist_param_info_from(fluid.default_main_program())
self._ps_inference_save_persistables(executor, dirname, program)
self._ps_inference_save_persistables(
executor, dirname, program, mode=0)
def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册