未验证 提交 2273471d 编写于 作者: L LiYuRio 提交者: GitHub

[fleet_executor] Support multi carriers (#38650)

上级 2d2609ea
...@@ -30,11 +30,9 @@ USE_INTERCEPTOR(Amplifier); ...@@ -30,11 +30,9 @@ USE_INTERCEPTOR(Amplifier);
void Carrier::Init( void Carrier::Init(
int64_t rank, int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
const std::unordered_set<int64_t>& interceptor_ids) {
rank_ = rank; rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank; interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;
// TODO(fleet_exe dev): thread pool // TODO(fleet_exe dev): thread pool
thread_num_ = 1; thread_num_ = 1;
...@@ -45,14 +43,12 @@ void Carrier::Init( ...@@ -45,14 +43,12 @@ void Carrier::Init(
void Carrier::Init( void Carrier::Init(
int64_t rank, int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node, const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope, framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes, const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) { const platform::Place& place) {
rank_ = rank; rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank; interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;
interceptor_id_to_node_ = interceptor_id_to_node; interceptor_id_to_node_ = interceptor_id_to_node;
minibatch_scope_ = minibatch_scope; minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes; microbatch_scopes_ = microbatch_scopes;
...@@ -156,9 +152,7 @@ bool Carrier::Send(const InterceptorMessage& msg) { ...@@ -156,9 +152,7 @@ bool Carrier::Send(const InterceptorMessage& msg) {
if (src_rank == dst_rank) { if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks."; << " to interceptor " << dst_id << ", which are in the same ranks.";
int64_t carrier_id = *GlobalMap<int64_t, int64_t>::Get(dst_id); return EnqueueInterceptorMessage(msg);
return GlobalMap<int64_t, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(msg);
} else { } else {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(), msg_bus_.get(),
...@@ -192,9 +186,6 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, ...@@ -192,9 +186,6 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
loop, platform::errors::Fatal("thread task loop must not null")); loop, platform::errors::Fatal("thread task loop must not null"));
interceptor->RegisterTaskLoop(loop); interceptor->RegisterTaskLoop(loop);
// TODO(liyurui): Using struct InterceptorID replace int64_t
GlobalMap<int64_t, int64_t>::Create(interceptor_id, carrier_id_);
auto* ptr = interceptor.get(); auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert( interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor))); std::make_pair(interceptor_id, std::move(interceptor)));
...@@ -220,19 +211,15 @@ static std::shared_ptr<framework::GarbageCollector> GetGC( ...@@ -220,19 +211,15 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
} }
void Carrier::CreateInterceptors() { void Carrier::CreateInterceptors() {
if (interceptor_ids_.empty()) return; if (interceptor_id_to_node_.empty()) return;
auto gc = GetGC(place_); auto gc = GetGC(place_);
// create each Interceptor // create each Interceptor
// no auto init since there is no config // no auto init since there is no config
for (int64_t interceptor_id : interceptor_ids_) { for (const auto& item : interceptor_id_to_node_) {
const auto& task_node_iter = interceptor_id_to_node_.find(interceptor_id); int64_t interceptor_id = item.first;
PADDLE_ENFORCE_NE( TaskNode* task_node = item.second;
task_node_iter, interceptor_id_to_node_.end(),
platform::errors::NotFound("Can not find task node for interceptor %ld",
interceptor_id));
TaskNode* task_node = task_node_iter->second;
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(), task_node->run_at_offset(), task_node->run_per_steps(),
......
...@@ -43,17 +43,17 @@ class InterceptorMessageServiceImpl; ...@@ -43,17 +43,17 @@ class InterceptorMessageServiceImpl;
class RuntimeGraph; class RuntimeGraph;
class MessageBus; class MessageBus;
// TODO(liyurui): Add CarrierId instead of std::string
class Carrier final { class Carrier final {
public: public:
explicit Carrier(int64_t carrier_id) : carrier_id_(carrier_id) {} explicit Carrier(const std::string& carrier_id) : carrier_id_(carrier_id) {}
~Carrier(); ~Carrier();
void Init(int64_t rank, void Init(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank);
const std::unordered_set<int64_t>& interceptor_ids);
void Init( void Init(
int64_t rank, int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node, const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope, framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes, const std::vector<framework::Scope*>& microbatch_scopes,
...@@ -109,7 +109,7 @@ class Carrier final { ...@@ -109,7 +109,7 @@ class Carrier final {
paddle::platform::DeviceContext* dev_ctx_{nullptr}; paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<MessageBus> msg_bus_; std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_; int64_t rank_;
int64_t carrier_id_; std::string carrier_id_;
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_; std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
int thread_num_; int thread_num_;
......
...@@ -36,14 +36,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { ...@@ -36,14 +36,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor::~FleetExecutor() { FleetExecutor::~FleetExecutor() {
root_scope_->DropKids(); root_scope_->DropKids();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { for (const auto& carrier_id : carrier_ids_) {
GlobalMap<int64_t, Carrier>::Get(item.first)->Release(); GlobalMap<std::string, Carrier>::Get(carrier_id)->Release();
} }
} }
void FleetExecutor::Init( void FleetExecutor::Init(
const framework::ProgramDesc& program_desc, framework::Scope* scope, const std::string& carrier_id, const framework::ProgramDesc& program_desc,
const platform::Place& place, const std::vector<TaskNode*>& task_nodes, framework::Scope* scope, const platform::Place& place,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank) { const std::unordered_map<int64_t, int64_t>& task_id_to_rank) {
PADDLE_ENFORCE_GT(task_nodes.size(), 0, PADDLE_ENFORCE_GT(task_nodes.size(), 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -58,19 +59,13 @@ void FleetExecutor::Init( ...@@ -58,19 +59,13 @@ void FleetExecutor::Init(
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
runtime_graph_ = std::make_shared<RuntimeGraph>(); runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids;
std::unordered_set<int64_t> interceptor_ids;
for (auto task_node : task_nodes) { for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars); task_node->SetUnusedVars(unused_vars);
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
interceptor_ids.insert(interceptor_id);
} }
carrier_id_to_interceptor_ids.emplace(0, interceptor_ids);
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
runtime_graph_->SetCarrierIdToInterceptorIds(carrier_id_to_interceptor_ids);
for (auto& unique_op : ops) { for (auto& unique_op : ops) {
unique_op.release(); unique_op.release();
} }
...@@ -87,27 +82,23 @@ void FleetExecutor::Init( ...@@ -87,27 +82,23 @@ void FleetExecutor::Init(
} }
VLOG(5) << runtime_graph_->DebugString(); VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>(); msg_bus_ = std::make_shared<MessageBus>();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { Carrier* carrier =
GlobalMap<int64_t, Carrier>::Create(item.first, item.first); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
} carrier_ids_.insert(carrier_id);
InitCarrier(); GlobalVal<std::string>::Set(carrier_id);
// TODO(liyurui): Maybe message bus should be created only once
InitCarrier(carrier);
InitMessageBus(); InitMessageBus();
// Wait for all message bus connected. // Wait for all message bus connected.
msg_bus_->Barrier(); msg_bus_->Barrier();
} }
void FleetExecutor::InitCarrier() { void FleetExecutor::InitCarrier(Carrier* carrier) {
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Get(item.first);
PADDLE_ENFORCE_NOT_NULL(carrier, platform::errors::InvalidArgument(
"Carrier has not been created."));
carrier->SetMsgBus(msg_bus_); carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(), carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(),
runtime_graph_->interceptor_id_to_rank(), item.second,
runtime_graph_->interceptor_id_to_node(), root_scope_, runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_); minibatch_scope_, microbatch_scopes_, place_);
}
} }
void FleetExecutor::InitMessageBus() { void FleetExecutor::InitMessageBus() {
...@@ -145,10 +136,9 @@ void FleetExecutor::InitMessageBus() { ...@@ -145,10 +136,9 @@ void FleetExecutor::InitMessageBus() {
} }
} }
void FleetExecutor::Run() { void FleetExecutor::Run(const std::string& carrier_id) {
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) { GlobalMap<std::string, Carrier>::Get(carrier_id)->Start();
GlobalMap<int64_t, Carrier>::Get(item.first)->Start(); GlobalVal<std::string>::Set(carrier_id);
}
for (auto* micro_scop : microbatch_scopes_) { for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because // By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op. // some operators may create local scope when running, such as while_op.
......
...@@ -37,16 +37,17 @@ class FleetExecutor final { ...@@ -37,16 +37,17 @@ class FleetExecutor final {
FleetExecutor() = delete; FleetExecutor() = delete;
explicit FleetExecutor(const std::string& exe_desc_str); explicit FleetExecutor(const std::string& exe_desc_str);
~FleetExecutor(); ~FleetExecutor();
void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope, void Init(const std::string& carrier_id,
const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place, const platform::Place& place,
const std::vector<TaskNode*>& task_nodes, const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank); const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run(); void Run(const std::string& carrier_id);
private: private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor); DISABLE_COPY_AND_ASSIGN(FleetExecutor);
void InitMessageBus(); void InitMessageBus();
void InitCarrier(); void InitCarrier(Carrier* carrier);
void CopyParameters(int microbatch_id, const framework::ProgramDesc& program); void CopyParameters(int microbatch_id, const framework::ProgramDesc& program);
FleetExecutorDesc exe_desc_; FleetExecutorDesc exe_desc_;
std::shared_ptr<RuntimeGraph> runtime_graph_; std::shared_ptr<RuntimeGraph> runtime_graph_;
...@@ -57,6 +58,7 @@ class FleetExecutor final { ...@@ -57,6 +58,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus, // The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race. // using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_; std::shared_ptr<MessageBus> msg_bus_;
std::unordered_set<std::string> carrier_ids_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -17,6 +17,24 @@ ...@@ -17,6 +17,24 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
// TODO(liyurui): Change this file to global.h
template <typename T>
class GlobalVal final {
public:
static T Get() { return *GetPtr(); }
static T Set(T val) {
auto* ptr = GetPtr();
*ptr = val;
return val;
}
private:
static T* GetPtr() {
static T value;
return &value;
}
};
template <typename KeyT, typename ValueT> template <typename KeyT, typename ValueT>
class GlobalMap final { class GlobalMap final {
public: public:
...@@ -26,6 +44,7 @@ class GlobalMap final { ...@@ -26,6 +44,7 @@ class GlobalMap final {
item, platform::errors::NotFound("This value is not in global map.")); item, platform::errors::NotFound("This value is not in global map."));
return item; return item;
} }
template <typename... Args> template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) { static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id); auto* ptr = GetPPtr(id);
...@@ -37,6 +56,34 @@ class GlobalMap final { ...@@ -37,6 +56,34 @@ class GlobalMap final {
return item; return item;
} }
private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::unordered_map<KeyT, std::unique_ptr<ValueT>> id_to_ptr;
return &id_to_ptr[id];
}
};
template <typename KeyT, typename ValueT>
class ThreadSafeGlobalMap final {
public:
static ValueT* Get(KeyT id) {
ValueT* item = GetPPtr(id)->get();
PADDLE_ENFORCE_NOT_NULL(
item, platform::errors::NotFound(
"This value is not in thread safe global map."));
return item;
}
template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
platform::errors::AlreadyExists(
"This value has already in thread safe global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...);
ptr->reset(item);
return item;
}
private: private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) { static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::mutex mutex; static std::mutex mutex;
......
...@@ -29,14 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( ...@@ -29,14 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor " VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id() << request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type(); << ", with the message: " << request->message_type();
// TODO(liyurui): Remove this hard code. const auto& carrier_id = GlobalVal<std::string>::Get();
int64_t carrier_id; bool flag = GlobalMap<std::string, Carrier>::Get(carrier_id)
if (request->ctrl_message()) {
carrier_id = 0;
} else {
carrier_id = *GlobalMap<int64_t, int64_t>::Get(request->dst_id());
}
bool flag = GlobalMap<int64_t, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(*request); ->EnqueueInterceptorMessage(*request);
response->set_rst(flag); response->set_rst(flag);
} }
......
...@@ -35,10 +35,6 @@ class RuntimeGraph final { ...@@ -35,10 +35,6 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const { const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const {
return interceptor_id_to_rank_; return interceptor_id_to_rank_;
} }
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids() const {
return carrier_id_to_interceptor_ids_;
}
void SetInterceptorIdToRank( void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) { const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
interceptor_id_to_rank_ = interceptor_id_to_rank; interceptor_id_to_rank_ = interceptor_id_to_rank;
...@@ -47,19 +43,12 @@ class RuntimeGraph final { ...@@ -47,19 +43,12 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) { const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
interceptor_id_to_node_ = interceptor_id_to_node; interceptor_id_to_node_ = interceptor_id_to_node;
} }
void SetCarrierIdToInterceptorIds(
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids) {
carrier_id_to_interceptor_ids_ = carrier_id_to_interceptor_ids;
}
std::string DebugString() const; std::string DebugString() const;
private: private:
DISABLE_COPY_AND_ASSIGN(RuntimeGraph); DISABLE_COPY_AND_ASSIGN(RuntimeGraph);
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_; std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -13,9 +13,6 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_ ...@@ -13,9 +13,6 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_
set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context) cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context)
set_source_files_properties(interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor ${BRPC_DEPS})
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS}) cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS})
......
...@@ -62,8 +62,10 @@ TEST(ComputeInterceptor, Compute) { ...@@ -62,8 +62,10 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope}; std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace(); platform::Place place = platform::CPUPlace();
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0); std::string carrier_id = "0";
carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1}); Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}});
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
......
...@@ -47,8 +47,10 @@ class StartInterceptor : public Interceptor { ...@@ -47,8 +47,10 @@ class StartInterceptor : public Interceptor {
}; };
TEST(ComputeInterceptor, Compute) { TEST(ComputeInterceptor, Compute) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0); std::string carrier_id = "0";
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}, {0, 1, 2}); Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}});
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
class ParcelInterceptor : public Interceptor {
public:
ParcelInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle(
[this](const InterceptorMessage& msg) { PassParcel(msg); });
}
void PassParcel(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
if (count_ == 5 && interceptor_id_ == 0) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
Send(2, stop);
Send(3, stop);
StopCarrier();
return;
}
++count_;
InterceptorMessage new_msg;
if (msg.dst_id() == 3) {
Send(0, new_msg);
} else {
Send(msg.dst_id() + 1, new_msg);
}
}
private:
int count_{0};
};
REGISTER_INTERCEPTOR(Parcel, ParcelInterceptor);
TEST(InterceptorTest, PassTheParcel) {
auto msg_bus = std::make_shared<MessageBus>();
Carrier* carrier_0 = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier_0->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0});
carrier_0->SetMsgBus(msg_bus);
Carrier* carrier_1 = GlobalMap<int64_t, Carrier>::Create(1, 1);
carrier_1->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {1});
carrier_1->SetMsgBus(msg_bus);
Carrier* carrier_2 = GlobalMap<int64_t, Carrier>::Create(2, 2);
carrier_2->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {2});
carrier_2->SetMsgBus(msg_bus);
Carrier* carrier_3 = GlobalMap<int64_t, Carrier>::Create(3, 3);
carrier_3->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {3});
carrier_3->SetMsgBus(msg_bus);
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
Interceptor* a = carrier_0->SetInterceptor(
0, InterceptorFactory::Create("Parcel", 0, nullptr));
carrier_1->SetInterceptor(1,
InterceptorFactory::Create("Parcel", 1, nullptr));
carrier_2->SetInterceptor(2,
InterceptorFactory::Create("Parcel", 2, nullptr));
carrier_3->SetInterceptor(3,
InterceptorFactory::Create("Parcel", 3, nullptr));
InterceptorMessage msg;
a->Send(1, msg);
carrier_0->Wait();
}
} // namespace distributed
} // namespace paddle
...@@ -60,8 +60,10 @@ class PingPongInterceptor : public Interceptor { ...@@ -60,8 +60,10 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor); REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
TEST(InterceptorTest, PingPong) { TEST(InterceptorTest, PingPong) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0); std::string carrier_id = "0";
carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1}); Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}});
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier->SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
......
...@@ -106,16 +106,18 @@ TEST(InterceptorTest, PingPong) { ...@@ -106,16 +106,18 @@ TEST(InterceptorTest, PingPong) {
std::cout << "ip1: " << ip1 << std::endl; std::cout << "ip1: " << ip1 << std::endl;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank = {{0, 0}, std::unordered_map<int64_t, int64_t> interceptor_id_to_rank = {{0, 0},
{1, 1}}; {1, 1}};
std::string carrier_id = "0";
int pid = fork(); int pid = fork();
if (pid == 0) { if (pid == 0) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0); Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
GlobalVal<std::string>::Set(carrier_id);
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
// NOTE: need Init msg_bus after carrier SetMsgBus // NOTE: need Init msg_bus after carrier SetMsgBus
carrier->Init(0, interceptor_id_to_rank, {0}); carrier->Init(0, interceptor_id_to_rank);
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
carrier->SetMsgBus(msg_bus);
Interceptor* a = carrier->SetInterceptor( Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr)); 0, InterceptorFactory::Create("PingPong", 0, nullptr));
msg_bus->Barrier(); msg_bus->Barrier();
...@@ -123,10 +125,12 @@ TEST(InterceptorTest, PingPong) { ...@@ -123,10 +125,12 @@ TEST(InterceptorTest, PingPong) {
a->Send(1, msg); a->Send(1, msg);
carrier->Wait(); carrier->Wait();
} else { } else {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0); Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
GlobalVal<std::string>::Set(carrier_id);
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
carrier->Init(1, interceptor_id_to_rank, {1}); carrier->Init(1, interceptor_id_to_rank);
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetInterceptor(1, carrier->SetInterceptor(1,
InterceptorFactory::Create("PingPong", 1, nullptr)); InterceptorFactory::Create("PingPong", 1, nullptr));
......
...@@ -52,9 +52,10 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) { ...@@ -52,9 +52,10 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
} }
TEST(AmplifierInterceptor, Amplifier) { TEST(AmplifierInterceptor, Amplifier) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0); std::string carrier_id = "0";
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}, Carrier* carrier =
{0, 1, 2, 3, 4, 5}); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}});
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
carrier->SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
......
...@@ -70,8 +70,10 @@ void LinkNodes(const std::vector<TaskNode*>& nodes, ...@@ -70,8 +70,10 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
} }
TEST(AmplifierInterceptor, Amplifier) { TEST(AmplifierInterceptor, Amplifier) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0); std::string carrier_id = "0";
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0, 1, 2, 3}); Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}});
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, ""}}, ""); msg_bus->Init(0, {{0, ""}}, "");
carrier->SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
......
...@@ -1956,7 +1956,11 @@ class Executor(object): ...@@ -1956,7 +1956,11 @@ class Executor(object):
return ctx return ctx
def _prepare_fleet_executor(self, program=None, scope=None, fleet_opt=None): def _prepare_fleet_executor(self,
carrier_id="",
program=None,
scope=None,
fleet_opt=None):
from ..distributed.fleet.proto import fleet_executor_desc_pb2 from ..distributed.fleet.proto import fleet_executor_desc_pb2
assert program, "Program for fleet executor should not be None" assert program, "Program for fleet executor should not be None"
assert fleet_opt, "Configurations for fleet executor should not be None" assert fleet_opt, "Configurations for fleet executor should not be None"
...@@ -2014,7 +2018,8 @@ class Executor(object): ...@@ -2014,7 +2018,8 @@ class Executor(object):
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
place = core.Place() place = core.Place()
place.set_place(self.place) place.set_place(self.place)
fleet_exe.init(program.desc, scope, place, tasks, task_id_to_rank) fleet_exe.init(carrier_id, program.desc, scope, place, tasks,
task_id_to_rank)
return fleet_exe return fleet_exe
def _run_using_fleet_executor(self, def _run_using_fleet_executor(self,
...@@ -2023,6 +2028,7 @@ class Executor(object): ...@@ -2023,6 +2028,7 @@ class Executor(object):
feed_var_name="feed", feed_var_name="feed",
fetch_var_name="fetch", fetch_var_name="fetch",
fetch_list=None): fetch_list=None):
# TODO(liyurui): Change cache strategy for multi carriers
cache_key = _get_strong_program_cache_key(program, feed, fetch_list) cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
cached_ctx = self._get_ctx_cache(cache_key) cached_ctx = self._get_ctx_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key) cached_scope = self._get_scope_cache(cache_key)
...@@ -2088,7 +2094,10 @@ class Executor(object): ...@@ -2088,7 +2094,10 @@ class Executor(object):
fetch_task.set_program(fetch_program) fetch_task.set_program(fetch_program)
cached_ctx = self._prepare_fleet_executor( cached_ctx = self._prepare_fleet_executor(
program=cached_program, scope=cached_scope, fleet_opt=fleet_opt) cache_key,
program=cached_program,
scope=cached_scope,
fleet_opt=fleet_opt)
self._add_ctx_cache(cache_key, cached_ctx) self._add_ctx_cache(cache_key, cached_ctx)
if feed: if feed:
# NOTE: don't have to traverse programs in task nodes, # NOTE: don't have to traverse programs in task nodes,
...@@ -2107,7 +2116,7 @@ class Executor(object): ...@@ -2107,7 +2116,7 @@ class Executor(object):
lr_sheduler._var_name) lr_sheduler._var_name)
tensor.set(data, self.place) tensor.set(data, self.place)
cached_ctx.run() cached_ctx.run(cache_key)
if fetch_list: if fetch_list:
arr = cached_scope.find_var(fetch_var_name).get_fetch_list() arr = cached_scope.find_var(fetch_var_name).get_fetch_list()
tensors = arr._move_to_list() tensors = arr._move_to_list()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册