提交 d93dc81c 编写于 作者: T tangwei12

add handle when checkpoint_notify_id = -1

上级 1571c25a
...@@ -125,11 +125,15 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -125,11 +125,15 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const std::string& out_var_name) { const std::string& out_var_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
auto *lt_var = scope->FindVar("loopup_table_path")->GetMutable<std::string>(); auto* lt_var = scope->FindVar("loopup_table_path")->GetMutable<std::string>();
lt_var->clear(); lt_var->clear();
lt_var->append(out_var_name); lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " << out_var_name; VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true; return true;
} }
......
...@@ -68,12 +68,17 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -68,12 +68,17 @@ class RequestPrefetchHandler final : public RequestHandler {
class RequestCheckpointHandler final : public RequestHandler { class RequestCheckpointHandler final : public RequestHandler {
public: public:
explicit RequestCheckpointHandler(bool sync_mode) explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id)
: RequestHandler(sync_mode) {} : RequestHandler(sync_mode) {
this.checkpoint_notify_id = checkpoint_notify_id;
}
virtual ~RequestCheckpointHandler() {} virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
private:
int checkpoint_notify_id;
}; };
} // namespace detail } // namespace detail
......
...@@ -247,9 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -247,9 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_point_block_id = Attr<int>(kCheckpointBlockId);
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint; << ", end_point:" << endpoint
<< ", CheckpointNotify Id: " << checkpoint_notify_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
...@@ -258,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -258,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode)); new detail::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset( request_checkpoint_handler_.reset(
new detail::RequestCheckpointHandler(sync_mode)); new detail::RequestCheckpointHandler(sync_mode, checkpoint_notify_id));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get());
...@@ -267,6 +269,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -267,6 +269,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->RegisterRPC(detail::kRequestCheckpoint, rpc_service_->RegisterRPC(detail::kRequestCheckpoint,
request_checkpoint_handler_.get()); request_checkpoint_handler_.get());
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
if (checkpoint_notify_id != -1) {
auto ctx = executor.Prepare(*program, checkpoint_point_block_id);
ckpt_pre_context = std::move(ctx);
}
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = optimize_block->Program(); auto *program = optimize_block->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
...@@ -301,12 +309,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -301,12 +309,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
} }
int checkpoint_point_block_id = Attr<int>(kCheckpointBlockId);
auto ctx = executor.Prepare(*program, checkpoint_point_block_id);
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context =
std::move(ctx);
auto f = auto f =
std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx,
&executor, program, &prefetch_var_name_to_prepared_ctx, &executor, program, &prefetch_var_name_to_prepared_ctx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册