diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 55995783c6eab10632ab2a5bca64ca856f000df1..de1a503154deb967eb4389a9f43b86c05626d966 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -41,6 +41,7 @@ bool RequestSendHandler::Handle(const std::string& varname, // Async if (!sync_mode_) { + rpc_server_->Profiler().OneStep(); try { executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), scope); diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index 83b14fa64d735d80f43bf55c798cddb2f3ea7032..406e7294c190172347d432fb155c2a81c43dda25 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -18,11 +18,44 @@ #include #include "paddle/fluid/operators/distributed/rpc_server.h" +#include "paddle/fluid/platform/profiler.h" + +DEFINE_int32(rpc_server_profile_period, 0, + "the period of listen_and_serv to do profile"); +DEFINE_string(rpc_server_profile_path, "/dev/null", + "the profile log file path"); namespace paddle { namespace operators { namespace distributed { +RPCServerProfiler::RPCServerProfiler(int profile_period, + const std::string& profile_log_path) + : profile_period_(profile_period), profile_log_path_(profile_log_path) { + step_ = 0; +} + +void RPCServerProfiler::OneStep() { + PADDLE_ENFORCE_LE(step_, profile_period_, + "step_ should not be larger then " + "profile_period_"); + if (profile_period_ <= 0) { + return; + } + + if (step_ == 0) { + auto pf_state = paddle::platform::ProfilerState::kCPU; + paddle::platform::EnableProfiler(pf_state); + } + if (step_ == profile_period_) { + paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kTotal, + profile_log_path_); + step_ = 0; + } else { + step_++; + } +} + void RPCServer::ShutDown() { LOG(INFO) << "RPCServer ShutDown "; ShutDownImpl(); diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h index fd914d7a72e61bc9472876c433b65598ef5b1980..d813ba03e2fbec6e808f59f814a9b2f4bfbcd77b 100644 --- a/paddle/fluid/operators/distributed/rpc_server.h +++ b/paddle/fluid/operators/distributed/rpc_server.h @@ -19,16 +19,33 @@ #include // NOLINT #include #include + #include "paddle/fluid/operators/distributed/request_handler.h" +DECLARE_int32(rpc_server_profile_period); +DECLARE_string(rpc_server_profile_path); + namespace paddle { namespace operators { namespace distributed { +class RPCServerProfiler { + public: + RPCServerProfiler(int profile_period, const std::string& profile_log_path); + void OneStep(); + + private: + const int profile_period_; + std::string profile_log_path_; + int step_; +}; + class RPCServer { public: explicit RPCServer(const std::string& address, int client_num) : cur_cond_(0), + profiler_(FLAGS_rpc_server_profile_period, + FLAGS_rpc_server_profile_path), bind_address_(address), exit_flag_(false), selected_port_(0), @@ -67,6 +84,7 @@ class RPCServer { void Complete(); void ResetBarrierCounter(); + RPCServerProfiler& Profiler() { return profiler_; } protected: virtual void ShutDownImpl() = 0; @@ -79,6 +97,7 @@ class RPCServer { std::unordered_map rpc_cond_map_; std::atomic cur_cond_; std::condition_variable rpc_cond_; + RPCServerProfiler profiler_; protected: std::string bind_address_; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index e14b148cc00f425e90b0b2256ab3462753a34f47..b1948076969e7a651179b97b107b85beea175280 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -25,10 +25,6 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/listen_and_serv_op.h" -#include "paddle/fluid/platform/profiler.h" - -DEFINE_int32(listen_and_serv_profile_period, 0, - "the period of listen_and_serv to do profile"); namespace paddle { namespace operators { @@ -108,6 +104,7 @@ void ListenAndServOp::RunSyncLoop( framework::Scope *recv_scope, const std::vector &prefetch_block_id_list, const int checkpoint_point_block_id) const { + VLOG(2) << "RunSyncLoop"; size_t num_blocks = program->Size(); auto optimize_blocks = Attr>(kOptimizeBlocks); @@ -128,17 +125,8 @@ void ListenAndServOp::RunSyncLoop( rpc_service_->ResetBarrierCounter(); - int32_t profile_step = 0; while (true) { - PADDLE_ENFORCE_LE(profile_step, FLAGS_listen_and_serv_profile_period, - "profile_step should not be larger then " - "FLAGS_listen_and_serv_profile_period"); - if (FLAGS_listen_and_serv_profile_period > 0) { - if (profile_step == 0) { - auto pf_state = paddle::platform::ProfilerState::kCPU; - paddle::platform::EnableProfiler(pf_state); - } - } + rpc_service_->Profiler().OneStep(); // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. rpc_service_->SetCond(distributed::kRequestSend); @@ -180,21 +168,13 @@ void ListenAndServOp::RunSyncLoop( // reset received sparse vars to avoid reuse it in the next mini-batch dynamic_cast(request_send_handler_.get()) ->ResetSparseVarRecorder(); - if (FLAGS_listen_and_serv_profile_period > 0) { - if (profile_step == FLAGS_listen_and_serv_profile_period) { - paddle::platform::DisableProfiler( - paddle::platform::EventSortingKey::kTotal, "/dev/null"); - profile_step = 0; - } else { - profile_step++; - } - } } // while(true) } void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, framework::ProgramDesc *program, framework::Scope *recv_scope) const { + VLOG(2) << "RunAsyncLoop"; // grad name to block id std::unordered_map grad_to_block_id; std::unordered_map id_to_grad; diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 2ab176ec272a9f250d5229ae67faf350e5eda72c..cccb4abe6cf14ac3e90ba59d970c9398b09b7db2 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -128,7 +128,8 @@ def __bootstrap__(): ] if core.is_compiled_with_dist(): read_env_flags.append('rpc_deadline') - read_env_flags.append('listen_and_serv_profile_period') + read_env_flags.append('rpc_server_profile_period') + read_env_flags.append('rpc_server_profile_path') if core.is_compiled_with_cuda(): read_env_flags += [