From 74ca89efe3ddeab6fbb927c85e5e7612ce110d00 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Mon, 29 Nov 2021 13:45:59 +0800 Subject: [PATCH] [fleet_executor] Hold the carrier while running for one micro step. (#37605) --- .../distributed/fleet_executor/carrier.cc | 4 ++++ .../distributed/fleet_executor/carrier.h | 6 ++++++ .../fleet_executor/compute_interceptor.cc | 11 +++++----- .../fleet_executor/compute_interceptor.h | 1 - .../distributed/fleet_executor/interceptor.cc | 20 +++++++++++-------- .../distributed/fleet_executor/interceptor.h | 3 +-- .../test/compute_interceptor_test.cc | 4 ++++ .../test/interceptor_ping_pong_test.cc | 4 ++++ .../interceptor_ping_pong_with_brpc_test.cc | 4 ++++ .../tests/unittests/test_fleet_executor.py | 5 +---- 10 files changed, 42 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 939e987b397..6b47095dbe1 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -96,8 +96,12 @@ void Carrier::Start() { "Message bus has not been initialized.")); message_bus_instance.Send(tmp_msg); } + std::unique_lock lock(running_mutex_); + cond_var_.wait(lock); } +std::condition_variable& Carrier::GetCondVar() { return cond_var_; } + bool Carrier::IsInit() const { return is_init_; } Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 980847c716b..6ad5aee128a 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -57,6 +58,8 @@ class Carrier final { void SetCreatingFlag(bool flag); + std::condition_variable& GetCondVar(); + void Start(); bool IsInit() const; @@ -83,6 +86,9 @@ class Carrier final { bool creating_interceptors_{true}; std::mutex creating_flag_mutex_; bool is_init_{false}; + + std::mutex running_mutex_; + std::condition_variable cond_var_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 09b86ba18e3..3008c830699 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -221,12 +221,11 @@ void ComputeInterceptor::TryStop() { Send(down_id, stop); } stop_ = true; -} - -void ComputeInterceptor::HandleStop(const InterceptorMessage& msg) { - ReceivedStop(msg.src_id()); - TryStop(); + if (out_buffs_.size() == 0) { + // TODO(fleet executor dev) need a better place to notify + StopCarrier(); + } } void ComputeInterceptor::Compute(const InterceptorMessage& msg) { @@ -236,6 +235,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { } else if (msg.message_type() == DATE_IS_USELESS) { DecreaseBuff(msg.src_id()); Run(); + } else if (msg.message_type() == STOP) { + ReceivedStop(msg.src_id()); } TryStop(); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index fd540e81afa..97e6da2f00e 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -39,7 +39,6 @@ class ComputeInterceptor : public Interceptor { void Run(); void Compute(const InterceptorMessage& msg); - void HandleStop(const InterceptorMessage& msg) override; void ReceivedStop(int64_t up_id); void TryStop(); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 916923ce590..40429502825 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -50,10 +51,20 @@ void Interceptor::Handle(const InterceptorMessage& msg) { InterceptorMessage msg; msg.set_message_type(STOP); Send(interceptor_id_, msg); + } else if (msg.message_type() == STOP) { + stop_ = true; + StopCarrier(); } } } +void Interceptor::StopCarrier() { + Carrier& carrier_instance = Carrier::Instance(); + std::condition_variable& cond_var = carrier_instance.GetCondVar(); + // probably double notify, but ok for ut + cond_var.notify_all(); +} + std::condition_variable& Interceptor::GetCondVar() { // get the conditional var return cond_var_; @@ -80,9 +91,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { return MessageBus::Instance().Send(msg); } -// maybe need a better method for interceptor base -void Interceptor::HandleStop(const InterceptorMessage& msg) { stop_ = true; } - void Interceptor::PoolTheMailbox() { // pool the local mailbox, parse the Message for (;;) { @@ -101,11 +109,7 @@ void Interceptor::PoolTheMailbox() { << " from interceptor " << interceptor_message.src_id() << " with message: " << message_type << "."; - if (message_type == STOP) { - HandleStop(interceptor_message); - } else { - Handle(interceptor_message); - } + Handle(interceptor_message); if (stop_) { // break the pooling thread diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 052c0cc55d5..ef1ffb1a53b 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -52,8 +52,6 @@ class Interceptor { // register interceptor handle void RegisterMsgHandle(MsgHandle handle); - virtual void HandleStop(const InterceptorMessage& msg); - void Handle(const InterceptorMessage& msg); // return the interceptor id @@ -89,6 +87,7 @@ class Interceptor { // for stop bool stop_{false}; + void StopCarrier(); // for runtime platform::Place place_; diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 5b85abb4258..3cfd3073c8c 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -33,6 +33,10 @@ class StartInterceptor : public Interceptor { } void NOP(const InterceptorMessage& msg) { + if (msg.message_type() == STOP) { + stop_ = true; + return; + } std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() << std::endl; ++count_; diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index 52df12395d5..c68688bfea6 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -32,6 +32,10 @@ class PingPongInterceptor : public Interceptor { } void PingPong(const InterceptorMessage& msg) { + if (msg.message_type() == STOP) { + stop_ = true; + return; + } std::cout << GetInterceptorId() << " recv msg, count=" << count_ << std::endl; ++count_; diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index dbbcd647292..233c4d92c9f 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -34,6 +34,10 @@ class PingPongInterceptor : public Interceptor { } void PingPong(const InterceptorMessage& msg) { + if (msg.message_type() == STOP) { + stop_ = true; + return; + } std::cout << GetInterceptorId() << " recv msg, count=" << count_ << std::endl; ++count_; diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor.py b/python/paddle/fluid/tests/unittests/test_fleet_executor.py index 6ba9e2d9e21..76d4a546746 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor.py @@ -32,11 +32,8 @@ class TestFleetExecutor(unittest.TestCase): exe.run(empty_program, feed={'x': [1]}) def test_executor_on_single_device(self): - places = [fluid.CPUPlace()] if fluid.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for place in places: - self.run_fleet_executor(place) + self.run_fleet_executor(fluid.CUDAPlace(0)) if __name__ == "__main__": -- GitLab