未验证 提交 5163c538 编写于 作者: L LiYuRio 提交者: GitHub

[FleetExecutor] Optimize gap between generation step (#53072)

* optimize gap between generation step

* remove useless header
上级 b1d3ec16
......@@ -17,6 +17,7 @@
#include <algorithm>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
......@@ -237,12 +238,10 @@ bool Carrier::Send(const InterceptorMessage& msg) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg);
} else {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
}
VLOG(3) << "Send a message from interceptor " << src_id << " to interceptor "
<< dst_id << ", which are in different ranks.";
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
}
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
......
......@@ -158,25 +158,47 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
}
bool ComputeInterceptor::IsInputReady() {
for (int64_t i = 0; i < node_->max_run_times(); ++i) {
std::map<int64_t, bool> scope_id_to_finish_flag;
if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
scope_id_to_finish_flag =
gen_step_to_scope_id_to_finish_flag_.begin()->second;
VLOG(3) << "Is Input Ready in gen step "
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
}
int64_t num_micro_step =
(num_micro_step_ == -1 ? node_->max_run_times() : num_micro_step_);
int64_t start_micro_step = (start_micro_step_ == -1 ? 0 : start_micro_step_);
for (int64_t i = start_micro_step; i < start_micro_step + num_micro_step;
++i) {
bool flag = true;
for (auto& ins : in_readys_) {
auto ready_size_map = ins.second.second;
flag = flag && (ready_size_map.at(i) != 0);
}
if (flag) {
for (auto iter : scope_id_to_finish_flag_) {
if (iter.first == i) {
break;
} else if (!iter.second) {
VLOG(3) << "The previous scope is not ready, waiting for the "
"previous scope "
<< iter.first;
return false;
if (scope_id_to_finish_flag.empty()) {
cur_scope_id_ = i;
return true;
} else if (scope_id_to_finish_flag.find(i) !=
scope_id_to_finish_flag.end()) {
for (auto iter : scope_id_to_finish_flag) {
if (iter.first == i) {
break;
} else if (!iter.second) {
VLOG(3) << "The previous scope is not ready, waiting for the "
"previous scope "
<< iter.first << " in gen_step "
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
return false;
}
}
cur_scope_id_ = i;
return true;
} else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< " is larger than gen_step "
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
}
cur_scope_id_ = i;
return true;
} else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< "'s upstreams aren't all ready.";
......@@ -203,6 +225,16 @@ bool ComputeInterceptor::CanWriteOutput() {
}
void ComputeInterceptor::SendDataReadyToDownStream() {
bool need_send_vars = !(node_->vars_to_dtype().empty());
InterceptorMessage ready_msg;
ready_msg.set_start_micro_step(start_micro_step_);
ready_msg.set_num_micro_step(num_micro_step_);
if (need_send_vars) {
ready_msg = PrepareVarsMsg();
} else {
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
}
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
auto max_buff_size = outs.second.first;
......@@ -221,13 +253,17 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
}
outs.second.second = used_size;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << cur_scope_id_;
Send(down_id, ready_msg);
if (need_send_vars) {
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_with_vars msg to " << down_id
<< " in scope: " << cur_scope_id_;
Send(down_id, ready_msg);
} else {
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << cur_scope_id_;
Send(down_id, ready_msg);
}
}
}
......@@ -289,13 +325,21 @@ void ComputeInterceptor::Run() {
RunOps();
if (!scope_id_to_finish_flag_.empty()) {
if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
auto iter = gen_step_to_scope_id_to_finish_flag_.begin();
VLOG(3) << "id=" << GetInterceptorId()
<< " ComputeInterceptor running in scope " << cur_scope_id_
<< " with gen_step " << iter->first;
auto& scope_id_to_finish_flag = iter->second;
PADDLE_ENFORCE_NE(
scope_id_to_finish_flag_.find(cur_scope_id_),
scope_id_to_finish_flag_.end(),
scope_id_to_finish_flag.find(cur_scope_id_),
scope_id_to_finish_flag.end(),
platform::errors::NotFound(
"Can not find scope %ld in scope_id_to_finish", cur_scope_id_));
scope_id_to_finish_flag_.erase(cur_scope_id_);
scope_id_to_finish_flag.erase(cur_scope_id_);
if (scope_id_to_finish_flag.empty()) {
gen_step_to_scope_id_to_finish_flag_.erase(iter);
}
}
// send to downstream and increase buff used
......@@ -310,6 +354,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " ";
start_micro_step_ = msg.start_micro_step();
num_micro_step_ = msg.num_micro_step();
IncreaseReady(msg.src_id(), msg.scope_idx());
Run();
} else if (msg.message_type() == DATA_IS_USELESS) {
......@@ -327,10 +373,14 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
Run();
} else if (msg.message_type() == START_LOOP) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive start_loop " << msg.src_id() << " " << msg.scope_idx()
<< " ";
<< " receive start_loop " << msg.src_id() << " in scope "
<< msg.scope_idx() << " with gen_step " << msg.gen_step();
start_micro_step_ = msg.start_micro_step();
num_micro_step_ = msg.num_micro_step();
IncreaseReady(msg.src_id(), msg.scope_idx());
scope_id_to_finish_flag_.emplace(msg.scope_idx(), false);
int64_t gen_step = msg.gen_step();
gen_step_to_scope_id_to_finish_flag_[gen_step].emplace(msg.scope_idx(),
false);
Run();
}
}
......
......@@ -52,7 +52,10 @@ class ComputeInterceptor : public Interceptor {
bool IsInputReady();
bool CanWriteOutput();
std::map<int64_t, bool> scope_id_to_finish_flag_;
std::map<int64_t, std::map<int64_t, bool>>
gen_step_to_scope_id_to_finish_flag_;
int64_t start_micro_step_{-1};
int64_t num_micro_step_{-1};
};
} // namespace distributed
......
......@@ -90,13 +90,18 @@ void CondInterceptor::SendDataReady(int64_t down_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
ready_msg.set_start_micro_step(start_micro_step_);
ready_msg.set_num_micro_step(num_micro_step_);
Send(down_id, ready_msg);
}
void CondInterceptor::SendStartLoop(int64_t down_id) {
void CondInterceptor::SendStartLoop(int64_t down_id, int64_t gen_step) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(START_LOOP);
ready_msg.set_scope_idx(cur_scope_id_);
ready_msg.set_gen_step(gen_step);
ready_msg.set_start_micro_step(start_micro_step_);
ready_msg.set_num_micro_step(num_micro_step_);
Send(down_id, ready_msg);
}
......@@ -107,43 +112,36 @@ void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
Send(up_id, ready_msg);
}
void CondInterceptor::Compute() {
void CondInterceptor::Compute(int64_t gen_step) {
bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond;
if (cond) {
VLOG(3) << "Loop again in scope " << cur_scope_id_;
VLOG(3) << "Loop again in scope " << cur_scope_id_ << " gen_step "
<< gen_step;
for (auto& down_id : normal_out_id_) {
SendStartLoop(down_id);
SendStartLoop(down_id, gen_step);
}
++num_of_scopes_;
} else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_;
PADDLE_ENFORCE_NE(scope_id_to_gen_step_.find(cur_scope_id_),
scope_id_to_gen_step_.end(),
platform::errors::InvalidArgument(
"Can not find scope id %ld in scope_id_to_gen_step",
cur_scope_id_));
VLOG(3) << "Finish loop in scope " << cur_scope_id_ << " with "
<< scope_id_to_gen_step_.at(cur_scope_id_) << " generation steps.";
scope_id_to_gen_step_.erase(cur_scope_id_);
SendDataReady(stop_loop_id_);
}
}
void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY ||
msg.message_type() == DATA_WITH_VARS) {
if (msg.src_id() == loop_id_) {
--num_of_scopes_;
VLOG(3) << "Receving loop again message from " << msg.src_id()
<< " waiting other " << num_of_scopes_ << " scopes ready";
ready_scope_id_.emplace_back(msg.scope_idx());
if (num_of_scopes_ == 0) {
std::sort(ready_scope_id_.begin(), ready_scope_id_.end());
for (auto scope_id : ready_scope_id_) {
VLOG(3) << "Start a new loop in scope " << scope_id;
cur_scope_id_ = scope_id;
Compute();
}
ready_scope_id_.clear();
}
} else {
cur_scope_id_ = msg.scope_idx();
Compute();
}
if (msg.message_type() == DATA_IS_READY) {
cur_scope_id_ = msg.scope_idx();
start_micro_step_ = msg.start_micro_step();
num_micro_step_ = msg.num_micro_step();
scope_id_to_gen_step_.emplace(cur_scope_id_, 0);
Compute(/*gen_step=*/0);
} else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
for (auto& up_id : normal_in_id_) {
......@@ -158,6 +156,53 @@ void CondInterceptor::Run(const InterceptorMessage& msg) {
gc_.get());
}
}
} else if (msg.message_type() == DATA_WITH_VARS) {
int64_t scope_id = msg.scope_idx();
PADDLE_ENFORCE_NE(
scope_id_to_gen_step_.find(scope_id),
scope_id_to_gen_step_.end(),
platform::errors::InvalidArgument(
"Can not find scope id %ld in scope_id_to_gen_step", scope_id));
// Keep the message in order with scope_id
// message with scope 3 never send before scope 1.
int64_t gen_step = scope_id_to_gen_step_.at(scope_id) + 1;
bool wait_prev_scope = false;
// If the previous scope gen_step less than cur scope
// means: the previous scope doesn't finish last step generation, should
// wait.
auto iter = scope_id_to_gen_step_.begin();
while (iter != scope_id_to_gen_step_.end()) {
if (iter->first == scope_id) {
break;
}
if (iter->second < gen_step) {
wait_prev_scope = true;
break;
}
++iter;
}
scope_id_to_gen_step_.at(scope_id) = gen_step;
if (!wait_prev_scope) {
// Start send message to all scopes gen_step equal to cur_scope
std::vector<int64_t> ready_scope_ids;
while (iter != scope_id_to_gen_step_.end()) {
if (iter->second == gen_step) {
ready_scope_ids.emplace_back(iter->first);
} else if (iter->second > gen_step) {
PADDLE_THROW(
platform::errors::Fatal("Some error may occur. Scope %ld's "
"gen_step is much larger than previous.",
iter->first));
} else {
break;
}
++iter;
}
for (auto& scope_id : ready_scope_ids) {
cur_scope_id_ = scope_id;
Compute(gen_step);
}
}
}
}
......
......@@ -35,10 +35,10 @@ class CondInterceptor final : public Interceptor {
private:
void PrepareDeps();
void Run(const InterceptorMessage& msg);
void Compute();
void Compute(int64_t gen_step);
bool GetCondResult();
void SendDataReady(int64_t down_id);
void SendStartLoop(int64_t down_id);
void SendStartLoop(int64_t down_id, int64_t gen_step);
void ReplyDataIsUseless(int64_t up_id);
int64_t cur_scope_id_;
......@@ -47,8 +47,9 @@ class CondInterceptor final : public Interceptor {
std::set<int64_t> normal_out_id_;
int64_t stop_loop_id_;
int64_t loop_id_;
int64_t num_of_scopes_{0};
std::vector<int64_t> ready_scope_id_;
std::map<int64_t, int64_t> scope_id_to_gen_step_;
int64_t start_micro_step_;
int64_t num_micro_step_;
};
} // namespace distributed
......
......@@ -40,6 +40,9 @@ message InterceptorMessage {
optional bool ctrl_message = 4 [ default = false ];
optional int64 scope_idx = 5 [ default = 0 ];
repeated VarList vars_list = 6;
optional int64 gen_step = 7 [ default = -1 ];
optional int64 start_micro_step = 8 [ default = -1 ];
optional int64 num_micro_step = 9 [ default = -1 ];
}
message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
......
......@@ -67,13 +67,16 @@ void StartInterceptor::SendDataReadyToDownStream() {
outs.second.second = used_size;
}
if (finish_count_ == batch_size_) {
int64_t start_micro_step = step_ % node_->max_run_times();
for (int64_t i = 0; i < batch_size_; ++i) {
int64_t scope_id = step_ % node_->max_run_times();
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(scope_id);
ready_msg.set_start_micro_step(start_micro_step);
ready_msg.set_num_micro_step(batch_size_);
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(scope_id);
VLOG(3) << "StartInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << scope_id;
......@@ -96,6 +99,15 @@ void StartInterceptor::Compute(const InterceptorMessage& msg) {
<< " " << finish_count_;
finish_count_--;
if (finish_count_ == 0) {
auto end = std::chrono::system_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
end - start_time_);
VLOG(3) << "Spent "
<< double(duration.count()) *
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den
<< " seconds.";
start_time_ = std::chrono::system_clock::now();
for (int64_t i = 0; i < batch_size_; ++i) {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
......
......@@ -14,6 +14,7 @@
#pragma once
#include <chrono>
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
......@@ -33,6 +34,8 @@ class StartInterceptor final : public ComputeInterceptor {
int64_t batch_size_{0};
int64_t finish_count_{0};
int64_t step_{0};
std::chrono::time_point<std::chrono::system_clock> start_time_{
std::chrono::system_clock::now()};
};
} // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册