From d93dc81c4eeaa070586ed25055933a4e6bda57e4 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 15:14:10 +0800 Subject: [PATCH] add handle when checkpoint_notify_id = -1 --- .../operators/detail/request_handler_impl.cc | 8 ++++++-- .../operators/detail/request_handler_impl.h | 9 +++++++-- paddle/fluid/operators/listen_and_serv_op.cc | 18 ++++++++++-------- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 87fa5842c..859f6a757 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -125,11 +125,15 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Variable* invar, framework::Variable** outvar, const std::string& out_var_name) { + PADDLE_ENFORCE( + checkpoint_notify_id != -1, + "when checkpoint_notify_id = -1, there should be no RPC invoke."); - auto *lt_var = scope->FindVar("loopup_table_path")->GetMutable(); + auto* lt_var = scope->FindVar("loopup_table_path")->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); - VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " << out_var_name; + VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " + << out_var_name; executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; } diff --git a/paddle/fluid/operators/detail/request_handler_impl.h b/paddle/fluid/operators/detail/request_handler_impl.h index 643eae4d3..b7cebf1a6 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.h +++ b/paddle/fluid/operators/detail/request_handler_impl.h @@ -68,12 +68,17 @@ class RequestPrefetchHandler final : public RequestHandler { class RequestCheckpointHandler final : public RequestHandler { public: - explicit RequestCheckpointHandler(bool sync_mode) - : RequestHandler(sync_mode) {} + explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id) + : RequestHandler(sync_mode) { + this.checkpoint_notify_id = checkpoint_notify_id; + } virtual ~RequestCheckpointHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, framework::Variable* var, framework::Variable** outvar, const std::string& out_var_name = "") override; + + private: + int checkpoint_notify_id; }; } // namespace detail diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 78b8c96f4..477cb90ef 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -247,9 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); + int checkpoint_point_block_id = Attr(kCheckpointBlockId); LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in - << ", end_point:" << endpoint; + << ", end_point:" << endpoint + << ", CheckpointNotify Id: " << checkpoint_notify_id; rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); @@ -258,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_prefetch_handler_.reset( new detail::RequestPrefetchHandler(sync_mode)); request_checkpoint_handler_.reset( - new detail::RequestCheckpointHandler(sync_mode)); + new detail::RequestCheckpointHandler(sync_mode, checkpoint_notify_id)); rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); @@ -267,6 +269,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, rpc_service_->RegisterRPC(detail::kRequestCheckpoint, request_checkpoint_handler_.get()); + std::shared_ptr ckpt_pre_context = nullptr; + if (checkpoint_notify_id != -1) { + auto ctx = executor.Prepare(*program, checkpoint_point_block_id); + ckpt_pre_context = std::move(ctx); + } + auto *optimize_block = Attr(kOptimizeBlock); auto *program = optimize_block->Program(); framework::Executor executor(dev_place); @@ -301,12 +309,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; } - int checkpoint_point_block_id = Attr(kCheckpointBlockId); - auto ctx = executor.Prepare(*program, checkpoint_point_block_id); - - std::shared_ptr ckpt_pre_context = - std::move(ctx); - auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, &executor, program, &prefetch_var_name_to_prepared_ctx, -- GitLab