From 8c38aca95401324a44a0aab8e017cae26a179b65 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 5 Mar 2019 16:49:52 +0800 Subject: [PATCH] tmp commit --- paddle/fluid/framework/details/CMakeLists.txt | 2 +- .../details/async_ssa_graph_executor.cc | 38 +++++++++++++++++++ .../operators/distributed/communicator.h | 36 +++++++++++++++--- 3 files changed, 69 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index b39673e2297..88e7dd3f88f 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -82,7 +82,7 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor) -cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor) +cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor communicator) cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory device_context broadcast_op_handle) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 69f770afee9..43391804c54 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/async_ssa_graph_executor.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/operators/distributed/communicator.h" namespace paddle { namespace framework { @@ -39,6 +40,43 @@ inline void NewTempScopeAndInitVars(const std::vector &var_infos, } } +// get RpcContext and remote send and recv op +void ProcessGraph(std::vector graphs, Scope *scope) { + using RpcCtxMap = operators::distributed::RpcCtxMap; + RpcCtxMap send_varname_to_ctx; + RpcCtxMap recv_varname_to_ctx; + for (auto i = 0; i < graphs.size(); ++i) { + for (auto &node : graphs[i]->Nodes()) { + if (node->IsOp()) { + if (node->Op()->Type() == "send") { + auto send_var_name = node->Op()->Input("X")[0]; + auto send_varnames = boost::get>( + node->Op()->GetNullableAttr("send_varnames")); + auto epmap = boost::get>( + node->Op()->GetNullableAttr("epmap")); + auto height_section = boost::get>( + node->Op()->GetNullableAttr("sections")); + send_varname_to_ctx[send_var_name] = + operators::distributed::RpcContext(send_var_name, send_varnames, + epmap, height_section); + } else if (node->Op()->Type() == "recv") { + auto recv_var_name = node->Op()->Input("X")[0]; + auto recv_varnames = boost::get>( + node->Op()->GetNullableAttr("recv_varnames")); + auto epmap = boost::get>( + node->Op()->GetNullableAttr("epmap")); + recv_varname_to_ctx[recv_var_name] = + operators::distributed::RpcContext(recv_var_name, recv_varnames, + epmap, {}); + } + } + } + } + // init communicator here + operators::distributed::Communicator::Init(send_varname_to_ctx, + recv_varname_to_ctx, scope); +} + AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, std::vector graphs) diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index ffdfa38b12f..44e2aa3be73 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -87,12 +87,12 @@ class BlockingQueue { std::condition_variable send_cv_; }; +using RpcCtxMap = std::unordered_map; + class Communicator { public: - Communicator( - const std::unordered_map& send_varname_to_ctx, - const std::unordered_map& recv_varname_to_ctx, - Scope* recv_scope) + Communicator(const RpcCtxMap& send_varname_to_ctx, + const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) : send_varname_to_ctx_(send_varname_to_ctx), recv_varname_to_ctx_(recv_varname_to_ctx), recv_scope_(recv_scope) { @@ -128,14 +128,38 @@ class Communicator { std::unordered_map>>> send_varname_to_queue_; - std::unordered_map send_varname_to_ctx_; - std::unordered_map recv_varname_to_ctx_; + RpcCtxMap send_varname_to_ctx_; + RpcCtxMap recv_varname_to_ctx_; std::unique_ptr send_thread_; std::unique_ptr recv_thread_; Scope* recv_scope_; // should be global scope std::unique_ptr send_scope_; // an independent scope std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr}; + + // the following code is for initialize the commnunicator + public: + static void Init(const RpcCtxMap& send_varname_to_ctx, + const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) { + InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope); + } + + static Communicator* GetInstance() { return communicator_.get(); } + + private: + // Init is called by GetInstance. + static void InitImpl(const RpcCtxMap& send_varname_to_ctx, + const RpcCtxMap& recv_varname_to_ctx, + Scope* recv_scope) { + if (communicator_ == nullptr) { + communicator_.reset(new Communicator(send_varname_to_ctx, + recv_varname_to_ctx, recv_scope)); + } + } + + private: + static std::once_flag init_flag_; + static std::unique_ptr communicator_; }; } // namespace distributed -- GitLab