未验证 提交 5163c538 编写于 作者: L LiYuRio 提交者: GitHub

[FleetExecutor] Optimize gap between generation step (#53072)

* optimize gap between generation step

* remove useless header
上级 b1d3ec16
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -237,12 +238,10 @@ bool Carrier::Send(const InterceptorMessage& msg) { ...@@ -237,12 +238,10 @@ bool Carrier::Send(const InterceptorMessage& msg) {
VLOG(3) << "Send a message from interceptor " << src_id VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks."; << " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg); return EnqueueInterceptorMessage(msg);
} else {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
} }
VLOG(3) << "Send a message from interceptor " << src_id << " to interceptor "
<< dst_id << ", which are in different ranks.";
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
} }
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
......
...@@ -158,25 +158,47 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) { ...@@ -158,25 +158,47 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
} }
bool ComputeInterceptor::IsInputReady() { bool ComputeInterceptor::IsInputReady() {
for (int64_t i = 0; i < node_->max_run_times(); ++i) { std::map<int64_t, bool> scope_id_to_finish_flag;
if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
scope_id_to_finish_flag =
gen_step_to_scope_id_to_finish_flag_.begin()->second;
VLOG(3) << "Is Input Ready in gen step "
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
}
int64_t num_micro_step =
(num_micro_step_ == -1 ? node_->max_run_times() : num_micro_step_);
int64_t start_micro_step = (start_micro_step_ == -1 ? 0 : start_micro_step_);
for (int64_t i = start_micro_step; i < start_micro_step + num_micro_step;
++i) {
bool flag = true; bool flag = true;
for (auto& ins : in_readys_) { for (auto& ins : in_readys_) {
auto ready_size_map = ins.second.second; auto ready_size_map = ins.second.second;
flag = flag && (ready_size_map.at(i) != 0); flag = flag && (ready_size_map.at(i) != 0);
} }
if (flag) { if (flag) {
for (auto iter : scope_id_to_finish_flag_) { if (scope_id_to_finish_flag.empty()) {
if (iter.first == i) { cur_scope_id_ = i;
break; return true;
} else if (!iter.second) { } else if (scope_id_to_finish_flag.find(i) !=
VLOG(3) << "The previous scope is not ready, waiting for the " scope_id_to_finish_flag.end()) {
"previous scope " for (auto iter : scope_id_to_finish_flag) {
<< iter.first; if (iter.first == i) {
return false; break;
} else if (!iter.second) {
VLOG(3) << "The previous scope is not ready, waiting for the "
"previous scope "
<< iter.first << " in gen_step "
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
return false;
}
} }
cur_scope_id_ = i;
return true;
} else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< " is larger than gen_step "
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
} }
cur_scope_id_ = i;
return true;
} else { } else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< "'s upstreams aren't all ready."; << "'s upstreams aren't all ready.";
...@@ -203,6 +225,16 @@ bool ComputeInterceptor::CanWriteOutput() { ...@@ -203,6 +225,16 @@ bool ComputeInterceptor::CanWriteOutput() {
} }
void ComputeInterceptor::SendDataReadyToDownStream() { void ComputeInterceptor::SendDataReadyToDownStream() {
bool need_send_vars = !(node_->vars_to_dtype().empty());
InterceptorMessage ready_msg;
ready_msg.set_start_micro_step(start_micro_step_);
ready_msg.set_num_micro_step(num_micro_step_);
if (need_send_vars) {
ready_msg = PrepareVarsMsg();
} else {
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
}
for (auto& outs : out_buffs_) { for (auto& outs : out_buffs_) {
auto down_id = outs.first; auto down_id = outs.first;
auto max_buff_size = outs.second.first; auto max_buff_size = outs.second.first;
...@@ -221,13 +253,17 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -221,13 +253,17 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
} }
outs.second.second = used_size; outs.second.second = used_size;
InterceptorMessage ready_msg; if (need_send_vars) {
ready_msg.set_message_type(DATA_IS_READY); VLOG(3) << "ComputeInterceptor " << interceptor_id_
ready_msg.set_scope_idx(cur_scope_id_); << " Send data_with_vars msg to " << down_id
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " in scope: " << cur_scope_id_;
<< " Send data_is_ready msg to " << down_id Send(down_id, ready_msg);
<< " in scope: " << cur_scope_id_; } else {
Send(down_id, ready_msg); VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << cur_scope_id_;
Send(down_id, ready_msg);
}
} }
} }
...@@ -289,13 +325,21 @@ void ComputeInterceptor::Run() { ...@@ -289,13 +325,21 @@ void ComputeInterceptor::Run() {
RunOps(); RunOps();
if (!scope_id_to_finish_flag_.empty()) { if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
auto iter = gen_step_to_scope_id_to_finish_flag_.begin();
VLOG(3) << "id=" << GetInterceptorId()
<< " ComputeInterceptor running in scope " << cur_scope_id_
<< " with gen_step " << iter->first;
auto& scope_id_to_finish_flag = iter->second;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
scope_id_to_finish_flag_.find(cur_scope_id_), scope_id_to_finish_flag.find(cur_scope_id_),
scope_id_to_finish_flag_.end(), scope_id_to_finish_flag.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Can not find scope %ld in scope_id_to_finish", cur_scope_id_)); "Can not find scope %ld in scope_id_to_finish", cur_scope_id_));
scope_id_to_finish_flag_.erase(cur_scope_id_); scope_id_to_finish_flag.erase(cur_scope_id_);
if (scope_id_to_finish_flag.empty()) {
gen_step_to_scope_id_to_finish_flag_.erase(iter);
}
} }
// send to downstream and increase buff used // send to downstream and increase buff used
...@@ -310,6 +354,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { ...@@ -310,6 +354,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
VLOG(3) << "Compute interceptor " << interceptor_id_ VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_is_ready " << msg.src_id() << " " << " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " "; << msg.scope_idx() << " ";
start_micro_step_ = msg.start_micro_step();
num_micro_step_ = msg.num_micro_step();
IncreaseReady(msg.src_id(), msg.scope_idx()); IncreaseReady(msg.src_id(), msg.scope_idx());
Run(); Run();
} else if (msg.message_type() == DATA_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
...@@ -327,10 +373,14 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { ...@@ -327,10 +373,14 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
Run(); Run();
} else if (msg.message_type() == START_LOOP) { } else if (msg.message_type() == START_LOOP) {
VLOG(3) << "Compute interceptor " << interceptor_id_ VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive start_loop " << msg.src_id() << " " << msg.scope_idx() << " receive start_loop " << msg.src_id() << " in scope "
<< " "; << msg.scope_idx() << " with gen_step " << msg.gen_step();
start_micro_step_ = msg.start_micro_step();
num_micro_step_ = msg.num_micro_step();
IncreaseReady(msg.src_id(), msg.scope_idx()); IncreaseReady(msg.src_id(), msg.scope_idx());
scope_id_to_finish_flag_.emplace(msg.scope_idx(), false); int64_t gen_step = msg.gen_step();
gen_step_to_scope_id_to_finish_flag_[gen_step].emplace(msg.scope_idx(),
false);
Run(); Run();
} }
} }
......
...@@ -52,7 +52,10 @@ class ComputeInterceptor : public Interceptor { ...@@ -52,7 +52,10 @@ class ComputeInterceptor : public Interceptor {
bool IsInputReady(); bool IsInputReady();
bool CanWriteOutput(); bool CanWriteOutput();
std::map<int64_t, bool> scope_id_to_finish_flag_; std::map<int64_t, std::map<int64_t, bool>>
gen_step_to_scope_id_to_finish_flag_;
int64_t start_micro_step_{-1};
int64_t num_micro_step_{-1};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -90,13 +90,18 @@ void CondInterceptor::SendDataReady(int64_t down_id) { ...@@ -90,13 +90,18 @@ void CondInterceptor::SendDataReady(int64_t down_id) {
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY); ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_); ready_msg.set_scope_idx(cur_scope_id_);
ready_msg.set_start_micro_step(start_micro_step_);
ready_msg.set_num_micro_step(num_micro_step_);
Send(down_id, ready_msg); Send(down_id, ready_msg);
} }
void CondInterceptor::SendStartLoop(int64_t down_id) { void CondInterceptor::SendStartLoop(int64_t down_id, int64_t gen_step) {
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(START_LOOP); ready_msg.set_message_type(START_LOOP);
ready_msg.set_scope_idx(cur_scope_id_); ready_msg.set_scope_idx(cur_scope_id_);
ready_msg.set_gen_step(gen_step);
ready_msg.set_start_micro_step(start_micro_step_);
ready_msg.set_num_micro_step(num_micro_step_);
Send(down_id, ready_msg); Send(down_id, ready_msg);
} }
...@@ -107,43 +112,36 @@ void CondInterceptor::ReplyDataIsUseless(int64_t up_id) { ...@@ -107,43 +112,36 @@ void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
Send(up_id, ready_msg); Send(up_id, ready_msg);
} }
void CondInterceptor::Compute() { void CondInterceptor::Compute(int64_t gen_step) {
bool cond = GetCondResult(); bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var() VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond; << " with value " << cond;
if (cond) { if (cond) {
VLOG(3) << "Loop again in scope " << cur_scope_id_; VLOG(3) << "Loop again in scope " << cur_scope_id_ << " gen_step "
<< gen_step;
for (auto& down_id : normal_out_id_) { for (auto& down_id : normal_out_id_) {
SendStartLoop(down_id); SendStartLoop(down_id, gen_step);
} }
++num_of_scopes_;
} else { } else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_; PADDLE_ENFORCE_NE(scope_id_to_gen_step_.find(cur_scope_id_),
scope_id_to_gen_step_.end(),
platform::errors::InvalidArgument(
"Can not find scope id %ld in scope_id_to_gen_step",
cur_scope_id_));
VLOG(3) << "Finish loop in scope " << cur_scope_id_ << " with "
<< scope_id_to_gen_step_.at(cur_scope_id_) << " generation steps.";
scope_id_to_gen_step_.erase(cur_scope_id_);
SendDataReady(stop_loop_id_); SendDataReady(stop_loop_id_);
} }
} }
void CondInterceptor::Run(const InterceptorMessage& msg) { void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY || if (msg.message_type() == DATA_IS_READY) {
msg.message_type() == DATA_WITH_VARS) { cur_scope_id_ = msg.scope_idx();
if (msg.src_id() == loop_id_) { start_micro_step_ = msg.start_micro_step();
--num_of_scopes_; num_micro_step_ = msg.num_micro_step();
VLOG(3) << "Receving loop again message from " << msg.src_id() scope_id_to_gen_step_.emplace(cur_scope_id_, 0);
<< " waiting other " << num_of_scopes_ << " scopes ready"; Compute(/*gen_step=*/0);
ready_scope_id_.emplace_back(msg.scope_idx());
if (num_of_scopes_ == 0) {
std::sort(ready_scope_id_.begin(), ready_scope_id_.end());
for (auto scope_id : ready_scope_id_) {
VLOG(3) << "Start a new loop in scope " << scope_id;
cur_scope_id_ = scope_id;
Compute();
}
ready_scope_id_.clear();
}
} else {
cur_scope_id_ = msg.scope_idx();
Compute();
}
} else if (msg.message_type() == DATA_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) { if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
for (auto& up_id : normal_in_id_) { for (auto& up_id : normal_in_id_) {
...@@ -158,6 +156,53 @@ void CondInterceptor::Run(const InterceptorMessage& msg) { ...@@ -158,6 +156,53 @@ void CondInterceptor::Run(const InterceptorMessage& msg) {
gc_.get()); gc_.get());
} }
} }
} else if (msg.message_type() == DATA_WITH_VARS) {
int64_t scope_id = msg.scope_idx();
PADDLE_ENFORCE_NE(
scope_id_to_gen_step_.find(scope_id),
scope_id_to_gen_step_.end(),
platform::errors::InvalidArgument(
"Can not find scope id %ld in scope_id_to_gen_step", scope_id));
// Keep the message in order with scope_id
// message with scope 3 never send before scope 1.
int64_t gen_step = scope_id_to_gen_step_.at(scope_id) + 1;
bool wait_prev_scope = false;
// If the previous scope gen_step less than cur scope
// means: the previous scope doesn't finish last step generation, should
// wait.
auto iter = scope_id_to_gen_step_.begin();
while (iter != scope_id_to_gen_step_.end()) {
if (iter->first == scope_id) {
break;
}
if (iter->second < gen_step) {
wait_prev_scope = true;
break;
}
++iter;
}
scope_id_to_gen_step_.at(scope_id) = gen_step;
if (!wait_prev_scope) {
// Start send message to all scopes gen_step equal to cur_scope
std::vector<int64_t> ready_scope_ids;
while (iter != scope_id_to_gen_step_.end()) {
if (iter->second == gen_step) {
ready_scope_ids.emplace_back(iter->first);
} else if (iter->second > gen_step) {
PADDLE_THROW(
platform::errors::Fatal("Some error may occur. Scope %ld's "
"gen_step is much larger than previous.",
iter->first));
} else {
break;
}
++iter;
}
for (auto& scope_id : ready_scope_ids) {
cur_scope_id_ = scope_id;
Compute(gen_step);
}
}
} }
} }
......
...@@ -35,10 +35,10 @@ class CondInterceptor final : public Interceptor { ...@@ -35,10 +35,10 @@ class CondInterceptor final : public Interceptor {
private: private:
void PrepareDeps(); void PrepareDeps();
void Run(const InterceptorMessage& msg); void Run(const InterceptorMessage& msg);
void Compute(); void Compute(int64_t gen_step);
bool GetCondResult(); bool GetCondResult();
void SendDataReady(int64_t down_id); void SendDataReady(int64_t down_id);
void SendStartLoop(int64_t down_id); void SendStartLoop(int64_t down_id, int64_t gen_step);
void ReplyDataIsUseless(int64_t up_id); void ReplyDataIsUseless(int64_t up_id);
int64_t cur_scope_id_; int64_t cur_scope_id_;
...@@ -47,8 +47,9 @@ class CondInterceptor final : public Interceptor { ...@@ -47,8 +47,9 @@ class CondInterceptor final : public Interceptor {
std::set<int64_t> normal_out_id_; std::set<int64_t> normal_out_id_;
int64_t stop_loop_id_; int64_t stop_loop_id_;
int64_t loop_id_; int64_t loop_id_;
int64_t num_of_scopes_{0}; std::map<int64_t, int64_t> scope_id_to_gen_step_;
std::vector<int64_t> ready_scope_id_; int64_t start_micro_step_;
int64_t num_micro_step_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -40,6 +40,9 @@ message InterceptorMessage { ...@@ -40,6 +40,9 @@ message InterceptorMessage {
optional bool ctrl_message = 4 [ default = false ]; optional bool ctrl_message = 4 [ default = false ];
optional int64 scope_idx = 5 [ default = 0 ]; optional int64 scope_idx = 5 [ default = 0 ];
repeated VarList vars_list = 6; repeated VarList vars_list = 6;
optional int64 gen_step = 7 [ default = -1 ];
optional int64 start_micro_step = 8 [ default = -1 ];
optional int64 num_micro_step = 9 [ default = -1 ];
} }
message InterceptorResponse { optional bool rst = 1 [ default = false ]; } message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
......
...@@ -67,13 +67,16 @@ void StartInterceptor::SendDataReadyToDownStream() { ...@@ -67,13 +67,16 @@ void StartInterceptor::SendDataReadyToDownStream() {
outs.second.second = used_size; outs.second.second = used_size;
} }
if (finish_count_ == batch_size_) { if (finish_count_ == batch_size_) {
int64_t start_micro_step = step_ % node_->max_run_times();
for (int64_t i = 0; i < batch_size_; ++i) { for (int64_t i = 0; i < batch_size_; ++i) {
int64_t scope_id = step_ % node_->max_run_times(); int64_t scope_id = step_ % node_->max_run_times();
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(scope_id);
ready_msg.set_start_micro_step(start_micro_step);
ready_msg.set_num_micro_step(batch_size_);
for (auto& outs : out_buffs_) { for (auto& outs : out_buffs_) {
auto down_id = outs.first; auto down_id = outs.first;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(scope_id);
VLOG(3) << "StartInterceptor " << interceptor_id_ VLOG(3) << "StartInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id << " Send data_is_ready msg to " << down_id
<< " in scope: " << scope_id; << " in scope: " << scope_id;
...@@ -96,6 +99,15 @@ void StartInterceptor::Compute(const InterceptorMessage& msg) { ...@@ -96,6 +99,15 @@ void StartInterceptor::Compute(const InterceptorMessage& msg) {
<< " " << finish_count_; << " " << finish_count_;
finish_count_--; finish_count_--;
if (finish_count_ == 0) { if (finish_count_ == 0) {
auto end = std::chrono::system_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
end - start_time_);
VLOG(3) << "Spent "
<< double(duration.count()) *
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den
<< " seconds.";
start_time_ = std::chrono::system_clock::now();
for (int64_t i = 0; i < batch_size_; ++i) { for (int64_t i = 0; i < batch_size_; ++i) {
for (auto& outs : out_buffs_) { for (auto& outs : out_buffs_) {
auto down_id = outs.first; auto down_id = outs.first;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <chrono>
#include <utility> #include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" #include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
...@@ -33,6 +34,8 @@ class StartInterceptor final : public ComputeInterceptor { ...@@ -33,6 +34,8 @@ class StartInterceptor final : public ComputeInterceptor {
int64_t batch_size_{0}; int64_t batch_size_{0};
int64_t finish_count_{0}; int64_t finish_count_{0};
int64_t step_{0}; int64_t step_{0};
std::chrono::time_point<std::chrono::system_clock> start_time_{
std::chrono::system_clock::now()};
}; };
} // namespace distributed } // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册