From 0a828fef8286c6b9cd7a5ca2345d19057762dc79 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 10 Mar 2019 23:16:50 +0800 Subject: [PATCH] add some flags for communicator --- .../operators/distributed/communicator.cc | 54 +++++++++++++++++-- .../operators/distributed/communicator.h | 23 +------- python/paddle/fluid/__init__.py | 4 ++ 3 files changed, 55 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index a7bce26234..73b9800d43 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/communicator.h" +#include #include // NOLINT #include // NOLINT @@ -24,6 +25,13 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/parameter_send.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +DEFINE_bool(communicator_independent_recv_thread, true, + "use an independent to recv vars from parameter server"); +DEFINE_int32(communicator_send_queue_size, 20, + "queue size to recv gradient before send"); +DEFINE_int32(communicator_recv_wait_ms, 200, "wait time between each recv"); +DEFINE_int32(communicator_thread_pool_size, 5, "wait time between each recv"); + namespace paddle { namespace operators { namespace distributed { @@ -70,6 +78,38 @@ static inline void MergeVars(const std::string &var_name, std::unique_ptr Communicator::communicator_(nullptr); std::once_flag Communicator::init_flag_; +Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, + const RpcCtxMap &recv_varname_to_ctx, + Scope *recv_scope) + : send_varname_to_ctx_(send_varname_to_ctx), + recv_varname_to_ctx_(recv_varname_to_ctx), + recv_scope_(recv_scope) { + // get all send information from graph, build vars_to_send + VLOG(0) << "communicator_independent_recv_thread: " + << FLAGS_communicator_independent_recv_thread; + VLOG(0) << "communicator_send_queue_size: " + << FLAGS_communicator_send_queue_size; + VLOG(0) << "communicator_recv_wait_ms: " << FLAGS_communicator_recv_wait_ms; + VLOG(0) << "communicator_thread_pool_size: " + << FLAGS_communicator_thread_pool_size; + send_scope_.reset(new Scope()); + for (auto &iter : send_varname_to_ctx_) { + send_varname_to_queue_[iter.first] = + std::make_shared>>( + FLAGS_communicator_send_queue_size); + } + send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size)); + recv_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size)); +} + +Communicator::~Communicator() { + VLOG(3) << "~Communicator"; + running_ = false; + if (send_thread_) send_thread_->join(); + if (recv_thread_) recv_thread_->join(); + VLOG(3) << "~Communicator done"; +} + void Communicator::SendThread() { VLOG(3) << "SendThread start!"; while (running_) { @@ -105,7 +145,9 @@ void Communicator::SendThread() { task_f.wait(); } VLOG(3) << "run send graph done"; - RecvAll(); + if (!FLAGS_communicator_independent_recv_thread) { + RecvAll(); + } } } @@ -132,8 +174,8 @@ void Communicator::RecvThread() { VLOG(3) << "RecvThread start!"; while (running_) { RecvAll(); - // TODO(qiao) need to be configuable - std::this_thread::sleep_for(std::chrono::milliseconds(200)); + std::this_thread::sleep_for( + std::chrono::milliseconds(FLAGS_communicator_recv_wait_ms)); } } @@ -157,8 +199,10 @@ void Communicator::Start() { // start send and recv thread send_thread_.reset( new std::thread(std::bind(&Communicator::SendThread, this))); - // recv_thread_.reset( - // new std::thread(std::bind(&Communicator::RecvThread, this))); + if (FLAGS_communicator_independent_recv_thread) { + recv_thread_.reset( + new std::thread(std::bind(&Communicator::RecvThread, this))); + } } } // namespace distributed diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 3c98b36b74..4104cb20a3 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -96,28 +96,9 @@ using RpcCtxMap = std::unordered_map; class Communicator { public: Communicator(const RpcCtxMap& send_varname_to_ctx, - const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) - : send_varname_to_ctx_(send_varname_to_ctx), - recv_varname_to_ctx_(recv_varname_to_ctx), - recv_scope_(recv_scope) { - // get all send information from graph, build vars_to_send - send_scope_.reset(new Scope()); - for (auto& iter : send_varname_to_ctx_) { - send_varname_to_queue_[iter.first] = - std::make_shared>>(10); - } - // TODO(qiao): default 5, need to config - send_threadpool_.reset(new ::ThreadPool(5)); - recv_threadpool_.reset(new ::ThreadPool(5)); - } + const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope); - ~Communicator() { - VLOG(3) << "~Communicator"; - running_ = false; - send_thread_->join(); - recv_thread_->join(); - VLOG(3) << "~Communicator done"; - } + ~Communicator(); void Start(); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index d12f04a6ab..8af5e1c509 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -150,6 +150,10 @@ def __bootstrap__(): read_env_flags.append('rpc_get_thread_num') read_env_flags.append('rpc_prefetch_thread_num') read_env_flags.append('rpc_disable_reuse_port') + read_env_flags.append('communicator_independent_recv_thread') + read_env_flags.append('communicator_send_queue_size') + read_env_flags.append('communicator_recv_wait_ms') + read_env_flags.append('communicator_thread_pool_size') if core.is_compiled_with_brpc(): read_env_flags.append('max_body_size') #set brpc max body size -- GitLab