diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 51ddda6255b8d0a95ed44d213235fe5fb1a0e1ce..58faead2bdf9a89749e08207d964836bbf5cb68e 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -19,14 +19,16 @@ limitations under the License. */ using ::grpc::ServerAsyncResponseWriter; +DEFINE_int32(rpc_server_handle_send_threads, 20, + "Number of threads used to handle send at rpc server."); +DEFINE_int32(rpc_server_handle_get_threads, 20, + "Number of threads used to handle get at rpc server."); +DEFINE_int32(rpc_server_handle_prefetch_threads, 1, + "Number of threads used to handle prefetch at rpc server."); + namespace paddle { namespace operators { namespace detail { -namespace { -const int kNumHandleSendThreads = 20; -const int kNumHandleGetThreads = 20; -const int kNumHandlePrefetchThreads = 1; -} // namespace enum CallStatus { PROCESS = 0, FINISH }; // reference: @@ -268,17 +270,17 @@ void AsyncGRPCServer::RunSyncUpdate() { TryToRegisterNewPrefetchOne(i); } - for (int i = 0; i < kNumHandleSendThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { t_sends_.emplace_back( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_send_.get(), "cq_send", send_register))); } - for (int i = 0; i < kNumHandleGetThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { t_gets_.emplace_back( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); } - for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { t_prefetchs_.emplace_back(new std::thread( std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), "cq_prefetch", prefetch_register))); @@ -290,13 +292,13 @@ void AsyncGRPCServer::RunSyncUpdate() { condition_ready_.notify_all(); // wait server server_->Wait(); - for (int i = 0; i < kNumHandleSendThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { t_sends_[i]->join(); } - for (int i = 0; i < kNumHandleGetThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { t_gets_[i]->join(); } - for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { t_prefetchs_[i]->join(); } } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 9a60ee5579a6a50d913123d061dc43625ccc6013..bdff9801a928699f8391bfb68c1c7bd2d75aa642 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -85,9 +85,9 @@ class AsyncGRPCServer final { void HandleRequest(::grpc::ServerCompletionQueue *cq, const std::string &cq_name, std::function TryToRegisterNewOne); - void TryToRegisterNewSendOne(int i); - void TryToRegisterNewGetOne(int i); - void TryToRegisterNewPrefetchOne(int i); + void TryToRegisterNewSendOne(int req_id); + void TryToRegisterNewGetOne(int req_id); + void TryToRegisterNewPrefetchOne(int req_id); void ShutdownQueue(); private: diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index a0d3345685210fee92e1871442d82a4a103b6f2a..0601988351789ed496bd132f3c7616a068b65b3c 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -31,10 +31,6 @@ limitations under the License. */ namespace paddle { namespace operators { namespace detail { -namespace { -const int kStartProfile = 1; -const int kStopProfile = 2; -} // namespace using VarMsg = sendrecv::VariableMessage; @@ -128,9 +124,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, // trainer. if (platform::ShouldSendProfileState()) { if (platform::IsProfileEnabled()) { - request.set_profile(kStartProfile); + request.set_profile(platform::kEnableProfiler); } else { - request.set_profile(kStopProfile); + request.set_profile(platform::kDisableProfiler); } } if (!out_name.empty()) { diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 2dfd9b2621e1d21e4bb02d9a2d50304638362e3e..24cb91a3bb820a0e5d51aaa49154434919080f69 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -458,9 +458,11 @@ int VariableResponse::Parse(Source* source) { if (listener_id <= 0) { break; } - if (profiling == 1 && !platform::IsProfileEnabled()) { + if (profiling == platform::kEnableProfiler && + !platform::IsProfileEnabled()) { platform::EnableProfiler(platform::ProfilerState::kCPU); - } else if (profiling == 2 && platform::IsProfileEnabled()) { + } else if (profiling == platform::kDisableProfiler && + platform::IsProfileEnabled()) { // TODO(panyx0718): Should we allow to customize file dir. platform::DisableProfiler( platform::EventSortingKey::kDefault, diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index 643bb6183d144ec11a4890d9ea1ca970acb08b4c..bf43925373a12cd9ff2155d68c42d0266ba4df60 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -116,6 +116,8 @@ void ResetProfiler(); void DisableProfiler(EventSortingKey sorted_key, const std::string& profile_path); +const int kEnableProfiler = 1; +const int kDisableProfiler = 2; // Test if the profiler is currently enabled. bool IsProfileEnabled(); // Whether the trainer should send profiling state to PS.