提交 dc847f12 编写于 作者: T tangwei12

bug fix and code optimize

上级 fb7e4791
...@@ -30,7 +30,7 @@ namespace distributed { ...@@ -30,7 +30,7 @@ namespace distributed {
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables // define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified. // to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path"; constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string& varname, bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
...@@ -136,7 +136,7 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -136,7 +136,7 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>(); auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear(); lt_var->clear();
lt_var->append(out_var_name); lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var lookup_table_path to: " VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name; << out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true; return true;
......
...@@ -206,7 +206,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -206,7 +206,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
VLOG(3) << "RunAsyncLoop into while"; VLOG(3) << "RunAsyncLoop into while";
while (true) { while (true) {
if (rpc_service_->IsExit()) { if (rpc_service_->IsExit()) {
LOG(INFO) << "get exit!rpc_processor break!"; VLOG(4) << "get exit!rpc_processor break!";
break; break;
} }
...@@ -245,11 +245,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -245,11 +245,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_notify_block_id = Attr<int>(kCheckpointBlockId); int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in VLOG(4) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint << ", end_point:" << endpoint
<< ", CheckpointNotify Id: " << checkpoint_notify_block_id; << ", checkpoint_block_id: " << checkpoint_block_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
...@@ -258,7 +258,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -258,7 +258,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(sync_mode)); new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_notify_block_id)); sync_mode, checkpoint_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get()); request_send_handler_.get());
...@@ -277,8 +277,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -277,8 +277,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr; std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
if (checkpoint_notify_block_id != -1) { if (checkpoint_block_id != -1) {
auto ctx = executor.Prepare(*program, checkpoint_notify_block_id); auto ctx = executor.Prepare(*program, checkpoint_block_id);
// see: https://stackoverflow.com/a/14856553
ckpt_pre_context = std::move(ctx); ckpt_pre_context = std::move(ctx);
} }
...@@ -335,7 +336,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -335,7 +336,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
SavePort(); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
checkpoint_notify_block_id); checkpoint_block_id);
} else { } else {
RunAsyncLoop(&executor, program); RunAsyncLoop(&executor, program);
} }
......
...@@ -31,7 +31,7 @@ namespace operators { ...@@ -31,7 +31,7 @@ namespace operators {
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables // define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified. // to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path"; constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
// TODO(yuyang18): If the functions below are needed by other files, move them // TODO(yuyang18): If the functions below are needed by other files, move them
// to paddle::filesystem namespace. // to paddle::filesystem namespace.
...@@ -138,7 +138,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -138,7 +138,7 @@ class SaveOp : public framework::OperatorBase {
auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>(); auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
PADDLE_ENFORCE( PADDLE_ENFORCE(
lt_var != nullptr, lt_var != nullptr,
"Can not find variable lookup_table_path for SaveSelectedRows"); "Can not find variable kLookupTablePath for SaveSelectedRows");
std::string filename = lt_var->data(); std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename; VLOG(4) << "SaveSelectedRows get File name: " << filename;
......
...@@ -920,19 +920,17 @@ class DistributeTranspiler(object): ...@@ -920,19 +920,17 @@ class DistributeTranspiler(object):
import os import os
pserver_program.global_block().create_var( pserver_program.global_block().create_var(
name="lookup_table_path", name="kLookupTablePath",
persistable=True, persistable=True,
type=core.VarDesc.VarType.RAW) type=core.VarDesc.VarType.RAW)
checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block = pserver_program.create_block(pre_block_idx)
# this 'file_path' do not be used in save lookup table variable
checkpoint_save_block.append_op( checkpoint_save_block.append_op(
type='save', type='save',
inputs={'X': [self.table_name]}, inputs={'X': [self.table_name]},
outputs={}, outputs={},
attrs={ attrs={'file_path': ""})
'file_path':
"this 'file_path' do not be used in save lookup table variable"
})
return checkpoint_save_block.idx return checkpoint_save_block.idx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册