提交 ae12281d 编写于 作者: T tangwei12

checkpoint notify

上级 30880844
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,12 +37,14 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -36,12 +37,14 @@ class CheckpointNotifyOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir"); std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table");
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < epmap.size(); i++) { for (size_t i = 0; i < epmap.size(); i++) {
VLOG(3) << "sending to " << epmap[i] << " to checkpoint notify ... "; VLOG(3) << "sending " << dir <<" to " << epmap[i] << " to checkpoint notify ... ";
rpc_client->AsyncCheckpointNotify(epmap[i], dir); auto serial_looku_table = string::Sprintf("%s/%s.%d", dir, lookup_table_name, i);
rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table);
} }
rpc_client->Wait(); rpc_client->Wait();
} }
...@@ -57,6 +60,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,6 +60,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({"127.0.0.1:6164"}); .SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>( AddAttr<std::string>(
"dir", "(string, default '') indicate the folder checkpoint will use"); "dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>(
"lookup_table", "(string, default '') the lookup table name");
AddComment(R"DOC( AddComment(R"DOC(
Prefetch operator Prefetch operator
......
...@@ -208,11 +208,14 @@ class RequestCheckpointNotify final : public RequestBase { ...@@ -208,11 +208,14 @@ class RequestCheckpointNotify final : public RequestBase {
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
std::string checkpoint_notify = request_->Varname(); std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->Varname(); std::string checkpoint_dir = request_->OutVarname();
framework::Variable* invar = nullptr; framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, invar, &outvar, request_handler_->Handle(checkpoint_notify, scope, invar, &outvar,
checkpoint_dir); checkpoint_dir);
Finish(reply_, &responder_); Finish(reply_, &responder_);
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -124,6 +125,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -124,6 +125,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const std::string& out_var_name) { const std::string& out_var_name) {
auto lt_varname = string::Sprintf("%s.path", varname);
auto *lt_var = scope->FindVar(lt_varname)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update " << lt_varname << " to: " << out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true; return true;
} }
......
...@@ -87,7 +87,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -87,7 +87,7 @@ class SaveOp : public framework::OperatorBase {
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
SaveLodTensor(filename, place, var); SaveLodTensor(filename, place, var);
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(filename, place, var); SaveSelectedRows(scope, place, var);
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
false, false,
...@@ -128,9 +128,17 @@ class SaveOp : public framework::OperatorBase { ...@@ -128,9 +128,17 @@ class SaveOp : public framework::OperatorBase {
fout.close(); fout.close();
} }
void SaveSelectedRows(const std::string &filename, void SaveSelectedRows(const framework::Scope &scope,
const platform::Place &place, const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
auto lt_varname = string::Sprintf("%s.path", Input("X"));
auto *lt_var = scope.FindVar(lt_varname)->GetMutable<std::string>();
PADDLE_ENFORCE(lt_var != nullptr, "Cannot find variable %s for SaveSelectedRows",
lt_varname);
std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename;
auto &selectedRows = var->Get<framework::SelectedRows>(); auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool // get device context from pool
......
...@@ -471,7 +471,10 @@ def save_checkpoint(executor, ...@@ -471,7 +471,10 @@ def save_checkpoint(executor,
trainer_id, trainer_id,
trainer_args=None, trainer_args=None,
main_program=None, main_program=None,
max_num_checkpoints=3): max_num_checkpoints=3,
lookup_table=None,
ps_endpoint_list=None
):
""" """
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
...@@ -500,7 +503,7 @@ def save_checkpoint(executor, ...@@ -500,7 +503,7 @@ def save_checkpoint(executor,
if trainer_id == 0: if trainer_id == 0:
save_persist_vars_without_grad(executor, cur_dir, main_program) save_persist_vars_without_grad(executor, cur_dir, main_program)
save_pserver_vars_by_notify(executor, cur_dir, "") save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table)
_scroll_delete(checkpoint_dir, max_num_checkpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints)
...@@ -600,7 +603,7 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -600,7 +603,7 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success(cur_dir) _write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, epmap): def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list):
""" """
""" """
cur_dir = _get_lookuptable_dir(dirname) cur_dir = _get_lookuptable_dir(dirname)
...@@ -609,11 +612,12 @@ def save_pserver_vars_by_notify(executor, dirname, epmap): ...@@ -609,11 +612,12 @@ def save_pserver_vars_by_notify(executor, dirname, epmap):
checkpoint_notify_block = checkpoint_notify_program.global_block() checkpoint_notify_block = checkpoint_notify_program.global_block()
attrs = {} attrs = {}
attrs['epmap'] = None attrs['epmap'] = ps_endpoint_list
attrs['dir'] = cur_dir attrs['dir'] = cur_dir
attrs['lookup_table'] = lookup_table
checkpoint_notify_block.append_op( checkpoint_notify_block.append_op(
type='checkpoint_notify', inputs={}, output={}, attrs=attrs) type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(checkpoint_notify_program) executor.run(checkpoint_notify_program)
...@@ -783,3 +787,4 @@ def get_latest_checkpoint_serial(checkpoint_dir): ...@@ -783,3 +787,4 @@ def get_latest_checkpoint_serial(checkpoint_dir):
if success_num > current_dir: if success_num > current_dir:
current_dir = success_num current_dir = success_num
return current_dir return current_dir
...@@ -838,13 +838,15 @@ class DistributeTranspiler: ...@@ -838,13 +838,15 @@ class DistributeTranspiler:
""" """
import os import os
pserver_program.global_block().create_var(name="%s.path"%self.table_name, persistable=True, type=core.VarDesc.VarType.RAW)
checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block = pserver_program.create_block(pre_block_idx)
checkpoint_save_block.append_op( checkpoint_save_block.append_op(
type='save', type='save',
inputs={'X': [self.table_name]}, inputs={'X': [self.table_name]},
outputs={}, outputs={},
attrs={ attrs={
'file_path': os.path.join("/tmp/pserver_ckpt/", self.table_name) 'file_path': self.table_name)
}) })
return checkpoint_save_block.idx return checkpoint_save_block.idx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册