未验证 提交 74ca89ef 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Hold the carrier while running for one micro step. (#37605)

上级 27a5f52b
...@@ -96,8 +96,12 @@ void Carrier::Start() { ...@@ -96,8 +96,12 @@ void Carrier::Start() {
"Message bus has not been initialized.")); "Message bus has not been initialized."));
message_bus_instance.Send(tmp_msg); message_bus_instance.Send(tmp_msg);
} }
std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
} }
std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool Carrier::IsInit() const { return is_init_; } bool Carrier::IsInit() const { return is_init_; }
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <condition_variable>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <string> #include <string>
...@@ -57,6 +58,8 @@ class Carrier final { ...@@ -57,6 +58,8 @@ class Carrier final {
void SetCreatingFlag(bool flag); void SetCreatingFlag(bool flag);
std::condition_variable& GetCondVar();
void Start(); void Start();
bool IsInit() const; bool IsInit() const;
...@@ -83,6 +86,9 @@ class Carrier final { ...@@ -83,6 +86,9 @@ class Carrier final {
bool creating_interceptors_{true}; bool creating_interceptors_{true};
std::mutex creating_flag_mutex_; std::mutex creating_flag_mutex_;
bool is_init_{false}; bool is_init_{false};
std::mutex running_mutex_;
std::condition_variable cond_var_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -221,12 +221,11 @@ void ComputeInterceptor::TryStop() { ...@@ -221,12 +221,11 @@ void ComputeInterceptor::TryStop() {
Send(down_id, stop); Send(down_id, stop);
} }
stop_ = true; stop_ = true;
}
void ComputeInterceptor::HandleStop(const InterceptorMessage& msg) {
ReceivedStop(msg.src_id());
TryStop(); if (out_buffs_.size() == 0) {
// TODO(fleet executor dev) need a better place to notify
StopCarrier();
}
} }
void ComputeInterceptor::Compute(const InterceptorMessage& msg) { void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
...@@ -236,6 +235,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { ...@@ -236,6 +235,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
} else if (msg.message_type() == DATE_IS_USELESS) { } else if (msg.message_type() == DATE_IS_USELESS) {
DecreaseBuff(msg.src_id()); DecreaseBuff(msg.src_id());
Run(); Run();
} else if (msg.message_type() == STOP) {
ReceivedStop(msg.src_id());
} }
TryStop(); TryStop();
......
...@@ -39,7 +39,6 @@ class ComputeInterceptor : public Interceptor { ...@@ -39,7 +39,6 @@ class ComputeInterceptor : public Interceptor {
void Run(); void Run();
void Compute(const InterceptorMessage& msg); void Compute(const InterceptorMessage& msg);
void HandleStop(const InterceptorMessage& msg) override;
void ReceivedStop(int64_t up_id); void ReceivedStop(int64_t up_id);
void TryStop(); void TryStop();
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -50,10 +51,20 @@ void Interceptor::Handle(const InterceptorMessage& msg) { ...@@ -50,10 +51,20 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(STOP); msg.set_message_type(STOP);
Send(interceptor_id_, msg); Send(interceptor_id_, msg);
} else if (msg.message_type() == STOP) {
stop_ = true;
StopCarrier();
} }
} }
} }
void Interceptor::StopCarrier() {
Carrier& carrier_instance = Carrier::Instance();
std::condition_variable& cond_var = carrier_instance.GetCondVar();
// probably double notify, but ok for ut
cond_var.notify_all();
}
std::condition_variable& Interceptor::GetCondVar() { std::condition_variable& Interceptor::GetCondVar() {
// get the conditional var // get the conditional var
return cond_var_; return cond_var_;
...@@ -80,9 +91,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { ...@@ -80,9 +91,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return MessageBus::Instance().Send(msg); return MessageBus::Instance().Send(msg);
} }
// maybe need a better method for interceptor base
void Interceptor::HandleStop(const InterceptorMessage& msg) { stop_ = true; }
void Interceptor::PoolTheMailbox() { void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message // pool the local mailbox, parse the Message
for (;;) { for (;;) {
...@@ -101,11 +109,7 @@ void Interceptor::PoolTheMailbox() { ...@@ -101,11 +109,7 @@ void Interceptor::PoolTheMailbox() {
<< " from interceptor " << interceptor_message.src_id() << " from interceptor " << interceptor_message.src_id()
<< " with message: " << message_type << "."; << " with message: " << message_type << ".";
if (message_type == STOP) { Handle(interceptor_message);
HandleStop(interceptor_message);
} else {
Handle(interceptor_message);
}
if (stop_) { if (stop_) {
// break the pooling thread // break the pooling thread
......
...@@ -52,8 +52,6 @@ class Interceptor { ...@@ -52,8 +52,6 @@ class Interceptor {
// register interceptor handle // register interceptor handle
void RegisterMsgHandle(MsgHandle handle); void RegisterMsgHandle(MsgHandle handle);
virtual void HandleStop(const InterceptorMessage& msg);
void Handle(const InterceptorMessage& msg); void Handle(const InterceptorMessage& msg);
// return the interceptor id // return the interceptor id
...@@ -89,6 +87,7 @@ class Interceptor { ...@@ -89,6 +87,7 @@ class Interceptor {
// for stop // for stop
bool stop_{false}; bool stop_{false};
void StopCarrier();
// for runtime // for runtime
platform::Place place_; platform::Place place_;
......
...@@ -33,6 +33,10 @@ class StartInterceptor : public Interceptor { ...@@ -33,6 +33,10 @@ class StartInterceptor : public Interceptor {
} }
void NOP(const InterceptorMessage& msg) { void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl; << std::endl;
++count_; ++count_;
......
...@@ -32,6 +32,10 @@ class PingPongInterceptor : public Interceptor { ...@@ -32,6 +32,10 @@ class PingPongInterceptor : public Interceptor {
} }
void PingPong(const InterceptorMessage& msg) { void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_ std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl; << std::endl;
++count_; ++count_;
......
...@@ -34,6 +34,10 @@ class PingPongInterceptor : public Interceptor { ...@@ -34,6 +34,10 @@ class PingPongInterceptor : public Interceptor {
} }
void PingPong(const InterceptorMessage& msg) { void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_ std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl; << std::endl;
++count_; ++count_;
......
...@@ -32,11 +32,8 @@ class TestFleetExecutor(unittest.TestCase): ...@@ -32,11 +32,8 @@ class TestFleetExecutor(unittest.TestCase):
exe.run(empty_program, feed={'x': [1]}) exe.run(empty_program, feed={'x': [1]})
def test_executor_on_single_device(self): def test_executor_on_single_device(self):
places = [fluid.CPUPlace()]
if fluid.is_compiled_with_cuda(): if fluid.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0)) self.run_fleet_executor(fluid.CUDAPlace(0))
for place in places:
self.run_fleet_executor(place)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册