未验证 提交 742378f4 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Add interceptor ping pong test (#37143)

上级 63c8c8c2
......@@ -5,7 +5,7 @@ endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if(WITH_DISTRIBUTE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set(BRPC_DEPS brpc ssl crypto)
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
else()
set(BRPC_DEPS "")
endif()
......@@ -23,4 +23,6 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
add_subdirectory(test)
endif()
......@@ -20,9 +20,9 @@
namespace paddle {
namespace distributed {
Carrier::Carrier(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node)
: interceptor_id_to_node_(interceptor_id_to_node) {
void Carrier::Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
interceptor_id_to_node_ = interceptor_id_to_node;
CreateInterceptors();
}
......@@ -56,20 +56,29 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
return iter->second.get();
}
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor> interceptor) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(),
platform::errors::AlreadyExists(
"The interceptor id %lld has already been created! "
"The interceptor id should be unique.",
interceptor_id));
auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
return ptr;
}
void Carrier::CreateInterceptors() {
// create each Interceptor
for (const auto& item : interceptor_id_to_node_) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
const auto& iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(),
platform::errors::AlreadyExists(
"The interceptor id %lld has already been created! "
"The interceptor is should be unique.",
interceptor_id));
interceptor_idx_to_interceptor_.insert(std::make_pair(
interceptor_id,
std::make_unique<Interceptor>(interceptor_id, task_node)));
// TODO(wangxi): use node_type to select different Interceptor
auto interceptor = std::make_unique<Interceptor>(interceptor_id, task_node);
SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor for " << interceptor_id;
}
}
......
......@@ -18,6 +18,7 @@
#include <string>
#include <unordered_map>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
......@@ -26,15 +27,18 @@
namespace paddle {
namespace distributed {
class Interceptor;
class TaskNode;
class InterceptorMessageServiceImpl;
// A singleton MessageBus
class Carrier final {
public:
Carrier() = delete;
static Carrier& Instance() {
static Carrier carrier;
return carrier;
}
explicit Carrier(
void Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node);
~Carrier() = default;
......@@ -42,15 +46,21 @@ class Carrier final {
// Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
// get interceptor based on the interceptor id
Interceptor* GetInterceptor(int64_t interceptor_id);
// set interceptor with interceptor id
Interceptor* SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor>);
DISABLE_COPY_AND_ASSIGN(Carrier);
private:
Carrier() = default;
// create each Interceptor
void CreateInterceptors();
// get interceptor based on the interceptor id
Interceptor* GetInterceptor(int64_t interceptor_id);
// interceptor logic id to the Nodes info
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
......
......@@ -76,10 +76,5 @@ void FleetExecutor::Release() {
// Release
}
std::shared_ptr<Carrier> FleetExecutor::GetCarrier() {
// get carrier
return nullptr;
}
} // namespace distributed
} // namespace paddle
......@@ -43,7 +43,6 @@ class FleetExecutor final {
FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_;
void InitMessageBus();
static std::shared_ptr<Carrier> global_carrier_;
};
} // namespace distributed
......
......@@ -56,11 +56,10 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
return true;
}
void Interceptor::Send(int64_t dst_id,
std::unique_ptr<InterceptorMessage> msg) {
msg->set_src_id(interceptor_id_);
msg->set_dst_id(dst_id);
MessageBus::Instance().Send(*msg.get());
void Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
msg.set_src_id(interceptor_id_);
msg.set_dst_id(dst_id);
MessageBus::Instance().Send(msg);
}
void Interceptor::PoolTheMailbox() {
......
......@@ -58,7 +58,7 @@ class Interceptor {
bool EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message);
void Send(int64_t dst_id, std::unique_ptr<InterceptorMessage> msg);
void Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT
DISABLE_COPY_AND_ASSIGN(Interceptor);
......
......@@ -31,10 +31,7 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
<< ", with the message: " << request->message_type();
response->set_rst(true);
// call interceptor manager's method to handle the message
std::shared_ptr<Carrier> carrier = FleetExecutor::GetCarrier();
if (carrier != nullptr) {
carrier->EnqueueInterceptorMessage(*request);
}
Carrier::Instance().EnqueueInterceptorMessage(*request);
}
} // namespace distributed
......
......@@ -32,10 +32,7 @@ void MessageBus::Init(
rank_to_addr_ = rank_to_addr;
addr_ = addr;
listen_port_thread_ = std::thread([this]() {
VLOG(3) << "Start listen_port_thread_ for message bus";
ListenPort();
});
ListenPort();
std::call_once(once_flag_, []() {
std::atexit([]() { MessageBus::Instance().Release(); });
......@@ -51,7 +48,6 @@ void MessageBus::Release() {
server_.Stop(1000);
server_.Join();
#endif
listen_port_thread_.join();
}
bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
......@@ -184,11 +180,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) {
// send the message intra rank (dst is the same rank with src)
std::shared_ptr<Carrier> carrier = FleetExecutor::GetCarrier();
if (carrier != nullptr) {
return carrier->EnqueueInterceptorMessage(interceptor_message);
}
return true;
return Carrier::Instance().EnqueueInterceptorMessage(interceptor_message);
}
} // namespace distributed
......
......@@ -92,10 +92,6 @@ class MessageBus final {
// brpc server
brpc::Server server_;
#endif
// thread keeps listening to the port to receive remote message
// this thread runs ListenPort() function
std::thread listen_port_thread_;
};
} // namespace distributed
......
set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS})
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
class PingPongInterceptor : public Interceptor {
public:
PingPongInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { PingPong(msg); });
}
void PingPong(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
++count_;
if (count_ == 20) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
return;
}
InterceptorMessage resp;
Send(msg.src_id(), resp);
}
private:
int count_{0};
};
TEST(InterceptorTest, PingPong) {
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
Carrier& carrier = Carrier::Instance();
Interceptor* a = carrier.SetInterceptor(
0, std::make_unique<PingPongInterceptor>(0, nullptr));
carrier.SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr));
InterceptorMessage msg;
a->Send(1, msg);
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册