未验证 提交 989e39a5 编写于 作者: L LiYuRio 提交者: GitHub

Modified compute and amplifier interceptor (#42044)

上级 39c6765a
...@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() { ...@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() {
// run_per_steps_, run_at_offset_ // run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12 // 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15 // 4, 3 --> run at step 3, 7, 11, 15
if ((step_ % run_per_steps_) == run_at_offset_) { if ((cur_scope_id_ % run_per_steps_) == run_at_offset_) {
ComputeInterceptor::RunOps(); ComputeInterceptor::RunOps();
} }
} }
...@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() { ...@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() {
void AmplifierInterceptor::SendDataReadyToDownStream() { void AmplifierInterceptor::SendDataReadyToDownStream() {
// run multi times, send ready one times to downstream, that is // run multi times, send ready one times to downstream, that is
// input multi times, output one times // input multi times, output one times
if (step_ % send_down_per_steps_ == 0) { if (cur_scope_id_ % send_down_per_steps_ == 0) {
ComputeInterceptor::SendDataReadyToDownStream(); ComputeInterceptor::SendDataReadyToDownStream();
} }
} }
...@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() { ...@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() {
void AmplifierInterceptor::ReplyCompletedToUpStream() { void AmplifierInterceptor::ReplyCompletedToUpStream() {
// run multi times, reply one times to upstream, that is // run multi times, reply one times to upstream, that is
// input one times, output multi times // input one times, output multi times
if (step_ % reply_up_per_steps_ == 0) { if (cur_scope_id_ % reply_up_per_steps_ == 0) {
ComputeInterceptor::ReplyCompletedToUpStream(); ComputeInterceptor::ReplyCompletedToUpStream();
} }
} }
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class AmplifierInterceptor : public ComputeInterceptor { class AmplifierInterceptor final : public ComputeInterceptor {
public: public:
AmplifierInterceptor(int64_t interceptor_id, TaskNode* node); AmplifierInterceptor(int64_t interceptor_id, TaskNode* node);
......
...@@ -71,6 +71,9 @@ void Carrier::Init( ...@@ -71,6 +71,9 @@ void Carrier::Init(
microbatch_scopes_[i] = &minibatch_scope_->NewScope(); microbatch_scopes_[i] = &minibatch_scope_->NewScope();
CopyParameters(i, program, inference_root_scope_vars); CopyParameters(i, program, inference_root_scope_vars);
} }
// Add source and sink interceptor id to rank
interceptor_id_to_rank_.emplace(SOURCE_ID, rank);
interceptor_id_to_rank_.emplace(SINK_ID, rank);
// TODO(fleet_exe dev): thread pool // TODO(fleet_exe dev): thread pool
thread_num_ = 1; thread_num_ = 1;
...@@ -159,16 +162,10 @@ void Carrier::Start() { ...@@ -159,16 +162,10 @@ void Carrier::Start() {
true, true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Using carrier before initialized.")); "Using carrier before initialized."));
for (int64_t id : source_interceptor_ids_) { InterceptorMessage start_msg;
VLOG(3) << "Carrier Start is sending start to source interceptor " << id start_msg.set_dst_id(SOURCE_ID);
<< "."; start_msg.set_message_type(START);
InterceptorMessage start_msg; Send(start_msg);
// source node data_is_ready is send by carrier, so set src_id=-1
start_msg.set_src_id(-1);
start_msg.set_dst_id(id);
start_msg.set_message_type(DATA_IS_READY);
Send(start_msg);
}
// TODO(wangxi): async step // TODO(wangxi): async step
Wait(); Wait();
dev_ctx_->Wait(); dev_ctx_->Wait();
...@@ -270,6 +267,38 @@ void Carrier::CreateInterceptors() { ...@@ -270,6 +267,38 @@ void Carrier::CreateInterceptors() {
auto gc = GetGC(place_); auto gc = GetGC(place_);
// create source and sink task node
auto max_run_times = microbatch_scopes_.size();
TaskNode* source = new TaskNode(
rank_, SOURCE_ID, max_run_times); // rank, task_id, max_run_times
TaskNode* sink = new TaskNode(rank_, SINK_ID, max_run_times);
// find nodes without upstreams or without downstreams
std::vector<TaskNode*> origin_sources, origin_sinks;
for (const auto& item : interceptor_id_to_node_) {
TaskNode* task_node = item.second;
if (task_node->upstream().empty()) {
origin_sources.emplace_back(task_node);
}
if (task_node->downstream().empty()) {
origin_sinks.emplace_back(task_node);
}
}
// link source node with origin source
for (const auto& node : origin_sources) {
source->AddDownstreamTask(node->task_id(),
std::numeric_limits<int64_t>::max());
node->AddUpstreamTask(SOURCE_ID, std::numeric_limits<int64_t>::max());
}
// link sink node with origin sink
for (const auto& node : origin_sinks) {
sink->AddUpstreamTask(node->task_id(), std::numeric_limits<int64_t>::max());
node->AddDownstreamTask(SINK_ID, std::numeric_limits<int64_t>::max());
}
// create source and sink interceptor
SetInterceptor(SOURCE_ID,
InterceptorFactory::Create("Source", SOURCE_ID, source));
SetInterceptor(SINK_ID, InterceptorFactory::Create("Sink", SINK_ID, sink));
// create each Interceptor // create each Interceptor
// no auto init since there is no config // no auto init since there is no config
for (const auto& item : interceptor_id_to_node_) { for (const auto& item : interceptor_id_to_node_) {
...@@ -303,9 +332,15 @@ void Carrier::CreateInterceptors() { ...@@ -303,9 +332,15 @@ void Carrier::CreateInterceptors() {
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< " with type: " << task_node->type() << "."; << " with type: " << task_node->type() << ".";
if (task_node->upstream().empty()) { PADDLE_ENFORCE_EQ(
source_interceptor_ids_.emplace_back(interceptor_id); task_node->upstream().empty(),
} false,
platform::errors::PreconditionNotMet(
"There should not have normal nodes as source nodes"));
PADDLE_ENFORCE_EQ(task_node->downstream().empty(),
false,
platform::errors::PreconditionNotMet(
"There should not have normal nodes as sink nodes"));
} }
} }
......
...@@ -100,8 +100,6 @@ class Carrier final { ...@@ -100,8 +100,6 @@ class Carrier final {
std::unordered_map<int64_t, std::unique_ptr<Interceptor>> std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_; interceptor_idx_to_interceptor_;
std::vector<int64_t> source_interceptor_ids_;
bool is_init_{false}; bool is_init_{false};
std::mutex running_mutex_; std::mutex running_mutex_;
......
...@@ -34,29 +34,10 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -34,29 +34,10 @@ void ComputeInterceptor::PrepareDeps() {
for (auto up : upstream) { for (auto up : upstream) {
in_readys_.emplace(up.first, std::make_pair(up.second, 0)); in_readys_.emplace(up.first, std::make_pair(up.second, 0));
in_stops_.emplace(up.first, false);
} }
for (auto down : downstream) { for (auto down : downstream) {
out_buffs_.emplace(down.first, std::make_pair(down.second, 0)); out_buffs_.emplace(down.first, std::make_pair(down.second, 0));
} }
// source compute node, should we add a new SourceInterceptor?
if (upstream.empty()) {
is_source_ = true;
PADDLE_ENFORCE_GT(node_->max_run_times(),
0,
platform::errors::InvalidArgument(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld",
node_->max_run_times()));
in_readys_.emplace(-1,
std::make_pair(std::numeric_limits<int64_t>::max(), 0));
}
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_ = downstream.empty();
} }
void ComputeInterceptor::IncreaseReady(int64_t up_id) { void ComputeInterceptor::IncreaseReady(int64_t up_id) {
...@@ -66,12 +47,6 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { ...@@ -66,12 +47,6 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
platform::errors::NotFound( platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id)); "Cannot find upstream=%lld in in_readys.", up_id));
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) {
it->second.second += GetTaskNode()->max_run_times();
return;
}
auto max_ready_size = it->second.first; auto max_ready_size = it->second.first;
auto ready_size = it->second.second; auto ready_size = it->second.second;
ready_size += 1; ready_size += 1;
...@@ -152,7 +127,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -152,7 +127,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
ready_msg.set_message_type(DATA_IS_READY); ready_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor " << interceptor_id_ VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id << " Send data_is_ready msg to " << down_id
<< " for step: " << step_; << " in scope: " << cur_scope_id_;
Send(down_id, ready_msg); Send(down_id, ready_msg);
} }
} }
...@@ -173,8 +148,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -173,8 +148,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id << " Reply data_is_useless msg to " << up_id
<< " for step: " << step_; << " in scope: " << cur_scope_id_;
if (is_source_ && up_id == -1) return;
InterceptorMessage reply_msg; InterceptorMessage reply_msg;
reply_msg.set_message_type(DATA_IS_USELESS); reply_msg.set_message_type(DATA_IS_USELESS);
...@@ -183,16 +157,20 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -183,16 +157,20 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
} }
void ComputeInterceptor::RunOps() { void ComputeInterceptor::RunOps() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ + 1 << " time.";
for (auto op : node_->ops()) { for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_); PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
op->Run(*microbatch_scopes_[cur_scope_id_], place_);
if (gc_) { if (gc_) {
framework::DeleteUnusedTensors( framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_],
*microbatch_scopes_[step_ % node_->max_run_times()], op,
op, node_->unused_vars(),
node_->unused_vars(), gc_.get());
gc_.get());
} }
} }
} }
...@@ -201,77 +179,28 @@ void ComputeInterceptor::Run() { ...@@ -201,77 +179,28 @@ void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) { while (IsInputReady() && CanWriteOutput()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
// get the ready scope id from queue
cur_scope_id_ = ready_queue_.front();
ready_queue_.pop();
RunOps(); RunOps();
++step_;
// send to downstream and increase buff used // send to downstream and increase buff used
SendDataReadyToDownStream(); SendDataReadyToDownStream();
// reply to upstream and decrease ready data // reply to upstream and decrease ready data
ReplyCompletedToUpStream(); ReplyCompletedToUpStream();
// Try to stop Carrier
if (is_last_ && (step_ % node_->max_run_times() == 0)) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " is stopping carrier.";
// FIXME(wangxi): with multi sink interceptor
StopCarrier();
}
}
}
void ComputeInterceptor::ReceivedStop(int64_t up_id) {
received_stop_ = true;
// source node has no upstream, stop is send by carrier or others
if (is_source_ && up_id == -1) return;
auto it = in_stops_.find(up_id);
PADDLE_ENFORCE_NE(it,
in_stops_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_stops.", up_id));
PADDLE_ENFORCE_EQ(
it->second,
false,
platform::errors::AlreadyExists("Already received stop from %lld, stop "
"cannot be send more than once."));
it->second = true;
}
void ComputeInterceptor::TryStop() {
if (!received_stop_) return;
// can stop only when all upstream is stop and
// downstream complete
for (auto& in_stop : in_stops_) {
if (!in_stop.second) return;
}
for (auto& out_buff : out_buffs_) {
auto used_size = out_buff.second.second;
if (used_size != 0) return;
} }
// send stop to downstream
for (auto& out : out_buffs_) {
auto down_id = out.first;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(down_id, stop);
}
stop_ = true;
} }
void ComputeInterceptor::Compute(const InterceptorMessage& msg) { void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) { if (msg.message_type() == DATA_IS_READY) {
IncreaseReady(msg.src_id()); IncreaseReady(msg.src_id());
ready_queue_.push(msg.scope_idx());
Run(); Run();
} else if (msg.message_type() == DATA_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
DecreaseBuff(msg.src_id()); DecreaseBuff(msg.src_id());
Run(); Run();
} else if (msg.message_type() == STOP) {
ReceivedStop(msg.src_id());
} }
TryStop();
} }
REGISTER_INTERCEPTOR(Compute, ComputeInterceptor); REGISTER_INTERCEPTOR(Compute, ComputeInterceptor);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <queue>
#include <utility> #include <utility>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...@@ -30,7 +31,8 @@ class ComputeInterceptor : public Interceptor { ...@@ -30,7 +31,8 @@ class ComputeInterceptor : public Interceptor {
virtual void SendDataReadyToDownStream(); virtual void SendDataReadyToDownStream();
virtual void ReplyCompletedToUpStream(); virtual void ReplyCompletedToUpStream();
int64_t step_{0}; std::queue<int64_t> ready_queue_;
int64_t cur_scope_id_;
private: private:
void PrepareDeps(); void PrepareDeps();
...@@ -43,19 +45,10 @@ class ComputeInterceptor : public Interceptor { ...@@ -43,19 +45,10 @@ class ComputeInterceptor : public Interceptor {
void Run(); void Run();
void Compute(const InterceptorMessage& msg); void Compute(const InterceptorMessage& msg);
void ReceivedStop(int64_t up_id);
void TryStop();
bool is_source_{false};
bool is_last_{false};
// upstream_id-->(max_ready_size, ready_size) // upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{}; std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size) // downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{}; std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
bool received_stop_{false};
std::map<int64_t, bool> in_stops_{};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -93,7 +93,6 @@ class Interceptor { ...@@ -93,7 +93,6 @@ class Interceptor {
TaskNode* node_; TaskNode* node_;
// for stop // for stop
bool stop_{false};
void StopCarrier(); void StopCarrier();
// for runtime // for runtime
...@@ -114,9 +113,6 @@ class Interceptor { ...@@ -114,9 +113,6 @@ class Interceptor {
std::mutex mutex_; std::mutex mutex_;
std::deque<InterceptorMessage> messages_; std::deque<InterceptorMessage> messages_;
int64_t already_run_times_{0};
int64_t used_slot_nums_{0};
}; };
class InterceptorFactory { class InterceptorFactory {
......
...@@ -25,7 +25,7 @@ namespace distributed { ...@@ -25,7 +25,7 @@ namespace distributed {
* 1. record the num of micro-step * 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished * 2. check whether to notify carrier the current step is finished
*/ */
class SinkInterceptor : public Interceptor { class SinkInterceptor final : public Interceptor {
public: public:
SinkInterceptor(int64_t interceptor_id, TaskNode* node); SinkInterceptor(int64_t interceptor_id, TaskNode* node);
......
...@@ -25,7 +25,7 @@ namespace distributed { ...@@ -25,7 +25,7 @@ namespace distributed {
* 1. receive `start` message from carrier * 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream * 2. send num_of_steps `data_is_ready` message to downstream
*/ */
class SourceInterceptor : public Interceptor { class SourceInterceptor final : public Interceptor {
public: public:
SourceInterceptor(int64_t interceptor_id, TaskNode* node); SourceInterceptor(int64_t interceptor_id, TaskNode* node);
......
...@@ -25,57 +25,42 @@ limitations under the License. */ ...@@ -25,57 +25,42 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class StartInterceptor : public Interceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(1, stop); // stop 1, compute
return;
}
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
}
};
TEST(ComputeInterceptor, Compute) { TEST(ComputeInterceptor, Compute) {
std::string carrier_id = "0"; std::string carrier_id = "0";
Carrier* carrier = Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}); carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create(); MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id TaskNode* source =
new TaskNode(0, SOURCE_ID, 3); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0);
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); TaskNode* sink = new TaskNode(0, SINK_ID, 3);
// a->b->c // source->a->b->sink
source->AddDownstreamTask(0);
node_a->AddUpstreamTask(SOURCE_ID);
node_a->AddDownstreamTask(1, 3); node_a->AddDownstreamTask(1, 3);
node_b->AddUpstreamTask(0, 3); node_b->AddUpstreamTask(0, 3);
node_b->AddDownstreamTask(2); node_b->AddDownstreamTask(SINK_ID);
node_c->AddUpstreamTask(1); sink->AddUpstreamTask(1);
Interceptor* a = carrier->SetInterceptor(
carrier->SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a)); SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY); msg.set_message_type(START);
// test run three times msg.set_dst_id(SOURCE_ID);
a->Send(1, msg); carrier->EnqueueInterceptorMessage(msg);
a->Send(1, msg);
a->Send(1, msg);
carrier->Wait(); carrier->Wait();
carrier->Release(); carrier->Release();
......
...@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor { ...@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor {
void PingPong(const InterceptorMessage& msg) { void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) { if (msg.message_type() == STOP) {
stop_ = true;
return; return;
} }
std::cout << GetInterceptorId() << " recv msg, count=" << count_ std::cout << GetInterceptorId() << " recv msg, count=" << count_
......
...@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor { ...@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor {
void PingPong(const InterceptorMessage& msg) { void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) { if (msg.message_type() == STOP) {
stop_ = true;
StopCarrier(); StopCarrier();
return; return;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册