未验证 提交 769e5bc4 编写于 作者: L LiYuRio 提交者: GitHub

[fleet_executor] Support multi carriers (#38709)

上级 7f3b0877
......@@ -19,6 +19,9 @@ cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime
if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
......@@ -71,17 +71,13 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
if (interceptor_message.ctrl_message()) {
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
// for barrier
msg_bus_->IncreaseBarrierCount();
} else {
PADDLE_ENFORCE_EQ(
interceptor_message.ctrl_message(), false,
platform::errors::Fatal(
"Control message should be only send inter rank using message bus."));
int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id);
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
}
return true;
}
......@@ -106,11 +102,6 @@ void Carrier::WakeUp() {
}
void Carrier::Start() {
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
PADDLE_ENFORCE_EQ(is_init_, true, platform::errors::PreconditionNotMet(
"Using carrier before initialized."));
for (int64_t id : source_interceptor_ids_) {
......@@ -154,19 +145,10 @@ bool Carrier::Send(const InterceptorMessage& msg) {
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
platform::errors::Unavailable("Message bus is released accidently"));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return msg_bus_->Send(dst_rank, msg);
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
}
}
......
......@@ -73,10 +73,6 @@ class Carrier final {
Interceptor* SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor>);
void SetMsgBus(const std::shared_ptr<MessageBus>& msg_bus) {
msg_bus_ = msg_bus;
}
void Start();
bool IsInit() const;
......@@ -107,7 +103,6 @@ class Carrier final {
framework::Scope* minibatch_scope_;
paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
std::string carrier_id_;
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -32,6 +32,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
"Error occurs while parsing string to proto"));
// Message bus will be created and inited only once
GlobalVal<MessageBus>::Create();
InitMessageBus();
}
FleetExecutor::~FleetExecutor() {
......@@ -81,21 +84,16 @@ void FleetExecutor::Init(
CopyParameters(i, program_desc);
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier_ids_.insert(carrier_id);
GlobalVal<std::string>::Set(carrier_id);
// TODO(liyurui): Maybe message bus should be created only once
// Set current running carrier
GlobalVal<std::string>::Set(new std::string(carrier_id));
InitCarrier(carrier);
InitMessageBus();
// Wait for all message bus connected.
msg_bus_->Barrier();
GlobalVal<MessageBus>::Get()->Barrier();
}
void FleetExecutor::InitCarrier(Carrier* carrier) {
carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(),
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
......@@ -131,14 +129,18 @@ void FleetExecutor::InitMessageBus() {
VLOG(3) << "The number of ranks are "
<< (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << ".";
VLOG(5) << ss.str();
if (!msg_bus_->IsInit()) {
msg_bus_->Init(cur_rank, rank_to_addr, addr);
}
GlobalVal<MessageBus>::Get()->Init(cur_rank, rank_to_addr, addr);
}
void FleetExecutor::Run(const std::string& carrier_id) {
GlobalMap<std::string, Carrier>::Get(carrier_id)->Start();
GlobalVal<std::string>::Set(carrier_id);
Carrier* carrier = GlobalMap<std::string, Carrier>::Get(carrier_id);
// Set current running carrier
if (*GlobalVal<std::string>::Get() != carrier_id) {
GlobalVal<std::string>::Set(new std::string(carrier_id));
// TODO(liyurui): Move barrier to service
GlobalVal<MessageBus>::Get()->Barrier();
}
carrier->Start();
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
......
......@@ -55,9 +55,6 @@ class FleetExecutor final {
framework::Scope* minibatch_scope_;
platform::Place place_;
std::vector<framework::Scope*> microbatch_scopes_;
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
std::unordered_set<std::string> carrier_ids_;
};
......
......@@ -14,24 +14,41 @@
#pragma once
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
// TODO(liyurui): Change this file to global.h
template <typename T>
class GlobalVal final {
public:
static T Get() { return *GetPtr(); }
static T Set(T val) {
auto* ptr = GetPtr();
*ptr = val;
return val;
static T* Get() {
T* ptr = GetPPtr()->get();
PADDLE_ENFORCE_NOT_NULL(
ptr, platform::errors::NotFound("This value is not global value."));
return ptr;
}
template <typename... Args>
static T* Create(Args&&... args) {
auto* ptr = GetPPtr();
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
platform::errors::AlreadyExists(
"This value is already a global value."));
T* item = new T(std::forward<Args>(args)...);
ptr->reset(item);
return item;
}
static T* Set(T* new_item) {
auto* ptr = GetPPtr();
ptr->reset(new_item);
return ptr->get();
}
private:
static T* GetPtr() {
static T value;
return &value;
static std::unique_ptr<T>* GetPPtr() {
static std::unique_ptr<T> ptr;
return &ptr;
}
};
......
......@@ -15,8 +15,8 @@
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
......@@ -29,9 +29,7 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
const auto& carrier_id = GlobalVal<std::string>::Get();
bool flag = GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(*request);
bool flag = GlobalVal<MessageBus>::Get()->DispatchMsgToCarrier(*request);
response->set_rst(flag);
}
......
......@@ -17,6 +17,8 @@
#include <set>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
......@@ -81,6 +83,10 @@ const std::string& MessageBus::GetAddr(int64_t rank) const {
bool MessageBus::Send(int64_t dst_rank,
const InterceptorMessage& interceptor_message) {
PADDLE_ENFORCE_EQ(
IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized."));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
int retry_time = 0; // message bus will retry sending for 10 times
......@@ -155,6 +161,22 @@ void MessageBus::Barrier() {
}
}
bool MessageBus::DispatchMsgToCarrier(
const InterceptorMessage& interceptor_message) {
if (interceptor_message.ctrl_message()) {
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
// for barrier
IncreaseBarrierCount();
return true;
} else {
const std::string& carrier_id = *GlobalVal<std::string>::Get();
return GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(interceptor_message);
}
}
void MessageBus::ListenPort() {
if (addr_ == "") {
LOG(INFO) << "No need listen to port since training on single card.";
......
......@@ -54,6 +54,7 @@ class MessageBus final {
void IncreaseBarrierCount();
void Barrier();
bool DispatchMsgToCarrier(const InterceptorMessage& interceptor_message);
private:
DISABLE_COPY_AND_ASSIGN(MessageBus);
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -67,9 +67,8 @@ TEST(ComputeInterceptor, Compute) {
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}});
auto msg_bus = std::make_shared<MessageBus>();
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier->SetMsgBus(msg_bus);
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a =
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -52,9 +52,8 @@ TEST(ComputeInterceptor, Compute) {
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}});
auto msg_bus = std::make_shared<MessageBus>();
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier->SetMsgBus(msg_bus);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
......@@ -64,9 +64,8 @@ TEST(InterceptorTest, PingPong) {
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}});
auto msg_bus = std::make_shared<MessageBus>();
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier->SetMsgBus(msg_bus);
Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr));
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
......@@ -112,12 +112,10 @@ TEST(InterceptorTest, PingPong) {
if (pid == 0) {
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
GlobalVal<std::string>::Set(carrier_id);
auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus);
// NOTE: need Init msg_bus after carrier SetMsgBus
carrier->Init(0, interceptor_id_to_rank);
GlobalVal<std::string>::Set(new std::string(carrier_id));
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
carrier->Init(0, interceptor_id_to_rank);
Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr));
msg_bus->Barrier();
......@@ -127,11 +125,10 @@ TEST(InterceptorTest, PingPong) {
} else {
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
GlobalVal<std::string>::Set(carrier_id);
auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus);
carrier->Init(1, interceptor_id_to_rank);
GlobalVal<std::string>::Set(new std::string(carrier_id));
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->Init(1, interceptor_id_to_rank);
carrier->SetInterceptor(1,
InterceptorFactory::Create("PingPong", 1, nullptr));
msg_bus->Barrier();
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -56,9 +56,8 @@ TEST(AmplifierInterceptor, Amplifier) {
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}});
auto msg_bus = std::make_shared<MessageBus>();
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
carrier->SetMsgBus(msg_bus);
int64_t micro_steps = 3;
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -74,9 +74,8 @@ TEST(AmplifierInterceptor, Amplifier) {
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}});
auto msg_bus = std::make_shared<MessageBus>();
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, ""}}, "");
carrier->SetMsgBus(msg_bus);
int64_t micro_steps = 6;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册