diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 443310c9d78fe5ef0f94e1095e217af0481145ac..53a3af22c45e7e6495b0b3dac249bf97bc52ed45 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -21,22 +21,57 @@ namespace paddle { namespace distributed { Carrier::Carrier( - const std::unordered_map& interceptor_id_to_node) { - // init -} - -Carrier::~Carrier() { - // destroy + const std::unordered_map& interceptor_id_to_node) + : interceptor_id_to_node_(interceptor_id_to_node) { + CreateInterceptors(); } bool Carrier::EnqueueInterceptorMessage( const InterceptorMessage& interceptor_message) { // enqueue message to interceptor - return true; + if (interceptor_message.ctrl_message()) { + // handle control message + return true; + } else { + int64_t dst_id = interceptor_message.dst_id(); + Interceptor* dst_interceptor = GetInterceptor(dst_id); + bool rst = + dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message); + if (rst) { + std::condition_variable& interceptor_cond_var = + dst_interceptor->GetCondVar(); + interceptor_cond_var.notify_all(); + } + return rst; + } +} + +Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { + auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); + PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(), + platform::errors::InvalidArgument( + "Cannot find interceptor instance for interceptor " + "id %lld. Wrong dst? Call before init?", + interceptor_id)); + return iter->second.get(); } void Carrier::CreateInterceptors() { // create each Interceptor + for (const auto& item : interceptor_id_to_node_) { + int64_t interceptor_id = item.first; + TaskNode* task_node = item.second; + const auto& iter = interceptor_idx_to_interceptor_.find(interceptor_id); + PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(), + platform::errors::AlreadyExists( + "The interceptor id %lld has already been created! " + "The interceptor is should be unique.", + interceptor_id)); + interceptor_idx_to_interceptor_.insert(std::make_pair( + interceptor_id, + std::make_unique(interceptor_id, task_node))); + VLOG(3) << "Create Interceptor for " << interceptor_id; + } } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index b0b0922e7bad0f40b4c001867347c152986ac21d..bac836deaaaf7f184dfcfe824ace1d97469d059e 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -19,6 +19,8 @@ #include #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/macros.h" namespace paddle { @@ -32,9 +34,10 @@ class Carrier final { public: Carrier() = delete; - Carrier(const std::unordered_map& interceptor_id_to_node); + explicit Carrier( + const std::unordered_map& interceptor_id_to_node); - ~Carrier(); + ~Carrier() = default; // Enqueue a message to corresponding interceptor id bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 48b5eddc41095951f7d4e0b23284e912a0479808..2bee99d183b996c3564eac14210968c666b56c0f 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -17,28 +17,73 @@ namespace paddle { namespace distributed { -Interceptor::Interceptor(int64_t interceptor_id_, TaskNode* node) { - // init +Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) + : interceptor_id_(interceptor_id), node_(node) { + interceptor_thread_ = std::thread([this]() { + VLOG(3) << "Start pooling local mailbox's thread."; + PoolTheMailbox(); + }); +} + +Interceptor::~Interceptor() { interceptor_thread_.join(); } + +std::condition_variable& Interceptor::GetCondVar() { + // get the conditional var + return cond_var_; } int64_t Interceptor::GetInterceptorId() const { // return the interceptor id - return 0; + return interceptor_id_; } bool Interceptor::EnqueueRemoteInterceptorMessage( const InterceptorMessage& interceptor_message) { // Called by Carrier, enqueue an InterceptorMessage to remote mailbox + VLOG(3) << "Enqueue message: " << interceptor_message.message_type() + << " into " << interceptor_id_ << "'s remote mailbox."; + remote_mailbox_mutex_.lock(); + remote_mailbox_.push(interceptor_message); + remote_mailbox_mutex_.unlock(); return true; } void Interceptor::PoolTheMailbox() { // pool the local mailbox, parse the Message + while (true) { + if (local_mailbox_.empty()) { + // local mailbox is empty, fetch the remote mailbox + VLOG(3) << interceptor_id_ << "'s local mailbox is empty. " + << "Fetch the remote mailbox."; + PADDLE_ENFORCE_EQ(FetchRemoteMailbox(), true, + platform::errors::InvalidArgument( + "Error encountered when fetch remote mailbox.")); + } + const InterceptorMessage interceptor_message = local_mailbox_.front(); + local_mailbox_.pop(); + const MessageType message_type = interceptor_message.message_type(); + VLOG(3) << interceptor_id_ << " has received a message: " << message_type + << "."; + if (message_type == STOP) { + // break the pooling thread + break; + } + } } bool Interceptor::FetchRemoteMailbox() { // fetch all Message from remote mailbox to local mailbox // return true if remote mailbox not empty, otherwise return false + std::unique_lock lock(remote_mailbox_mutex_); + cond_var_.wait(lock, [this]() { return !remote_mailbox_.empty(); }); + if (remote_mailbox_.empty()) { + // the thread has been unblocked accidentally + return false; + } + while (!remote_mailbox_.empty()) { + local_mailbox_.push(std::move(remote_mailbox_.front())); + remote_mailbox_.pop(); + } return true; } diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index a2e25a591bf4fb6fea581ebab145d93ef6a56cf4..85b1d2351f2496d106d25e398e84f4d8e5bbf43e 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -22,6 +22,8 @@ #include #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/macros.h" namespace paddle { @@ -33,13 +35,16 @@ class Interceptor { public: Interceptor() = delete; - Interceptor(int64_t interceptor_id_, TaskNode* node); + Interceptor(int64_t interceptor_id, TaskNode* node); - virtual ~Interceptor() = default; + virtual ~Interceptor(); // return the interceptor id int64_t GetInterceptorId() const; + // return the conditional var + std::condition_variable& GetCondVar(); + // Called by Carrier, enqueue an InterceptorMessage to remote mailbox bool EnqueueRemoteInterceptorMessage( const InterceptorMessage& interceptor_message); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc index c038e2333d29a21db65fa74ab38612f55eb85d5d..d30d356e4ff28f54728efc6ded5a6a59eb0f23a4 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc @@ -11,9 +11,12 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#ifndef PADDLE_WITH_ASCEND_CL -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ + !defined(PADDLE_WITH_ASCEND_CL) #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" namespace paddle { namespace distributed { @@ -22,10 +25,18 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( google::protobuf::RpcController* control_base, const InterceptorMessage* request, InterceptorResponse* response, google::protobuf::Closure* done) { - // receive msg + brpc::ClosureGuard done_guard(done); + VLOG(3) << "Interceptor Message Service receives a message from: " + << request->src_id() + << ", with the message: " << request->message_type(); + response->set_rst(true); + // call interceptor manager's method to handle the message + std::shared_ptr carrier = FleetExecutor::GetCarrier(); + if (carrier != nullptr) { + carrier->EnqueueInterceptorMessage(*request); + } } } // namespace distributed } // namespace paddle #endif -#endif diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.h b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.h index 77eda7816e468dd4ae0b860f33dde189184059ad..0a8dfc861a910b2dfe500f684af1996612d10299 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.h @@ -11,8 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#ifndef PADDLE_WITH_ASCEND_CL -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ + !defined(PADDLE_WITH_ASCEND_CL) #pragma once #include "brpc/server.h" @@ -34,4 +34,3 @@ class InterceptorMessageServiceImpl : public TheInterceptorMessageService { } // namespace distributed } // namespace paddle #endif -#endif diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index d529a0ba5fa26269115b83ebabc89d4c30592fba..6853768c84839304d0fa536c2f2e229f61fc0584 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -12,41 +12,160 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/distributed/fleet_executor/message_bus.h" +#include + #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" namespace paddle { namespace distributed { +MessageBus::MessageBus( + const std::unordered_map& interceptor_id_to_rank, + const std::unordered_map& rank_to_addr, + const std::string& addr) + : interceptor_id_to_rank_(interceptor_id_to_rank), + rank_to_addr_(rank_to_addr), + addr_(addr) { + listen_port_thread_ = std::thread([this]() { + VLOG(3) << "Start listen_port_thread_ for message bus"; + ListenPort(); + }); +} + MessageBus::~MessageBus() { - // destroy +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ + !defined(PADDLE_WITH_ASCEND_CL) + server_.Stop(1000); + server_.Join(); +#endif + listen_port_thread_.join(); } bool MessageBus::Send(const InterceptorMessage& interceptor_message) { // called by Interceptor, send InterceptorMessage to dst + int64_t src_id = interceptor_message.src_id(); + int64_t dst_id = interceptor_message.dst_id(); + if (IsSameRank(src_id, dst_id)) { + VLOG(3) << "Send a message from: " << src_id << " to " << dst_id + << " within a same rank."; + return SendIntraRank(interceptor_message); + } else { + VLOG(3) << "Send a message from: " << src_id << " to " << dst_id + << " between different ranks."; +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ + !defined(PADDLE_WITH_ASCEND_CL) + return SendInterRank(interceptor_message); +#else + PADDLE_THROW(platform::errors::Unavailable( + "Fleet executor does not support sending message between different " + "ranks when Paddle is compiled with npu or " + "isn't compiled with distributed for now.")); +#endif + } return true; } void MessageBus::ListenPort() { +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ + !defined(PADDLE_WITH_ASCEND_CL) // function keep listen the port and handle the message + InterceptorMessageServiceImpl interceptor_message_service; + PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service, + brpc::SERVER_DOESNT_OWN_SERVICE), + 0, platform::errors::Unavailable( + "Message bus: init brpc service error.")); + + // start the server + const char* ip_for_brpc = addr_.c_str(); + brpc::ServerOptions options; + options.idle_timeout_sec = -1; + PADDLE_ENFORCE_EQ( + server_.Start(ip_for_brpc, &options), 0, + platform::errors::Unavailable("Message bus: start brpc service error.")); + VLOG(3) << "Message bus's listen port thread starts successful."; +#else + VLOG(3) << "Fleet executor's ListenPort() is a fake function when Paddle is " + "compiled with npu or Paddle isn't compiled " + "with distributed for now."; +#endif } bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) { // check whether the dst is the same rank or different rank with src - return true; + const auto& src_rank = interceptor_id_to_rank_.find(src_id); + const auto& dst_rank = interceptor_id_to_rank_.find(dst_id); + PADDLE_ENFORCE_NE( + src_rank, interceptor_id_to_rank_.end(), + platform::errors::NotFound( + "Cannot find rank for src interceptor id %lld. Init error.", src_id)); + PADDLE_ENFORCE_NE( + dst_rank, interceptor_id_to_rank_.end(), + platform::errors::NotFound( + "Cannot find rank for dst interceptor id %lld. Init error.", dst_id)); + const auto& src_ip = rank_to_addr_.find(src_rank->second); + PADDLE_ENFORCE_NE(src_ip, rank_to_addr_.end(), + platform::errors::NotFound( + "Cannot find addr for src rank id %lld. Init error.", + src_rank->second)); + PADDLE_ENFORCE_EQ( + src_ip->second, addr_, + platform::errors::Fatal("The src interceptor's addr is %s, while the " + "message bus's addr is %s, which are different. " + "Init error.", + src_ip->second, addr_)); + return src_rank->second == dst_rank->second; } -#ifndef PADDLE_WITH_ASCEND_CL -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ + !defined(PADDLE_WITH_ASCEND_CL) bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { // send the message inter rank (dst is different rank with src) - return true; + int64_t dst_id = interceptor_message.dst_id(); + int64_t dst_rank = interceptor_id_to_rank_[dst_id]; + auto dst_ip = rank_to_addr_.find(dst_rank); + PADDLE_ENFORCE_NE(dst_ip, rank_to_addr_.end(), + platform::errors::InvalidArgument( + "Cannot find rank for dst interceptor id %lld. " + "Init error.", + dst_id)); + const char* dst_ip_for_brpc = dst_ip->second.c_str(); + brpc::Channel channel; + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = 1000; + options.max_retry = 5; + PADDLE_ENFORCE_EQ( + channel.Init(dst_ip_for_brpc, &options), 0, + platform::errors::Unavailable("Message bus: init brpc channel error.")); + TheInterceptorMessageService_Stub stub(&channel); + InterceptorResponse response; + brpc::Controller ctrl; + ctrl.set_log_id(0); + stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL); + if (!ctrl.Failed()) { + if (response.rst()) { + VLOG(3) << "Message bus: brpc sends success."; + return true; + } else { + VLOG(3) << "Message bus: InterceptorMessageService error."; + return false; + } + } else { + VLOG(3) << "Message bus: brpc sends failed with error text: " + << ctrl.ErrorText(); + return false; + } } #endif -#endif bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) { // send the message intra rank (dst is the same rank with src) + std::shared_ptr carrier = FleetExecutor::GetCarrier(); + if (carrier != nullptr) { + return carrier->EnqueueInterceptorMessage(interceptor_message); + } return true; } diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index f0f491b603061e6bcd9da96dd761068686e040fd..86f34e203c5de589b06bbe769ba6811295b5b700 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -18,14 +18,16 @@ #include #include -#ifndef PADDLE_WITH_ASCEND_CL -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ + !defined(PADDLE_WITH_ASCEND_CL) #include "brpc/channel.h" #include "brpc/server.h" -#endif +#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #endif #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/macros.h" namespace paddle { @@ -37,13 +39,9 @@ class MessageBus final { public: MessageBus() = delete; - explicit MessageBus( - const std::unordered_map& interceptor_id_to_rank, - const std::unordered_map& rank_to_addr, - const std::string& addr) - : interceptor_id_to_rank_(interceptor_id_to_rank), - rank_to_addr_(rank_to_addr), - addr_(addr) {} + MessageBus(const std::unordered_map& interceptor_id_to_rank, + const std::unordered_map& rank_to_addr, + const std::string& addr); ~MessageBus(); @@ -59,11 +57,10 @@ class MessageBus final { // check whether the dst is the same rank or different rank with src bool IsSameRank(int64_t src_id, int64_t dst_id); -#ifndef PADDLE_WITH_ASCEND_CL -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ + !defined(PADDLE_WITH_ASCEND_CL) // send the message inter rank (dst is different rank with src) bool SendInterRank(const InterceptorMessage& interceptor_message); -#endif #endif // send the message intra rank (dst is the same rank with src) @@ -78,11 +75,10 @@ class MessageBus final { // the ip needs to be listened std::string addr_; -#ifndef PADDLE_WITH_ASCEND_CL -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ + !defined(PADDLE_WITH_ASCEND_CL) // brpc server brpc::Server server_; -#endif #endif // thread keeps listening to the port to receive remote message