diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 084e91c11caa7372fc079999aed279871cd96792..41c77c1ead045fd79eb43d7d8bd7e4472ebf58c1 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -46,6 +46,8 @@ void ComputeInterceptor::PrepareDeps() { "Source ComputeInterceptor must run at least one " "times, but now max_run_times=%ld", node_->max_run_times())); + in_readys_.emplace(-1, + std::make_pair(std::numeric_limits::max(), 0)); } // If there is no downstream or every downstream is in different rank, @@ -55,14 +57,17 @@ void ComputeInterceptor::PrepareDeps() { } void ComputeInterceptor::IncreaseReady(int64_t up_id) { - // source node has no upstream, data_is_ready is send by carrier or others - if (is_source_ && up_id == -1) return; - auto it = in_readys_.find(up_id); PADDLE_ENFORCE_NE(it, in_readys_.end(), platform::errors::NotFound( "Cannot find upstream=%lld in in_readys.", up_id)); + // source node has no upstream, data_is_ready is send by carrier or others + if (is_source_ && up_id == -1) { + it->second.second = GetTaskNode()->max_run_times(); + return; + } + auto max_ready_size = it->second.first; auto ready_size = it->second.second; ready_size += 1; @@ -93,7 +98,11 @@ bool ComputeInterceptor::IsInputReady() { for (auto& ins : in_readys_) { auto ready_size = ins.second.second; // not ready, return false - if (ready_size == 0) return false; + if (ready_size == 0) { + VLOG(3) << "Interceptor " << GetInterceptorId() + << "'s upstreams aren't all ready."; + return false; + } } return true; } @@ -103,14 +112,23 @@ bool ComputeInterceptor::CanWriteOutput() { auto max_buffer_size = outs.second.first; auto used_size = outs.second.second; // full, return false - if (used_size == max_buffer_size) return false; + if (used_size == max_buffer_size) { + VLOG(3) << "Interceptor " << GetInterceptorId() + << "'s out buffer is full."; + return false; + } } return true; } // only source node need reset bool ComputeInterceptor::ShouldReset() { - return is_source_ && (step_ == node_->max_run_times()); + if (is_source_ && step_ == node_->max_run_times()) { + VLOG(3) << "Interceptor " << GetInterceptorId() + << " should reset for step: " << step_ << "."; + return true; + } + return false; } void ComputeInterceptor::SendDataReadyToDownStream() { @@ -130,7 +148,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() { InterceptorMessage ready_msg; ready_msg.set_message_type(DATA_IS_READY); VLOG(3) << "ComputeInterceptor " << interceptor_id_ - << " Send data_is_ready msg to " << down_id; + << " Send data_is_ready msg to " << down_id + << " for step: " << step_; Send(down_id, ready_msg); } } @@ -147,23 +166,43 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ready_size)); ins.second.second = ready_size; + VLOG(3) << "ComputeInterceptor " << interceptor_id_ + << " Reply data_is_useless msg to " << up_id + << " for step: " << step_; + if (up_id == -1) return; + InterceptorMessage reply_msg; reply_msg.set_message_type(DATE_IS_USELESS); - VLOG(3) << "ComputeInterceptor " << interceptor_id_ - << " Reply data_is_useless msg to " << up_id; Send(up_id, reply_msg); } } void ComputeInterceptor::RunOps() { VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the " - << step_ << " time."; + << step_ + 1 << " time."; for (auto op : node_->ops()) { op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_); } } void ComputeInterceptor::Run() { + // If there is no limit, source interceptor can be executed + // an unlimited number of times. + // Now source node can only run max_run_times. + if (ShouldReset()) { + for (auto& out_buff : out_buffs_) { + // buffer is using + if (out_buff.second.second != 0) { + VLOG(3) << "Interceptor " << GetInterceptorId() + << " out buffer for downstream: " << out_buff.first + << "'s counter is: " << out_buff.second.second + << ". Cannot be reset."; + return; + } + } + step_ = 0; // reset + } + while (IsInputReady() && CanWriteOutput() && !ShouldReset()) { VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; @@ -181,18 +220,6 @@ void ComputeInterceptor::Run() { StopCarrier(); } } - - // If there is no limit, source interceptor can be executed - // an unlimited number of times. - // Now source node can only run max_run_times. - if (ShouldReset()) { - for (auto& out_buff : out_buffs_) { - // buffer is using - if (out_buff.second.second != 0) return; - } - step_ = 0; // reset - return; - } } void ComputeInterceptor::ReceivedStop(int64_t up_id) { diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 3479157de5c454325c50bdc0dfea2bedf485d58b..3a823674d842c5a8e76d10d36b0e44dbeef90148 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -109,6 +109,15 @@ void FleetExecutor::Run() { message_bus_instance.IsInit(), true, platform::errors::Unavailable("MessageBus has not been init yet.")); carrier_instance.Start(); + for (auto* micro_scop : microbatch_scopes_) { + // By default, we should delete all kid scopes after run executor because + // some operators may create local scope when running, such as while_op. + // But when while_op also create a local executor to run it's sub block, + // the sub scopes it created should not be dropped immediately, because + // while_grad_op will use some variables created during while_op run, so + // we need to keep the kids and wait for the outer executor to drop them. + micro_scop->DropKids(); + } } void FleetExecutor::CopyParameters(int microbatch_id, diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index 688a6f3a3882183c324d2b05062bf97b0b76602b..f087de69fa96b2861d916e28fa4f4a292f791401 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include "paddle/fluid/distributed/fleet_executor/carrier.h" @@ -56,11 +57,11 @@ void MessageBus::Init( bool MessageBus::IsInit() const { return is_init_; } MessageBus::~MessageBus() { - VLOG(3) << "Message bus releases resource."; // NOTE: fleet_executor inits carrier before message bus, // therefore the message bus's destructor will be called first Carrier& carrier = Carrier::Instance(); carrier.Release(); + VLOG(3) << "Message bus releases resource."; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) server_.Stop(1000); @@ -90,6 +91,8 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) { << retry_time << " times retries."; return true; } + VLOG(3) << "Message bus sends failed, retry after 1 seconds."; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } VLOG(3) << "Message bus sends inter rank fail after 10 times retries."; return false; @@ -121,16 +124,40 @@ void MessageBus::ListenPort() { brpc::ServerOptions options; options.idle_timeout_sec = -1; int retry_times = 0; - int interval = 1000; + int interval = 100; while (server_.Start(ip_for_brpc, &options) != 0) { ++retry_times; LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times << " times. And will retry after " << interval / 1000 << " seconds."; std::this_thread::sleep_for(std::chrono::milliseconds(interval)); - interval += 2000; + interval += 500; } LOG(INFO) << "Message bus's listen port thread starts successful."; + + std::set visit; + InterceptorMessage tmp_msg; + tmp_msg.set_ctrl_message(true); + for (auto pair : interceptor_id_to_rank_) { + if (rank_to_addr_.at(pair.second) == addr_) { + tmp_msg.set_src_id(pair.first); + } + } + for (auto pair : interceptor_id_to_rank_) { + int64_t rank = pair.second; + if (rank_to_addr_.at(rank) == addr_) { + continue; + } + tmp_msg.set_dst_id(pair.first); + if (visit.find(rank) == visit.end()) { + VLOG(3) << "Message bus is testing connection for rank: " << rank << "."; + visit.insert(rank); + while (!Send(tmp_msg)) { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + VLOG(3) << "Message bus has connected to rank: " << rank << "."; + } + } #else LOG(WARNING) << "Fleet executor's ListenPort() is a fake function when Paddle is "