From 742378f47f71083291dafd414c33e5b69d775f8b Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 12 Nov 2021 14:25:07 +0800 Subject: [PATCH] [fleet_executor] Add interceptor ping pong test (#37143) --- .../distributed/fleet_executor/CMakeLists.txt | 4 +- .../distributed/fleet_executor/carrier.cc | 33 +++++---- .../distributed/fleet_executor/carrier.h | 22 ++++-- .../fleet_executor/fleet_executor.cc | 5 -- .../fleet_executor/fleet_executor.h | 1 - .../distributed/fleet_executor/interceptor.cc | 9 ++- .../distributed/fleet_executor/interceptor.h | 2 +- .../interceptor_message_service.cc | 5 +- .../distributed/fleet_executor/message_bus.cc | 12 +--- .../distributed/fleet_executor/message_bus.h | 4 -- .../fleet_executor/test/CMakeLists.txt | 2 + .../test/interceptor_ping_pong_test.cc | 70 +++++++++++++++++++ 12 files changed, 120 insertions(+), 49 deletions(-) create mode 100644 paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt create mode 100644 paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 4114cb08119..adcf3c5e41d 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -5,7 +5,7 @@ endif() proto_library(interceptor_message_proto SRCS interceptor_message.proto) if(WITH_DISTRIBUTE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) - set(BRPC_DEPS brpc ssl crypto) + set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog) else() set(BRPC_DEPS "") endif() @@ -23,4 +23,6 @@ if(WITH_DISTRIBUTE) set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + + add_subdirectory(test) endif() diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 53a3af22c45..0e79656edea 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -20,9 +20,9 @@ namespace paddle { namespace distributed { -Carrier::Carrier( - const std::unordered_map& interceptor_id_to_node) - : interceptor_id_to_node_(interceptor_id_to_node) { +void Carrier::Init( + const std::unordered_map& interceptor_id_to_node) { + interceptor_id_to_node_ = interceptor_id_to_node; CreateInterceptors(); } @@ -56,20 +56,29 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { return iter->second.get(); } +Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, + std::unique_ptr interceptor) { + auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); + PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(), + platform::errors::AlreadyExists( + "The interceptor id %lld has already been created! " + "The interceptor id should be unique.", + interceptor_id)); + auto* ptr = interceptor.get(); + interceptor_idx_to_interceptor_.insert( + std::make_pair(interceptor_id, std::move(interceptor))); + return ptr; +} + void Carrier::CreateInterceptors() { // create each Interceptor for (const auto& item : interceptor_id_to_node_) { int64_t interceptor_id = item.first; TaskNode* task_node = item.second; - const auto& iter = interceptor_idx_to_interceptor_.find(interceptor_id); - PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(), - platform::errors::AlreadyExists( - "The interceptor id %lld has already been created! " - "The interceptor is should be unique.", - interceptor_id)); - interceptor_idx_to_interceptor_.insert(std::make_pair( - interceptor_id, - std::make_unique(interceptor_id, task_node))); + + // TODO(wangxi): use node_type to select different Interceptor + auto interceptor = std::make_unique(interceptor_id, task_node); + SetInterceptor(interceptor_id, std::move(interceptor)); VLOG(3) << "Create Interceptor for " << interceptor_id; } } diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index bac836deaaa..64974714f7b 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -18,6 +18,7 @@ #include #include +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" @@ -26,15 +27,18 @@ namespace paddle { namespace distributed { -class Interceptor; class TaskNode; class InterceptorMessageServiceImpl; +// A singleton MessageBus class Carrier final { public: - Carrier() = delete; + static Carrier& Instance() { + static Carrier carrier; + return carrier; + } - explicit Carrier( + void Init( const std::unordered_map& interceptor_id_to_node); ~Carrier() = default; @@ -42,15 +46,21 @@ class Carrier final { // Enqueue a message to corresponding interceptor id bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); + // get interceptor based on the interceptor id + Interceptor* GetInterceptor(int64_t interceptor_id); + + // set interceptor with interceptor id + Interceptor* SetInterceptor(int64_t interceptor_id, + std::unique_ptr); + DISABLE_COPY_AND_ASSIGN(Carrier); private: + Carrier() = default; + // create each Interceptor void CreateInterceptors(); - // get interceptor based on the interceptor id - Interceptor* GetInterceptor(int64_t interceptor_id); - // interceptor logic id to the Nodes info std::unordered_map interceptor_id_to_node_; diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index eed6d6ef7e4..47d0c526c03 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -76,10 +76,5 @@ void FleetExecutor::Release() { // Release } -std::shared_ptr FleetExecutor::GetCarrier() { - // get carrier - return nullptr; -} - } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index 242e1a74fc4..779d2f91221 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -43,7 +43,6 @@ class FleetExecutor final { FleetExecutorDesc exe_desc_; std::unique_ptr runtime_graph_; void InitMessageBus(); - static std::shared_ptr global_carrier_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 03f04d8340f..dbee46afcf8 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -56,11 +56,10 @@ bool Interceptor::EnqueueRemoteInterceptorMessage( return true; } -void Interceptor::Send(int64_t dst_id, - std::unique_ptr msg) { - msg->set_src_id(interceptor_id_); - msg->set_dst_id(dst_id); - MessageBus::Instance().Send(*msg.get()); +void Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { + msg.set_src_id(interceptor_id_); + msg.set_dst_id(dst_id); + MessageBus::Instance().Send(msg); } void Interceptor::PoolTheMailbox() { diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 7744ecbb110..24fad833186 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -58,7 +58,7 @@ class Interceptor { bool EnqueueRemoteInterceptorMessage( const InterceptorMessage& interceptor_message); - void Send(int64_t dst_id, std::unique_ptr msg); + void Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT DISABLE_COPY_AND_ASSIGN(Interceptor); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc index d30d356e4ff..2205c6e5544 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc @@ -31,10 +31,7 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( << ", with the message: " << request->message_type(); response->set_rst(true); // call interceptor manager's method to handle the message - std::shared_ptr carrier = FleetExecutor::GetCarrier(); - if (carrier != nullptr) { - carrier->EnqueueInterceptorMessage(*request); - } + Carrier::Instance().EnqueueInterceptorMessage(*request); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index 75e7b2fcb3d..2a8afb99ba7 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -32,10 +32,7 @@ void MessageBus::Init( rank_to_addr_ = rank_to_addr; addr_ = addr; - listen_port_thread_ = std::thread([this]() { - VLOG(3) << "Start listen_port_thread_ for message bus"; - ListenPort(); - }); + ListenPort(); std::call_once(once_flag_, []() { std::atexit([]() { MessageBus::Instance().Release(); }); @@ -51,7 +48,6 @@ void MessageBus::Release() { server_.Stop(1000); server_.Join(); #endif - listen_port_thread_.join(); } bool MessageBus::Send(const InterceptorMessage& interceptor_message) { @@ -184,11 +180,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) { // send the message intra rank (dst is the same rank with src) - std::shared_ptr carrier = FleetExecutor::GetCarrier(); - if (carrier != nullptr) { - return carrier->EnqueueInterceptorMessage(interceptor_message); - } - return true; + return Carrier::Instance().EnqueueInterceptorMessage(interceptor_message); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index e45f2e3c712..9212a93df42 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -92,10 +92,6 @@ class MessageBus final { // brpc server brpc::Server server_; #endif - - // thread keeps listening to the port to receive remote message - // this thread runs ListenPort() function - std::thread listen_port_thread_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt new file mode 100644 index 00000000000..524aebe3b95 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt @@ -0,0 +1,2 @@ +set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc new file mode 100644 index 00000000000..856bbb47547 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" + +namespace paddle { +namespace distributed { + +class PingPongInterceptor : public Interceptor { + public: + PingPongInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + RegisterMsgHandle([this](const InterceptorMessage& msg) { PingPong(msg); }); + } + + void PingPong(const InterceptorMessage& msg) { + std::cout << GetInterceptorId() << " recv msg, count=" << count_ + << std::endl; + ++count_; + if (count_ == 20) { + InterceptorMessage stop; + stop.set_message_type(STOP); + Send(0, stop); + Send(1, stop); + return; + } + + InterceptorMessage resp; + Send(msg.src_id(), resp); + } + + private: + int count_{0}; +}; + +TEST(InterceptorTest, PingPong) { + MessageBus& msg_bus = MessageBus::Instance(); + msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + + Carrier& carrier = Carrier::Instance(); + + Interceptor* a = carrier.SetInterceptor( + 0, std::make_unique(0, nullptr)); + + carrier.SetInterceptor(1, std::make_unique(1, nullptr)); + + InterceptorMessage msg; + a->Send(1, msg); +} + +} // namespace distributed +} // namespace paddle -- GitLab