diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index ff8ed811ee6f84a84a4d2e11b19a489a3c2062d6..4ab6b92bcb97eee01a20f0f240ac4f726928b06e 100755 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -51,6 +51,7 @@ cc_library( op_registry executor_gc_helper gflags + flags glog ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 9b023e12a8893c3788a45b5fd87cdcbef2cb0b97..257b1c80cd66d2815b8604da0c6fe99cf1d8a589 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -17,6 +17,7 @@ #include #include +#include "gflags/gflags.h" #include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -28,6 +29,8 @@ #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable_helper.h" +DECLARE_bool(fleetexecutor_debug_mode); + namespace paddle { namespace distributed { @@ -48,6 +51,72 @@ void Carrier::Init( thread_num_ = 1; thread_pool_.SetThreadNum(thread_num_); thread_pool_.Start(); + + if (FLAGS_fleetexecutor_debug_mode) { + test_thread_ = std::thread([this]() { loop_to_send_msg(); }); + cache_begin_ == std::chrono::steady_clock::now(); + } +} + +void Carrier::loop_to_send_msg() { + // VLOG(3) << "loop_send_msg loop now"; + while (1) { + while (1) { + int q_size = 0; + std::chrono::time_point c_begin; + { + std::lock_guard lock(running_mutex_); + q_size = messages_for_test_.size(); + c_begin = cache_begin_; + } + + auto now = std::chrono::steady_clock::now(); + auto delta = + std::chrono::duration_cast(now - c_begin) + .count(); + + if (q_size < 2 && delta < 5000) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + continue; + } else { + VLOG(3) << "messages_for_test_ q_size:" << q_size << ", delta:" << delta + << ", will send all msg"; + break; + } + } + + { + std::lock_guard lock(running_mutex_); + while (!messages_for_test_.empty()) { + auto msg = messages_for_test_.back(); + messages_for_test_.pop_back(); + + int64_t src_id = msg.src_id(); + // TODO(liyurui): compatible solution, will be removed completely in the + // future + if (interceptor_id_to_rank_.find(src_id) == + interceptor_id_to_rank_.end() && + src_id == SOURCE_ID) { + src_id = msg.dst_id(); + } + int64_t dst_id = msg.dst_id(); + int64_t dst_rank = GetRank(dst_id); + + VLOG(3) << "Send a cached message from interceptor " << src_id + << " to interceptor " << dst_id + << ", which are in different ranks, scope_idx:" + << msg.scope_idx(); + + if (!GlobalVal::Get()->Send(dst_rank, msg)) { + LOG(FATAL) << "send msg error"; + } + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + } + + cache_begin_ = std::chrono::steady_clock::now(); + } + } + VLOG(3) << "reset cache_begin_"; } void Carrier::Init( @@ -95,6 +164,11 @@ void Carrier::Init( thread_pool_.SetThreadNum(thread_num_); thread_pool_.Start(); + if (FLAGS_fleetexecutor_debug_mode) { + test_thread_ = std::thread([this]() { loop_to_send_msg(); }); + cache_begin_ == std::chrono::steady_clock::now(); + } + CreateInterceptors(); is_init_ = true; } @@ -230,12 +304,39 @@ bool Carrier::Send(const InterceptorMessage& msg) { VLOG(3) << "Send a message from interceptor " << src_id << " to interceptor " << dst_id << ", which are in the same ranks."; return EnqueueInterceptorMessage(msg); - } else { + } + if (!FLAGS_fleetexecutor_debug_mode) { + VLOG(3) << "Send a message from interceptor " << src_id + << " to interceptor " << dst_id + << ", which are in different ranks."; + return GlobalVal::Get()->Send(dst_rank, msg); + } + + if (msg.message_type() != DATA_IS_READY) { VLOG(3) << "Send a message from interceptor " << src_id << " to interceptor " << dst_id << ", which are in different ranks."; return GlobalVal::Get()->Send(dst_rank, msg); } + + { + VLOG(3) << "prepare executor debug"; + + std::unique_lock lock(running_mutex_); + if (messages_for_test_.empty()) { + cache_begin_ = std::chrono::steady_clock::now(); + // std::time_t now_c = + // std::chrono::system_clock::to_time_t(cache_begin_)); + VLOG(3) << "messages_for_test_ empty, reset cache_begin_"; + } + + VLOG(3) << "Cache message from interceptor " << src_id << " to interceptor " + << dst_id + << ", which are in different ranks, scope_idx:" << msg.scope_idx(); + messages_for_test_.emplace_back(msg); + } + + return true; } Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 8e7fad3e892d87d735bc79069692238e6b7015f4..d99dcd19477f474c8d941dda0e3ab9806658df42 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -14,11 +14,14 @@ #pragma once +#include #include #include #include +#include #include #include +#include #include #include @@ -118,6 +121,12 @@ class Carrier final { int thread_num_; TaskLoopThreadPool thread_pool_; std::unordered_set interceptor_ids_; + + std::deque messages_for_test_; + std::thread test_thread_; + std::chrono::time_point cache_begin_; + + void loop_to_send_msg(); }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 08b2cb4b6cb03f2db86525c786dbe79a1a18e281..8938be849eb5e72550e685e78b42e713d0697a39 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -29,74 +29,6 @@ namespace paddle { namespace distributed { -namespace { - -template -void SetVarResult(const std::string& name, - T value, - int64_t scope_id, - framework::Scope* scope, - const platform::Place& place, - const std::vector& dim_vec) { - auto* var = scope->FindVar(name); - auto* tensor = var->GetMutable(); - if (!var) { - VLOG(3) << "Create var and memory for var " << name; - var = scope->Var(name); - phi::DDim dims = phi::make_ddim(dim_vec); - tensor->Resize(dims); - tensor->mutable_data(dims, place); - } - - PADDLE_ENFORCE_EQ( - tensor->dims().size(), - 1, - platform::errors::OutOfRange("Only support transfer size 1 value.")); - PADDLE_ENFORCE_EQ( - tensor->dims().at(0), - 1, - platform::errors::OutOfRange("Only support transfer size 1 value.")); - if (platform::is_gpu_place(tensor->place())) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - phi::DenseTensor cpu_tensor; - auto dim = phi::make_ddim({1}); - cpu_tensor.mutable_data(dim, platform::CPUPlace()); - auto* cpu_tensor_ptr = cpu_tensor.data(); - cpu_tensor_ptr[0] = value; - framework::TensorCopySync(cpu_tensor, tensor->place(), tensor); -#endif - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupport device for cond interceptor.")); - } -} - -template -T GetVarResult(const std::string& name, - int64_t scope_id, - framework::Scope* scope) { - auto* var = scope->FindVar(name); - PADDLE_ENFORCE(var, - platform::errors::NotFound( - "Variable %s not exists in scope %ld", name, scope_id)); - const auto& tensor = var->Get(); - T res; - if (platform::is_gpu_place(tensor.place())) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - phi::DenseTensor cpu_tensor; - framework::TensorCopySync(tensor, platform::CPUPlace(), &cpu_tensor); - res = cpu_tensor.data()[0]; -#endif - } else if (platform::is_cpu_place(tensor.place())) { - res = tensor.data()[0]; - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupport device for cond interceptor.")); - } - return res; -} -} // namespace - ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node) : Interceptor(interceptor_id, node) { PrepareDeps(); @@ -172,25 +104,47 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) { } bool ComputeInterceptor::IsInputReady() { - for (int64_t i = 0; i < node_->max_run_times(); ++i) { + std::map scope_id_to_finish_flag; + if (!gen_step_to_scope_id_to_finish_flag_.empty()) { + scope_id_to_finish_flag = + gen_step_to_scope_id_to_finish_flag_.begin()->second; + VLOG(3) << "Is Input Ready in gen step " + << gen_step_to_scope_id_to_finish_flag_.begin()->first; + } + int64_t num_micro_step = + (num_micro_step_ == -1 ? node_->max_run_times() : num_micro_step_); + int64_t start_micro_step = (start_micro_step_ == -1 ? 0 : start_micro_step_); + for (int64_t i = start_micro_step; i < start_micro_step + num_micro_step; + ++i) { bool flag = true; for (auto& ins : in_readys_) { auto ready_size_map = ins.second.second; flag = flag && (ready_size_map.at(i) != 0); } if (flag) { - for (auto iter : scope_id_to_finish_flag_) { - if (iter.first == i) { - break; - } else if (!iter.second) { - VLOG(3) << "The previous scope is not ready, waiting for the " - "previous scope " - << iter.first; - return false; + if (scope_id_to_finish_flag.empty()) { + cur_scope_id_ = i; + return true; + } else if (scope_id_to_finish_flag.find(i) != + scope_id_to_finish_flag.end()) { + for (auto iter : scope_id_to_finish_flag) { + if (iter.first == i) { + break; + } else if (!iter.second) { + VLOG(3) << "The previous scope is not ready, waiting for the " + "previous scope " + << iter.first << " in gen_step " + << gen_step_to_scope_id_to_finish_flag_.begin()->first; + return false; + } } + cur_scope_id_ = i; + return true; + } else { + VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i + << " is larger than gen_step " + << gen_step_to_scope_id_to_finish_flag_.begin()->first; } - cur_scope_id_ = i; - return true; } else { VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i << "'s upstreams aren't all ready."; @@ -217,6 +171,16 @@ bool ComputeInterceptor::CanWriteOutput() { } void ComputeInterceptor::SendDataReadyToDownStream() { + bool need_send_vars = !(node_->vars_to_dtype().empty()); + InterceptorMessage ready_msg; + ready_msg.set_start_micro_step(start_micro_step_); + ready_msg.set_num_micro_step(num_micro_step_); + if (need_send_vars) { + ready_msg = PrepareVarsMsg(); + } else { + ready_msg.set_message_type(DATA_IS_READY); + ready_msg.set_scope_idx(cur_scope_id_); + } for (auto& outs : out_buffs_) { auto down_id = outs.first; auto max_buff_size = outs.second.first; @@ -235,17 +199,12 @@ void ComputeInterceptor::SendDataReadyToDownStream() { } outs.second.second = used_size; - bool need_send_vars = !(node_->vars_to_dtype().empty()); if (need_send_vars) { - InterceptorMessage ready_msg = PrepareVarsMsg(); VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " Send data_with_vars msg to " << down_id << " in scope: " << cur_scope_id_; Send(down_id, ready_msg); } else { - InterceptorMessage ready_msg; - ready_msg.set_message_type(DATA_IS_READY); - ready_msg.set_scope_idx(cur_scope_id_); VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " Send data_is_ready msg to " << down_id << " in scope: " << cur_scope_id_; @@ -339,13 +298,21 @@ void ComputeInterceptor::Run() { RunOps(); - if (!scope_id_to_finish_flag_.empty()) { + if (!gen_step_to_scope_id_to_finish_flag_.empty()) { + auto iter = gen_step_to_scope_id_to_finish_flag_.begin(); + VLOG(3) << "id=" << GetInterceptorId() + << " ComputeInterceptor running in scope " << cur_scope_id_ + << " with gen_step " << iter->first; + auto& scope_id_to_finish_flag = iter->second; PADDLE_ENFORCE_NE( - scope_id_to_finish_flag_.find(cur_scope_id_), - scope_id_to_finish_flag_.end(), + scope_id_to_finish_flag.find(cur_scope_id_), + scope_id_to_finish_flag.end(), platform::errors::NotFound( "Can not find scope %ld in scope_id_to_finish", cur_scope_id_)); - scope_id_to_finish_flag_.erase(cur_scope_id_); + scope_id_to_finish_flag.erase(cur_scope_id_); + if (scope_id_to_finish_flag.empty()) { + gen_step_to_scope_id_to_finish_flag_.erase(iter); + } } // send to downstream and increase buff used @@ -385,6 +352,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { VLOG(3) << "Compute interceptor " << interceptor_id_ << " receive data_is_ready " << msg.src_id() << " " << msg.scope_idx() << " "; + start_micro_step_ = msg.start_micro_step(); + num_micro_step_ = msg.num_micro_step(); IncreaseReady(msg.src_id(), msg.scope_idx()); Run(); } else if (msg.message_type() == DATA_IS_USELESS) { @@ -402,10 +371,14 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { Run(); } else if (msg.message_type() == START_LOOP) { VLOG(3) << "Compute interceptor " << interceptor_id_ - << " receive start_loop " << msg.src_id() << " " << msg.scope_idx() - << " "; + << " receive start_loop " << msg.src_id() << " in scope " + << msg.scope_idx() << " with gen_step " << msg.gen_step(); + start_micro_step_ = msg.start_micro_step(); + num_micro_step_ = msg.num_micro_step(); IncreaseReady(msg.src_id(), msg.scope_idx()); - scope_id_to_finish_flag_.emplace(msg.scope_idx(), false); + int64_t gen_step = msg.gen_step(); + gen_step_to_scope_id_to_finish_flag_[gen_step].emplace(msg.scope_idx(), + false); Run(); } } diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index 07e0dd5b0255f5b29b9320b589b4968b96be60fa..26205d5ac82644b2bc77f5d44670acd79f5baec7 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -52,7 +52,10 @@ class ComputeInterceptor : public Interceptor { bool IsInputReady(); bool CanWriteOutput(); - std::map scope_id_to_finish_flag_; + std::map> + gen_step_to_scope_id_to_finish_flag_; + int64_t start_micro_step_{-1}; + int64_t num_micro_step_{-1}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc b/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc index d3412a2443fda941c0ede9f18105ace1b62bcc81..846f3d722e1faa24b3b5323ef9397427694e2b0b 100644 --- a/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc @@ -90,13 +90,18 @@ void CondInterceptor::SendDataReady(int64_t down_id) { InterceptorMessage ready_msg; ready_msg.set_message_type(DATA_IS_READY); ready_msg.set_scope_idx(cur_scope_id_); + ready_msg.set_start_micro_step(start_micro_step_); + ready_msg.set_num_micro_step(num_micro_step_); Send(down_id, ready_msg); } -void CondInterceptor::SendStartLoop(int64_t down_id) { +void CondInterceptor::SendStartLoop(int64_t down_id, int64_t gen_step) { InterceptorMessage ready_msg; ready_msg.set_message_type(START_LOOP); ready_msg.set_scope_idx(cur_scope_id_); + ready_msg.set_gen_step(gen_step); + ready_msg.set_start_micro_step(start_micro_step_); + ready_msg.set_num_micro_step(num_micro_step_); Send(down_id, ready_msg); } @@ -107,43 +112,36 @@ void CondInterceptor::ReplyDataIsUseless(int64_t up_id) { Send(up_id, ready_msg); } -void CondInterceptor::Compute() { +void CondInterceptor::Compute(int64_t gen_step) { bool cond = GetCondResult(); VLOG(3) << "Cond interceptor get condition var " << node_->cond_var() << " with value " << cond; if (cond) { - VLOG(3) << "Loop again in scope " << cur_scope_id_; + VLOG(3) << "Loop again in scope " << cur_scope_id_ << " gen_step " + << gen_step; for (auto& down_id : normal_out_id_) { - SendStartLoop(down_id); + SendStartLoop(down_id, gen_step); } - ++num_of_scopes_; } else { - VLOG(3) << "Finish loop in scope " << cur_scope_id_; + PADDLE_ENFORCE_NE(scope_id_to_gen_step_.find(cur_scope_id_), + scope_id_to_gen_step_.end(), + platform::errors::InvalidArgument( + "Can not find scope id %ld in scope_id_to_gen_step", + cur_scope_id_)); + VLOG(3) << "Finish loop in scope " << cur_scope_id_ << " with " + << scope_id_to_gen_step_.at(cur_scope_id_) << " generation steps."; + scope_id_to_gen_step_.erase(cur_scope_id_); SendDataReady(stop_loop_id_); } } void CondInterceptor::Run(const InterceptorMessage& msg) { - if (msg.message_type() == DATA_IS_READY || - msg.message_type() == DATA_WITH_VARS) { - if (msg.src_id() == loop_id_) { - --num_of_scopes_; - VLOG(3) << "Receving loop again message from " << msg.src_id() - << " waiting other " << num_of_scopes_ << " scopes ready"; - ready_scope_id_.emplace_back(msg.scope_idx()); - if (num_of_scopes_ == 0) { - std::sort(ready_scope_id_.begin(), ready_scope_id_.end()); - for (auto scope_id : ready_scope_id_) { - VLOG(3) << "Start a new loop in scope " << scope_id; - cur_scope_id_ = scope_id; - Compute(); - } - ready_scope_id_.clear(); - } - } else { - cur_scope_id_ = msg.scope_idx(); - Compute(); - } + if (msg.message_type() == DATA_IS_READY) { + cur_scope_id_ = msg.scope_idx(); + start_micro_step_ = msg.start_micro_step(); + num_micro_step_ = msg.num_micro_step(); + scope_id_to_gen_step_.emplace(cur_scope_id_, 0); + Compute(/*gen_step=*/0); } else if (msg.message_type() == DATA_IS_USELESS) { if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) { for (auto& up_id : normal_in_id_) { @@ -158,6 +156,53 @@ void CondInterceptor::Run(const InterceptorMessage& msg) { gc_.get()); } } + } else if (msg.message_type() == DATA_WITH_VARS) { + int64_t scope_id = msg.scope_idx(); + PADDLE_ENFORCE_NE( + scope_id_to_gen_step_.find(scope_id), + scope_id_to_gen_step_.end(), + platform::errors::InvalidArgument( + "Can not find scope id %ld in scope_id_to_gen_step", scope_id)); + // Keep the message in order with scope_id + // message with scope 3 never send before scope 1. + int64_t gen_step = scope_id_to_gen_step_.at(scope_id) + 1; + bool wait_prev_scope = false; + // If the previous scope gen_step less than cur scope + // means: the previous scope doesn't finish last step generation, should + // wait. + auto iter = scope_id_to_gen_step_.begin(); + while (iter != scope_id_to_gen_step_.end()) { + if (iter->first == scope_id) { + break; + } + if (iter->second < gen_step) { + wait_prev_scope = true; + break; + } + ++iter; + } + scope_id_to_gen_step_.at(scope_id) = gen_step; + if (!wait_prev_scope) { + // Start send message to all scopes gen_step equal to cur_scope + std::vector ready_scope_ids; + while (iter != scope_id_to_gen_step_.end()) { + if (iter->second == gen_step) { + ready_scope_ids.emplace_back(iter->first); + } else if (iter->second > gen_step) { + PADDLE_THROW( + platform::errors::Fatal("Some error may occur. Scope %ld's " + "gen_step is much larger than previous.", + iter->first)); + } else { + break; + } + ++iter; + } + for (auto& scope_id : ready_scope_ids) { + cur_scope_id_ = scope_id; + Compute(gen_step); + } + } } } diff --git a/paddle/fluid/distributed/fleet_executor/cond_interceptor.h b/paddle/fluid/distributed/fleet_executor/cond_interceptor.h index a69468b28b45dc92c8f0bf0c9f5e00f1f7d77d94..4371457606d13745cef339aefa836c5e90b751e2 100644 --- a/paddle/fluid/distributed/fleet_executor/cond_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/cond_interceptor.h @@ -35,10 +35,10 @@ class CondInterceptor final : public Interceptor { private: void PrepareDeps(); void Run(const InterceptorMessage& msg); - void Compute(); + void Compute(int64_t gen_step); bool GetCondResult(); void SendDataReady(int64_t down_id); - void SendStartLoop(int64_t down_id); + void SendStartLoop(int64_t down_id, int64_t gen_step); void ReplyDataIsUseless(int64_t up_id); int64_t cur_scope_id_; @@ -47,8 +47,9 @@ class CondInterceptor final : public Interceptor { std::set normal_out_id_; int64_t stop_loop_id_; int64_t loop_id_; - int64_t num_of_scopes_{0}; - std::vector ready_scope_id_; + std::map scope_id_to_gen_step_; + int64_t start_micro_step_; + int64_t num_micro_step_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto index 4db5a72d897c32a9146f3413e15b5f565da779cd..86728ba3f93da8ba2b1eafd728802dde2e3638c8 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto @@ -48,6 +48,9 @@ message InterceptorMessage { optional bool ctrl_message = 4 [ default = false ]; optional int64 scope_idx = 5 [ default = 0 ]; repeated VarList vars_list = 6; + optional int64 gen_step = 7 [ default = -1 ]; + optional int64 start_micro_step = 8 [ default = -1 ]; + optional int64 num_micro_step = 9 [ default = -1 ]; } message InterceptorResponse { optional bool rst = 1 [ default = false ]; } diff --git a/paddle/fluid/distributed/fleet_executor/start_interceptor.cc b/paddle/fluid/distributed/fleet_executor/start_interceptor.cc index b9ce4fabed4ad639f9a36ef81b785ff401d1ae6a..830f619ed3c00c32af889e5b75ace987b24566b1 100644 --- a/paddle/fluid/distributed/fleet_executor/start_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/start_interceptor.cc @@ -67,13 +67,16 @@ void StartInterceptor::SendDataReadyToDownStream() { outs.second.second = used_size; } if (finish_count_ == batch_size_) { + int64_t start_micro_step = step_ % node_->max_run_times(); for (int64_t i = 0; i < batch_size_; ++i) { int64_t scope_id = step_ % node_->max_run_times(); + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_IS_READY); + ready_msg.set_scope_idx(scope_id); + ready_msg.set_start_micro_step(start_micro_step); + ready_msg.set_num_micro_step(batch_size_); for (auto& outs : out_buffs_) { auto down_id = outs.first; - InterceptorMessage ready_msg; - ready_msg.set_message_type(DATA_IS_READY); - ready_msg.set_scope_idx(scope_id); VLOG(3) << "StartInterceptor " << interceptor_id_ << " Send data_is_ready msg to " << down_id << " in scope: " << scope_id; @@ -96,6 +99,15 @@ void StartInterceptor::Compute(const InterceptorMessage& msg) { << " " << finish_count_; finish_count_--; if (finish_count_ == 0) { + auto end = std::chrono::system_clock::now(); + auto duration = std::chrono::duration_cast( + end - start_time_); + VLOG(3) << "Spent " + << double(duration.count()) * + std::chrono::microseconds::period::num / + std::chrono::microseconds::period::den + << " seconds."; + start_time_ = std::chrono::system_clock::now(); for (int64_t i = 0; i < batch_size_; ++i) { for (auto& outs : out_buffs_) { auto down_id = outs.first; diff --git a/paddle/fluid/distributed/fleet_executor/start_interceptor.h b/paddle/fluid/distributed/fleet_executor/start_interceptor.h index f082c48922bdfa952e2a55ed2f18fb59658bd574..bb709beb070f4da14d9b6ffe616b31054950fc79 100644 --- a/paddle/fluid/distributed/fleet_executor/start_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/start_interceptor.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" @@ -33,6 +34,8 @@ class StartInterceptor final : public ComputeInterceptor { int64_t batch_size_{0}; int64_t finish_count_{0}; int64_t step_{0}; + std::chrono::time_point start_time_{ + std::chrono::system_clock::now()}; }; } // namespace distributed diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 518aabbb09ead8dcbdf760ae372e71d69cc53b28..faea60903b35fad1c519a116ca76afda0cdc0b81 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -1020,3 +1020,17 @@ PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_string(jit_engine_type, "Predictor", "Choose default funciton type in JitLayer."); + +/** + * Executor debug FLAG + * Name: FLAGS_fleetexecutor_debug_mode + * Since Version: 2.5 + * Value Range: bool + * default=False + * Example: + * Note: + * FLAGS_fleetexecutor_debug_mode == 1, enter in debug mode + */ +PADDLE_DEFINE_EXPORTED_bool(fleetexecutor_debug_mode, + false, + "Enter in FleetExecutor debug mode.");