diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 9d6b30b06334a917193cc0cfda5acbe3293f23db..4114cb08119dba636b1504292a49c6cb426c3baf 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -19,6 +19,7 @@ if(WITH_DISTRIBUTE) set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 842695fdea60b26935223365ca083dbc451d38cc..eed6d6ef7e47e106a4ccb222d7a120f6fd0956cf 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/framework/program_desc.h" @@ -31,6 +32,40 @@ FleetExecutor::~FleetExecutor() { void FleetExecutor::Init(const paddle::framework::ProgramDesc& program_desc) { // Compile and Initialize + InitMessageBus(); +} + +void FleetExecutor::InitMessageBus() { + std::stringstream ss; + ss << "\nThe DNS table of the message bus is: \n"; + int64_t cur_rank = exe_desc_.cur_rank(); + std::unordered_map interceptor_id_to_rank; + std::unordered_map rank_to_addr; + std::string addr; + for (const auto& rank_info : exe_desc_.cluster_info()) { + int64_t rank = rank_info.rank(); + std::string ip_port = rank_info.ip_port(); + ss << rank << "\t->\t" << ip_port << "\n"; + // TODO(Yuang): replace the first 'rank' with real interceptor id + interceptor_id_to_rank.insert(std::make_pair(rank, rank)); + rank_to_addr.insert(std::make_pair(rank, ip_port)); + if (rank == cur_rank) { + addr = ip_port; + } + } + PADDLE_ENFORCE_NE( + addr, "", + platform::errors::NotFound( + "Current rank is %s, which ip_port cannot be found in the config.", + cur_rank)); + VLOG(3) << "Current rank is " << cur_rank << " and the ip_port is " << addr + << "."; + VLOG(3) << "The number of ranks are " << interceptor_id_to_rank.size() << "."; + VLOG(5) << ss.str(); + MessageBus& message_bus_instance = MessageBus::Instance(); + if (!message_bus_instance.IsInit()) { + message_bus_instance.Init(interceptor_id_to_rank, rank_to_addr, addr); + } } void FleetExecutor::Run() { diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index e12629844933a7e69f0df9984e4345a79935ef9e..242e1a74fc489d270c86da1fcd01582723c41e9e 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -42,6 +42,7 @@ class FleetExecutor final { DISABLE_COPY_AND_ASSIGN(FleetExecutor); FleetExecutorDesc exe_desc_; std::unique_ptr runtime_graph_; + void InitMessageBus(); static std::shared_ptr global_carrier_; }; diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index 08cd100f108fd8efa3afcf590017196f27e651e9..75e7b2fcb3dc5b5493d38067226abcce13a1352c 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -42,7 +42,10 @@ void MessageBus::Init( }); } +bool MessageBus::IsInit() const { return is_init_; } + void MessageBus::Release() { + VLOG(3) << "Message bus releases resource."; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) server_.Stop(1000); diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index 08e8d2e24abd8347c171d08a8c87cee5161884d5..e45f2e3c7125955c310aae6f43b207b7bbc44382 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -48,6 +48,8 @@ class MessageBus final { const std::unordered_map& rank_to_addr, const std::string& addr); + bool IsInit() const; + void Release(); // called by Interceptor, send InterceptorMessage to dst