From 08e4970e458a068c76af8ba89c78403b45c430d0 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 23 May 2018 01:18:09 -0700 Subject: [PATCH] follow comments --- paddle/fluid/operators/detail/grpc_server.cc | 24 ++++++++++--------- paddle/fluid/operators/detail/grpc_server.h | 6 ++--- .../operators/detail/sendrecvop_utils.cc | 8 ++----- .../operators/detail/variable_response.cc | 6 +++-- paddle/fluid/platform/profiler.h | 2 ++ 5 files changed, 24 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 51ddda6255..58faead2bd 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -19,14 +19,16 @@ limitations under the License. */ using ::grpc::ServerAsyncResponseWriter; +DEFINE_int32(rpc_server_handle_send_threads, 20, + "Number of threads used to handle send at rpc server."); +DEFINE_int32(rpc_server_handle_get_threads, 20, + "Number of threads used to handle get at rpc server."); +DEFINE_int32(rpc_server_handle_prefetch_threads, 1, + "Number of threads used to handle prefetch at rpc server."); + namespace paddle { namespace operators { namespace detail { -namespace { -const int kNumHandleSendThreads = 20; -const int kNumHandleGetThreads = 20; -const int kNumHandlePrefetchThreads = 1; -} // namespace enum CallStatus { PROCESS = 0, FINISH }; // reference: @@ -268,17 +270,17 @@ void AsyncGRPCServer::RunSyncUpdate() { TryToRegisterNewPrefetchOne(i); } - for (int i = 0; i < kNumHandleSendThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { t_sends_.emplace_back( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_send_.get(), "cq_send", send_register))); } - for (int i = 0; i < kNumHandleGetThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { t_gets_.emplace_back( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); } - for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { t_prefetchs_.emplace_back(new std::thread( std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), "cq_prefetch", prefetch_register))); @@ -290,13 +292,13 @@ void AsyncGRPCServer::RunSyncUpdate() { condition_ready_.notify_all(); // wait server server_->Wait(); - for (int i = 0; i < kNumHandleSendThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { t_sends_[i]->join(); } - for (int i = 0; i < kNumHandleGetThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { t_gets_[i]->join(); } - for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { t_prefetchs_[i]->join(); } } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 9a60ee5579..bdff9801a9 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -85,9 +85,9 @@ class AsyncGRPCServer final { void HandleRequest(::grpc::ServerCompletionQueue *cq, const std::string &cq_name, std::function TryToRegisterNewOne); - void TryToRegisterNewSendOne(int i); - void TryToRegisterNewGetOne(int i); - void TryToRegisterNewPrefetchOne(int i); + void TryToRegisterNewSendOne(int req_id); + void TryToRegisterNewGetOne(int req_id); + void TryToRegisterNewPrefetchOne(int req_id); void ShutdownQueue(); private: diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index a0d3345685..0601988351 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -31,10 +31,6 @@ limitations under the License. */ namespace paddle { namespace operators { namespace detail { -namespace { -const int kStartProfile = 1; -const int kStopProfile = 2; -} // namespace using VarMsg = sendrecv::VariableMessage; @@ -128,9 +124,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, // trainer. if (platform::ShouldSendProfileState()) { if (platform::IsProfileEnabled()) { - request.set_profile(kStartProfile); + request.set_profile(platform::kEnableProfiler); } else { - request.set_profile(kStopProfile); + request.set_profile(platform::kDisableProfiler); } } if (!out_name.empty()) { diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 2dfd9b2621..24cb91a3bb 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -458,9 +458,11 @@ int VariableResponse::Parse(Source* source) { if (listener_id <= 0) { break; } - if (profiling == 1 && !platform::IsProfileEnabled()) { + if (profiling == platform::kEnableProfiler && + !platform::IsProfileEnabled()) { platform::EnableProfiler(platform::ProfilerState::kCPU); - } else if (profiling == 2 && platform::IsProfileEnabled()) { + } else if (profiling == platform::kDisableProfiler && + platform::IsProfileEnabled()) { // TODO(panyx0718): Should we allow to customize file dir. platform::DisableProfiler( platform::EventSortingKey::kDefault, diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index 643bb6183d..bf43925373 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -116,6 +116,8 @@ void ResetProfiler(); void DisableProfiler(EventSortingKey sorted_key, const std::string& profile_path); +const int kEnableProfiler = 1; +const int kDisableProfiler = 2; // Test if the profiler is currently enabled. bool IsProfileEnabled(); // Whether the trainer should send profiling state to PS. -- GitLab