From 50601501e52ce6bd0b34864dc2410e1a6083a3cd Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 4 Mar 2019 15:01:22 +0800 Subject: [PATCH] improve communicator --- .../operators/distributed/CMakeLists.txt | 2 +- .../operators/distributed/communicator.cc | 69 ++++++++++++------- .../operators/distributed/communicator.h | 16 ++++- .../fluid/operators/distributed/rpc_common.h | 8 +++ 4 files changed, 70 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 22f44c4217..1301467fa7 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -54,7 +54,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) -cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor) +cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool) if(WITH_GPU) cc_test(collective_server_test SRCS collective_server_test.cc DEPS sendrecvop_rpc executor ${RPC_DEPS} diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index bc0a57f344..403fcf4b16 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -25,9 +25,9 @@ namespace paddle { namespace operators { namespace distributed { -static void MergeVars(const std::string &var_name, - const std::vector> &vars, - Scope *scope) { +static inline void MergeVars(const std::string &var_name, + const std::vector> &vars, + Scope *scope) { PADDLE_ENFORCE(!vars.empty(), "should have value to merge!"); auto cpu_place = platform::CPUPlace(); auto &var0 = vars[0]; @@ -62,31 +62,53 @@ static void MergeVars(const std::string &var_name, } void Communicator::SendThread() { - for (auto &iter : send_varname_to_queue_) { - auto &var_name = iter.first; - VLOG(3) << "merge var " << var_name << " and send"; - auto &var_queue = iter.second; - std::vector> vars; - const size_t max_merge_var_num = 20; - size_t merged_var_num = 0; - while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) { - vars.push_back(var_queue->Pop()); - merged_var_num++; + while (running_) { + std::vector> task_futures; + task_futures.reserve(send_varname_to_ctx_.size()); + for (auto &iter : send_varname_to_queue_) { + auto send_task = [this, &iter] { + auto &var_name = iter.first; + VLOG(3) << "merge var " << var_name << " and send"; + auto &var_queue = iter.second; + std::vector> vars; + const size_t max_merge_var_num = 20; + size_t merged_var_num = 0; + while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) { + vars.push_back(var_queue->Pop()); + merged_var_num++; + } + MergeVars(var_name, vars, send_scope_.get()); + auto send_functor = distributed::ParameterSend(); + auto &ctx = send_varname_to_ctx_.at(var_name); + send_functor(ctx, *send_scope_, true); + }; + task_futures.emplace_back( + send_threadpool_->enqueue(std::move(send_task))); + } + for (auto &task_f : task_futures) { + task_f.wait(); } - MergeVars(var_name, vars, send_scope_.get()); - // auto send_functor = distributed::ParameterSend(); - // send_functor(var_name, send_varname_to_ctx_[var_name], exe_ctx, - // send_scope_, true); } } void Communicator::RecvThread() { - // parallel run recv graph - for (auto &iter : recv_varname_to_ctx_) { - auto &var_name = iter.first; - VLOG(3) << "recv var " << iter.first; - // auto recv_functor = distributed::ParameterRecv(); - // recv_functor(var_name, iter.second, exe_ctx, recv_scope_); + while (running_) { + // parallel run recv graph + std::vector> task_futures; + task_futures.reserve(recv_varname_to_ctx_.size()); + for (auto &iter : recv_varname_to_ctx_) { + auto recv_task = [this, &iter] { + auto &var_name = iter.first; + VLOG(3) << "recv var " << var_name; + auto recv_functor = distributed::ParameterRecv(); + recv_functor(iter.second, *recv_scope_); + }; + task_futures.emplace_back( + recv_threadpool_->enqueue(std::move(recv_task))); + } + for (auto &task : task_futures) { + task.wait(); + } } } @@ -101,6 +123,7 @@ void Communicator::Send(const std::string &var_name, } void Communicator::Start() { + running_ = true; // start send and recv thread send_thread_.reset( new std::thread(std::bind(&Communicator::SendThread, this))); diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 614d6ade81..ffdfa38b12 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -19,6 +19,8 @@ limitations under the License. */ #include #include +#include + #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/operators/distributed/rpc_common.h" @@ -100,9 +102,18 @@ class Communicator { send_varname_to_queue_[iter.first] = std::make_shared>>(10); } + // TODO(qiao): default 5, need to config + send_threadpool_.reset(new ::ThreadPool(5)); + recv_threadpool_.reset(new ::ThreadPool(5)); } - ~Communicator() {} + ~Communicator() { + VLOG(3) << "~Communicator"; + running_ = false; + send_thread_->join(); + recv_thread_->join(); + VLOG(3) << "~Communicator done"; + } void Start(); @@ -113,6 +124,7 @@ class Communicator { void SendThread(); void RecvThread(); + bool running_ = false; std::unordered_map>>> send_varname_to_queue_; @@ -122,6 +134,8 @@ class Communicator { std::unique_ptr recv_thread_; Scope* recv_scope_; // should be global scope std::unique_ptr send_scope_; // an independent scope + std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; + std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr}; }; } // namespace distributed diff --git a/paddle/fluid/operators/distributed/rpc_common.h b/paddle/fluid/operators/distributed/rpc_common.h index 7dede07b5a..39eb2d078c 100644 --- a/paddle/fluid/operators/distributed/rpc_common.h +++ b/paddle/fluid/operators/distributed/rpc_common.h @@ -29,6 +29,14 @@ struct RpcContext { splited_var_names(names), epmap(emap), height_sections(sections) {} + + RpcContext(const RpcContext& ctx) { + var_name = ctx.var_name; + splited_var_names = ctx.splited_var_names; + epmap = ctx.epmap; + height_sections = ctx.height_sections; + } + std::string var_name; std::vector splited_var_names; std::vector epmap; -- GitLab