未验证 提交 81f4ef4f 编写于 作者: L LiYuRio 提交者: GitHub

optimize overlap between steps (#51974)

上级 92c2dcbd
...@@ -51,6 +51,7 @@ cc_library( ...@@ -51,6 +51,7 @@ cc_library(
op_registry op_registry
executor_gc_helper executor_gc_helper
gflags gflags
flags
glog glog
${BRPC_DEPS}) ${BRPC_DEPS})
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -28,6 +29,8 @@ ...@@ -28,6 +29,8 @@
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
DECLARE_bool(fleetexecutor_debug_mode);
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -48,6 +51,72 @@ void Carrier::Init( ...@@ -48,6 +51,72 @@ void Carrier::Init(
thread_num_ = 1; thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_); thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start(); thread_pool_.Start();
if (FLAGS_fleetexecutor_debug_mode) {
test_thread_ = std::thread([this]() { loop_to_send_msg(); });
cache_begin_ == std::chrono::steady_clock::now();
}
}
void Carrier::loop_to_send_msg() {
// VLOG(3) << "loop_send_msg loop now";
while (1) {
while (1) {
int q_size = 0;
std::chrono::time_point<std::chrono::steady_clock> c_begin;
{
std::lock_guard<std::mutex> lock(running_mutex_);
q_size = messages_for_test_.size();
c_begin = cache_begin_;
}
auto now = std::chrono::steady_clock::now();
auto delta =
std::chrono::duration_cast<std::chrono::milliseconds>(now - c_begin)
.count();
if (q_size < 2 && delta < 5000) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
} else {
VLOG(3) << "messages_for_test_ q_size:" << q_size << ", delta:" << delta
<< ", will send all msg";
break;
}
}
{
std::lock_guard<std::mutex> lock(running_mutex_);
while (!messages_for_test_.empty()) {
auto msg = messages_for_test_.back();
messages_for_test_.pop_back();
int64_t src_id = msg.src_id();
// TODO(liyurui): compatible solution, will be removed completely in the
// future
if (interceptor_id_to_rank_.find(src_id) ==
interceptor_id_to_rank_.end() &&
src_id == SOURCE_ID) {
src_id = msg.dst_id();
}
int64_t dst_id = msg.dst_id();
int64_t dst_rank = GetRank(dst_id);
VLOG(3) << "Send a cached message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks, scope_idx:"
<< msg.scope_idx();
if (!GlobalVal<MessageBus>::Get()->Send(dst_rank, msg)) {
LOG(FATAL) << "send msg error";
}
std::this_thread::sleep_for(std::chrono::milliseconds(2));
}
cache_begin_ = std::chrono::steady_clock::now();
}
}
VLOG(3) << "reset cache_begin_";
} }
void Carrier::Init( void Carrier::Init(
...@@ -95,6 +164,11 @@ void Carrier::Init( ...@@ -95,6 +164,11 @@ void Carrier::Init(
thread_pool_.SetThreadNum(thread_num_); thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start(); thread_pool_.Start();
if (FLAGS_fleetexecutor_debug_mode) {
test_thread_ = std::thread([this]() { loop_to_send_msg(); });
cache_begin_ == std::chrono::steady_clock::now();
}
CreateInterceptors(); CreateInterceptors();
is_init_ = true; is_init_ = true;
} }
...@@ -230,12 +304,39 @@ bool Carrier::Send(const InterceptorMessage& msg) { ...@@ -230,12 +304,39 @@ bool Carrier::Send(const InterceptorMessage& msg) {
VLOG(3) << "Send a message from interceptor " << src_id VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks."; << " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg); return EnqueueInterceptorMessage(msg);
} else { }
if (!FLAGS_fleetexecutor_debug_mode) {
VLOG(3) << "Send a message from interceptor " << src_id VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << " to interceptor " << dst_id
<< ", which are in different ranks."; << ", which are in different ranks.";
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg); return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
} }
if (msg.message_type() != DATA_IS_READY) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
}
{
VLOG(3) << "prepare executor debug";
std::unique_lock<std::mutex> lock(running_mutex_);
if (messages_for_test_.empty()) {
cache_begin_ = std::chrono::steady_clock::now();
// std::time_t now_c =
// std::chrono::system_clock::to_time_t(cache_begin_));
VLOG(3) << "messages_for_test_ empty, reset cache_begin_";
}
VLOG(3) << "Cache message from interceptor " << src_id << " to interceptor "
<< dst_id
<< ", which are in different ranks, scope_idx:" << msg.scope_idx();
messages_for_test_.emplace_back(msg);
}
return true;
} }
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
......
...@@ -14,11 +14,14 @@ ...@@ -14,11 +14,14 @@
#pragma once #pragma once
#include <chrono>
#include <condition_variable> #include <condition_variable>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <queue>
#include <set> #include <set>
#include <string> #include <string>
#include <thread>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -118,6 +121,12 @@ class Carrier final { ...@@ -118,6 +121,12 @@ class Carrier final {
int thread_num_; int thread_num_;
TaskLoopThreadPool thread_pool_; TaskLoopThreadPool thread_pool_;
std::unordered_set<int64_t> interceptor_ids_; std::unordered_set<int64_t> interceptor_ids_;
std::deque<InterceptorMessage> messages_for_test_;
std::thread test_thread_;
std::chrono::time_point<std::chrono::steady_clock> cache_begin_;
void loop_to_send_msg();
}; };
} // namespace distributed } // namespace distributed
......
...@@ -29,74 +29,6 @@ ...@@ -29,74 +29,6 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
namespace {
template <typename T>
void SetVarResult(const std::string& name,
T value,
int64_t scope_id,
framework::Scope* scope,
const platform::Place& place,
const std::vector<int64_t>& dim_vec) {
auto* var = scope->FindVar(name);
auto* tensor = var->GetMutable<phi::DenseTensor>();
if (!var) {
VLOG(3) << "Create var and memory for var " << name;
var = scope->Var(name);
phi::DDim dims = phi::make_ddim(dim_vec);
tensor->Resize(dims);
tensor->mutable_data<T>(dims, place);
}
PADDLE_ENFORCE_EQ(
tensor->dims().size(),
1,
platform::errors::OutOfRange("Only support transfer size 1 value."));
PADDLE_ENFORCE_EQ(
tensor->dims().at(0),
1,
platform::errors::OutOfRange("Only support transfer size 1 value."));
if (platform::is_gpu_place(tensor->place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
auto dim = phi::make_ddim({1});
cpu_tensor.mutable_data<T>(dim, platform::CPUPlace());
auto* cpu_tensor_ptr = cpu_tensor.data<T>();
cpu_tensor_ptr[0] = value;
framework::TensorCopySync(cpu_tensor, tensor->place(), tensor);
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
}
template <typename T>
T GetVarResult(const std::string& name,
int64_t scope_id,
framework::Scope* scope) {
auto* var = scope->FindVar(name);
PADDLE_ENFORCE(var,
platform::errors::NotFound(
"Variable %s not exists in scope %ld", name, scope_id));
const auto& tensor = var->Get<phi::DenseTensor>();
T res;
if (platform::is_gpu_place(tensor.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
framework::TensorCopySync(tensor, platform::CPUPlace(), &cpu_tensor);
res = cpu_tensor.data<T>()[0];
#endif
} else if (platform::is_cpu_place(tensor.place())) {
res = tensor.data<T>()[0];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
return res;
}
} // namespace
ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node) ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) { : Interceptor(interceptor_id, node) {
PrepareDeps(); PrepareDeps();
...@@ -172,25 +104,47 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) { ...@@ -172,25 +104,47 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
} }
bool ComputeInterceptor::IsInputReady() { bool ComputeInterceptor::IsInputReady() {
for (int64_t i = 0; i < node_->max_run_times(); ++i) { std::map<int64_t, bool> scope_id_to_finish_flag;
if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
scope_id_to_finish_flag =
gen_step_to_scope_id_to_finish_flag_.begin()->second;
VLOG(3) << "Is Input Ready in gen step "
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
}
int64_t num_micro_step =
(num_micro_step_ == -1 ? node_->max_run_times() : num_micro_step_);
int64_t start_micro_step = (start_micro_step_ == -1 ? 0 : start_micro_step_);
for (int64_t i = start_micro_step; i < start_micro_step + num_micro_step;
++i) {
bool flag = true; bool flag = true;
for (auto& ins : in_readys_) { for (auto& ins : in_readys_) {
auto ready_size_map = ins.second.second; auto ready_size_map = ins.second.second;
flag = flag && (ready_size_map.at(i) != 0); flag = flag && (ready_size_map.at(i) != 0);
} }
if (flag) { if (flag) {
for (auto iter : scope_id_to_finish_flag_) { if (scope_id_to_finish_flag.empty()) {
cur_scope_id_ = i;
return true;
} else if (scope_id_to_finish_flag.find(i) !=
scope_id_to_finish_flag.end()) {
for (auto iter : scope_id_to_finish_flag) {
if (iter.first == i) { if (iter.first == i) {
break; break;
} else if (!iter.second) { } else if (!iter.second) {
VLOG(3) << "The previous scope is not ready, waiting for the " VLOG(3) << "The previous scope is not ready, waiting for the "
"previous scope " "previous scope "
<< iter.first; << iter.first << " in gen_step "
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
return false; return false;
} }
} }
cur_scope_id_ = i; cur_scope_id_ = i;
return true; return true;
} else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< " is larger than gen_step "
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
}
} else { } else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< "'s upstreams aren't all ready."; << "'s upstreams aren't all ready.";
...@@ -217,6 +171,16 @@ bool ComputeInterceptor::CanWriteOutput() { ...@@ -217,6 +171,16 @@ bool ComputeInterceptor::CanWriteOutput() {
} }
void ComputeInterceptor::SendDataReadyToDownStream() { void ComputeInterceptor::SendDataReadyToDownStream() {
bool need_send_vars = !(node_->vars_to_dtype().empty());
InterceptorMessage ready_msg;
ready_msg.set_start_micro_step(start_micro_step_);
ready_msg.set_num_micro_step(num_micro_step_);
if (need_send_vars) {
ready_msg = PrepareVarsMsg();
} else {
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
}
for (auto& outs : out_buffs_) { for (auto& outs : out_buffs_) {
auto down_id = outs.first; auto down_id = outs.first;
auto max_buff_size = outs.second.first; auto max_buff_size = outs.second.first;
...@@ -235,17 +199,12 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -235,17 +199,12 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
} }
outs.second.second = used_size; outs.second.second = used_size;
bool need_send_vars = !(node_->vars_to_dtype().empty());
if (need_send_vars) { if (need_send_vars) {
InterceptorMessage ready_msg = PrepareVarsMsg();
VLOG(3) << "ComputeInterceptor " << interceptor_id_ VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_with_vars msg to " << down_id << " Send data_with_vars msg to " << down_id
<< " in scope: " << cur_scope_id_; << " in scope: " << cur_scope_id_;
Send(down_id, ready_msg); Send(down_id, ready_msg);
} else { } else {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
VLOG(3) << "ComputeInterceptor " << interceptor_id_ VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id << " Send data_is_ready msg to " << down_id
<< " in scope: " << cur_scope_id_; << " in scope: " << cur_scope_id_;
...@@ -339,13 +298,21 @@ void ComputeInterceptor::Run() { ...@@ -339,13 +298,21 @@ void ComputeInterceptor::Run() {
RunOps(); RunOps();
if (!scope_id_to_finish_flag_.empty()) { if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
auto iter = gen_step_to_scope_id_to_finish_flag_.begin();
VLOG(3) << "id=" << GetInterceptorId()
<< " ComputeInterceptor running in scope " << cur_scope_id_
<< " with gen_step " << iter->first;
auto& scope_id_to_finish_flag = iter->second;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
scope_id_to_finish_flag_.find(cur_scope_id_), scope_id_to_finish_flag.find(cur_scope_id_),
scope_id_to_finish_flag_.end(), scope_id_to_finish_flag.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Can not find scope %ld in scope_id_to_finish", cur_scope_id_)); "Can not find scope %ld in scope_id_to_finish", cur_scope_id_));
scope_id_to_finish_flag_.erase(cur_scope_id_); scope_id_to_finish_flag.erase(cur_scope_id_);
if (scope_id_to_finish_flag.empty()) {
gen_step_to_scope_id_to_finish_flag_.erase(iter);
}
} }
// send to downstream and increase buff used // send to downstream and increase buff used
...@@ -385,6 +352,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { ...@@ -385,6 +352,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
VLOG(3) << "Compute interceptor " << interceptor_id_ VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_is_ready " << msg.src_id() << " " << " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " "; << msg.scope_idx() << " ";
start_micro_step_ = msg.start_micro_step();
num_micro_step_ = msg.num_micro_step();
IncreaseReady(msg.src_id(), msg.scope_idx()); IncreaseReady(msg.src_id(), msg.scope_idx());
Run(); Run();
} else if (msg.message_type() == DATA_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
...@@ -402,10 +371,14 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { ...@@ -402,10 +371,14 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
Run(); Run();
} else if (msg.message_type() == START_LOOP) { } else if (msg.message_type() == START_LOOP) {
VLOG(3) << "Compute interceptor " << interceptor_id_ VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive start_loop " << msg.src_id() << " " << msg.scope_idx() << " receive start_loop " << msg.src_id() << " in scope "
<< " "; << msg.scope_idx() << " with gen_step " << msg.gen_step();
start_micro_step_ = msg.start_micro_step();
num_micro_step_ = msg.num_micro_step();
IncreaseReady(msg.src_id(), msg.scope_idx()); IncreaseReady(msg.src_id(), msg.scope_idx());
scope_id_to_finish_flag_.emplace(msg.scope_idx(), false); int64_t gen_step = msg.gen_step();
gen_step_to_scope_id_to_finish_flag_[gen_step].emplace(msg.scope_idx(),
false);
Run(); Run();
} }
} }
......
...@@ -52,7 +52,10 @@ class ComputeInterceptor : public Interceptor { ...@@ -52,7 +52,10 @@ class ComputeInterceptor : public Interceptor {
bool IsInputReady(); bool IsInputReady();
bool CanWriteOutput(); bool CanWriteOutput();
std::map<int64_t, bool> scope_id_to_finish_flag_; std::map<int64_t, std::map<int64_t, bool>>
gen_step_to_scope_id_to_finish_flag_;
int64_t start_micro_step_{-1};
int64_t num_micro_step_{-1};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -90,13 +90,18 @@ void CondInterceptor::SendDataReady(int64_t down_id) { ...@@ -90,13 +90,18 @@ void CondInterceptor::SendDataReady(int64_t down_id) {
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY); ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_); ready_msg.set_scope_idx(cur_scope_id_);
ready_msg.set_start_micro_step(start_micro_step_);
ready_msg.set_num_micro_step(num_micro_step_);
Send(down_id, ready_msg); Send(down_id, ready_msg);
} }
void CondInterceptor::SendStartLoop(int64_t down_id) { void CondInterceptor::SendStartLoop(int64_t down_id, int64_t gen_step) {
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(START_LOOP); ready_msg.set_message_type(START_LOOP);
ready_msg.set_scope_idx(cur_scope_id_); ready_msg.set_scope_idx(cur_scope_id_);
ready_msg.set_gen_step(gen_step);
ready_msg.set_start_micro_step(start_micro_step_);
ready_msg.set_num_micro_step(num_micro_step_);
Send(down_id, ready_msg); Send(down_id, ready_msg);
} }
...@@ -107,43 +112,36 @@ void CondInterceptor::ReplyDataIsUseless(int64_t up_id) { ...@@ -107,43 +112,36 @@ void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
Send(up_id, ready_msg); Send(up_id, ready_msg);
} }
void CondInterceptor::Compute() { void CondInterceptor::Compute(int64_t gen_step) {
bool cond = GetCondResult(); bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var() VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond; << " with value " << cond;
if (cond) { if (cond) {
VLOG(3) << "Loop again in scope " << cur_scope_id_; VLOG(3) << "Loop again in scope " << cur_scope_id_ << " gen_step "
<< gen_step;
for (auto& down_id : normal_out_id_) { for (auto& down_id : normal_out_id_) {
SendStartLoop(down_id); SendStartLoop(down_id, gen_step);
} }
++num_of_scopes_;
} else { } else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_; PADDLE_ENFORCE_NE(scope_id_to_gen_step_.find(cur_scope_id_),
scope_id_to_gen_step_.end(),
platform::errors::InvalidArgument(
"Can not find scope id %ld in scope_id_to_gen_step",
cur_scope_id_));
VLOG(3) << "Finish loop in scope " << cur_scope_id_ << " with "
<< scope_id_to_gen_step_.at(cur_scope_id_) << " generation steps.";
scope_id_to_gen_step_.erase(cur_scope_id_);
SendDataReady(stop_loop_id_); SendDataReady(stop_loop_id_);
} }
} }
void CondInterceptor::Run(const InterceptorMessage& msg) { void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY || if (msg.message_type() == DATA_IS_READY) {
msg.message_type() == DATA_WITH_VARS) {
if (msg.src_id() == loop_id_) {
--num_of_scopes_;
VLOG(3) << "Receving loop again message from " << msg.src_id()
<< " waiting other " << num_of_scopes_ << " scopes ready";
ready_scope_id_.emplace_back(msg.scope_idx());
if (num_of_scopes_ == 0) {
std::sort(ready_scope_id_.begin(), ready_scope_id_.end());
for (auto scope_id : ready_scope_id_) {
VLOG(3) << "Start a new loop in scope " << scope_id;
cur_scope_id_ = scope_id;
Compute();
}
ready_scope_id_.clear();
}
} else {
cur_scope_id_ = msg.scope_idx(); cur_scope_id_ = msg.scope_idx();
Compute(); start_micro_step_ = msg.start_micro_step();
} num_micro_step_ = msg.num_micro_step();
scope_id_to_gen_step_.emplace(cur_scope_id_, 0);
Compute(/*gen_step=*/0);
} else if (msg.message_type() == DATA_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) { if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
for (auto& up_id : normal_in_id_) { for (auto& up_id : normal_in_id_) {
...@@ -158,6 +156,53 @@ void CondInterceptor::Run(const InterceptorMessage& msg) { ...@@ -158,6 +156,53 @@ void CondInterceptor::Run(const InterceptorMessage& msg) {
gc_.get()); gc_.get());
} }
} }
} else if (msg.message_type() == DATA_WITH_VARS) {
int64_t scope_id = msg.scope_idx();
PADDLE_ENFORCE_NE(
scope_id_to_gen_step_.find(scope_id),
scope_id_to_gen_step_.end(),
platform::errors::InvalidArgument(
"Can not find scope id %ld in scope_id_to_gen_step", scope_id));
// Keep the message in order with scope_id
// message with scope 3 never send before scope 1.
int64_t gen_step = scope_id_to_gen_step_.at(scope_id) + 1;
bool wait_prev_scope = false;
// If the previous scope gen_step less than cur scope
// means: the previous scope doesn't finish last step generation, should
// wait.
auto iter = scope_id_to_gen_step_.begin();
while (iter != scope_id_to_gen_step_.end()) {
if (iter->first == scope_id) {
break;
}
if (iter->second < gen_step) {
wait_prev_scope = true;
break;
}
++iter;
}
scope_id_to_gen_step_.at(scope_id) = gen_step;
if (!wait_prev_scope) {
// Start send message to all scopes gen_step equal to cur_scope
std::vector<int64_t> ready_scope_ids;
while (iter != scope_id_to_gen_step_.end()) {
if (iter->second == gen_step) {
ready_scope_ids.emplace_back(iter->first);
} else if (iter->second > gen_step) {
PADDLE_THROW(
platform::errors::Fatal("Some error may occur. Scope %ld's "
"gen_step is much larger than previous.",
iter->first));
} else {
break;
}
++iter;
}
for (auto& scope_id : ready_scope_ids) {
cur_scope_id_ = scope_id;
Compute(gen_step);
}
}
} }
} }
......
...@@ -35,10 +35,10 @@ class CondInterceptor final : public Interceptor { ...@@ -35,10 +35,10 @@ class CondInterceptor final : public Interceptor {
private: private:
void PrepareDeps(); void PrepareDeps();
void Run(const InterceptorMessage& msg); void Run(const InterceptorMessage& msg);
void Compute(); void Compute(int64_t gen_step);
bool GetCondResult(); bool GetCondResult();
void SendDataReady(int64_t down_id); void SendDataReady(int64_t down_id);
void SendStartLoop(int64_t down_id); void SendStartLoop(int64_t down_id, int64_t gen_step);
void ReplyDataIsUseless(int64_t up_id); void ReplyDataIsUseless(int64_t up_id);
int64_t cur_scope_id_; int64_t cur_scope_id_;
...@@ -47,8 +47,9 @@ class CondInterceptor final : public Interceptor { ...@@ -47,8 +47,9 @@ class CondInterceptor final : public Interceptor {
std::set<int64_t> normal_out_id_; std::set<int64_t> normal_out_id_;
int64_t stop_loop_id_; int64_t stop_loop_id_;
int64_t loop_id_; int64_t loop_id_;
int64_t num_of_scopes_{0}; std::map<int64_t, int64_t> scope_id_to_gen_step_;
std::vector<int64_t> ready_scope_id_; int64_t start_micro_step_;
int64_t num_micro_step_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -48,6 +48,9 @@ message InterceptorMessage { ...@@ -48,6 +48,9 @@ message InterceptorMessage {
optional bool ctrl_message = 4 [ default = false ]; optional bool ctrl_message = 4 [ default = false ];
optional int64 scope_idx = 5 [ default = 0 ]; optional int64 scope_idx = 5 [ default = 0 ];
repeated VarList vars_list = 6; repeated VarList vars_list = 6;
optional int64 gen_step = 7 [ default = -1 ];
optional int64 start_micro_step = 8 [ default = -1 ];
optional int64 num_micro_step = 9 [ default = -1 ];
} }
message InterceptorResponse { optional bool rst = 1 [ default = false ]; } message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
......
...@@ -67,13 +67,16 @@ void StartInterceptor::SendDataReadyToDownStream() { ...@@ -67,13 +67,16 @@ void StartInterceptor::SendDataReadyToDownStream() {
outs.second.second = used_size; outs.second.second = used_size;
} }
if (finish_count_ == batch_size_) { if (finish_count_ == batch_size_) {
int64_t start_micro_step = step_ % node_->max_run_times();
for (int64_t i = 0; i < batch_size_; ++i) { for (int64_t i = 0; i < batch_size_; ++i) {
int64_t scope_id = step_ % node_->max_run_times(); int64_t scope_id = step_ % node_->max_run_times();
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY); ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(scope_id); ready_msg.set_scope_idx(scope_id);
ready_msg.set_start_micro_step(start_micro_step);
ready_msg.set_num_micro_step(batch_size_);
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
VLOG(3) << "StartInterceptor " << interceptor_id_ VLOG(3) << "StartInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id << " Send data_is_ready msg to " << down_id
<< " in scope: " << scope_id; << " in scope: " << scope_id;
...@@ -96,6 +99,15 @@ void StartInterceptor::Compute(const InterceptorMessage& msg) { ...@@ -96,6 +99,15 @@ void StartInterceptor::Compute(const InterceptorMessage& msg) {
<< " " << finish_count_; << " " << finish_count_;
finish_count_--; finish_count_--;
if (finish_count_ == 0) { if (finish_count_ == 0) {
auto end = std::chrono::system_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
end - start_time_);
VLOG(3) << "Spent "
<< double(duration.count()) *
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den
<< " seconds.";
start_time_ = std::chrono::system_clock::now();
for (int64_t i = 0; i < batch_size_; ++i) { for (int64_t i = 0; i < batch_size_; ++i) {
for (auto& outs : out_buffs_) { for (auto& outs : out_buffs_) {
auto down_id = outs.first; auto down_id = outs.first;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <chrono>
#include <utility> #include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" #include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
...@@ -33,6 +34,8 @@ class StartInterceptor final : public ComputeInterceptor { ...@@ -33,6 +34,8 @@ class StartInterceptor final : public ComputeInterceptor {
int64_t batch_size_{0}; int64_t batch_size_{0};
int64_t finish_count_{0}; int64_t finish_count_{0};
int64_t step_{0}; int64_t step_{0};
std::chrono::time_point<std::chrono::system_clock> start_time_{
std::chrono::system_clock::now()};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -1020,3 +1020,17 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -1020,3 +1020,17 @@ PADDLE_DEFINE_EXPORTED_bool(
PADDLE_DEFINE_EXPORTED_string(jit_engine_type, PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"Predictor", "Predictor",
"Choose default funciton type in JitLayer."); "Choose default funciton type in JitLayer.");
/**
* Executor debug FLAG
* Name: FLAGS_fleetexecutor_debug_mode
* Since Version: 2.5
* Value Range: bool
* default=False
* Example:
* Note:
* FLAGS_fleetexecutor_debug_mode == 1, enter in debug mode
*/
PADDLE_DEFINE_EXPORTED_bool(fleetexecutor_debug_mode,
false,
"Enter in FleetExecutor debug mode.");
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册