diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 939e987b397a37b9de2ecc33056819994e5e2eb2..6b47095dbe1b438e840c3816ce9ee46d73986510 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -96,8 +96,12 @@ void Carrier::Start() { "Message bus has not been initialized.")); message_bus_instance.Send(tmp_msg); } + std::unique_lock lock(running_mutex_); + cond_var_.wait(lock); } +std::condition_variable& Carrier::GetCondVar() { return cond_var_; } + bool Carrier::IsInit() const { return is_init_; } Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 980847c716b79f7e5be91cb1751699e31ba3c26a..6ad5aee128a83fbb0458b00c33bbc31b57bccd76 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -57,6 +58,8 @@ class Carrier final { void SetCreatingFlag(bool flag); + std::condition_variable& GetCondVar(); + void Start(); bool IsInit() const; @@ -83,6 +86,9 @@ class Carrier final { bool creating_interceptors_{true}; std::mutex creating_flag_mutex_; bool is_init_{false}; + + std::mutex running_mutex_; + std::condition_variable cond_var_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 09b86ba18e34be69a79a8cd2a25d543b5643793a..3008c83069942c2b7bf8cf3759de7d1ec5dde2b0 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -221,12 +221,11 @@ void ComputeInterceptor::TryStop() { Send(down_id, stop); } stop_ = true; -} - -void ComputeInterceptor::HandleStop(const InterceptorMessage& msg) { - ReceivedStop(msg.src_id()); - TryStop(); + if (out_buffs_.size() == 0) { + // TODO(fleet executor dev) need a better place to notify + StopCarrier(); + } } void ComputeInterceptor::Compute(const InterceptorMessage& msg) { @@ -236,6 +235,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { } else if (msg.message_type() == DATE_IS_USELESS) { DecreaseBuff(msg.src_id()); Run(); + } else if (msg.message_type() == STOP) { + ReceivedStop(msg.src_id()); } TryStop(); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index fd540e81afacae1cedfc410464ddf774b1bc7f27..97e6da2f00eaead14a075de86ce552b87c30633f 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -39,7 +39,6 @@ class ComputeInterceptor : public Interceptor { void Run(); void Compute(const InterceptorMessage& msg); - void HandleStop(const InterceptorMessage& msg) override; void ReceivedStop(int64_t up_id); void TryStop(); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 916923ce5900df57b36bb343dc86bf90dbdf620f..40429502825c9ca02c3503a51da0fd87b6805af2 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -50,10 +51,20 @@ void Interceptor::Handle(const InterceptorMessage& msg) { InterceptorMessage msg; msg.set_message_type(STOP); Send(interceptor_id_, msg); + } else if (msg.message_type() == STOP) { + stop_ = true; + StopCarrier(); } } } +void Interceptor::StopCarrier() { + Carrier& carrier_instance = Carrier::Instance(); + std::condition_variable& cond_var = carrier_instance.GetCondVar(); + // probably double notify, but ok for ut + cond_var.notify_all(); +} + std::condition_variable& Interceptor::GetCondVar() { // get the conditional var return cond_var_; @@ -80,9 +91,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { return MessageBus::Instance().Send(msg); } -// maybe need a better method for interceptor base -void Interceptor::HandleStop(const InterceptorMessage& msg) { stop_ = true; } - void Interceptor::PoolTheMailbox() { // pool the local mailbox, parse the Message for (;;) { @@ -101,11 +109,7 @@ void Interceptor::PoolTheMailbox() { << " from interceptor " << interceptor_message.src_id() << " with message: " << message_type << "."; - if (message_type == STOP) { - HandleStop(interceptor_message); - } else { - Handle(interceptor_message); - } + Handle(interceptor_message); if (stop_) { // break the pooling thread diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 052c0cc55d550cbae3b4ce7803143ac76cee6c76..ef1ffb1a53b3fc52ee5baff4ce93b4f4644fc623 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -52,8 +52,6 @@ class Interceptor { // register interceptor handle void RegisterMsgHandle(MsgHandle handle); - virtual void HandleStop(const InterceptorMessage& msg); - void Handle(const InterceptorMessage& msg); // return the interceptor id @@ -89,6 +87,7 @@ class Interceptor { // for stop bool stop_{false}; + void StopCarrier(); // for runtime platform::Place place_; diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 5b85abb4258326d9cc2d1d26af26c844e239d1fe..3cfd3073c8cb9c1b8537ee9b3c2dc00acab0b192 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -33,6 +33,10 @@ class StartInterceptor : public Interceptor { } void NOP(const InterceptorMessage& msg) { + if (msg.message_type() == STOP) { + stop_ = true; + return; + } std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() << std::endl; ++count_; diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index 52df12395d55a6173098cdb44b519c69bc682d78..c68688bfea646b11470eef0a5de62ddf3369f6da 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -32,6 +32,10 @@ class PingPongInterceptor : public Interceptor { } void PingPong(const InterceptorMessage& msg) { + if (msg.message_type() == STOP) { + stop_ = true; + return; + } std::cout << GetInterceptorId() << " recv msg, count=" << count_ << std::endl; ++count_; diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index dbbcd647292db3a0a2e81fe1f837bfe1016113de..233c4d92c9f3943b880a4e34db3877ff6b5c9096 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -34,6 +34,10 @@ class PingPongInterceptor : public Interceptor { } void PingPong(const InterceptorMessage& msg) { + if (msg.message_type() == STOP) { + stop_ = true; + return; + } std::cout << GetInterceptorId() << " recv msg, count=" << count_ << std::endl; ++count_; diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor.py b/python/paddle/fluid/tests/unittests/test_fleet_executor.py index 6ba9e2d9e21211662a480660e55cd64726d89371..76d4a546746a434cae6c64d5750c61322ba8f7f2 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor.py @@ -32,11 +32,8 @@ class TestFleetExecutor(unittest.TestCase): exe.run(empty_program, feed={'x': [1]}) def test_executor_on_single_device(self): - places = [fluid.CPUPlace()] if fluid.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for place in places: - self.run_fleet_executor(place) + self.run_fleet_executor(fluid.CUDAPlace(0)) if __name__ == "__main__":