提交 d69c8207 编写于 作者: Q Qiao Longfei

Merge branch 'add-flag-to-control-rpc-thread-num' of...

Merge branch 'add-flag-to-control-rpc-thread-num' of ssh://github.com/jacquesqiao/Paddle into cpu-for-1.1-merge
...@@ -27,6 +27,10 @@ limitations under the License. */ ...@@ -27,6 +27,10 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
DEFINE_int32(rpc_send_thread_num, 5, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 5, "number of threads for rpc get");
DEFINE_int32(rpc_prefetch_thread_num, 5, "number of threads for rpc prefetch");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -332,11 +336,14 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -332,11 +336,14 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sync_mode, checkpoint_block_id)); sync_mode, checkpoint_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get()); request_send_handler_.get(),
FLAGS_rpc_send_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestGet, rpc_service_->RegisterRPC(distributed::kRequestGet,
request_get_handler_.get()); request_get_handler_.get(),
FLAGS_rpc_get_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestPrefetch, rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
request_prefetch_handler_.get()); request_prefetch_handler_.get(),
FLAGS_rpc_prefetch_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestCheckpoint, rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
request_checkpoint_handler_.get()); request_checkpoint_handler_.get());
......
...@@ -121,6 +121,9 @@ def __bootstrap__(): ...@@ -121,6 +121,9 @@ def __bootstrap__():
read_env_flags.append('rpc_server_profile_period') read_env_flags.append('rpc_server_profile_period')
read_env_flags.append('rpc_server_profile_path') read_env_flags.append('rpc_server_profile_path')
read_env_flags.append('enable_rpc_profiler') read_env_flags.append('enable_rpc_profiler')
read_env_flags.append('rpc_send_thread_num')
read_env_flags.append('rpc_get_thread_num')
read_env_flags.append('rpc_prefetch_thread_num')
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
read_env_flags += [ read_env_flags += [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册