diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 95ec6b329964ed9ad167a148d3270f55483007d0..e9da55c417e9a065aad96affa85b4af701d65f6d 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -19,6 +19,9 @@ cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) + set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") + endif() set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 45296853adf7b4577eceaa7014a85fe745598bb7..79be1824b864db2ce848ddf1eb04771f894ff528 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/global_map.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -71,17 +71,13 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; } bool Carrier::EnqueueInterceptorMessage( const InterceptorMessage& interceptor_message) { - if (interceptor_message.ctrl_message()) { - VLOG(3) << "Receiving control message from rank " - << interceptor_message.src_id() << " to rank " - << interceptor_message.dst_id(); - // for barrier - msg_bus_->IncreaseBarrierCount(); - } else { - int64_t dst_id = interceptor_message.dst_id(); - Interceptor* dst_interceptor = GetInterceptor(dst_id); - dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message); - } + PADDLE_ENFORCE_EQ( + interceptor_message.ctrl_message(), false, + platform::errors::Fatal( + "Control message should be only send inter rank using message bus.")); + int64_t dst_id = interceptor_message.dst_id(); + Interceptor* dst_interceptor = GetInterceptor(dst_id); + dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message); return true; } @@ -106,11 +102,6 @@ void Carrier::WakeUp() { } void Carrier::Start() { - PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true, - platform::errors::PreconditionNotMet( - "Using message bus since it has not been initialized. " - "Please invoke MessageBus::Init() before using it or " - "neccessary components are not ready.")); PADDLE_ENFORCE_EQ(is_init_, true, platform::errors::PreconditionNotMet( "Using carrier before initialized.")); for (int64_t id : source_interceptor_ids_) { @@ -154,19 +145,10 @@ bool Carrier::Send(const InterceptorMessage& msg) { << " to interceptor " << dst_id << ", which are in the same ranks."; return EnqueueInterceptorMessage(msg); } else { - PADDLE_ENFORCE_NOT_NULL( - msg_bus_.get(), - platform::errors::Unavailable("Message bus is released accidently")); - PADDLE_ENFORCE_EQ( - msg_bus_->IsInit(), true, - platform::errors::PreconditionNotMet( - "Using message bus since it has not been initialized. " - "Please invoke MessageBus::Init() before using it or " - "neccessary components are not ready.")); VLOG(3) << "Send a message from interceptor " << src_id << " to interceptor " << dst_id << ", which are in different ranks."; - return msg_bus_->Send(dst_rank, msg); + return GlobalVal::Get()->Send(dst_rank, msg); } } diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index cd70ab46ce58e84eebae34e44fe35c8b0e00bd33..75ac07083a7968f379e7e946f2ac91fac180d65d 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -73,10 +73,6 @@ class Carrier final { Interceptor* SetInterceptor(int64_t interceptor_id, std::unique_ptr); - void SetMsgBus(const std::shared_ptr& msg_bus) { - msg_bus_ = msg_bus; - } - void Start(); bool IsInit() const; @@ -107,7 +103,6 @@ class Carrier final { framework::Scope* minibatch_scope_; paddle::platform::Place place_; paddle::platform::DeviceContext* dev_ctx_{nullptr}; - std::shared_ptr msg_bus_; int64_t rank_; std::string carrier_id_; std::unordered_map interceptor_id_to_node_; diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index f81cdf200c65d7044b1a7d71482aa9194c10e641..e22d0945a23980bcd47679d1e63583d5ac9ed792 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" -#include "paddle/fluid/distributed/fleet_executor/global_map.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -32,6 +32,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { bool parse_flag = exe_desc_.ParseFromString(exe_desc_str); PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet( "Error occurs while parsing string to proto")); + // Message bus will be created and inited only once + GlobalVal::Create(); + InitMessageBus(); } FleetExecutor::~FleetExecutor() { @@ -81,21 +84,16 @@ void FleetExecutor::Init( CopyParameters(i, program_desc); } VLOG(5) << runtime_graph_->DebugString(); - msg_bus_ = std::make_shared(); Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); carrier_ids_.insert(carrier_id); - GlobalVal::Set(carrier_id); - // TODO(liyurui): Maybe message bus should be created only once + // Set current running carrier + GlobalVal::Set(new std::string(carrier_id)); InitCarrier(carrier); - InitMessageBus(); - - // Wait for all message bus connected. - msg_bus_->Barrier(); + GlobalVal::Get()->Barrier(); } void FleetExecutor::InitCarrier(Carrier* carrier) { - carrier->SetMsgBus(msg_bus_); carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(), runtime_graph_->interceptor_id_to_node(), root_scope_, minibatch_scope_, microbatch_scopes_, place_); @@ -131,14 +129,18 @@ void FleetExecutor::InitMessageBus() { VLOG(3) << "The number of ranks are " << (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << "."; VLOG(5) << ss.str(); - if (!msg_bus_->IsInit()) { - msg_bus_->Init(cur_rank, rank_to_addr, addr); - } + GlobalVal::Get()->Init(cur_rank, rank_to_addr, addr); } void FleetExecutor::Run(const std::string& carrier_id) { - GlobalMap::Get(carrier_id)->Start(); - GlobalVal::Set(carrier_id); + Carrier* carrier = GlobalMap::Get(carrier_id); + // Set current running carrier + if (*GlobalVal::Get() != carrier_id) { + GlobalVal::Set(new std::string(carrier_id)); + // TODO(liyurui): Move barrier to service + GlobalVal::Get()->Barrier(); + } + carrier->Start(); for (auto* micro_scop : microbatch_scopes_) { // By default, we should delete all kid scopes after run executor because // some operators may create local scope when running, such as while_op. diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index 33b7d4a40dc3bba6a8cb7a5be5b4ffdea2e959c8..89ab4c62d386f21ab46b45e8782dd1bf7e3ecc3e 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -55,9 +55,6 @@ class FleetExecutor final { framework::Scope* minibatch_scope_; platform::Place place_; std::vector microbatch_scopes_; - // The carriers under FleetExecutor will share message bus, - // using shared_ptr to manage lifetime and condition race. - std::shared_ptr msg_bus_; std::unordered_set carrier_ids_; }; diff --git a/paddle/fluid/distributed/fleet_executor/global_map.h b/paddle/fluid/distributed/fleet_executor/global.h similarity index 76% rename from paddle/fluid/distributed/fleet_executor/global_map.h rename to paddle/fluid/distributed/fleet_executor/global.h index 2e2923e447d299e27d3511aa7143a0fb950d18a7..776f314e6afb2ed604ea54182ddc8a135f38a7f8 100644 --- a/paddle/fluid/distributed/fleet_executor/global_map.h +++ b/paddle/fluid/distributed/fleet_executor/global.h @@ -14,24 +14,41 @@ #pragma once +#include "paddle/fluid/platform/enforce.h" + namespace paddle { namespace distributed { -// TODO(liyurui): Change this file to global.h template class GlobalVal final { public: - static T Get() { return *GetPtr(); } - static T Set(T val) { - auto* ptr = GetPtr(); - *ptr = val; - return val; + static T* Get() { + T* ptr = GetPPtr()->get(); + PADDLE_ENFORCE_NOT_NULL( + ptr, platform::errors::NotFound("This value is not global value.")); + return ptr; + } + template + static T* Create(Args&&... args) { + auto* ptr = GetPPtr(); + PADDLE_ENFORCE_EQ(ptr->get(), nullptr, + platform::errors::AlreadyExists( + "This value is already a global value.")); + T* item = new T(std::forward(args)...); + ptr->reset(item); + return item; + } + + static T* Set(T* new_item) { + auto* ptr = GetPPtr(); + ptr->reset(new_item); + return ptr->get(); } private: - static T* GetPtr() { - static T value; - return &value; + static std::unique_ptr* GetPPtr() { + static std::unique_ptr ptr; + return &ptr; } }; diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc index 52be135f1ce42d68509fece254882695bb72d89c..ce8a73602d0bec6cfc049885c3e5d2d58f358831 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc @@ -15,8 +15,8 @@ !defined(PADDLE_WITH_ASCEND_CL) #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "brpc/server.h" -#include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/global_map.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" namespace paddle { namespace distributed { @@ -29,9 +29,7 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( VLOG(3) << "Interceptor Message Service receives a message from interceptor " << request->src_id() << " to interceptor " << request->dst_id() << ", with the message: " << request->message_type(); - const auto& carrier_id = GlobalVal::Get(); - bool flag = GlobalMap::Get(carrier_id) - ->EnqueueInterceptorMessage(*request); + bool flag = GlobalVal::Get()->DispatchMsgToCarrier(*request); response->set_rst(flag); } diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index dd95a90ad1ba447a974274d606f4167bac9ff607..110c5feafc71a11a85c86454dae59a08ebfaf587 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -17,6 +17,8 @@ #include #include +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" @@ -81,6 +83,10 @@ const std::string& MessageBus::GetAddr(int64_t rank) const { bool MessageBus::Send(int64_t dst_rank, const InterceptorMessage& interceptor_message) { + PADDLE_ENFORCE_EQ( + IsInit(), true, + platform::errors::PreconditionNotMet( + "Using message bus since it has not been initialized.")); #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) int retry_time = 0; // message bus will retry sending for 10 times @@ -155,6 +161,22 @@ void MessageBus::Barrier() { } } +bool MessageBus::DispatchMsgToCarrier( + const InterceptorMessage& interceptor_message) { + if (interceptor_message.ctrl_message()) { + VLOG(3) << "Receiving control message from rank " + << interceptor_message.src_id() << " to rank " + << interceptor_message.dst_id(); + // for barrier + IncreaseBarrierCount(); + return true; + } else { + const std::string& carrier_id = *GlobalVal::Get(); + return GlobalMap::Get(carrier_id) + ->EnqueueInterceptorMessage(interceptor_message); + } +} + void MessageBus::ListenPort() { if (addr_ == "") { LOG(INFO) << "No need listen to port since training on single card."; diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index c8685a73900d547f7985237a88da5d3119debe2f..456cd77e2dde8ce158a1735e93b01df5df442ddc 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -54,6 +54,7 @@ class MessageBus final { void IncreaseBarrierCount(); void Barrier(); + bool DispatchMsgToCarrier(const InterceptorMessage& interceptor_message); private: DISABLE_COPY_AND_ASSIGN(MessageBus); diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc index c48fd0962379579030b6b44624e95c4dc286a9ec..07d2a0f6b727aa56ef804e5ca9dee8e7a86e2cdb 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/global_map.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -67,9 +67,8 @@ TEST(ComputeInterceptor, Compute) { GlobalMap::Create(carrier_id, carrier_id); carrier->Init(0, {{0, 0}, {1, 0}}); - auto msg_bus = std::make_shared(); + MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); - carrier->SetMsgBus(msg_bus); // FIXME: don't delete, otherwise interceptor will use undefined node TaskNode* node_a = diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index f34f862c6285c7a0e44069d48ee96a050c83db37..954b52693f46c0d4b87e030f5629800eefc7c9e1 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/global_map.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -52,9 +52,8 @@ TEST(ComputeInterceptor, Compute) { GlobalMap::Create(carrier_id, carrier_id); carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}); - auto msg_bus = std::make_shared(); + MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); - carrier->SetMsgBus(msg_bus); // NOTE: don't delete, otherwise interceptor will use undefined node TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index 8289eab16750045bb677c307a27775b856e33dbd..19c1d0a0d7a6a20a18e2ead960a69230e065cac5 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/global_map.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -64,9 +64,8 @@ TEST(InterceptorTest, PingPong) { Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); carrier->Init(0, {{0, 0}, {1, 0}}); - auto msg_bus = std::make_shared(); + MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); - carrier->SetMsgBus(msg_bus); Interceptor* a = carrier->SetInterceptor( 0, InterceptorFactory::Create("PingPong", 0, nullptr)); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index f7adf59a6e81996a0c9df5897f384015454817a6..78cff2606f6b8fabc459838ee17b9ec29221ba32 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/global_map.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -112,12 +112,10 @@ TEST(InterceptorTest, PingPong) { if (pid == 0) { Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - GlobalVal::Set(carrier_id); - auto msg_bus = std::make_shared(); - carrier->SetMsgBus(msg_bus); - // NOTE: need Init msg_bus after carrier SetMsgBus - carrier->Init(0, interceptor_id_to_rank); + GlobalVal::Set(new std::string(carrier_id)); + MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); + carrier->Init(0, interceptor_id_to_rank); Interceptor* a = carrier->SetInterceptor( 0, InterceptorFactory::Create("PingPong", 0, nullptr)); msg_bus->Barrier(); @@ -127,11 +125,10 @@ TEST(InterceptorTest, PingPong) { } else { Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - GlobalVal::Set(carrier_id); - auto msg_bus = std::make_shared(); - carrier->SetMsgBus(msg_bus); - carrier->Init(1, interceptor_id_to_rank); + GlobalVal::Set(new std::string(carrier_id)); + MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); + carrier->Init(1, interceptor_id_to_rank); carrier->SetInterceptor(1, InterceptorFactory::Create("PingPong", 1, nullptr)); msg_bus->Barrier(); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc index 2cd0813803f0c422604dc19144fa200acc96637a..3860e9f4e137e3f3222d5ce9995e466e3c22db00 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/global_map.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -56,9 +56,8 @@ TEST(AmplifierInterceptor, Amplifier) { Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}); - auto msg_bus = std::make_shared(); + MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); - carrier->SetMsgBus(msg_bus); int64_t micro_steps = 3; diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc index 66c283b65fb76b3776b9f7d3094a5d9b03da4ef3..b510b68e4e2ed5e770d96cf92575188f879d62b6 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/global_map.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -74,9 +74,8 @@ TEST(AmplifierInterceptor, Amplifier) { Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}); - auto msg_bus = std::make_shared(); + MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, ""}}, ""); - carrier->SetMsgBus(msg_bus); int64_t micro_steps = 6;