提交 ae12281d 编写于 作者: T tangwei12

checkpoint notify

上级 30880844
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
......@@ -36,12 +37,14 @@ class CheckpointNotifyOp : public framework::OperatorBase {
const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table");
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < epmap.size(); i++) {
VLOG(3) << "sending to " << epmap[i] << " to checkpoint notify ... ";
rpc_client->AsyncCheckpointNotify(epmap[i], dir);
VLOG(3) << "sending " << dir <<" to " << epmap[i] << " to checkpoint notify ... ";
auto serial_looku_table = string::Sprintf("%s/%s.%d", dir, lookup_table_name, i);
rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table);
}
rpc_client->Wait();
}
......@@ -57,6 +60,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>(
"dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>(
"lookup_table", "(string, default '') the lookup table name");
AddComment(R"DOC(
Prefetch operator
......
......@@ -208,11 +208,14 @@ class RequestCheckpointNotify final : public RequestBase {
auto scope = request_->GetMutableLocalScope();
std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->Varname();
std::string checkpoint_dir = request_->OutVarname();
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, invar, &outvar,
checkpoint_dir);
Finish(reply_, &responder_);
......
......@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
......@@ -124,6 +125,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const std::string& out_var_name) {
auto lt_varname = string::Sprintf("%s.path", varname);
auto *lt_var = scope->FindVar(lt_varname)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update " << lt_varname << " to: " << out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true;
}
......
......@@ -87,7 +87,7 @@ class SaveOp : public framework::OperatorBase {
if (var->IsType<framework::LoDTensor>()) {
SaveLodTensor(filename, place, var);
} else if (var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(filename, place, var);
SaveSelectedRows(scope, place, var);
} else {
PADDLE_ENFORCE(
false,
......@@ -128,9 +128,17 @@ class SaveOp : public framework::OperatorBase {
fout.close();
}
void SaveSelectedRows(const std::string &filename,
void SaveSelectedRows(const framework::Scope &scope,
const platform::Place &place,
framework::Variable *var) const {
auto lt_varname = string::Sprintf("%s.path", Input("X"));
auto *lt_var = scope.FindVar(lt_varname)->GetMutable<std::string>();
PADDLE_ENFORCE(lt_var != nullptr, "Cannot find variable %s for SaveSelectedRows",
lt_varname);
std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename;
auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool
......
......@@ -471,7 +471,10 @@ def save_checkpoint(executor,
trainer_id,
trainer_args=None,
main_program=None,
max_num_checkpoints=3):
max_num_checkpoints=3,
lookup_table=None,
ps_endpoint_list=None
):
"""
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
......@@ -500,7 +503,7 @@ def save_checkpoint(executor,
if trainer_id == 0:
save_persist_vars_without_grad(executor, cur_dir, main_program)
save_pserver_vars_by_notify(executor, cur_dir, "")
save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table)
_scroll_delete(checkpoint_dir, max_num_checkpoints)
......@@ -600,7 +603,7 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, epmap):
def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list):
"""
"""
cur_dir = _get_lookuptable_dir(dirname)
......@@ -609,11 +612,12 @@ def save_pserver_vars_by_notify(executor, dirname, epmap):
checkpoint_notify_block = checkpoint_notify_program.global_block()
attrs = {}
attrs['epmap'] = None
attrs['epmap'] = ps_endpoint_list
attrs['dir'] = cur_dir
attrs['lookup_table'] = lookup_table
checkpoint_notify_block.append_op(
type='checkpoint_notify', inputs={}, output={}, attrs=attrs)
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(checkpoint_notify_program)
......@@ -783,3 +787,4 @@ def get_latest_checkpoint_serial(checkpoint_dir):
if success_num > current_dir:
current_dir = success_num
return current_dir
......@@ -838,13 +838,15 @@ class DistributeTranspiler:
"""
import os
pserver_program.global_block().create_var(name="%s.path"%self.table_name, persistable=True, type=core.VarDesc.VarType.RAW)
checkpoint_save_block = pserver_program.create_block(pre_block_idx)
checkpoint_save_block.append_op(
type='save',
inputs={'X': [self.table_name]},
outputs={},
attrs={
'file_path': os.path.join("/tmp/pserver_ckpt/", self.table_name)
'file_path': self.table_name)
})
return checkpoint_save_block.idx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册