未验证 提交 843435ff 编写于 作者: L LiYuRio 提交者: GitHub

[fleet_executor] Fix the problem in fleet executor stop (#38114)

上级 e3b033f9
...@@ -49,10 +49,11 @@ void Carrier::Release() { ...@@ -49,10 +49,11 @@ void Carrier::Release() {
// otherwise Derived object will be destructed before thread complete. // otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor // Sending STOP msg to the source interceptor
MessageBus& msg_bus = MessageBus::Instance(); PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Message bus has not been initialized.")); "Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
for (int64_t id : source_interceptor_ids_) { for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< "."; << ".";
...@@ -61,7 +62,7 @@ void Carrier::Release() { ...@@ -61,7 +62,7 @@ void Carrier::Release() {
stop_msg.set_src_id(-1); stop_msg.set_src_id(-1);
stop_msg.set_dst_id(id); stop_msg.set_dst_id(id);
stop_msg.set_message_type(STOP); stop_msg.set_message_type(STOP);
msg_bus.Send(stop_msg); Send(stop_msg);
} }
// TODO(wangxi): Maybe need a better to use thread. // TODO(wangxi): Maybe need a better to use thread.
...@@ -113,11 +114,17 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { ...@@ -113,11 +114,17 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
return iter->second.get(); return iter->second.get();
} }
void Carrier::Wait() {
std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
}
void Carrier::Start() { void Carrier::Start() {
MessageBus& msg_bus = MessageBus::Instance(); PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Message bus has not been initialized.")); "Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
for (int64_t id : source_interceptor_ids_) { for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Start is sending start to source interceptor " << id VLOG(3) << "Carrier Start is sending start to source interceptor " << id
...@@ -127,11 +134,9 @@ void Carrier::Start() { ...@@ -127,11 +134,9 @@ void Carrier::Start() {
start_msg.set_src_id(-1); start_msg.set_src_id(-1);
start_msg.set_dst_id(id); start_msg.set_dst_id(id);
start_msg.set_message_type(DATA_IS_READY); start_msg.set_message_type(DATA_IS_READY);
msg_bus.Send(start_msg); Send(start_msg);
} }
Wait();
std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
dev_ctx_->Wait(); dev_ctx_->Wait();
} }
...@@ -139,6 +144,11 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; } ...@@ -139,6 +144,11 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool Carrier::IsInit() const { return is_init_; } bool Carrier::IsInit() const { return is_init_; }
// TODO(liyurui): Move SendIntra into carrier
bool Carrier::Send(const InterceptorMessage& msg) const {
return msg_bus_->Send(msg);
}
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor> interceptor) { std::unique_ptr<Interceptor> interceptor) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
...@@ -147,6 +157,7 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, ...@@ -147,6 +157,7 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id %lld has already been created! " "The interceptor id %lld has already been created! "
"The interceptor id should be unique.", "The interceptor id should be unique.",
interceptor_id)); interceptor_id));
interceptor->RegisterCarrier(this);
auto* ptr = interceptor.get(); auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert( interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor))); std::make_pair(interceptor_id, std::move(interceptor)));
......
...@@ -40,22 +40,19 @@ namespace distributed { ...@@ -40,22 +40,19 @@ namespace distributed {
class TaskNode; class TaskNode;
class InterceptorMessageServiceImpl; class InterceptorMessageServiceImpl;
class RuntimeGraph; class RuntimeGraph;
class MessageBus;
// A singleton MessageBus
class Carrier final { class Carrier final {
public: public:
static Carrier& Instance() { Carrier() = default;
static Carrier carrier; ~Carrier();
return carrier;
}
void Init(std::shared_ptr<RuntimeGraph> runtime_graph, void Init(std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope, framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes, const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place); const platform::Place& place);
~Carrier();
void Release(); void Release();
void Wait();
// Enqueue a message to corresponding interceptor id // Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
...@@ -68,6 +65,9 @@ class Carrier final { ...@@ -68,6 +65,9 @@ class Carrier final {
std::unique_ptr<Interceptor>); std::unique_ptr<Interceptor>);
void SetCreatingFlag(bool flag); void SetCreatingFlag(bool flag);
void SetMsgBus(const std::shared_ptr<MessageBus>& msg_bus) {
msg_bus_ = msg_bus;
}
std::condition_variable& GetCondVar(); std::condition_variable& GetCondVar();
...@@ -75,15 +75,15 @@ class Carrier final { ...@@ -75,15 +75,15 @@ class Carrier final {
bool IsInit() const; bool IsInit() const;
bool Send(const InterceptorMessage& msg) const;
// NOTE: This mutex will be used in interceptor's RunOps function. // NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run // This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops. // simultaneously, which will lead to a random hang for some sync ops.
std::mutex run; std::mutex run;
DISABLE_COPY_AND_ASSIGN(Carrier);
private: private:
Carrier() = default; DISABLE_COPY_AND_ASSIGN(Carrier);
// create each Interceptor // create each Interceptor
void CreateInterceptors(); void CreateInterceptors();
...@@ -110,6 +110,8 @@ class Carrier final { ...@@ -110,6 +110,8 @@ class Carrier final {
paddle::platform::Place place_; paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_{nullptr}; paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_; std::shared_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<MessageBus> msg_bus_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -170,8 +170,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -170,8 +170,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
} }
void ComputeInterceptor::RunOps() { void ComputeInterceptor::RunOps() {
Carrier& carrier_instance = Carrier::Instance(); std::unique_lock<std::mutex> lock(carrier_->run);
std::unique_lock<std::mutex> lock(carrier_instance.run);
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the " VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ + 1 << " time."; << step_ + 1 << " time.";
for (auto op : node_->ops()) { for (auto op : node_->ops()) {
......
...@@ -34,7 +34,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { ...@@ -34,7 +34,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
"Error occurs while parsing string to proto")); "Error occurs while parsing string to proto"));
} }
FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); } FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
GetCarrier().Release();
}
Carrier& FleetExecutor::GetCarrier() {
static Carrier carrier;
return carrier;
}
void FleetExecutor::Init( void FleetExecutor::Init(
const framework::ProgramDesc& program_desc, framework::Scope* scope, const framework::ProgramDesc& program_desc, framework::Scope* scope,
...@@ -78,15 +86,17 @@ void FleetExecutor::Init( ...@@ -78,15 +86,17 @@ void FleetExecutor::Init(
CopyParameters(i, program_desc); CopyParameters(i, program_desc);
} }
VLOG(5) << runtime_graph_->DebugString(); VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
InitCarrier(); InitCarrier();
InitMessageBus(); InitMessageBus();
} }
void FleetExecutor::InitCarrier() { void FleetExecutor::InitCarrier() {
Carrier& carrier_instance = Carrier::Instance(); Carrier& carrier = GetCarrier();
if (!carrier_instance.IsInit()) { if (!carrier.IsInit()) {
carrier_instance.Init(runtime_graph_, root_scope_, minibatch_scope_, carrier.SetMsgBus(msg_bus_);
microbatch_scopes_, place_); carrier.Init(runtime_graph_, root_scope_, minibatch_scope_,
microbatch_scopes_, place_);
} }
} }
...@@ -120,24 +130,22 @@ void FleetExecutor::InitMessageBus() { ...@@ -120,24 +130,22 @@ void FleetExecutor::InitMessageBus() {
VLOG(3) << "The number of ranks are " VLOG(3) << "The number of ranks are "
<< (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << "."; << (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << ".";
VLOG(5) << ss.str(); VLOG(5) << ss.str();
MessageBus& message_bus_instance = MessageBus::Instance(); if (!msg_bus_->IsInit()) {
if (!message_bus_instance.IsInit()) { msg_bus_->Init(runtime_graph_->intercepter_id_to_rank(), rank_to_addr,
message_bus_instance.Init(runtime_graph_->intercepter_id_to_rank(), addr);
rank_to_addr, addr);
} }
} }
void FleetExecutor::Run() { void FleetExecutor::Run() {
// Run // Run
Carrier& carrier_instance = Carrier::Instance(); Carrier& carrier = GetCarrier();
MessageBus& message_bus_instance = MessageBus::Instance();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
carrier_instance.IsInit(), true, carrier.IsInit(), true,
platform::errors::Unavailable("Carrier has not been init yet.")); platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
message_bus_instance.IsInit(), true, msg_bus_->IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet.")); platform::errors::Unavailable("MessageBus has not been init yet."));
carrier_instance.Start(); carrier.Start();
for (auto* micro_scop : microbatch_scopes_) { for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because // By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op. // some operators may create local scope when running, such as while_op.
......
...@@ -28,9 +28,9 @@ class Scope; ...@@ -28,9 +28,9 @@ class Scope;
namespace distributed { namespace distributed {
class RuntimeGraph; class RuntimeGraph;
class Carrier;
class MessageBus; class MessageBus;
class TaskNode; class TaskNode;
class Carrier;
class FleetExecutor final { class FleetExecutor final {
public: public:
...@@ -42,6 +42,8 @@ class FleetExecutor final { ...@@ -42,6 +42,8 @@ class FleetExecutor final {
const std::vector<TaskNode*>& task_nodes, const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank); const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run(); void Run();
// TODO(liyurui): Change to use registry table for multi-carrier.
static Carrier& GetCarrier();
private: private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor); DISABLE_COPY_AND_ASSIGN(FleetExecutor);
...@@ -54,6 +56,9 @@ class FleetExecutor final { ...@@ -54,6 +56,9 @@ class FleetExecutor final {
framework::Scope* minibatch_scope_; framework::Scope* minibatch_scope_;
platform::Place place_; platform::Place place_;
std::vector<framework::Scope*> microbatch_scopes_; std::vector<framework::Scope*> microbatch_scopes_;
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle { namespace paddle {
...@@ -46,8 +45,9 @@ void Interceptor::Handle(const InterceptorMessage& msg) { ...@@ -46,8 +45,9 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
} }
void Interceptor::StopCarrier() { void Interceptor::StopCarrier() {
Carrier& carrier_instance = Carrier::Instance(); PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
std::condition_variable& cond_var = carrier_instance.GetCondVar(); "Carrier is not registered."));
std::condition_variable& cond_var = carrier_->GetCondVar();
// probably double notify, but ok for ut // probably double notify, but ok for ut
cond_var.notify_all(); cond_var.notify_all();
} }
...@@ -73,9 +73,11 @@ bool Interceptor::EnqueueRemoteInterceptorMessage( ...@@ -73,9 +73,11 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
} }
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
"Carrier is not registered."));
msg.set_src_id(interceptor_id_); msg.set_src_id(interceptor_id_);
msg.set_dst_id(dst_id); msg.set_dst_id(dst_id);
return MessageBus::Instance().Send(msg); return carrier_->Send(msg);
} }
void Interceptor::PoolTheMailbox() { void Interceptor::PoolTheMailbox() {
......
...@@ -36,6 +36,7 @@ class GarbageCollector; ...@@ -36,6 +36,7 @@ class GarbageCollector;
namespace distributed { namespace distributed {
class TaskNode; class TaskNode;
class Carrier;
class Interceptor { class Interceptor {
public: public:
...@@ -77,6 +78,7 @@ class Interceptor { ...@@ -77,6 +78,7 @@ class Interceptor {
void SetGC(const std::shared_ptr<framework::GarbageCollector>& gc) { void SetGC(const std::shared_ptr<framework::GarbageCollector>& gc) {
gc_ = gc; gc_ = gc;
} }
void RegisterCarrier(Carrier* carrier) { carrier_ = carrier; }
TaskNode* GetTaskNode() const { return node_; } TaskNode* GetTaskNode() const { return node_; }
...@@ -100,6 +102,8 @@ class Interceptor { ...@@ -100,6 +102,8 @@ class Interceptor {
std::vector<framework::Scope*> microbatch_scopes_{}; std::vector<framework::Scope*> microbatch_scopes_{};
std::shared_ptr<framework::GarbageCollector> gc_{nullptr}; std::shared_ptr<framework::GarbageCollector> gc_{nullptr};
Carrier* carrier_;
private: private:
// pool the local mailbox, parse the Message // pool the local mailbox, parse the Message
void PoolTheMailbox(); void PoolTheMailbox();
......
...@@ -29,9 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( ...@@ -29,9 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor " VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id() << request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type(); << ", with the message: " << request->message_type();
FleetExecutor::GetCarrier().EnqueueInterceptorMessage(*request);
response->set_rst(true); response->set_rst(true);
// call interceptor manager's method to handle the message
Carrier::Instance().EnqueueInterceptorMessage(*request);
} }
} // namespace distributed } // namespace distributed
......
...@@ -57,10 +57,6 @@ void MessageBus::Init( ...@@ -57,10 +57,6 @@ void MessageBus::Init(
bool MessageBus::IsInit() const { return is_init_; } bool MessageBus::IsInit() const { return is_init_; }
MessageBus::~MessageBus() { MessageBus::~MessageBus() {
// NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first
Carrier& carrier = Carrier::Instance();
carrier.Release();
VLOG(3) << "Message bus releases resource."; VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
...@@ -245,7 +241,8 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { ...@@ -245,7 +241,8 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) { bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) {
// send the message intra rank (dst is the same rank with src) // send the message intra rank (dst is the same rank with src)
return Carrier::Instance().EnqueueInterceptorMessage(interceptor_message); return FleetExecutor::GetCarrier().EnqueueInterceptorMessage(
interceptor_message);
} }
} // namespace distributed } // namespace distributed
......
...@@ -39,10 +39,8 @@ class Carrier; ...@@ -39,10 +39,8 @@ class Carrier;
// A singleton MessageBus // A singleton MessageBus
class MessageBus final { class MessageBus final {
public: public:
static MessageBus& Instance() { MessageBus() = default;
static MessageBus msg_bus; ~MessageBus();
return msg_bus;
}
void Init(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, void Init(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr, const std::unordered_map<int64_t, std::string>& rank_to_addr,
...@@ -53,12 +51,8 @@ class MessageBus final { ...@@ -53,12 +51,8 @@ class MessageBus final {
// called by Interceptor, send InterceptorMessage to dst // called by Interceptor, send InterceptorMessage to dst
bool Send(const InterceptorMessage& interceptor_message); bool Send(const InterceptorMessage& interceptor_message);
~MessageBus();
DISABLE_COPY_AND_ASSIGN(MessageBus);
private: private:
MessageBus() = default; DISABLE_COPY_AND_ASSIGN(MessageBus);
// function keep listen the port and handle the message // function keep listen the port and handle the message
void ListenPort(); void ListenPort();
...@@ -72,12 +66,11 @@ class MessageBus final { ...@@ -72,12 +66,11 @@ class MessageBus final {
bool SendInterRank(const InterceptorMessage& interceptor_message); bool SendInterRank(const InterceptorMessage& interceptor_message);
#endif #endif
bool is_init_{false};
// send the message intra rank (dst is the same rank with src) // send the message intra rank (dst is the same rank with src)
bool SendIntraRank(const InterceptorMessage& interceptor_message); bool SendIntraRank(const InterceptorMessage& interceptor_message);
bool is_init_{false};
std::once_flag once_flag_;
// handed by above layer, save the info mapping interceptor id to rank id // handed by above layer, save the info mapping interceptor id to rank id
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_; std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -61,10 +62,12 @@ TEST(ComputeInterceptor, Compute) { ...@@ -61,10 +62,12 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope}; std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace(); platform::Place place = platform::CPUPlace();
Carrier& carrier = Carrier::Instance(); // TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
MessageBus& msg_bus = MessageBus::Instance(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus);
// FIXME: don't delete, otherwise interceptor will use undefined node // FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = TaskNode* node_a =
...@@ -90,6 +93,9 @@ TEST(ComputeInterceptor, Compute) { ...@@ -90,6 +93,9 @@ TEST(ComputeInterceptor, Compute) {
msg.set_src_id(-1); msg.set_src_id(-1);
msg.set_dst_id(0); msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg); carrier.EnqueueInterceptorMessage(msg);
carrier.Wait();
carrier.Release();
} }
} // namespace distributed } // namespace distributed
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -46,9 +47,12 @@ class StartInterceptor : public Interceptor { ...@@ -46,9 +47,12 @@ class StartInterceptor : public Interceptor {
}; };
TEST(ComputeInterceptor, Compute) { TEST(ComputeInterceptor, Compute) {
Carrier& carrier = Carrier::Instance(); // TODO(liyurui): Remove singleton when move SendIntra into Carrier
MessageBus& msg_bus = MessageBus::Instance(); Carrier& carrier = FleetExecutor::GetCarrier();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus);
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
...@@ -74,6 +78,9 @@ TEST(ComputeInterceptor, Compute) { ...@@ -74,6 +78,9 @@ TEST(ComputeInterceptor, Compute) {
a->Send(1, msg); a->Send(1, msg);
a->Send(1, msg); a->Send(1, msg);
a->Send(1, msg); a->Send(1, msg);
carrier.Wait();
carrier.Release();
} }
} // namespace distributed } // namespace distributed
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -44,6 +45,7 @@ class PingPongInterceptor : public Interceptor { ...@@ -44,6 +45,7 @@ class PingPongInterceptor : public Interceptor {
stop.set_message_type(STOP); stop.set_message_type(STOP);
Send(0, stop); Send(0, stop);
Send(1, stop); Send(1, stop);
StopCarrier();
return; return;
} }
...@@ -58,10 +60,12 @@ class PingPongInterceptor : public Interceptor { ...@@ -58,10 +60,12 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor); REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
TEST(InterceptorTest, PingPong) { TEST(InterceptorTest, PingPong) {
MessageBus& msg_bus = MessageBus::Instance(); // TODO(liyurui): Remove singleton when move SendIntra into Carrier
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); Carrier& carrier = FleetExecutor::GetCarrier();
Carrier& carrier = Carrier::Instance(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus);
Interceptor* a = carrier.SetInterceptor( Interceptor* a = carrier.SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr)); 0, InterceptorFactory::Create("PingPong", 0, nullptr));
...@@ -71,6 +75,8 @@ TEST(InterceptorTest, PingPong) { ...@@ -71,6 +75,8 @@ TEST(InterceptorTest, PingPong) {
InterceptorMessage msg; InterceptorMessage msg;
a->Send(1, msg); a->Send(1, msg);
carrier.Wait();
} }
} // namespace distributed } // namespace distributed
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -36,6 +37,7 @@ class PingPongInterceptor : public Interceptor { ...@@ -36,6 +37,7 @@ class PingPongInterceptor : public Interceptor {
void PingPong(const InterceptorMessage& msg) { void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) { if (msg.message_type() == STOP) {
stop_ = true; stop_ = true;
StopCarrier();
return; return;
} }
std::cout << GetInterceptorId() << " recv msg, count=" << count_ std::cout << GetInterceptorId() << " recv msg, count=" << count_
...@@ -105,10 +107,12 @@ TEST(InterceptorTest, PingPong) { ...@@ -105,10 +107,12 @@ TEST(InterceptorTest, PingPong) {
int pid = fork(); int pid = fork();
if (pid == 0) { if (pid == 0) {
MessageBus& msg_bus = MessageBus::Instance(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus.Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip0); msg_bus->Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip0);
Carrier& carrier = Carrier::Instance(); // TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
carrier.SetMsgBus(msg_bus);
Interceptor* a = carrier.SetInterceptor( Interceptor* a = carrier.SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr)); 0, InterceptorFactory::Create("PingPong", 0, nullptr));
...@@ -116,15 +120,19 @@ TEST(InterceptorTest, PingPong) { ...@@ -116,15 +120,19 @@ TEST(InterceptorTest, PingPong) {
InterceptorMessage msg; InterceptorMessage msg;
a->Send(1, msg); a->Send(1, msg);
carrier.Wait();
} else { } else {
MessageBus& msg_bus = MessageBus::Instance(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus.Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip1); msg_bus->Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip1);
Carrier& carrier = Carrier::Instance(); // TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
carrier.SetMsgBus(msg_bus);
carrier.SetInterceptor(1, carrier.SetInterceptor(1,
InterceptorFactory::Create("PingPong", 1, nullptr)); InterceptorFactory::Create("PingPong", 1, nullptr));
carrier.SetCreatingFlag(false); carrier.SetCreatingFlag(false);
carrier.Wait();
} }
} }
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) { ...@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
} }
TEST(AmplifierInterceptor, Amplifier) { TEST(AmplifierInterceptor, Amplifier) {
Carrier& carrier = Carrier::Instance(); // TODO(liyurui): Remove singleton when move SendIntra into Carrier
MessageBus& msg_bus = MessageBus::Instance(); Carrier& carrier = FleetExecutor::GetCarrier();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}, auto msg_bus = std::make_shared<MessageBus>();
{{0, "127.0.0.0:0"}}, "127.0.0.0:0"); msg_bus->Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}},
{{0, "127.0.0.0:0"}}, "127.0.0.0:0");
carrier.SetMsgBus(msg_bus);
int64_t micro_steps = 3; int64_t micro_steps = 3;
...@@ -88,6 +91,8 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -88,6 +91,8 @@ TEST(AmplifierInterceptor, Amplifier) {
msg.set_src_id(-1); msg.set_src_id(-1);
msg.set_dst_id(0); msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg); carrier.EnqueueInterceptorMessage(msg);
carrier.Wait();
carrier.Release();
} }
} // namespace distributed } // namespace distributed
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -69,9 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes, ...@@ -69,9 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
} }
TEST(AmplifierInterceptor, Amplifier) { TEST(AmplifierInterceptor, Amplifier) {
Carrier& carrier = Carrier::Instance(); // TODO(liyurui): Remove singleton when move SendIntra into Carrier
MessageBus& msg_bus = MessageBus::Instance(); Carrier& carrier = FleetExecutor::GetCarrier();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, ""}}, ""); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, ""}}, "");
carrier.SetMsgBus(msg_bus);
int64_t micro_steps = 6; int64_t micro_steps = 6;
...@@ -103,6 +106,8 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -103,6 +106,8 @@ TEST(AmplifierInterceptor, Amplifier) {
msg.set_src_id(-1); msg.set_src_id(-1);
msg.set_dst_id(0); msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg); carrier.EnqueueInterceptorMessage(msg);
carrier.Wait();
carrier.Release();
} }
} // namespace distributed } // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册