提交 08e4970e 编写于 作者: X Xin Pan

follow comments

上级 a848303e
...@@ -19,14 +19,16 @@ limitations under the License. */ ...@@ -19,14 +19,16 @@ limitations under the License. */
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
DEFINE_int32(rpc_server_handle_send_threads, 20,
"Number of threads used to handle send at rpc server.");
DEFINE_int32(rpc_server_handle_get_threads, 20,
"Number of threads used to handle get at rpc server.");
DEFINE_int32(rpc_server_handle_prefetch_threads, 1,
"Number of threads used to handle prefetch at rpc server.");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
namespace {
const int kNumHandleSendThreads = 20;
const int kNumHandleGetThreads = 20;
const int kNumHandlePrefetchThreads = 1;
} // namespace
enum CallStatus { PROCESS = 0, FINISH }; enum CallStatus { PROCESS = 0, FINISH };
// reference: // reference:
...@@ -268,17 +270,17 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -268,17 +270,17 @@ void AsyncGRPCServer::RunSyncUpdate() {
TryToRegisterNewPrefetchOne(i); TryToRegisterNewPrefetchOne(i);
} }
for (int i = 0; i < kNumHandleSendThreads; ++i) { for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
t_sends_.emplace_back( t_sends_.emplace_back(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register))); cq_send_.get(), "cq_send", send_register)));
} }
for (int i = 0; i < kNumHandleGetThreads; ++i) { for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
t_gets_.emplace_back( t_gets_.emplace_back(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register))); cq_get_.get(), "cq_get", get_register)));
} }
for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) {
t_prefetchs_.emplace_back(new std::thread( t_prefetchs_.emplace_back(new std::thread(
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register))); "cq_prefetch", prefetch_register)));
...@@ -290,13 +292,13 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -290,13 +292,13 @@ void AsyncGRPCServer::RunSyncUpdate() {
condition_ready_.notify_all(); condition_ready_.notify_all();
// wait server // wait server
server_->Wait(); server_->Wait();
for (int i = 0; i < kNumHandleSendThreads; ++i) { for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
t_sends_[i]->join(); t_sends_[i]->join();
} }
for (int i = 0; i < kNumHandleGetThreads; ++i) { for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
t_gets_[i]->join(); t_gets_[i]->join();
} }
for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) {
t_prefetchs_[i]->join(); t_prefetchs_[i]->join();
} }
} }
......
...@@ -85,9 +85,9 @@ class AsyncGRPCServer final { ...@@ -85,9 +85,9 @@ class AsyncGRPCServer final {
void HandleRequest(::grpc::ServerCompletionQueue *cq, void HandleRequest(::grpc::ServerCompletionQueue *cq,
const std::string &cq_name, const std::string &cq_name,
std::function<void(int)> TryToRegisterNewOne); std::function<void(int)> TryToRegisterNewOne);
void TryToRegisterNewSendOne(int i); void TryToRegisterNewSendOne(int req_id);
void TryToRegisterNewGetOne(int i); void TryToRegisterNewGetOne(int req_id);
void TryToRegisterNewPrefetchOne(int i); void TryToRegisterNewPrefetchOne(int req_id);
void ShutdownQueue(); void ShutdownQueue();
private: private:
......
...@@ -31,10 +31,6 @@ limitations under the License. */ ...@@ -31,10 +31,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
namespace {
const int kStartProfile = 1;
const int kStopProfile = 2;
} // namespace
using VarMsg = sendrecv::VariableMessage; using VarMsg = sendrecv::VariableMessage;
...@@ -128,9 +124,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -128,9 +124,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
// trainer. // trainer.
if (platform::ShouldSendProfileState()) { if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) { if (platform::IsProfileEnabled()) {
request.set_profile(kStartProfile); request.set_profile(platform::kEnableProfiler);
} else { } else {
request.set_profile(kStopProfile); request.set_profile(platform::kDisableProfiler);
} }
} }
if (!out_name.empty()) { if (!out_name.empty()) {
......
...@@ -458,9 +458,11 @@ int VariableResponse::Parse(Source* source) { ...@@ -458,9 +458,11 @@ int VariableResponse::Parse(Source* source) {
if (listener_id <= 0) { if (listener_id <= 0) {
break; break;
} }
if (profiling == 1 && !platform::IsProfileEnabled()) { if (profiling == platform::kEnableProfiler &&
!platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU); platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (profiling == 2 && platform::IsProfileEnabled()) { } else if (profiling == platform::kDisableProfiler &&
platform::IsProfileEnabled()) {
// TODO(panyx0718): Should we allow to customize file dir. // TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler( platform::DisableProfiler(
platform::EventSortingKey::kDefault, platform::EventSortingKey::kDefault,
......
...@@ -116,6 +116,8 @@ void ResetProfiler(); ...@@ -116,6 +116,8 @@ void ResetProfiler();
void DisableProfiler(EventSortingKey sorted_key, void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path); const std::string& profile_path);
const int kEnableProfiler = 1;
const int kDisableProfiler = 2;
// Test if the profiler is currently enabled. // Test if the profiler is currently enabled.
bool IsProfileEnabled(); bool IsProfileEnabled();
// Whether the trainer should send profiling state to PS. // Whether the trainer should send profiling state to PS.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册