未验证 提交 742378f4 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Add interceptor ping pong test (#37143)

上级 63c8c8c2
...@@ -5,7 +5,7 @@ endif() ...@@ -5,7 +5,7 @@ endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto) proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if(WITH_DISTRIBUTE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) if(WITH_DISTRIBUTE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set(BRPC_DEPS brpc ssl crypto) set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
else() else()
set(BRPC_DEPS "") set(BRPC_DEPS "")
endif() endif()
...@@ -23,4 +23,6 @@ if(WITH_DISTRIBUTE) ...@@ -23,4 +23,6 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
add_subdirectory(test)
endif() endif()
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
Carrier::Carrier( void Carrier::Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
: interceptor_id_to_node_(interceptor_id_to_node) { interceptor_id_to_node_ = interceptor_id_to_node;
CreateInterceptors(); CreateInterceptors();
} }
...@@ -56,20 +56,29 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { ...@@ -56,20 +56,29 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
return iter->second.get(); return iter->second.get();
} }
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor> interceptor) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(),
platform::errors::AlreadyExists(
"The interceptor id %lld has already been created! "
"The interceptor id should be unique.",
interceptor_id));
auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
return ptr;
}
void Carrier::CreateInterceptors() { void Carrier::CreateInterceptors() {
// create each Interceptor // create each Interceptor
for (const auto& item : interceptor_id_to_node_) { for (const auto& item : interceptor_id_to_node_) {
int64_t interceptor_id = item.first; int64_t interceptor_id = item.first;
TaskNode* task_node = item.second; TaskNode* task_node = item.second;
const auto& iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(), // TODO(wangxi): use node_type to select different Interceptor
platform::errors::AlreadyExists( auto interceptor = std::make_unique<Interceptor>(interceptor_id, task_node);
"The interceptor id %lld has already been created! " SetInterceptor(interceptor_id, std::move(interceptor));
"The interceptor is should be unique.",
interceptor_id));
interceptor_idx_to_interceptor_.insert(std::make_pair(
interceptor_id,
std::make_unique<Interceptor>(interceptor_id, task_node)));
VLOG(3) << "Create Interceptor for " << interceptor_id; VLOG(3) << "Create Interceptor for " << interceptor_id;
} }
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
...@@ -26,15 +27,18 @@ ...@@ -26,15 +27,18 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class Interceptor;
class TaskNode; class TaskNode;
class InterceptorMessageServiceImpl; class InterceptorMessageServiceImpl;
// A singleton MessageBus
class Carrier final { class Carrier final {
public: public:
Carrier() = delete; static Carrier& Instance() {
static Carrier carrier;
return carrier;
}
explicit Carrier( void Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node); const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node);
~Carrier() = default; ~Carrier() = default;
...@@ -42,15 +46,21 @@ class Carrier final { ...@@ -42,15 +46,21 @@ class Carrier final {
// Enqueue a message to corresponding interceptor id // Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
// get interceptor based on the interceptor id
Interceptor* GetInterceptor(int64_t interceptor_id);
// set interceptor with interceptor id
Interceptor* SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor>);
DISABLE_COPY_AND_ASSIGN(Carrier); DISABLE_COPY_AND_ASSIGN(Carrier);
private: private:
Carrier() = default;
// create each Interceptor // create each Interceptor
void CreateInterceptors(); void CreateInterceptors();
// get interceptor based on the interceptor id
Interceptor* GetInterceptor(int64_t interceptor_id);
// interceptor logic id to the Nodes info // interceptor logic id to the Nodes info
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
......
...@@ -76,10 +76,5 @@ void FleetExecutor::Release() { ...@@ -76,10 +76,5 @@ void FleetExecutor::Release() {
// Release // Release
} }
std::shared_ptr<Carrier> FleetExecutor::GetCarrier() {
// get carrier
return nullptr;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -43,7 +43,6 @@ class FleetExecutor final { ...@@ -43,7 +43,6 @@ class FleetExecutor final {
FleetExecutorDesc exe_desc_; FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_; std::unique_ptr<RuntimeGraph> runtime_graph_;
void InitMessageBus(); void InitMessageBus();
static std::shared_ptr<Carrier> global_carrier_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -56,11 +56,10 @@ bool Interceptor::EnqueueRemoteInterceptorMessage( ...@@ -56,11 +56,10 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
return true; return true;
} }
void Interceptor::Send(int64_t dst_id, void Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
std::unique_ptr<InterceptorMessage> msg) { msg.set_src_id(interceptor_id_);
msg->set_src_id(interceptor_id_); msg.set_dst_id(dst_id);
msg->set_dst_id(dst_id); MessageBus::Instance().Send(msg);
MessageBus::Instance().Send(*msg.get());
} }
void Interceptor::PoolTheMailbox() { void Interceptor::PoolTheMailbox() {
......
...@@ -58,7 +58,7 @@ class Interceptor { ...@@ -58,7 +58,7 @@ class Interceptor {
bool EnqueueRemoteInterceptorMessage( bool EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message); const InterceptorMessage& interceptor_message);
void Send(int64_t dst_id, std::unique_ptr<InterceptorMessage> msg); void Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT
DISABLE_COPY_AND_ASSIGN(Interceptor); DISABLE_COPY_AND_ASSIGN(Interceptor);
......
...@@ -31,10 +31,7 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( ...@@ -31,10 +31,7 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
<< ", with the message: " << request->message_type(); << ", with the message: " << request->message_type();
response->set_rst(true); response->set_rst(true);
// call interceptor manager's method to handle the message // call interceptor manager's method to handle the message
std::shared_ptr<Carrier> carrier = FleetExecutor::GetCarrier(); Carrier::Instance().EnqueueInterceptorMessage(*request);
if (carrier != nullptr) {
carrier->EnqueueInterceptorMessage(*request);
}
} }
} // namespace distributed } // namespace distributed
......
...@@ -32,10 +32,7 @@ void MessageBus::Init( ...@@ -32,10 +32,7 @@ void MessageBus::Init(
rank_to_addr_ = rank_to_addr; rank_to_addr_ = rank_to_addr;
addr_ = addr; addr_ = addr;
listen_port_thread_ = std::thread([this]() { ListenPort();
VLOG(3) << "Start listen_port_thread_ for message bus";
ListenPort();
});
std::call_once(once_flag_, []() { std::call_once(once_flag_, []() {
std::atexit([]() { MessageBus::Instance().Release(); }); std::atexit([]() { MessageBus::Instance().Release(); });
...@@ -51,7 +48,6 @@ void MessageBus::Release() { ...@@ -51,7 +48,6 @@ void MessageBus::Release() {
server_.Stop(1000); server_.Stop(1000);
server_.Join(); server_.Join();
#endif #endif
listen_port_thread_.join();
} }
bool MessageBus::Send(const InterceptorMessage& interceptor_message) { bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
...@@ -184,11 +180,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { ...@@ -184,11 +180,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) { bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) {
// send the message intra rank (dst is the same rank with src) // send the message intra rank (dst is the same rank with src)
std::shared_ptr<Carrier> carrier = FleetExecutor::GetCarrier(); return Carrier::Instance().EnqueueInterceptorMessage(interceptor_message);
if (carrier != nullptr) {
return carrier->EnqueueInterceptorMessage(interceptor_message);
}
return true;
} }
} // namespace distributed } // namespace distributed
......
...@@ -92,10 +92,6 @@ class MessageBus final { ...@@ -92,10 +92,6 @@ class MessageBus final {
// brpc server // brpc server
brpc::Server server_; brpc::Server server_;
#endif #endif
// thread keeps listening to the port to receive remote message
// this thread runs ListenPort() function
std::thread listen_port_thread_;
}; };
} // namespace distributed } // namespace distributed
......
set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS})
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
class PingPongInterceptor : public Interceptor {
public:
PingPongInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { PingPong(msg); });
}
void PingPong(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
++count_;
if (count_ == 20) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
return;
}
InterceptorMessage resp;
Send(msg.src_id(), resp);
}
private:
int count_{0};
};
TEST(InterceptorTest, PingPong) {
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
Carrier& carrier = Carrier::Instance();
Interceptor* a = carrier.SetInterceptor(
0, std::make_unique<PingPongInterceptor>(0, nullptr));
carrier.SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr));
InterceptorMessage msg;
a->Send(1, msg);
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册