diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index d1c57ccfddd41c25d1947e80102feb312f174699..45296853adf7b4577eceaa7014a85fe745598bb7 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -30,11 +30,9 @@ USE_INTERCEPTOR(Amplifier); void Carrier::Init( int64_t rank, - const std::unordered_map& interceptor_id_to_rank, - const std::unordered_set& interceptor_ids) { + const std::unordered_map& interceptor_id_to_rank) { rank_ = rank; interceptor_id_to_rank_ = interceptor_id_to_rank; - interceptor_ids_ = interceptor_ids; // TODO(fleet_exe dev): thread pool thread_num_ = 1; @@ -45,14 +43,12 @@ void Carrier::Init( void Carrier::Init( int64_t rank, const std::unordered_map& interceptor_id_to_rank, - const std::unordered_set& interceptor_ids, const std::unordered_map& interceptor_id_to_node, framework::Scope* root_scope, framework::Scope* minibatch_scope, const std::vector& microbatch_scopes, const platform::Place& place) { rank_ = rank; interceptor_id_to_rank_ = interceptor_id_to_rank; - interceptor_ids_ = interceptor_ids; interceptor_id_to_node_ = interceptor_id_to_node; minibatch_scope_ = minibatch_scope; microbatch_scopes_ = microbatch_scopes; @@ -156,9 +152,7 @@ bool Carrier::Send(const InterceptorMessage& msg) { if (src_rank == dst_rank) { VLOG(3) << "Send a message from interceptor " << src_id << " to interceptor " << dst_id << ", which are in the same ranks."; - int64_t carrier_id = *GlobalMap::Get(dst_id); - return GlobalMap::Get(carrier_id) - ->EnqueueInterceptorMessage(msg); + return EnqueueInterceptorMessage(msg); } else { PADDLE_ENFORCE_NOT_NULL( msg_bus_.get(), @@ -192,9 +186,6 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, loop, platform::errors::Fatal("thread task loop must not null")); interceptor->RegisterTaskLoop(loop); - // TODO(liyurui): Using struct InterceptorID replace int64_t - GlobalMap::Create(interceptor_id, carrier_id_); - auto* ptr = interceptor.get(); interceptor_idx_to_interceptor_.insert( std::make_pair(interceptor_id, std::move(interceptor))); @@ -220,19 +211,15 @@ static std::shared_ptr GetGC( } void Carrier::CreateInterceptors() { - if (interceptor_ids_.empty()) return; + if (interceptor_id_to_node_.empty()) return; auto gc = GetGC(place_); // create each Interceptor // no auto init since there is no config - for (int64_t interceptor_id : interceptor_ids_) { - const auto& task_node_iter = interceptor_id_to_node_.find(interceptor_id); - PADDLE_ENFORCE_NE( - task_node_iter, interceptor_id_to_node_.end(), - platform::errors::NotFound("Can not find task node for interceptor %ld", - interceptor_id)); - TaskNode* task_node = task_node_iter->second; + for (const auto& item : interceptor_id_to_node_) { + int64_t interceptor_id = item.first; + TaskNode* task_node = item.second; PADDLE_ENFORCE_LT( task_node->run_at_offset(), task_node->run_per_steps(), diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 87356f9ea8ddf38ced774d8496a496d184d4db9f..cd70ab46ce58e84eebae34e44fe35c8b0e00bd33 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -43,17 +43,17 @@ class InterceptorMessageServiceImpl; class RuntimeGraph; class MessageBus; +// TODO(liyurui): Add CarrierId instead of std::string + class Carrier final { public: - explicit Carrier(int64_t carrier_id) : carrier_id_(carrier_id) {} + explicit Carrier(const std::string& carrier_id) : carrier_id_(carrier_id) {} ~Carrier(); void Init(int64_t rank, - const std::unordered_map& interceptor_id_to_rank, - const std::unordered_set& interceptor_ids); + const std::unordered_map& interceptor_id_to_rank); void Init( int64_t rank, const std::unordered_map& interceptor_id_to_rank, - const std::unordered_set& interceptor_ids, const std::unordered_map& interceptor_id_to_node, framework::Scope* root_scope, framework::Scope* minibatch_scope, const std::vector& microbatch_scopes, @@ -109,7 +109,7 @@ class Carrier final { paddle::platform::DeviceContext* dev_ctx_{nullptr}; std::shared_ptr msg_bus_; int64_t rank_; - int64_t carrier_id_; + std::string carrier_id_; std::unordered_map interceptor_id_to_node_; std::unordered_map interceptor_id_to_rank_; int thread_num_; diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 29e9e0861831bff75b317c5fd6a0aada72d238d0..f81cdf200c65d7044b1a7d71482aa9194c10e641 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -36,14 +36,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); - for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { - GlobalMap::Get(item.first)->Release(); + for (const auto& carrier_id : carrier_ids_) { + GlobalMap::Get(carrier_id)->Release(); } } void FleetExecutor::Init( - const framework::ProgramDesc& program_desc, framework::Scope* scope, - const platform::Place& place, const std::vector& task_nodes, + const std::string& carrier_id, const framework::ProgramDesc& program_desc, + framework::Scope* scope, const platform::Place& place, + const std::vector& task_nodes, const std::unordered_map& task_id_to_rank) { PADDLE_ENFORCE_GT(task_nodes.size(), 0, platform::errors::InvalidArgument( @@ -58,19 +59,13 @@ void FleetExecutor::Init( auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); runtime_graph_ = std::make_shared(); std::unordered_map interceptor_id_to_task; - std::unordered_map> - carrier_id_to_interceptor_ids; - std::unordered_set interceptor_ids; for (auto task_node : task_nodes) { task_node->SetUnusedVars(unused_vars); int64_t interceptor_id = task_node->task_id(); interceptor_id_to_task.emplace(interceptor_id, task_node); - interceptor_ids.insert(interceptor_id); } - carrier_id_to_interceptor_ids.emplace(0, interceptor_ids); runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); - runtime_graph_->SetCarrierIdToInterceptorIds(carrier_id_to_interceptor_ids); for (auto& unique_op : ops) { unique_op.release(); } @@ -87,27 +82,23 @@ void FleetExecutor::Init( } VLOG(5) << runtime_graph_->DebugString(); msg_bus_ = std::make_shared(); - for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { - GlobalMap::Create(item.first, item.first); - } - InitCarrier(); + Carrier* carrier = + GlobalMap::Create(carrier_id, carrier_id); + carrier_ids_.insert(carrier_id); + GlobalVal::Set(carrier_id); + // TODO(liyurui): Maybe message bus should be created only once + InitCarrier(carrier); InitMessageBus(); // Wait for all message bus connected. msg_bus_->Barrier(); } -void FleetExecutor::InitCarrier() { - for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { - Carrier* carrier = GlobalMap::Get(item.first); - PADDLE_ENFORCE_NOT_NULL(carrier, platform::errors::InvalidArgument( - "Carrier has not been created.")); - carrier->SetMsgBus(msg_bus_); - carrier->Init(exe_desc_.cur_rank(), - runtime_graph_->interceptor_id_to_rank(), item.second, - runtime_graph_->interceptor_id_to_node(), root_scope_, - minibatch_scope_, microbatch_scopes_, place_); - } +void FleetExecutor::InitCarrier(Carrier* carrier) { + carrier->SetMsgBus(msg_bus_); + carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(), + runtime_graph_->interceptor_id_to_node(), root_scope_, + minibatch_scope_, microbatch_scopes_, place_); } void FleetExecutor::InitMessageBus() { @@ -145,10 +136,9 @@ void FleetExecutor::InitMessageBus() { } } -void FleetExecutor::Run() { - for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { - GlobalMap::Get(item.first)->Start(); - } +void FleetExecutor::Run(const std::string& carrier_id) { + GlobalMap::Get(carrier_id)->Start(); + GlobalVal::Set(carrier_id); for (auto* micro_scop : microbatch_scopes_) { // By default, we should delete all kid scopes after run executor because // some operators may create local scope when running, such as while_op. diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index e65155237d0878df10dbc3d88ccbf544a02db63e..33b7d4a40dc3bba6a8cb7a5be5b4ffdea2e959c8 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -37,16 +37,17 @@ class FleetExecutor final { FleetExecutor() = delete; explicit FleetExecutor(const std::string& exe_desc_str); ~FleetExecutor(); - void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope, + void Init(const std::string& carrier_id, + const framework::ProgramDesc& program_desc, framework::Scope* scope, const platform::Place& place, const std::vector& task_nodes, const std::unordered_map& task_id_to_rank); - void Run(); + void Run(const std::string& carrier_id); private: DISABLE_COPY_AND_ASSIGN(FleetExecutor); void InitMessageBus(); - void InitCarrier(); + void InitCarrier(Carrier* carrier); void CopyParameters(int microbatch_id, const framework::ProgramDesc& program); FleetExecutorDesc exe_desc_; std::shared_ptr runtime_graph_; @@ -57,6 +58,7 @@ class FleetExecutor final { // The carriers under FleetExecutor will share message bus, // using shared_ptr to manage lifetime and condition race. std::shared_ptr msg_bus_; + std::unordered_set carrier_ids_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/global_map.h b/paddle/fluid/distributed/fleet_executor/global_map.h index ef50563d3f10d4fbe9fa3fe972f40ec44c73e856..2e2923e447d299e27d3511aa7143a0fb950d18a7 100644 --- a/paddle/fluid/distributed/fleet_executor/global_map.h +++ b/paddle/fluid/distributed/fleet_executor/global_map.h @@ -17,6 +17,24 @@ namespace paddle { namespace distributed { +// TODO(liyurui): Change this file to global.h +template +class GlobalVal final { + public: + static T Get() { return *GetPtr(); } + static T Set(T val) { + auto* ptr = GetPtr(); + *ptr = val; + return val; + } + + private: + static T* GetPtr() { + static T value; + return &value; + } +}; + template class GlobalMap final { public: @@ -26,6 +44,7 @@ class GlobalMap final { item, platform::errors::NotFound("This value is not in global map.")); return item; } + template static ValueT* Create(KeyT id, Args&&... args) { auto* ptr = GetPPtr(id); @@ -37,6 +56,34 @@ class GlobalMap final { return item; } + private: + static std::unique_ptr* GetPPtr(KeyT id) { + static std::unordered_map> id_to_ptr; + return &id_to_ptr[id]; + } +}; + +template +class ThreadSafeGlobalMap final { + public: + static ValueT* Get(KeyT id) { + ValueT* item = GetPPtr(id)->get(); + PADDLE_ENFORCE_NOT_NULL( + item, platform::errors::NotFound( + "This value is not in thread safe global map.")); + return item; + } + template + static ValueT* Create(KeyT id, Args&&... args) { + auto* ptr = GetPPtr(id); + PADDLE_ENFORCE_EQ(ptr->get(), nullptr, + platform::errors::AlreadyExists( + "This value has already in thread safe global map.")); + ValueT* item = new ValueT(std::forward(args)...); + ptr->reset(item); + return item; + } + private: static std::unique_ptr* GetPPtr(KeyT id) { static std::mutex mutex; diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc index e939ebae13464a79eced6537512b054f58547e68..52be135f1ce42d68509fece254882695bb72d89c 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc @@ -29,14 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( VLOG(3) << "Interceptor Message Service receives a message from interceptor " << request->src_id() << " to interceptor " << request->dst_id() << ", with the message: " << request->message_type(); - // TODO(liyurui): Remove this hard code. - int64_t carrier_id; - if (request->ctrl_message()) { - carrier_id = 0; - } else { - carrier_id = *GlobalMap::Get(request->dst_id()); - } - bool flag = GlobalMap::Get(carrier_id) + const auto& carrier_id = GlobalVal::Get(); + bool flag = GlobalMap::Get(carrier_id) ->EnqueueInterceptorMessage(*request); response->set_rst(flag); } diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.h b/paddle/fluid/distributed/fleet_executor/runtime_graph.h index 6342aae483814df9a0d6c1b7bbdd96914fd29b6a..1ca9f0174ed07f3c12a8fb937799cfc4dd444b37 100644 --- a/paddle/fluid/distributed/fleet_executor/runtime_graph.h +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.h @@ -35,10 +35,6 @@ class RuntimeGraph final { const std::unordered_map& interceptor_id_to_rank() const { return interceptor_id_to_rank_; } - const std::unordered_map>& - carrier_id_to_interceptor_ids() const { - return carrier_id_to_interceptor_ids_; - } void SetInterceptorIdToRank( const std::unordered_map& interceptor_id_to_rank) { interceptor_id_to_rank_ = interceptor_id_to_rank; @@ -47,19 +43,12 @@ class RuntimeGraph final { const std::unordered_map& interceptor_id_to_node) { interceptor_id_to_node_ = interceptor_id_to_node; } - void SetCarrierIdToInterceptorIds( - const std::unordered_map>& - carrier_id_to_interceptor_ids) { - carrier_id_to_interceptor_ids_ = carrier_id_to_interceptor_ids; - } std::string DebugString() const; private: DISABLE_COPY_AND_ASSIGN(RuntimeGraph); std::unordered_map interceptor_id_to_node_; std::unordered_map interceptor_id_to_rank_; - std::unordered_map> - carrier_id_to_interceptor_ids_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt index e8b97aa508b8fce08a77e0c09c12ada573f1b028..d4587b90c87f3deadba686230728ed084b2a18ad 100644 --- a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt @@ -13,9 +13,6 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_ set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context) -set_source_files_properties(interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test(interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor ${BRPC_DEPS}) - if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc index f4b5f70948a04f7610052a9538bedf846e0a9468..c48fd0962379579030b6b44624e95c4dc286a9ec 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc @@ -62,8 +62,10 @@ TEST(ComputeInterceptor, Compute) { std::vector scopes = {scope, scope}; platform::Place place = platform::CPUPlace(); - Carrier* carrier = GlobalMap::Create(0, 0); - carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1}); + std::string carrier_id = "0"; + Carrier* carrier = + GlobalMap::Create(carrier_id, carrier_id); + carrier->Init(0, {{0, 0}, {1, 0}}); auto msg_bus = std::make_shared(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 862a037414dc976c87a47afdd1d4de9a5656cc96..f34f862c6285c7a0e44069d48ee96a050c83db37 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -47,8 +47,10 @@ class StartInterceptor : public Interceptor { }; TEST(ComputeInterceptor, Compute) { - Carrier* carrier = GlobalMap::Create(0, 0); - carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}, {0, 1, 2}); + std::string carrier_id = "0"; + Carrier* carrier = + GlobalMap::Create(carrier_id, carrier_id); + carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}); auto msg_bus = std::make_shared(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc deleted file mode 100644 index 17c01edbcb17fe7648c6476279fc7181fe5e9069..0000000000000000000000000000000000000000 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "gtest/gtest.h" - -#include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/global_map.h" -#include "paddle/fluid/distributed/fleet_executor/interceptor.h" -#include "paddle/fluid/distributed/fleet_executor/message_bus.h" - -namespace paddle { -namespace distributed { - -class ParcelInterceptor : public Interceptor { - public: - ParcelInterceptor(int64_t interceptor_id, TaskNode* node) - : Interceptor(interceptor_id, node) { - RegisterMsgHandle( - [this](const InterceptorMessage& msg) { PassParcel(msg); }); - } - - void PassParcel(const InterceptorMessage& msg) { - if (msg.message_type() == STOP) { - stop_ = true; - return; - } - std::cout << GetInterceptorId() << " recv msg, count=" << count_ - << std::endl; - if (count_ == 5 && interceptor_id_ == 0) { - InterceptorMessage stop; - stop.set_message_type(STOP); - Send(0, stop); - Send(1, stop); - Send(2, stop); - Send(3, stop); - StopCarrier(); - return; - } - ++count_; - InterceptorMessage new_msg; - if (msg.dst_id() == 3) { - Send(0, new_msg); - } else { - Send(msg.dst_id() + 1, new_msg); - } - } - - private: - int count_{0}; -}; - -REGISTER_INTERCEPTOR(Parcel, ParcelInterceptor); - -TEST(InterceptorTest, PassTheParcel) { - auto msg_bus = std::make_shared(); - Carrier* carrier_0 = GlobalMap::Create(0, 0); - carrier_0->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0}); - carrier_0->SetMsgBus(msg_bus); - Carrier* carrier_1 = GlobalMap::Create(1, 1); - carrier_1->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {1}); - carrier_1->SetMsgBus(msg_bus); - Carrier* carrier_2 = GlobalMap::Create(2, 2); - carrier_2->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {2}); - carrier_2->SetMsgBus(msg_bus); - Carrier* carrier_3 = GlobalMap::Create(3, 3); - carrier_3->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {3}); - carrier_3->SetMsgBus(msg_bus); - msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); - - Interceptor* a = carrier_0->SetInterceptor( - 0, InterceptorFactory::Create("Parcel", 0, nullptr)); - - carrier_1->SetInterceptor(1, - InterceptorFactory::Create("Parcel", 1, nullptr)); - carrier_2->SetInterceptor(2, - InterceptorFactory::Create("Parcel", 2, nullptr)); - carrier_3->SetInterceptor(3, - InterceptorFactory::Create("Parcel", 3, nullptr)); - - InterceptorMessage msg; - a->Send(1, msg); - - carrier_0->Wait(); -} - -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index c1864c81a46e8e7360863e00b22f9e2e195876e2..8289eab16750045bb677c307a27775b856e33dbd 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -60,8 +60,10 @@ class PingPongInterceptor : public Interceptor { REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor); TEST(InterceptorTest, PingPong) { - Carrier* carrier = GlobalMap::Create(0, 0); - carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1}); + std::string carrier_id = "0"; + Carrier* carrier = + GlobalMap::Create(carrier_id, carrier_id); + carrier->Init(0, {{0, 0}, {1, 0}}); auto msg_bus = std::make_shared(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); carrier->SetMsgBus(msg_bus); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index 482ec73931a2e1f710c620aa0b5078d230d10a14..f7adf59a6e81996a0c9df5897f384015454817a6 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -106,16 +106,18 @@ TEST(InterceptorTest, PingPong) { std::cout << "ip1: " << ip1 << std::endl; std::unordered_map interceptor_id_to_rank = {{0, 0}, {1, 1}}; + std::string carrier_id = "0"; int pid = fork(); if (pid == 0) { - Carrier* carrier = GlobalMap::Create(0, 0); + Carrier* carrier = + GlobalMap::Create(carrier_id, carrier_id); + GlobalVal::Set(carrier_id); auto msg_bus = std::make_shared(); carrier->SetMsgBus(msg_bus); // NOTE: need Init msg_bus after carrier SetMsgBus - carrier->Init(0, interceptor_id_to_rank, {0}); + carrier->Init(0, interceptor_id_to_rank); msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); - carrier->SetMsgBus(msg_bus); Interceptor* a = carrier->SetInterceptor( 0, InterceptorFactory::Create("PingPong", 0, nullptr)); msg_bus->Barrier(); @@ -123,10 +125,12 @@ TEST(InterceptorTest, PingPong) { a->Send(1, msg); carrier->Wait(); } else { - Carrier* carrier = GlobalMap::Create(0, 0); + Carrier* carrier = + GlobalMap::Create(carrier_id, carrier_id); + GlobalVal::Set(carrier_id); auto msg_bus = std::make_shared(); carrier->SetMsgBus(msg_bus); - carrier->Init(1, interceptor_id_to_rank, {1}); + carrier->Init(1, interceptor_id_to_rank); msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); carrier->SetInterceptor(1, InterceptorFactory::Create("PingPong", 1, nullptr)); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc index 531f4002b45de1699526c0a0d2faa3a75688ceee..2cd0813803f0c422604dc19144fa200acc96637a 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc @@ -52,9 +52,10 @@ void LinkNodes(const std::vector& nodes) { } TEST(AmplifierInterceptor, Amplifier) { - Carrier* carrier = GlobalMap::Create(0, 0); - carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}, - {0, 1, 2, 3, 4, 5}); + std::string carrier_id = "0"; + Carrier* carrier = + GlobalMap::Create(carrier_id, carrier_id); + carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}); auto msg_bus = std::make_shared(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); carrier->SetMsgBus(msg_bus); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc index a6f1692f6374f52cbdbaf754d0d6e016f3c164ec..66c283b65fb76b3776b9f7d3094a5d9b03da4ef3 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc @@ -70,8 +70,10 @@ void LinkNodes(const std::vector& nodes, } TEST(AmplifierInterceptor, Amplifier) { - Carrier* carrier = GlobalMap::Create(0, 0); - carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0, 1, 2, 3}); + std::string carrier_id = "0"; + Carrier* carrier = + GlobalMap::Create(carrier_id, carrier_id); + carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}); auto msg_bus = std::make_shared(); msg_bus->Init(0, {{0, ""}}, ""); carrier->SetMsgBus(msg_bus); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 710e86e14b51c24cc67aa985f01c7112229b14d9..995fa188e25128b771c23ead8bc7195127832b5c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1956,7 +1956,11 @@ class Executor(object): return ctx - def _prepare_fleet_executor(self, program=None, scope=None, fleet_opt=None): + def _prepare_fleet_executor(self, + carrier_id="", + program=None, + scope=None, + fleet_opt=None): from ..distributed.fleet.proto import fleet_executor_desc_pb2 assert program, "Program for fleet executor should not be None" assert fleet_opt, "Configurations for fleet executor should not be None" @@ -2014,7 +2018,8 @@ class Executor(object): fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) place = core.Place() place.set_place(self.place) - fleet_exe.init(program.desc, scope, place, tasks, task_id_to_rank) + fleet_exe.init(carrier_id, program.desc, scope, place, tasks, + task_id_to_rank) return fleet_exe def _run_using_fleet_executor(self, @@ -2023,6 +2028,7 @@ class Executor(object): feed_var_name="feed", fetch_var_name="fetch", fetch_list=None): + # TODO(liyurui): Change cache strategy for multi carriers cache_key = _get_strong_program_cache_key(program, feed, fetch_list) cached_ctx = self._get_ctx_cache(cache_key) cached_scope = self._get_scope_cache(cache_key) @@ -2088,7 +2094,10 @@ class Executor(object): fetch_task.set_program(fetch_program) cached_ctx = self._prepare_fleet_executor( - program=cached_program, scope=cached_scope, fleet_opt=fleet_opt) + cache_key, + program=cached_program, + scope=cached_scope, + fleet_opt=fleet_opt) self._add_ctx_cache(cache_key, cached_ctx) if feed: # NOTE: don't have to traverse programs in task nodes, @@ -2107,7 +2116,7 @@ class Executor(object): lr_sheduler._var_name) tensor.set(data, self.place) - cached_ctx.run() + cached_ctx.run(cache_key) if fetch_list: arr = cached_scope.find_var(fetch_var_name).get_fetch_list() tensors = arr._move_to_list()