未验证 提交 74ca89ef 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Hold the carrier while running for one micro step. (#37605)

上级 27a5f52b
......@@ -96,8 +96,12 @@ void Carrier::Start() {
"Message bus has not been initialized."));
message_bus_instance.Send(tmp_msg);
}
std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
}
std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool Carrier::IsInit() const { return is_init_; }
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
......
......@@ -14,6 +14,7 @@
#pragma once
#include <condition_variable>
#include <memory>
#include <mutex>
#include <string>
......@@ -57,6 +58,8 @@ class Carrier final {
void SetCreatingFlag(bool flag);
std::condition_variable& GetCondVar();
void Start();
bool IsInit() const;
......@@ -83,6 +86,9 @@ class Carrier final {
bool creating_interceptors_{true};
std::mutex creating_flag_mutex_;
bool is_init_{false};
std::mutex running_mutex_;
std::condition_variable cond_var_;
};
} // namespace distributed
......
......@@ -221,12 +221,11 @@ void ComputeInterceptor::TryStop() {
Send(down_id, stop);
}
stop_ = true;
}
void ComputeInterceptor::HandleStop(const InterceptorMessage& msg) {
ReceivedStop(msg.src_id());
TryStop();
if (out_buffs_.size() == 0) {
// TODO(fleet executor dev) need a better place to notify
StopCarrier();
}
}
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
......@@ -236,6 +235,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
} else if (msg.message_type() == DATE_IS_USELESS) {
DecreaseBuff(msg.src_id());
Run();
} else if (msg.message_type() == STOP) {
ReceivedStop(msg.src_id());
}
TryStop();
......
......@@ -39,7 +39,6 @@ class ComputeInterceptor : public Interceptor {
void Run();
void Compute(const InterceptorMessage& msg);
void HandleStop(const InterceptorMessage& msg) override;
void ReceivedStop(int64_t up_id);
void TryStop();
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -50,10 +51,20 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
InterceptorMessage msg;
msg.set_message_type(STOP);
Send(interceptor_id_, msg);
} else if (msg.message_type() == STOP) {
stop_ = true;
StopCarrier();
}
}
}
void Interceptor::StopCarrier() {
Carrier& carrier_instance = Carrier::Instance();
std::condition_variable& cond_var = carrier_instance.GetCondVar();
// probably double notify, but ok for ut
cond_var.notify_all();
}
std::condition_variable& Interceptor::GetCondVar() {
// get the conditional var
return cond_var_;
......@@ -80,9 +91,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return MessageBus::Instance().Send(msg);
}
// maybe need a better method for interceptor base
void Interceptor::HandleStop(const InterceptorMessage& msg) { stop_ = true; }
void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message
for (;;) {
......@@ -101,11 +109,7 @@ void Interceptor::PoolTheMailbox() {
<< " from interceptor " << interceptor_message.src_id()
<< " with message: " << message_type << ".";
if (message_type == STOP) {
HandleStop(interceptor_message);
} else {
Handle(interceptor_message);
}
Handle(interceptor_message);
if (stop_) {
// break the pooling thread
......
......@@ -52,8 +52,6 @@ class Interceptor {
// register interceptor handle
void RegisterMsgHandle(MsgHandle handle);
virtual void HandleStop(const InterceptorMessage& msg);
void Handle(const InterceptorMessage& msg);
// return the interceptor id
......@@ -89,6 +87,7 @@ class Interceptor {
// for stop
bool stop_{false};
void StopCarrier();
// for runtime
platform::Place place_;
......
......@@ -33,6 +33,10 @@ class StartInterceptor : public Interceptor {
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
++count_;
......
......@@ -32,6 +32,10 @@ class PingPongInterceptor : public Interceptor {
}
void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
++count_;
......
......@@ -34,6 +34,10 @@ class PingPongInterceptor : public Interceptor {
}
void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
++count_;
......
......@@ -32,11 +32,8 @@ class TestFleetExecutor(unittest.TestCase):
exe.run(empty_program, feed={'x': [1]})
def test_executor_on_single_device(self):
places = [fluid.CPUPlace()]
if fluid.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.run_fleet_executor(place)
self.run_fleet_executor(fluid.CUDAPlace(0))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册