// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" namespace paddle { namespace distributed { void MessageBus::Init( int64_t rank, const std::unordered_map& rank_to_addr, const std::string& addr) { PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( "MessageBus is already init.")); rank_ = rank; is_init_ = true; rank_to_addr_ = rank_to_addr; addr_ = addr; if (addr_ != "") { const auto& addr = GetAddr(rank_); PADDLE_ENFORCE_EQ(addr, addr_, platform::errors::Fatal( "The current rank's addr is %s, while the " "message bus's addr is %s, which are different. " "Init error.", addr, addr_)); } #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) // NOTE: To make the brpc is compatible with collective, // need release the handler holding the ip address. if (addr_ != "") { VLOG(3) << "Message bus is releasing the fd held by gen_comm_id."; paddle::platform::SocketServer& socket_server = paddle::platform::SocketServer::GetInstance(addr_); int server_fd = socket_server.socket(); if (server_fd != -1) { socket_server.Release(); } } #endif ListenPort(); } bool MessageBus::IsInit() const { return is_init_; } MessageBus::~MessageBus() { VLOG(3) << "Message bus releases resource."; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) server_.Stop(1000); server_.Join(); #endif } const std::string& MessageBus::GetAddr(int64_t rank) const { PADDLE_ENFORCE_NE( rank_to_addr_.find(rank), rank_to_addr_.end(), platform::errors::NotFound("Cannot find addr rank id %lld.", rank)); return rank_to_addr_.at(rank); } bool MessageBus::Send(int64_t dst_rank, const InterceptorMessage& interceptor_message) { #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) int retry_time = 0; // message bus will retry sending for 10 times while (retry_time < 10) { ++retry_time; if (SendInterRank(dst_rank, interceptor_message)) { VLOG(3) << "Message bus sends inter rank successfully with " << retry_time << " times retries."; return true; } VLOG(3) << "Message bus sends failed, retry after 1 seconds."; std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } VLOG(3) << "Message bus sends inter rank fail after 10 times retries."; return false; #else PADDLE_THROW(platform::errors::Unavailable( "Fleet executor does not support sending message between different " "ranks when Paddle is compiled with npu or " "isn't compiled with distributed for now.")); #endif return true; } void MessageBus::TestConnection() { InterceptorMessage ctrl_msg; ctrl_msg.set_ctrl_message(true); ctrl_msg.set_src_id(rank_); for (const auto& dst_rank_pair : rank_to_addr_) { int64_t dst_rank = dst_rank_pair.first; if (dst_rank != rank_) { ctrl_msg.set_dst_id(dst_rank); VLOG(3) << "Send control message bus from rank " << rank_ << " to rank " << dst_rank; while (!Send(dst_rank, ctrl_msg)) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } VLOG(3) << "Message bus has connected to rank: " << dst_rank << "."; } } } void MessageBus::ListenPort() { if (addr_ == "") { LOG(INFO) << "No need listen to port since training on single card."; return; } #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) // function keep listen the port and handle the message PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0, platform::errors::Unavailable( "Message bus: init brpc service error.")); // start the server const char* ip_for_brpc = addr_.c_str(); brpc::ServerOptions options; options.idle_timeout_sec = -1; int retry_times = 0; int interval = 100; while (server_.Start(ip_for_brpc, &options) != 0) { ++retry_times; LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times << " times. And will retry after " << interval / 1000 << " seconds."; std::this_thread::sleep_for(std::chrono::milliseconds(interval)); interval += 500; } LOG(INFO) << "Message bus's listen port thread starts successful."; TestConnection(); #else LOG(WARNING) << "Fleet executor's ListenPort() is a fake function when Paddle is " "compiled with npu or Paddle isn't compiled " "with distributed for now."; #endif } #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) bool MessageBus::SendInterRank(int64_t dst_rank, const InterceptorMessage& interceptor_message) { const auto& dst_addr = GetAddr(dst_rank); VLOG(3) << "Message bus sending to addr: " << dst_addr; const char* dst_addr_for_brpc = dst_addr.c_str(); brpc::Channel channel; brpc::ChannelOptions options; options.protocol = "baidu_std"; options.connect_timeout_ms = 1000; options.timeout_ms = 1000; options.max_retry = 5; PADDLE_ENFORCE_EQ( channel.Init(dst_addr_for_brpc, &options), 0, platform::errors::Unavailable("Message bus: init brpc channel error.")); TheInterceptorMessageService_Stub stub(&channel); InterceptorResponse response; brpc::Controller ctrl; ctrl.set_log_id(0); stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL); if (!ctrl.Failed()) { if (response.rst()) { VLOG(3) << "Message bus: brpc sends success."; return true; } else { VLOG(4) << "Message bus: InterceptorMessageService error."; return false; } } else { VLOG(4) << "Message bus: brpc sends failed with error text: " << ctrl.ErrorText(); return false; } } #endif } // namespace distributed } // namespace paddle