未验证 提交 6bf208c3 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Parse rank_to_ip map on cpp side and start message bus. (#37126)

上级 778a3630
...@@ -19,6 +19,7 @@ if(WITH_DISTRIBUTE) ...@@ -19,6 +19,7 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -31,6 +32,40 @@ FleetExecutor::~FleetExecutor() { ...@@ -31,6 +32,40 @@ FleetExecutor::~FleetExecutor() {
void FleetExecutor::Init(const paddle::framework::ProgramDesc& program_desc) { void FleetExecutor::Init(const paddle::framework::ProgramDesc& program_desc) {
// Compile and Initialize // Compile and Initialize
InitMessageBus();
}
void FleetExecutor::InitMessageBus() {
std::stringstream ss;
ss << "\nThe DNS table of the message bus is: \n";
int64_t cur_rank = exe_desc_.cur_rank();
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank;
std::unordered_map<int64_t, std::string> rank_to_addr;
std::string addr;
for (const auto& rank_info : exe_desc_.cluster_info()) {
int64_t rank = rank_info.rank();
std::string ip_port = rank_info.ip_port();
ss << rank << "\t->\t" << ip_port << "\n";
// TODO(Yuang): replace the first 'rank' with real interceptor id
interceptor_id_to_rank.insert(std::make_pair(rank, rank));
rank_to_addr.insert(std::make_pair(rank, ip_port));
if (rank == cur_rank) {
addr = ip_port;
}
}
PADDLE_ENFORCE_NE(
addr, "",
platform::errors::NotFound(
"Current rank is %s, which ip_port cannot be found in the config.",
cur_rank));
VLOG(3) << "Current rank is " << cur_rank << " and the ip_port is " << addr
<< ".";
VLOG(3) << "The number of ranks are " << interceptor_id_to_rank.size() << ".";
VLOG(5) << ss.str();
MessageBus& message_bus_instance = MessageBus::Instance();
if (!message_bus_instance.IsInit()) {
message_bus_instance.Init(interceptor_id_to_rank, rank_to_addr, addr);
}
} }
void FleetExecutor::Run() { void FleetExecutor::Run() {
......
...@@ -42,6 +42,7 @@ class FleetExecutor final { ...@@ -42,6 +42,7 @@ class FleetExecutor final {
DISABLE_COPY_AND_ASSIGN(FleetExecutor); DISABLE_COPY_AND_ASSIGN(FleetExecutor);
FleetExecutorDesc exe_desc_; FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_; std::unique_ptr<RuntimeGraph> runtime_graph_;
void InitMessageBus();
static std::shared_ptr<Carrier> global_carrier_; static std::shared_ptr<Carrier> global_carrier_;
}; };
......
...@@ -42,7 +42,10 @@ void MessageBus::Init( ...@@ -42,7 +42,10 @@ void MessageBus::Init(
}); });
} }
bool MessageBus::IsInit() const { return is_init_; }
void MessageBus::Release() { void MessageBus::Release() {
VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000); server_.Stop(1000);
......
...@@ -48,6 +48,8 @@ class MessageBus final { ...@@ -48,6 +48,8 @@ class MessageBus final {
const std::unordered_map<int64_t, std::string>& rank_to_addr, const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr); const std::string& addr);
bool IsInit() const;
void Release(); void Release();
// called by Interceptor, send InterceptorMessage to dst // called by Interceptor, send InterceptorMessage to dst
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册