未验证 提交 989e39a5 编写于 作者: L LiYuRio 提交者: GitHub

Modified compute and amplifier interceptor (#42044)

上级 39c6765a
......@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() {
// run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15
if ((step_ % run_per_steps_) == run_at_offset_) {
if ((cur_scope_id_ % run_per_steps_) == run_at_offset_) {
ComputeInterceptor::RunOps();
}
}
......@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() {
void AmplifierInterceptor::SendDataReadyToDownStream() {
// run multi times, send ready one times to downstream, that is
// input multi times, output one times
if (step_ % send_down_per_steps_ == 0) {
if (cur_scope_id_ % send_down_per_steps_ == 0) {
ComputeInterceptor::SendDataReadyToDownStream();
}
}
......@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() {
void AmplifierInterceptor::ReplyCompletedToUpStream() {
// run multi times, reply one times to upstream, that is
// input one times, output multi times
if (step_ % reply_up_per_steps_ == 0) {
if (cur_scope_id_ % reply_up_per_steps_ == 0) {
ComputeInterceptor::ReplyCompletedToUpStream();
}
}
......
......@@ -21,7 +21,7 @@
namespace paddle {
namespace distributed {
class AmplifierInterceptor : public ComputeInterceptor {
class AmplifierInterceptor final : public ComputeInterceptor {
public:
AmplifierInterceptor(int64_t interceptor_id, TaskNode* node);
......
......@@ -71,6 +71,9 @@ void Carrier::Init(
microbatch_scopes_[i] = &minibatch_scope_->NewScope();
CopyParameters(i, program, inference_root_scope_vars);
}
// Add source and sink interceptor id to rank
interceptor_id_to_rank_.emplace(SOURCE_ID, rank);
interceptor_id_to_rank_.emplace(SINK_ID, rank);
// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
......@@ -159,16 +162,10 @@ void Carrier::Start() {
true,
platform::errors::PreconditionNotMet(
"Using carrier before initialized."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Start is sending start to source interceptor " << id
<< ".";
InterceptorMessage start_msg;
// source node data_is_ready is send by carrier, so set src_id=-1
start_msg.set_src_id(-1);
start_msg.set_dst_id(id);
start_msg.set_message_type(DATA_IS_READY);
Send(start_msg);
}
InterceptorMessage start_msg;
start_msg.set_dst_id(SOURCE_ID);
start_msg.set_message_type(START);
Send(start_msg);
// TODO(wangxi): async step
Wait();
dev_ctx_->Wait();
......@@ -270,6 +267,38 @@ void Carrier::CreateInterceptors() {
auto gc = GetGC(place_);
// create source and sink task node
auto max_run_times = microbatch_scopes_.size();
TaskNode* source = new TaskNode(
rank_, SOURCE_ID, max_run_times); // rank, task_id, max_run_times
TaskNode* sink = new TaskNode(rank_, SINK_ID, max_run_times);
// find nodes without upstreams or without downstreams
std::vector<TaskNode*> origin_sources, origin_sinks;
for (const auto& item : interceptor_id_to_node_) {
TaskNode* task_node = item.second;
if (task_node->upstream().empty()) {
origin_sources.emplace_back(task_node);
}
if (task_node->downstream().empty()) {
origin_sinks.emplace_back(task_node);
}
}
// link source node with origin source
for (const auto& node : origin_sources) {
source->AddDownstreamTask(node->task_id(),
std::numeric_limits<int64_t>::max());
node->AddUpstreamTask(SOURCE_ID, std::numeric_limits<int64_t>::max());
}
// link sink node with origin sink
for (const auto& node : origin_sinks) {
sink->AddUpstreamTask(node->task_id(), std::numeric_limits<int64_t>::max());
node->AddDownstreamTask(SINK_ID, std::numeric_limits<int64_t>::max());
}
// create source and sink interceptor
SetInterceptor(SOURCE_ID,
InterceptorFactory::Create("Source", SOURCE_ID, source));
SetInterceptor(SINK_ID, InterceptorFactory::Create("Sink", SINK_ID, sink));
// create each Interceptor
// no auto init since there is no config
for (const auto& item : interceptor_id_to_node_) {
......@@ -303,9 +332,15 @@ void Carrier::CreateInterceptors() {
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< " with type: " << task_node->type() << ".";
if (task_node->upstream().empty()) {
source_interceptor_ids_.emplace_back(interceptor_id);
}
PADDLE_ENFORCE_EQ(
task_node->upstream().empty(),
false,
platform::errors::PreconditionNotMet(
"There should not have normal nodes as source nodes"));
PADDLE_ENFORCE_EQ(task_node->downstream().empty(),
false,
platform::errors::PreconditionNotMet(
"There should not have normal nodes as sink nodes"));
}
}
......
......@@ -100,8 +100,6 @@ class Carrier final {
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
std::vector<int64_t> source_interceptor_ids_;
bool is_init_{false};
std::mutex running_mutex_;
......
......@@ -34,29 +34,10 @@ void ComputeInterceptor::PrepareDeps() {
for (auto up : upstream) {
in_readys_.emplace(up.first, std::make_pair(up.second, 0));
in_stops_.emplace(up.first, false);
}
for (auto down : downstream) {
out_buffs_.emplace(down.first, std::make_pair(down.second, 0));
}
// source compute node, should we add a new SourceInterceptor?
if (upstream.empty()) {
is_source_ = true;
PADDLE_ENFORCE_GT(node_->max_run_times(),
0,
platform::errors::InvalidArgument(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld",
node_->max_run_times()));
in_readys_.emplace(-1,
std::make_pair(std::numeric_limits<int64_t>::max(), 0));
}
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_ = downstream.empty();
}
void ComputeInterceptor::IncreaseReady(int64_t up_id) {
......@@ -66,12 +47,6 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id));
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) {
it->second.second += GetTaskNode()->max_run_times();
return;
}
auto max_ready_size = it->second.first;
auto ready_size = it->second.second;
ready_size += 1;
......@@ -152,7 +127,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
ready_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " for step: " << step_;
<< " in scope: " << cur_scope_id_;
Send(down_id, ready_msg);
}
}
......@@ -173,8 +148,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id
<< " for step: " << step_;
if (is_source_ && up_id == -1) return;
<< " in scope: " << cur_scope_id_;
InterceptorMessage reply_msg;
reply_msg.set_message_type(DATA_IS_USELESS);
......@@ -183,16 +157,20 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
void ComputeInterceptor::RunOps() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ + 1 << " time.";
for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
op->Run(*microbatch_scopes_[cur_scope_id_], place_);
if (gc_) {
framework::DeleteUnusedTensors(
*microbatch_scopes_[step_ % node_->max_run_times()],
op,
node_->unused_vars(),
gc_.get());
framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_],
op,
node_->unused_vars(),
gc_.get());
}
}
}
......@@ -201,77 +179,28 @@ void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
// get the ready scope id from queue
cur_scope_id_ = ready_queue_.front();
ready_queue_.pop();
RunOps();
++step_;
// send to downstream and increase buff used
SendDataReadyToDownStream();
// reply to upstream and decrease ready data
ReplyCompletedToUpStream();
// Try to stop Carrier
if (is_last_ && (step_ % node_->max_run_times() == 0)) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " is stopping carrier.";
// FIXME(wangxi): with multi sink interceptor
StopCarrier();
}
}
}
void ComputeInterceptor::ReceivedStop(int64_t up_id) {
received_stop_ = true;
// source node has no upstream, stop is send by carrier or others
if (is_source_ && up_id == -1) return;
auto it = in_stops_.find(up_id);
PADDLE_ENFORCE_NE(it,
in_stops_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_stops.", up_id));
PADDLE_ENFORCE_EQ(
it->second,
false,
platform::errors::AlreadyExists("Already received stop from %lld, stop "
"cannot be send more than once."));
it->second = true;
}
void ComputeInterceptor::TryStop() {
if (!received_stop_) return;
// can stop only when all upstream is stop and
// downstream complete
for (auto& in_stop : in_stops_) {
if (!in_stop.second) return;
}
for (auto& out_buff : out_buffs_) {
auto used_size = out_buff.second.second;
if (used_size != 0) return;
}
// send stop to downstream
for (auto& out : out_buffs_) {
auto down_id = out.first;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(down_id, stop);
}
stop_ = true;
}
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
IncreaseReady(msg.src_id());
ready_queue_.push(msg.scope_idx());
Run();
} else if (msg.message_type() == DATA_IS_USELESS) {
DecreaseBuff(msg.src_id());
Run();
} else if (msg.message_type() == STOP) {
ReceivedStop(msg.src_id());
}
TryStop();
}
REGISTER_INTERCEPTOR(Compute, ComputeInterceptor);
......
......@@ -14,6 +14,7 @@
#pragma once
#include <queue>
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
......@@ -30,7 +31,8 @@ class ComputeInterceptor : public Interceptor {
virtual void SendDataReadyToDownStream();
virtual void ReplyCompletedToUpStream();
int64_t step_{0};
std::queue<int64_t> ready_queue_;
int64_t cur_scope_id_;
private:
void PrepareDeps();
......@@ -43,19 +45,10 @@ class ComputeInterceptor : public Interceptor {
void Run();
void Compute(const InterceptorMessage& msg);
void ReceivedStop(int64_t up_id);
void TryStop();
bool is_source_{false};
bool is_last_{false};
// upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
bool received_stop_{false};
std::map<int64_t, bool> in_stops_{};
};
} // namespace distributed
......
......@@ -93,7 +93,6 @@ class Interceptor {
TaskNode* node_;
// for stop
bool stop_{false};
void StopCarrier();
// for runtime
......@@ -114,9 +113,6 @@ class Interceptor {
std::mutex mutex_;
std::deque<InterceptorMessage> messages_;
int64_t already_run_times_{0};
int64_t used_slot_nums_{0};
};
class InterceptorFactory {
......
......@@ -25,7 +25,7 @@ namespace distributed {
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
*/
class SinkInterceptor : public Interceptor {
class SinkInterceptor final : public Interceptor {
public:
SinkInterceptor(int64_t interceptor_id, TaskNode* node);
......
......@@ -25,7 +25,7 @@ namespace distributed {
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
*/
class SourceInterceptor : public Interceptor {
class SourceInterceptor final : public Interceptor {
public:
SourceInterceptor(int64_t interceptor_id, TaskNode* node);
......
......@@ -25,57 +25,42 @@ limitations under the License. */
namespace paddle {
namespace distributed {
class StartInterceptor : public Interceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(1, stop); // stop 1, compute
return;
}
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
}
};
TEST(ComputeInterceptor, Compute) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}});
carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* source =
new TaskNode(0, SOURCE_ID, 3); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0);
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, 3);
// a->b->c
// source->a->b->sink
source->AddDownstreamTask(0);
node_a->AddUpstreamTask(SOURCE_ID);
node_a->AddDownstreamTask(1, 3);
node_b->AddUpstreamTask(0, 3);
node_b->AddDownstreamTask(2);
node_c->AddUpstreamTask(1);
node_b->AddDownstreamTask(SINK_ID);
sink->AddUpstreamTask(1);
Interceptor* a =
carrier->SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
// test run three times
a->Send(1, msg);
a->Send(1, msg);
a->Send(1, msg);
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
......
......@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor {
void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
......
......@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor {
void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
StopCarrier();
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册