diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index b39673e2297d4cb6299c4b2695af3896eb5dea91..88e7dd3f88fe01cc7c35e6b40ab0ac320d673cbf 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -82,7 +82,7 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor) -cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor) +cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor communicator) cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory device_context broadcast_op_handle) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 69f770afee9a6cbb83e43da31f04aeece5581cde..43391804c5430cf49bbdb68c43f623d204f24cea 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/async_ssa_graph_executor.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/operators/distributed/communicator.h" namespace paddle { namespace framework { @@ -39,6 +40,43 @@ inline void NewTempScopeAndInitVars(const std::vector &var_infos, } } +// get RpcContext and remote send and recv op +void ProcessGraph(std::vector graphs, Scope *scope) { + using RpcCtxMap = operators::distributed::RpcCtxMap; + RpcCtxMap send_varname_to_ctx; + RpcCtxMap recv_varname_to_ctx; + for (auto i = 0; i < graphs.size(); ++i) { + for (auto &node : graphs[i]->Nodes()) { + if (node->IsOp()) { + if (node->Op()->Type() == "send") { + auto send_var_name = node->Op()->Input("X")[0]; + auto send_varnames = boost::get>( + node->Op()->GetNullableAttr("send_varnames")); + auto epmap = boost::get>( + node->Op()->GetNullableAttr("epmap")); + auto height_section = boost::get>( + node->Op()->GetNullableAttr("sections")); + send_varname_to_ctx[send_var_name] = + operators::distributed::RpcContext(send_var_name, send_varnames, + epmap, height_section); + } else if (node->Op()->Type() == "recv") { + auto recv_var_name = node->Op()->Input("X")[0]; + auto recv_varnames = boost::get>( + node->Op()->GetNullableAttr("recv_varnames")); + auto epmap = boost::get>( + node->Op()->GetNullableAttr("epmap")); + recv_varname_to_ctx[recv_var_name] = + operators::distributed::RpcContext(recv_var_name, recv_varnames, + epmap, {}); + } + } + } + } + // init communicator here + operators::distributed::Communicator::Init(send_varname_to_ctx, + recv_varname_to_ctx, scope); +} + AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, std::vector graphs) diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index ffdfa38b12fc6240d31aeaa9d95fe498805548b9..44e2aa3be7392e25f0ac1fb86de3e8600b97eb53 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -87,12 +87,12 @@ class BlockingQueue { std::condition_variable send_cv_; }; +using RpcCtxMap = std::unordered_map; + class Communicator { public: - Communicator( - const std::unordered_map& send_varname_to_ctx, - const std::unordered_map& recv_varname_to_ctx, - Scope* recv_scope) + Communicator(const RpcCtxMap& send_varname_to_ctx, + const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) : send_varname_to_ctx_(send_varname_to_ctx), recv_varname_to_ctx_(recv_varname_to_ctx), recv_scope_(recv_scope) { @@ -128,14 +128,38 @@ class Communicator { std::unordered_map>>> send_varname_to_queue_; - std::unordered_map send_varname_to_ctx_; - std::unordered_map recv_varname_to_ctx_; + RpcCtxMap send_varname_to_ctx_; + RpcCtxMap recv_varname_to_ctx_; std::unique_ptr send_thread_; std::unique_ptr recv_thread_; Scope* recv_scope_; // should be global scope std::unique_ptr send_scope_; // an independent scope std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr}; + + // the following code is for initialize the commnunicator + public: + static void Init(const RpcCtxMap& send_varname_to_ctx, + const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) { + InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope); + } + + static Communicator* GetInstance() { return communicator_.get(); } + + private: + // Init is called by GetInstance. + static void InitImpl(const RpcCtxMap& send_varname_to_ctx, + const RpcCtxMap& recv_varname_to_ctx, + Scope* recv_scope) { + if (communicator_ == nullptr) { + communicator_.reset(new Communicator(send_varname_to_ctx, + recv_varname_to_ctx, recv_scope)); + } + } + + private: + static std::once_flag init_flag_; + static std::unique_ptr communicator_; }; } // namespace distributed