diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index ea35b36aa4a75e32a3331b4aa2f870c5d925c465..d1c57ccfddd41c25d1947e80102feb312f174699 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/global_map.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -27,16 +28,32 @@ namespace distributed { USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Amplifier); -void Carrier::Init(int64_t rank, std::shared_ptr runtime_graph, - framework::Scope* root_scope, - framework::Scope* minibatch_scope, - const std::vector& microbatch_scopes, - const platform::Place& place) { - PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( - "Carrier is already init.")); +void Carrier::Init( + int64_t rank, + const std::unordered_map& interceptor_id_to_rank, + const std::unordered_set& interceptor_ids) { rank_ = rank; - runtime_graph_ = runtime_graph; - interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank(); + interceptor_id_to_rank_ = interceptor_id_to_rank; + interceptor_ids_ = interceptor_ids; + + // TODO(fleet_exe dev): thread pool + thread_num_ = 1; + thread_pool_.SetThreadNum(thread_num_); + thread_pool_.Start(); +} + +void Carrier::Init( + int64_t rank, + const std::unordered_map& interceptor_id_to_rank, + const std::unordered_set& interceptor_ids, + const std::unordered_map& interceptor_id_to_node, + framework::Scope* root_scope, framework::Scope* minibatch_scope, + const std::vector& microbatch_scopes, + const platform::Place& place) { + rank_ = rank; + interceptor_id_to_rank_ = interceptor_id_to_rank; + interceptor_ids_ = interceptor_ids; + interceptor_id_to_node_ = interceptor_id_to_node; minibatch_scope_ = minibatch_scope; microbatch_scopes_ = microbatch_scopes; place_ = place; @@ -72,8 +89,6 @@ bool Carrier::EnqueueInterceptorMessage( return true; } -void Carrier::Barrier() { msg_bus_->Barrier(); } - Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(), @@ -100,7 +115,8 @@ void Carrier::Start() { "Using message bus since it has not been initialized. " "Please invoke MessageBus::Init() before using it or " "neccessary components are not ready.")); - + PADDLE_ENFORCE_EQ(is_init_, true, platform::errors::PreconditionNotMet( + "Using carrier before initialized.")); for (int64_t id : source_interceptor_ids_) { VLOG(3) << "Carrier Start is sending start to source interceptor " << id << "."; @@ -140,7 +156,9 @@ bool Carrier::Send(const InterceptorMessage& msg) { if (src_rank == dst_rank) { VLOG(3) << "Send a message from interceptor " << src_id << " to interceptor " << dst_id << ", which are in the same ranks."; - return EnqueueInterceptorMessage(msg); + int64_t carrier_id = *GlobalMap::Get(dst_id); + return GlobalMap::Get(carrier_id) + ->EnqueueInterceptorMessage(msg); } else { PADDLE_ENFORCE_NOT_NULL( msg_bus_.get(), @@ -174,6 +192,9 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, loop, platform::errors::Fatal("thread task loop must not null")); interceptor->RegisterTaskLoop(loop); + // TODO(liyurui): Using struct InterceptorID replace int64_t + GlobalMap::Create(interceptor_id, carrier_id_); + auto* ptr = interceptor.get(); interceptor_idx_to_interceptor_.insert( std::make_pair(interceptor_id, std::move(interceptor))); @@ -199,15 +220,19 @@ static std::shared_ptr GetGC( } void Carrier::CreateInterceptors() { - if (runtime_graph_->interceptor_id_to_node().empty()) return; + if (interceptor_ids_.empty()) return; auto gc = GetGC(place_); // create each Interceptor // no auto init since there is no config - for (const auto& item : runtime_graph_->interceptor_id_to_node()) { - int64_t interceptor_id = item.first; - TaskNode* task_node = item.second; + for (int64_t interceptor_id : interceptor_ids_) { + const auto& task_node_iter = interceptor_id_to_node_.find(interceptor_id); + PADDLE_ENFORCE_NE( + task_node_iter, interceptor_id_to_node_.end(), + platform::errors::NotFound("Can not find task node for interceptor %ld", + interceptor_id)); + TaskNode* task_node = task_node_iter->second; PADDLE_ENFORCE_LT( task_node->run_at_offset(), task_node->run_per_steps(), diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 5b7275416f57f8eb7bbdba878e51c9dc5e11bba0..87356f9ea8ddf38ced774d8496a496d184d4db9f 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -45,19 +45,19 @@ class MessageBus; class Carrier final { public: - Carrier() = default; - Carrier(int64_t rank, - const std::unordered_map& interceptor_id_to_rank) - : rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) { - thread_num_ = 1; - thread_pool_.SetThreadNum(thread_num_); - thread_pool_.Start(); - } + explicit Carrier(int64_t carrier_id) : carrier_id_(carrier_id) {} ~Carrier(); - void Init(int64_t rank, std::shared_ptr runtime_graph, - framework::Scope* root_scope, framework::Scope* minibatch_scope, - const std::vector& microbatch_scopes, - const platform::Place& place); + void Init(int64_t rank, + const std::unordered_map& interceptor_id_to_rank, + const std::unordered_set& interceptor_ids); + void Init( + int64_t rank, + const std::unordered_map& interceptor_id_to_rank, + const std::unordered_set& interceptor_ids, + const std::unordered_map& interceptor_id_to_node, + framework::Scope* root_scope, framework::Scope* minibatch_scope, + const std::vector& microbatch_scopes, + const platform::Place& place); void Release(); void Wait(); @@ -83,10 +83,9 @@ class Carrier final { bool Send(const InterceptorMessage& msg); - void Barrier(); - private: DISABLE_COPY_AND_ASSIGN(Carrier); + Carrier() = delete; // create each Interceptor void CreateInterceptors(); @@ -108,13 +107,14 @@ class Carrier final { framework::Scope* minibatch_scope_; paddle::platform::Place place_; paddle::platform::DeviceContext* dev_ctx_{nullptr}; - std::shared_ptr runtime_graph_; std::shared_ptr msg_bus_; int64_t rank_; + int64_t carrier_id_; + std::unordered_map interceptor_id_to_node_; std::unordered_map interceptor_id_to_rank_; - int thread_num_; TaskLoopThreadPool thread_pool_; + std::unordered_set interceptor_ids_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index a5badcb36eb3ebf5b7754b3ac83e63efda4ec3fb..29e9e0861831bff75b317c5fd6a0aada72d238d0 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/global_map.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -27,8 +28,6 @@ namespace paddle { namespace distributed { -std::unique_ptr FleetExecutor::carrier_; - FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { bool parse_flag = exe_desc_.ParseFromString(exe_desc_str); PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet( @@ -37,13 +36,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); - GetCarrier()->Release(); -} - -Carrier* FleetExecutor::GetCarrier() { - PADDLE_ENFORCE_NOT_NULL(carrier_.get(), platform::errors::NotFound( - "Carrier has not been created.")); - return carrier_.get(); + for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { + GlobalMap::Get(item.first)->Release(); + } } void FleetExecutor::Init( @@ -63,13 +58,19 @@ void FleetExecutor::Init( auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); runtime_graph_ = std::make_shared(); std::unordered_map interceptor_id_to_task; + std::unordered_map> + carrier_id_to_interceptor_ids; + std::unordered_set interceptor_ids; for (auto task_node : task_nodes) { task_node->SetUnusedVars(unused_vars); int64_t interceptor_id = task_node->task_id(); interceptor_id_to_task.emplace(interceptor_id, task_node); + interceptor_ids.insert(interceptor_id); } + carrier_id_to_interceptor_ids.emplace(0, interceptor_ids); runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); + runtime_graph_->SetCarrierIdToInterceptorIds(carrier_id_to_interceptor_ids); for (auto& unique_op : ops) { unique_op.release(); } @@ -86,21 +87,26 @@ void FleetExecutor::Init( } VLOG(5) << runtime_graph_->DebugString(); msg_bus_ = std::make_shared(); - CreateCarrier(); + for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { + GlobalMap::Create(item.first, item.first); + } InitCarrier(); InitMessageBus(); - // refine this? wait all carrier ready - // NOTE(wangxi): must add after Carrier::SetMsgBus, for we use - // MessageBus::IncreaseBarrierCount when receive barrier msg. - GetCarrier()->Barrier(); + // Wait for all message bus connected. + msg_bus_->Barrier(); } void FleetExecutor::InitCarrier() { - if (!GetCarrier()->IsInit()) { - GetCarrier()->SetMsgBus(msg_bus_); - GetCarrier()->Init(exe_desc_.cur_rank(), runtime_graph_, root_scope_, - minibatch_scope_, microbatch_scopes_, place_); + for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { + Carrier* carrier = GlobalMap::Get(item.first); + PADDLE_ENFORCE_NOT_NULL(carrier, platform::errors::InvalidArgument( + "Carrier has not been created.")); + carrier->SetMsgBus(msg_bus_); + carrier->Init(exe_desc_.cur_rank(), + runtime_graph_->interceptor_id_to_rank(), item.second, + runtime_graph_->interceptor_id_to_node(), root_scope_, + minibatch_scope_, microbatch_scopes_, place_); } } @@ -140,14 +146,9 @@ void FleetExecutor::InitMessageBus() { } void FleetExecutor::Run() { - // Run - PADDLE_ENFORCE_EQ( - GetCarrier()->IsInit(), true, - platform::errors::Unavailable("Carrier has not been init yet.")); - PADDLE_ENFORCE_EQ( - msg_bus_->IsInit(), true, - platform::errors::Unavailable("MessageBus has not been init yet.")); - GetCarrier()->Start(); + for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { + GlobalMap::Get(item.first)->Start(); + } for (auto* micro_scop : microbatch_scopes_) { // By default, we should delete all kid scopes after run executor because // some operators may create local scope when running, such as while_op. diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index 3572e07efc5da6797961c2661fcbb5781bce6d7c..e65155237d0878df10dbc3d88ccbf544a02db63e 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -42,16 +42,6 @@ class FleetExecutor final { const std::vector& task_nodes, const std::unordered_map& task_id_to_rank); void Run(); - // TODO(liyurui): Change to use registry table for multi-carrier. - static Carrier* GetCarrier(); - template - static Carrier* CreateCarrier(Args&&... args) { - PADDLE_ENFORCE_EQ( - carrier_.get(), nullptr, - platform::errors::AlreadyExists("Carrier has been created already.")); - carrier_ = std::make_unique(std::forward(args)...); - return carrier_.get(); - } private: DISABLE_COPY_AND_ASSIGN(FleetExecutor); @@ -67,7 +57,6 @@ class FleetExecutor final { // The carriers under FleetExecutor will share message bus, // using shared_ptr to manage lifetime and condition race. std::shared_ptr msg_bus_; - static std::unique_ptr carrier_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/global_map.h b/paddle/fluid/distributed/fleet_executor/global_map.h new file mode 100644 index 0000000000000000000000000000000000000000..ef50563d3f10d4fbe9fa3fe972f40ec44c73e856 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/global_map.h @@ -0,0 +1,49 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace distributed { + +template +class GlobalMap final { + public: + static ValueT* Get(KeyT id) { + ValueT* item = GetPPtr(id)->get(); + PADDLE_ENFORCE_NOT_NULL( + item, platform::errors::NotFound("This value is not in global map.")); + return item; + } + template + static ValueT* Create(KeyT id, Args&&... args) { + auto* ptr = GetPPtr(id); + PADDLE_ENFORCE_EQ(ptr->get(), nullptr, + platform::errors::AlreadyExists( + "This value has already in global map.")); + ValueT* item = new ValueT(std::forward(args)...); + ptr->reset(item); + return item; + } + + private: + static std::unique_ptr* GetPPtr(KeyT id) { + static std::mutex mutex; + static std::unordered_map> id_to_ptr; + std::unique_lock lock(mutex); + return &id_to_ptr[id]; + } +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc index 231b6c780e24e77683def9955eca49a1b0a07b22..e939ebae13464a79eced6537512b054f58547e68 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "brpc/server.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/global_map.h" namespace paddle { namespace distributed { @@ -29,7 +29,15 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( VLOG(3) << "Interceptor Message Service receives a message from interceptor " << request->src_id() << " to interceptor " << request->dst_id() << ", with the message: " << request->message_type(); - bool flag = FleetExecutor::GetCarrier()->EnqueueInterceptorMessage(*request); + // TODO(liyurui): Remove this hard code. + int64_t carrier_id; + if (request->ctrl_message()) { + carrier_id = 0; + } else { + carrier_id = *GlobalMap::Get(request->dst_id()); + } + bool flag = GlobalMap::Get(carrier_id) + ->EnqueueInterceptorMessage(*request); response->set_rst(flag); } diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.h b/paddle/fluid/distributed/fleet_executor/runtime_graph.h index 1ca9f0174ed07f3c12a8fb937799cfc4dd444b37..6342aae483814df9a0d6c1b7bbdd96914fd29b6a 100644 --- a/paddle/fluid/distributed/fleet_executor/runtime_graph.h +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.h @@ -35,6 +35,10 @@ class RuntimeGraph final { const std::unordered_map& interceptor_id_to_rank() const { return interceptor_id_to_rank_; } + const std::unordered_map>& + carrier_id_to_interceptor_ids() const { + return carrier_id_to_interceptor_ids_; + } void SetInterceptorIdToRank( const std::unordered_map& interceptor_id_to_rank) { interceptor_id_to_rank_ = interceptor_id_to_rank; @@ -43,12 +47,19 @@ class RuntimeGraph final { const std::unordered_map& interceptor_id_to_node) { interceptor_id_to_node_ = interceptor_id_to_node; } + void SetCarrierIdToInterceptorIds( + const std::unordered_map>& + carrier_id_to_interceptor_ids) { + carrier_id_to_interceptor_ids_ = carrier_id_to_interceptor_ids; + } std::string DebugString() const; private: DISABLE_COPY_AND_ASSIGN(RuntimeGraph); std::unordered_map interceptor_id_to_node_; std::unordered_map interceptor_id_to_rank_; + std::unordered_map> + carrier_id_to_interceptor_ids_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt index d4587b90c87f3deadba686230728ed084b2a18ad..e8b97aa508b8fce08a77e0c09c12ada573f1b028 100644 --- a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt @@ -13,6 +13,9 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_ set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context) +set_source_files_properties(interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor ${BRPC_DEPS}) + if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc index b14ca5fc46d52e3b5a98d2645d056f125c795e23..f4b5f70948a04f7610052a9538bedf846e0a9468 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/global_map.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -62,11 +62,12 @@ TEST(ComputeInterceptor, Compute) { std::vector scopes = {scope, scope}; platform::Place place = platform::CPUPlace(); - Carrier carrier(0, {{0, 0}, {1, 0}}); + Carrier* carrier = GlobalMap::Create(0, 0); + carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1}); auto msg_bus = std::make_shared(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); - carrier.SetMsgBus(msg_bus); + carrier->SetMsgBus(msg_bus); // FIXME: don't delete, otherwise interceptor will use undefined node TaskNode* node_a = @@ -77,9 +78,9 @@ TEST(ComputeInterceptor, Compute) { node_a->AddDownstreamTask(1); node_b->AddUpstreamTask(0); - auto* a = carrier.SetInterceptor( + auto* a = carrier->SetInterceptor( 0, InterceptorFactory::Create("Compute", 0, node_a)); - carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); + carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); a->SetPlace(place); a->SetMicroBatchScope(scopes); @@ -89,10 +90,10 @@ TEST(ComputeInterceptor, Compute) { msg.set_message_type(DATA_IS_READY); msg.set_src_id(-1); msg.set_dst_id(0); - carrier.EnqueueInterceptorMessage(msg); + carrier->EnqueueInterceptorMessage(msg); - carrier.Wait(); - carrier.Release(); + carrier->Wait(); + carrier->Release(); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 5b1c0de6f9ce550cf3ec46de201a00226da3660e..862a037414dc976c87a47afdd1d4de9a5656cc96 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/global_map.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -47,11 +47,12 @@ class StartInterceptor : public Interceptor { }; TEST(ComputeInterceptor, Compute) { - Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}}); + Carrier* carrier = GlobalMap::Create(0, 0); + carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}, {0, 1, 2}); auto msg_bus = std::make_shared(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); - carrier.SetMsgBus(msg_bus); + carrier->SetMsgBus(msg_bus); // NOTE: don't delete, otherwise interceptor will use undefined node TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id @@ -65,9 +66,9 @@ TEST(ComputeInterceptor, Compute) { node_c->AddUpstreamTask(1); Interceptor* a = - carrier.SetInterceptor(0, std::make_unique(0, node_a)); - carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); - carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); + carrier->SetInterceptor(0, std::make_unique(0, node_a)); + carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); + carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); InterceptorMessage msg; msg.set_message_type(DATA_IS_READY); @@ -76,8 +77,8 @@ TEST(ComputeInterceptor, Compute) { a->Send(1, msg); a->Send(1, msg); - carrier.Wait(); - carrier.Release(); + carrier->Wait(); + carrier->Release(); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..17c01edbcb17fe7648c6476279fc7181fe5e9069 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/global_map.h" +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" + +namespace paddle { +namespace distributed { + +class ParcelInterceptor : public Interceptor { + public: + ParcelInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + RegisterMsgHandle( + [this](const InterceptorMessage& msg) { PassParcel(msg); }); + } + + void PassParcel(const InterceptorMessage& msg) { + if (msg.message_type() == STOP) { + stop_ = true; + return; + } + std::cout << GetInterceptorId() << " recv msg, count=" << count_ + << std::endl; + if (count_ == 5 && interceptor_id_ == 0) { + InterceptorMessage stop; + stop.set_message_type(STOP); + Send(0, stop); + Send(1, stop); + Send(2, stop); + Send(3, stop); + StopCarrier(); + return; + } + ++count_; + InterceptorMessage new_msg; + if (msg.dst_id() == 3) { + Send(0, new_msg); + } else { + Send(msg.dst_id() + 1, new_msg); + } + } + + private: + int count_{0}; +}; + +REGISTER_INTERCEPTOR(Parcel, ParcelInterceptor); + +TEST(InterceptorTest, PassTheParcel) { + auto msg_bus = std::make_shared(); + Carrier* carrier_0 = GlobalMap::Create(0, 0); + carrier_0->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0}); + carrier_0->SetMsgBus(msg_bus); + Carrier* carrier_1 = GlobalMap::Create(1, 1); + carrier_1->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {1}); + carrier_1->SetMsgBus(msg_bus); + Carrier* carrier_2 = GlobalMap::Create(2, 2); + carrier_2->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {2}); + carrier_2->SetMsgBus(msg_bus); + Carrier* carrier_3 = GlobalMap::Create(3, 3); + carrier_3->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {3}); + carrier_3->SetMsgBus(msg_bus); + msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); + + Interceptor* a = carrier_0->SetInterceptor( + 0, InterceptorFactory::Create("Parcel", 0, nullptr)); + + carrier_1->SetInterceptor(1, + InterceptorFactory::Create("Parcel", 1, nullptr)); + carrier_2->SetInterceptor(2, + InterceptorFactory::Create("Parcel", 2, nullptr)); + carrier_3->SetInterceptor(3, + InterceptorFactory::Create("Parcel", 3, nullptr)); + + InterceptorMessage msg; + a->Send(1, msg); + + carrier_0->Wait(); +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index 37f13dabb0787a9b02457013ecf6e684892d8f40..c1864c81a46e8e7360863e00b22f9e2e195876e2 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/global_map.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -59,20 +60,21 @@ class PingPongInterceptor : public Interceptor { REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor); TEST(InterceptorTest, PingPong) { - Carrier carrier(0, {{0, 0}, {1, 0}}); + Carrier* carrier = GlobalMap::Create(0, 0); + carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1}); auto msg_bus = std::make_shared(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); - carrier.SetMsgBus(msg_bus); + carrier->SetMsgBus(msg_bus); - Interceptor* a = carrier.SetInterceptor( + Interceptor* a = carrier->SetInterceptor( 0, InterceptorFactory::Create("PingPong", 0, nullptr)); - carrier.SetInterceptor(1, std::make_unique(1, nullptr)); + carrier->SetInterceptor(1, std::make_unique(1, nullptr)); InterceptorMessage msg; a->Send(1, msg); - carrier.Wait(); + carrier->Wait(); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index 16e40de77460aab9c4731d7c89ff41dc3d5b342f..482ec73931a2e1f710c620aa0b5078d230d10a14 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/global_map.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -107,42 +107,31 @@ TEST(InterceptorTest, PingPong) { std::unordered_map interceptor_id_to_rank = {{0, 0}, {1, 1}}; - int exe_pid = fork(); - if (exe_pid == 0) { - int pid = fork(); - if (pid == 0) { - Carrier* carrier = - FleetExecutor::CreateCarrier(0, interceptor_id_to_rank); - auto msg_bus = std::make_shared(); - carrier->SetMsgBus(msg_bus); - // NOTE: need Init msg_bus after carrier SetMsgBus - msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); - Interceptor* a = carrier->SetInterceptor( - 0, InterceptorFactory::Create("PingPong", 0, nullptr)); - carrier->Barrier(); - - InterceptorMessage msg; - a->Send(1, msg); - carrier->Wait(); - } else { - Carrier* carrier = - FleetExecutor::CreateCarrier(1, interceptor_id_to_rank); - auto msg_bus = std::make_shared(); - carrier->SetMsgBus(msg_bus); - msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); - carrier->SetInterceptor( - 1, InterceptorFactory::Create("PingPong", 1, nullptr)); - carrier->Barrier(); - - carrier->Wait(); - int status; - int ret = waitpid(pid, &status, 0); - CHECK_EQ(ret, pid); - } + int pid = fork(); + if (pid == 0) { + Carrier* carrier = GlobalMap::Create(0, 0); + auto msg_bus = std::make_shared(); + carrier->SetMsgBus(msg_bus); + // NOTE: need Init msg_bus after carrier SetMsgBus + carrier->Init(0, interceptor_id_to_rank, {0}); + msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); + carrier->SetMsgBus(msg_bus); + Interceptor* a = carrier->SetInterceptor( + 0, InterceptorFactory::Create("PingPong", 0, nullptr)); + msg_bus->Barrier(); + InterceptorMessage msg; + a->Send(1, msg); + carrier->Wait(); } else { - int status; - int ret = waitpid(exe_pid, &status, 0); - CHECK_EQ(ret, exe_pid); + Carrier* carrier = GlobalMap::Create(0, 0); + auto msg_bus = std::make_shared(); + carrier->SetMsgBus(msg_bus); + carrier->Init(1, interceptor_id_to_rank, {1}); + msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); + carrier->SetInterceptor(1, + InterceptorFactory::Create("PingPong", 1, nullptr)); + msg_bus->Barrier(); + carrier->Wait(); } } diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc index 0e902f3d744c48020e5cdacb8543a342b05108b3..531f4002b45de1699526c0a0d2faa3a75688ceee 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/global_map.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -51,10 +52,12 @@ void LinkNodes(const std::vector& nodes) { } TEST(AmplifierInterceptor, Amplifier) { - Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}); + Carrier* carrier = GlobalMap::Create(0, 0); + carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}, + {0, 1, 2, 3, 4, 5}); auto msg_bus = std::make_shared(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); - carrier.SetMsgBus(msg_bus); + carrier->SetMsgBus(msg_bus); int64_t micro_steps = 3; @@ -73,21 +76,23 @@ TEST(AmplifierInterceptor, Amplifier) { node_b->SetReplyUpPerSteps(micro_steps); node_e->SetSendDownPerSteps(micro_steps); - carrier.SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a)); - carrier.SetInterceptor(1, InterceptorFactory::Create("Amplifier", 1, node_b)); - carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); - carrier.SetInterceptor(3, InterceptorFactory::Create("Compute", 3, node_d)); - carrier.SetInterceptor(4, InterceptorFactory::Create("Amplifier", 4, node_e)); - carrier.SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f)); + carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a)); + carrier->SetInterceptor(1, + InterceptorFactory::Create("Amplifier", 1, node_b)); + carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); + carrier->SetInterceptor(3, InterceptorFactory::Create("Compute", 3, node_d)); + carrier->SetInterceptor(4, + InterceptorFactory::Create("Amplifier", 4, node_e)); + carrier->SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f)); // start InterceptorMessage msg; msg.set_message_type(DATA_IS_READY); msg.set_src_id(-1); msg.set_dst_id(0); - carrier.EnqueueInterceptorMessage(msg); - carrier.Wait(); - carrier.Release(); + carrier->EnqueueInterceptorMessage(msg); + carrier->Wait(); + carrier->Release(); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc index d84b909eec95f0a323baeebe8beb99bf144ea318..a6f1692f6374f52cbdbaf754d0d6e016f3c164ec 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/global_map.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -69,10 +70,11 @@ void LinkNodes(const std::vector& nodes, } TEST(AmplifierInterceptor, Amplifier) { - Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}); + Carrier* carrier = GlobalMap::Create(0, 0); + carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0, 1, 2, 3}); auto msg_bus = std::make_shared(); msg_bus->Init(0, {{0, ""}}, ""); - carrier.SetMsgBus(msg_bus); + carrier->SetMsgBus(msg_bus); int64_t micro_steps = 6; @@ -91,19 +93,21 @@ TEST(AmplifierInterceptor, Amplifier) { node_d->SetRunPerSteps(micro_steps); node_d->SetRunAtOffset(micro_steps - 1); - carrier.SetInterceptor(0, InterceptorFactory::Create("Amplifier", 0, node_a)); - carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); - carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); - carrier.SetInterceptor(3, InterceptorFactory::Create("Amplifier", 3, node_d)); + carrier->SetInterceptor(0, + InterceptorFactory::Create("Amplifier", 0, node_a)); + carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); + carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); + carrier->SetInterceptor(3, + InterceptorFactory::Create("Amplifier", 3, node_d)); // start InterceptorMessage msg; msg.set_message_type(DATA_IS_READY); msg.set_src_id(-1); msg.set_dst_id(0); - carrier.EnqueueInterceptorMessage(msg); - carrier.Wait(); - carrier.Release(); + carrier->EnqueueInterceptorMessage(msg); + carrier->Wait(); + carrier->Release(); } } // namespace distributed