未验证 提交 0adc2006 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] auto STOP msg and auto notify carrier (#37742)

上级 79095918
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -24,14 +25,14 @@ namespace distributed { ...@@ -24,14 +25,14 @@ namespace distributed {
USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Compute);
void Carrier::Init( void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node, framework::Scope* root_scope,
framework::Scope* root_scope, framework::Scope* minibatch_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes, const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) { const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init.")); "Carrier is already init."));
interceptor_id_to_node_ = interceptor_id_to_node; runtime_graph_ = runtime_graph;
minibatch_scope_ = minibatch_scope; minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes; microbatch_scopes_ = microbatch_scopes;
place_ = place; place_ = place;
...@@ -41,15 +42,34 @@ void Carrier::Init( ...@@ -41,15 +42,34 @@ void Carrier::Init(
is_init_ = true; is_init_ = true;
} }
Carrier::~Carrier() { void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct, // NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete. // otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
MessageBus& msg_bus = MessageBus::Instance();
PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true,
platform::errors::PreconditionNotMet(
"Message bus has not been initialized."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
InterceptorMessage stop_msg;
// source node STOP is send by carrier, so set src_id=-1
stop_msg.set_src_id(-1);
stop_msg.set_dst_id(id);
stop_msg.set_message_type(STOP);
msg_bus.Send(stop_msg);
}
// TODO(wangxi): Maybe need a better to use thread. // TODO(wangxi): Maybe need a better to use thread.
for (auto& interceptor : interceptor_idx_to_interceptor_) { for (auto& interceptor : interceptor_idx_to_interceptor_) {
interceptor.second->Join(); interceptor.second->Join();
} }
} }
Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
bool Carrier::EnqueueInterceptorMessage( bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) { const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor // enqueue message to interceptor
...@@ -139,6 +159,17 @@ void Carrier::SetCreatingFlag(bool flag) { ...@@ -139,6 +159,17 @@ void Carrier::SetCreatingFlag(bool flag) {
creating_interceptors_ = flag; creating_interceptors_ = flag;
creating_flag_mutex_.unlock(); creating_flag_mutex_.unlock();
if (!flag) { if (!flag) {
for (auto& pair : interceptor_idx_to_interceptor_) {
// update the source interceptor id
if (std::find(source_interceptor_ids_.begin(),
source_interceptor_ids_.end(),
pair.first) == source_interceptor_ids_.end()) {
auto task = pair.second->GetTaskNode();
if (task != nullptr && task->upstream().empty()) {
source_interceptor_ids_.emplace_back(pair.first);
}
}
}
// finish create interceptors outside, handle tmp messsages // finish create interceptors outside, handle tmp messsages
HandleTmpMessages(); HandleTmpMessages();
} }
...@@ -161,9 +192,9 @@ void Carrier::HandleTmpMessages() { ...@@ -161,9 +192,9 @@ void Carrier::HandleTmpMessages() {
void Carrier::CreateInterceptors() { void Carrier::CreateInterceptors() {
// create each Interceptor // create each Interceptor
if (!interceptor_id_to_node_.empty()) { if (!(runtime_graph_->intercepter_id_to_node().empty())) {
// no auto init since there is no config // no auto init since there is no config
for (const auto& item : interceptor_id_to_node_) { for (const auto& item : runtime_graph_->intercepter_id_to_node()) {
int64_t interceptor_id = item.first; int64_t interceptor_id = item.first;
TaskNode* task_node = item.second; TaskNode* task_node = item.second;
......
...@@ -39,6 +39,7 @@ namespace distributed { ...@@ -39,6 +39,7 @@ namespace distributed {
class TaskNode; class TaskNode;
class InterceptorMessageServiceImpl; class InterceptorMessageServiceImpl;
class RuntimeGraph;
// A singleton MessageBus // A singleton MessageBus
class Carrier final { class Carrier final {
...@@ -48,13 +49,13 @@ class Carrier final { ...@@ -48,13 +49,13 @@ class Carrier final {
return carrier; return carrier;
} }
void Init( void Init(std::shared_ptr<RuntimeGraph> runtime_graph,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node, framework::Scope* root_scope, framework::Scope* minibatch_scope,
framework::Scope* root_scope, framework::Scope* minibatch_scope, const std::vector<framework::Scope*>& microbatch_scopes,
const std::vector<framework::Scope*>& microbatch_scopes, const platform::Place& place);
const platform::Place& place);
~Carrier(); ~Carrier();
void Release();
// Enqueue a message to corresponding interceptor id // Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
...@@ -84,9 +85,6 @@ class Carrier final { ...@@ -84,9 +85,6 @@ class Carrier final {
void HandleTmpMessages(); void HandleTmpMessages();
// interceptor logic id to the Nodes info
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
// interceptor logic id to actually interceptor // interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>> std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_; interceptor_idx_to_interceptor_;
...@@ -105,7 +103,8 @@ class Carrier final { ...@@ -105,7 +103,8 @@ class Carrier final {
framework::Scope* root_scope_; framework::Scope* root_scope_;
framework::Scope* minibatch_scope_; framework::Scope* minibatch_scope_;
paddle::platform::Place place_; paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_ = nullptr; paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -51,6 +51,11 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -51,6 +51,11 @@ void ComputeInterceptor::PrepareDeps() {
"times, but now max_run_times=%ld", "times, but now max_run_times=%ld",
node_->max_run_times())); node_->max_run_times()));
} }
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_ = downstream.empty();
} }
void ComputeInterceptor::IncreaseReady(int64_t up_id) { void ComputeInterceptor::IncreaseReady(int64_t up_id) {
...@@ -129,7 +134,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -129,7 +134,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY); ready_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor Send data_is_ready msg to " << down_id; VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id;
Send(down_id, ready_msg); Send(down_id, ready_msg);
} }
} }
...@@ -148,7 +154,8 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -148,7 +154,8 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
InterceptorMessage reply_msg; InterceptorMessage reply_msg;
reply_msg.set_message_type(DATE_IS_USELESS); reply_msg.set_message_type(DATE_IS_USELESS);
VLOG(3) << "ComputeInterceptor Reply data_is_useless msg to " << up_id; VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id;
Send(up_id, reply_msg); Send(up_id, reply_msg);
} }
} }
...@@ -159,7 +166,7 @@ void ComputeInterceptor::Run() { ...@@ -159,7 +166,7 @@ void ComputeInterceptor::Run() {
// step_ %= node_->max_run_times(); // step_ %= node_->max_run_times();
for (auto op : node_->ops()) { for (auto op : node_->ops()) {
auto* scope = microbatch_scopes_[step_ % node_->max_slot_nums()]; auto* scope = microbatch_scopes_[step_ % node_->max_run_times()];
op->Run(*scope, place_); op->Run(*scope, place_);
} }
++step_; ++step_;
...@@ -168,6 +175,10 @@ void ComputeInterceptor::Run() { ...@@ -168,6 +175,10 @@ void ComputeInterceptor::Run() {
SendDataReadyToDownStream(); SendDataReadyToDownStream();
// reply to upstream and decrease ready data // reply to upstream and decrease ready data
ReplyCompletedToUpStream(); ReplyCompletedToUpStream();
// Try to stop Carrier
if (step_ % node_->max_run_times() == 0 && is_last_) {
StopCarrier();
}
} }
// If there is no limit, source interceptor can be executed // If there is no limit, source interceptor can be executed
...@@ -221,11 +232,6 @@ void ComputeInterceptor::TryStop() { ...@@ -221,11 +232,6 @@ void ComputeInterceptor::TryStop() {
Send(down_id, stop); Send(down_id, stop);
} }
stop_ = true; stop_ = true;
if (out_buffs_.size() == 0) {
// TODO(fleet executor dev) need a better place to notify
StopCarrier();
}
} }
void ComputeInterceptor::Compute(const InterceptorMessage& msg) { void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
......
...@@ -44,6 +44,7 @@ class ComputeInterceptor : public Interceptor { ...@@ -44,6 +44,7 @@ class ComputeInterceptor : public Interceptor {
private: private:
bool is_source_{false}; bool is_source_{false};
bool is_last_{false};
int64_t step_{0}; int64_t step_{0};
// upstream_id-->(max_ready_size, ready_size) // upstream_id-->(max_ready_size, ready_size)
......
...@@ -38,7 +38,7 @@ FleetExecutor::~FleetExecutor() { ...@@ -38,7 +38,7 @@ FleetExecutor::~FleetExecutor() {
void FleetExecutor::Init(const framework::ProgramDesc& program_desc, void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
framework::Scope* scope, framework::Scope* scope,
const platform::Place& place) { const platform::Place& place) {
runtime_graph_ = std::make_unique<RuntimeGraph>(program_desc, exe_desc_); runtime_graph_ = std::make_shared<RuntimeGraph>(program_desc, exe_desc_);
root_scope_ = scope; root_scope_ = scope;
place_ = place; place_ = place;
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument( PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
...@@ -58,8 +58,8 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc, ...@@ -58,8 +58,8 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
void FleetExecutor::InitCarrier() { void FleetExecutor::InitCarrier() {
Carrier& carrier_instance = Carrier::Instance(); Carrier& carrier_instance = Carrier::Instance();
if (!carrier_instance.IsInit()) { if (!carrier_instance.IsInit()) {
carrier_instance.Init(runtime_graph_->intercepter_id_to_node(), root_scope_, carrier_instance.Init(runtime_graph_, root_scope_, minibatch_scope_,
minibatch_scope_, microbatch_scopes_, place_); microbatch_scopes_, place_);
} }
} }
......
...@@ -47,7 +47,7 @@ class FleetExecutor final { ...@@ -47,7 +47,7 @@ class FleetExecutor final {
void InitCarrier(); void InitCarrier();
void CopyParameters(int microbatch_id, const framework::ProgramDesc& program); void CopyParameters(int microbatch_id, const framework::ProgramDesc& program);
FleetExecutorDesc exe_desc_; FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_; std::shared_ptr<RuntimeGraph> runtime_graph_;
framework::Scope* root_scope_; framework::Scope* root_scope_;
framework::Scope* minibatch_scope_; framework::Scope* minibatch_scope_;
platform::Place place_; platform::Place place_;
......
...@@ -46,7 +46,6 @@ void Interceptor::Handle(const InterceptorMessage& msg) { ...@@ -46,7 +46,6 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
VLOG(3) << "Interceptor is using default message handler. This handler is " VLOG(3) << "Interceptor is using default message handler. This handler is "
"only used for test purpose. Check whether you init interceptor " "only used for test purpose. Check whether you init interceptor "
"in the proper way."; "in the proper way.";
if (msg.message_type() == DATA_IS_READY) { if (msg.message_type() == DATA_IS_READY) {
if (node_->role() != 2) { if (node_->role() != 2) {
VLOG(3) << "Fake handler is sending DATA_IS_READY message to: " VLOG(3) << "Fake handler is sending DATA_IS_READY message to: "
...@@ -54,14 +53,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) { ...@@ -54,14 +53,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
InterceptorMessage data_is_ready_msg; InterceptorMessage data_is_ready_msg;
data_is_ready_msg.set_message_type(DATA_IS_READY); data_is_ready_msg.set_message_type(DATA_IS_READY);
Send(interceptor_id_ + 1, data_is_ready_msg); Send(interceptor_id_ + 1, data_is_ready_msg);
} else {
// NOTE: max run time is reach for last interceptor
StopCarrier();
} }
VLOG(3) << "Fake handler is sending stop message to it self.";
InterceptorMessage stop_msg;
stop_msg.set_message_type(STOP);
Send(interceptor_id_, stop_msg);
} else if (msg.message_type() == STOP) { } else if (msg.message_type() == STOP) {
stop_ = true; stop_ = true;
StopCarrier(); if (node_->role() != 2) {
VLOG(3) << "Fake handler is sending STOP message to: "
<< interceptor_id_ + 1 << ".";
InterceptorMessage stop_msg;
stop_msg.set_message_type(STOP);
Send(interceptor_id_ + 1, stop_msg);
}
} }
} }
} }
......
...@@ -57,6 +57,10 @@ bool MessageBus::IsInit() const { return is_init_; } ...@@ -57,6 +57,10 @@ bool MessageBus::IsInit() const { return is_init_; }
MessageBus::~MessageBus() { MessageBus::~MessageBus() {
VLOG(3) << "Message bus releases resource."; VLOG(3) << "Message bus releases resource.";
// NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first
Carrier& carrier = Carrier::Instance();
carrier.Release();
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000); server_.Stop(1000);
......
...@@ -61,15 +61,15 @@ TEST(ComputeInterceptor, Compute) { ...@@ -61,15 +61,15 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope}; std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace(); platform::Place place = platform::CPUPlace();
Carrier& carrier = Carrier::Instance();
MessageBus& msg_bus = MessageBus::Instance(); MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
Carrier& carrier = Carrier::Instance();
// FIXME: don't delete, otherwise interceptor will use undefined node // FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = TaskNode* node_a =
new TaskNode(0, ops, 0, 0, 2, 2); // role, ops, rank, task_id new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0);
// a->b // a->b
node_a->AddDownstreamTask(1); node_a->AddDownstreamTask(1);
...@@ -90,13 +90,6 @@ TEST(ComputeInterceptor, Compute) { ...@@ -90,13 +90,6 @@ TEST(ComputeInterceptor, Compute) {
msg.set_src_id(-1); msg.set_src_id(-1);
msg.set_dst_id(0); msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg); carrier.EnqueueInterceptorMessage(msg);
// stop
InterceptorMessage stop;
stop.set_message_type(STOP);
stop.set_src_id(-1);
stop.set_dst_id(0);
carrier.EnqueueInterceptorMessage(stop);
} }
} // namespace distributed } // namespace distributed
......
...@@ -35,31 +35,25 @@ class StartInterceptor : public Interceptor { ...@@ -35,31 +35,25 @@ class StartInterceptor : public Interceptor {
void NOP(const InterceptorMessage& msg) { void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) { if (msg.message_type() == STOP) {
stop_ = true; stop_ = true;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(1, stop); // stop 1, compute
return; return;
} }
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl; << std::endl;
++count_;
if (count_ == 3) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(msg.dst_id(), stop); // stop 0, this
Send(msg.src_id(), stop); // stop 1, compute
}
} }
int count_{0};
}; };
TEST(ComputeInterceptor, Compute) { TEST(ComputeInterceptor, Compute) {
Carrier& carrier = Carrier::Instance();
MessageBus& msg_bus = MessageBus::Instance(); MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
Carrier& carrier = Carrier::Instance();
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
// a->b->c // a->b->c
node_a->AddDownstreamTask(1); node_a->AddDownstreamTask(1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册