未验证 提交 3658405c 编写于 作者: L LiYuRio 提交者: GitHub

[Fleet Executor] Support multi carrier (#38535)

上级 2421a25a
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -27,16 +28,32 @@ namespace distributed { ...@@ -27,16 +28,32 @@ namespace distributed {
USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier); USE_INTERCEPTOR(Amplifier);
void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph, void Carrier::Init(
framework::Scope* root_scope, int64_t rank,
framework::Scope* minibatch_scope, const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;
// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
}
void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes, const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) { const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
rank_ = rank; rank_ = rank;
runtime_graph_ = runtime_graph; interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank(); interceptor_ids_ = interceptor_ids;
interceptor_id_to_node_ = interceptor_id_to_node;
minibatch_scope_ = minibatch_scope; minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes; microbatch_scopes_ = microbatch_scopes;
place_ = place; place_ = place;
...@@ -72,8 +89,6 @@ bool Carrier::EnqueueInterceptorMessage( ...@@ -72,8 +89,6 @@ bool Carrier::EnqueueInterceptorMessage(
return true; return true;
} }
void Carrier::Barrier() { msg_bus_->Barrier(); }
Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(), PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(),
...@@ -100,7 +115,8 @@ void Carrier::Start() { ...@@ -100,7 +115,8 @@ void Carrier::Start() {
"Using message bus since it has not been initialized. " "Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or " "Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready.")); "neccessary components are not ready."));
PADDLE_ENFORCE_EQ(is_init_, true, platform::errors::PreconditionNotMet(
"Using carrier before initialized."));
for (int64_t id : source_interceptor_ids_) { for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Start is sending start to source interceptor " << id VLOG(3) << "Carrier Start is sending start to source interceptor " << id
<< "."; << ".";
...@@ -140,7 +156,9 @@ bool Carrier::Send(const InterceptorMessage& msg) { ...@@ -140,7 +156,9 @@ bool Carrier::Send(const InterceptorMessage& msg) {
if (src_rank == dst_rank) { if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks."; << " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg); int64_t carrier_id = *GlobalMap<int64_t, int64_t>::Get(dst_id);
return GlobalMap<int64_t, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(msg);
} else { } else {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(), msg_bus_.get(),
...@@ -174,6 +192,9 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, ...@@ -174,6 +192,9 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
loop, platform::errors::Fatal("thread task loop must not null")); loop, platform::errors::Fatal("thread task loop must not null"));
interceptor->RegisterTaskLoop(loop); interceptor->RegisterTaskLoop(loop);
// TODO(liyurui): Using struct InterceptorID replace int64_t
GlobalMap<int64_t, int64_t>::Create(interceptor_id, carrier_id_);
auto* ptr = interceptor.get(); auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert( interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor))); std::make_pair(interceptor_id, std::move(interceptor)));
...@@ -199,15 +220,19 @@ static std::shared_ptr<framework::GarbageCollector> GetGC( ...@@ -199,15 +220,19 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
} }
void Carrier::CreateInterceptors() { void Carrier::CreateInterceptors() {
if (runtime_graph_->interceptor_id_to_node().empty()) return; if (interceptor_ids_.empty()) return;
auto gc = GetGC(place_); auto gc = GetGC(place_);
// create each Interceptor // create each Interceptor
// no auto init since there is no config // no auto init since there is no config
for (const auto& item : runtime_graph_->interceptor_id_to_node()) { for (int64_t interceptor_id : interceptor_ids_) {
int64_t interceptor_id = item.first; const auto& task_node_iter = interceptor_id_to_node_.find(interceptor_id);
TaskNode* task_node = item.second; PADDLE_ENFORCE_NE(
task_node_iter, interceptor_id_to_node_.end(),
platform::errors::NotFound("Can not find task node for interceptor %ld",
interceptor_id));
TaskNode* task_node = task_node_iter->second;
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(), task_node->run_at_offset(), task_node->run_per_steps(),
......
...@@ -45,16 +45,16 @@ class MessageBus; ...@@ -45,16 +45,16 @@ class MessageBus;
class Carrier final { class Carrier final {
public: public:
Carrier() = default; explicit Carrier(int64_t carrier_id) : carrier_id_(carrier_id) {}
Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
}
~Carrier(); ~Carrier();
void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph, void Init(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids);
void Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope, framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes, const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place); const platform::Place& place);
...@@ -83,10 +83,9 @@ class Carrier final { ...@@ -83,10 +83,9 @@ class Carrier final {
bool Send(const InterceptorMessage& msg); bool Send(const InterceptorMessage& msg);
void Barrier();
private: private:
DISABLE_COPY_AND_ASSIGN(Carrier); DISABLE_COPY_AND_ASSIGN(Carrier);
Carrier() = delete;
// create each Interceptor // create each Interceptor
void CreateInterceptors(); void CreateInterceptors();
...@@ -108,13 +107,14 @@ class Carrier final { ...@@ -108,13 +107,14 @@ class Carrier final {
framework::Scope* minibatch_scope_; framework::Scope* minibatch_scope_;
paddle::platform::Place place_; paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_{nullptr}; paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<MessageBus> msg_bus_; std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_; int64_t rank_;
int64_t carrier_id_;
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_; std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
int thread_num_; int thread_num_;
TaskLoopThreadPool thread_pool_; TaskLoopThreadPool thread_pool_;
std::unordered_set<int64_t> interceptor_ids_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -27,8 +28,6 @@ ...@@ -27,8 +28,6 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
std::unique_ptr<Carrier> FleetExecutor::carrier_;
FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str); bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet( PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
...@@ -37,13 +36,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { ...@@ -37,13 +36,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor::~FleetExecutor() { FleetExecutor::~FleetExecutor() {
root_scope_->DropKids(); root_scope_->DropKids();
GetCarrier()->Release(); for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
} GlobalMap<int64_t, Carrier>::Get(item.first)->Release();
}
Carrier* FleetExecutor::GetCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_.get(), platform::errors::NotFound(
"Carrier has not been created."));
return carrier_.get();
} }
void FleetExecutor::Init( void FleetExecutor::Init(
...@@ -63,13 +58,19 @@ void FleetExecutor::Init( ...@@ -63,13 +58,19 @@ void FleetExecutor::Init(
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
runtime_graph_ = std::make_shared<RuntimeGraph>(); runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids;
std::unordered_set<int64_t> interceptor_ids;
for (auto task_node : task_nodes) { for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars); task_node->SetUnusedVars(unused_vars);
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
interceptor_ids.insert(interceptor_id);
} }
carrier_id_to_interceptor_ids.emplace(0, interceptor_ids);
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
runtime_graph_->SetCarrierIdToInterceptorIds(carrier_id_to_interceptor_ids);
for (auto& unique_op : ops) { for (auto& unique_op : ops) {
unique_op.release(); unique_op.release();
} }
...@@ -86,20 +87,25 @@ void FleetExecutor::Init( ...@@ -86,20 +87,25 @@ void FleetExecutor::Init(
} }
VLOG(5) << runtime_graph_->DebugString(); VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>(); msg_bus_ = std::make_shared<MessageBus>();
CreateCarrier(); for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Create(item.first, item.first);
}
InitCarrier(); InitCarrier();
InitMessageBus(); InitMessageBus();
// refine this? wait all carrier ready // Wait for all message bus connected.
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use msg_bus_->Barrier();
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier()->Barrier();
} }
void FleetExecutor::InitCarrier() { void FleetExecutor::InitCarrier() {
if (!GetCarrier()->IsInit()) { for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GetCarrier()->SetMsgBus(msg_bus_); Carrier* carrier = GlobalMap<int64_t, Carrier>::Get(item.first);
GetCarrier()->Init(exe_desc_.cur_rank(), runtime_graph_, root_scope_, PADDLE_ENFORCE_NOT_NULL(carrier, platform::errors::InvalidArgument(
"Carrier has not been created."));
carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(),
runtime_graph_->interceptor_id_to_rank(), item.second,
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_); minibatch_scope_, microbatch_scopes_, place_);
} }
} }
...@@ -140,14 +146,9 @@ void FleetExecutor::InitMessageBus() { ...@@ -140,14 +146,9 @@ void FleetExecutor::InitMessageBus() {
} }
void FleetExecutor::Run() { void FleetExecutor::Run() {
// Run for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
PADDLE_ENFORCE_EQ( GlobalMap<int64_t, Carrier>::Get(item.first)->Start();
GetCarrier()->IsInit(), true, }
platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
GetCarrier()->Start();
for (auto* micro_scop : microbatch_scopes_) { for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because // By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op. // some operators may create local scope when running, such as while_op.
......
...@@ -42,16 +42,6 @@ class FleetExecutor final { ...@@ -42,16 +42,6 @@ class FleetExecutor final {
const std::vector<TaskNode*>& task_nodes, const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank); const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run(); void Run();
// TODO(liyurui): Change to use registry table for multi-carrier.
static Carrier* GetCarrier();
template <typename... Args>
static Carrier* CreateCarrier(Args&&... args) {
PADDLE_ENFORCE_EQ(
carrier_.get(), nullptr,
platform::errors::AlreadyExists("Carrier has been created already."));
carrier_ = std::make_unique<Carrier>(std::forward<Args>(args)...);
return carrier_.get();
}
private: private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor); DISABLE_COPY_AND_ASSIGN(FleetExecutor);
...@@ -67,7 +57,6 @@ class FleetExecutor final { ...@@ -67,7 +57,6 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus, // The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race. // using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_; std::shared_ptr<MessageBus> msg_bus_;
static std::unique_ptr<Carrier> carrier_;
}; };
} // namespace distributed } // namespace distributed
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace distributed {
template <typename KeyT, typename ValueT>
class GlobalMap final {
public:
static ValueT* Get(KeyT id) {
ValueT* item = GetPPtr(id)->get();
PADDLE_ENFORCE_NOT_NULL(
item, platform::errors::NotFound("This value is not in global map."));
return item;
}
template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
platform::errors::AlreadyExists(
"This value has already in global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...);
ptr->reset(item);
return item;
}
private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::mutex mutex;
static std::unordered_map<KeyT, std::unique_ptr<ValueT>> id_to_ptr;
std::unique_lock<std::mutex> lock(mutex);
return &id_to_ptr[id];
}
};
} // namespace distributed
} // namespace paddle
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/global_map.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -29,7 +29,15 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( ...@@ -29,7 +29,15 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor " VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id() << request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type(); << ", with the message: " << request->message_type();
bool flag = FleetExecutor::GetCarrier()->EnqueueInterceptorMessage(*request); // TODO(liyurui): Remove this hard code.
int64_t carrier_id;
if (request->ctrl_message()) {
carrier_id = 0;
} else {
carrier_id = *GlobalMap<int64_t, int64_t>::Get(request->dst_id());
}
bool flag = GlobalMap<int64_t, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(*request);
response->set_rst(flag); response->set_rst(flag);
} }
......
...@@ -35,6 +35,10 @@ class RuntimeGraph final { ...@@ -35,6 +35,10 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const { const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const {
return interceptor_id_to_rank_; return interceptor_id_to_rank_;
} }
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids() const {
return carrier_id_to_interceptor_ids_;
}
void SetInterceptorIdToRank( void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) { const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
interceptor_id_to_rank_ = interceptor_id_to_rank; interceptor_id_to_rank_ = interceptor_id_to_rank;
...@@ -43,12 +47,19 @@ class RuntimeGraph final { ...@@ -43,12 +47,19 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) { const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
interceptor_id_to_node_ = interceptor_id_to_node; interceptor_id_to_node_ = interceptor_id_to_node;
} }
void SetCarrierIdToInterceptorIds(
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids) {
carrier_id_to_interceptor_ids_ = carrier_id_to_interceptor_ids;
}
std::string DebugString() const; std::string DebugString() const;
private: private:
DISABLE_COPY_AND_ASSIGN(RuntimeGraph); DISABLE_COPY_AND_ASSIGN(RuntimeGraph);
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_; std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -13,6 +13,9 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_ ...@@ -13,6 +13,9 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_
set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context) cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context)
set_source_files_properties(interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor ${BRPC_DEPS})
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS}) cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS})
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -62,11 +62,12 @@ TEST(ComputeInterceptor, Compute) { ...@@ -62,11 +62,12 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope}; std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace(); platform::Place place = platform::CPUPlace();
Carrier carrier(0, {{0, 0}, {1, 0}}); Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1});
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
// FIXME: don't delete, otherwise interceptor will use undefined node // FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = TaskNode* node_a =
...@@ -77,9 +78,9 @@ TEST(ComputeInterceptor, Compute) { ...@@ -77,9 +78,9 @@ TEST(ComputeInterceptor, Compute) {
node_a->AddDownstreamTask(1); node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0); node_b->AddUpstreamTask(0);
auto* a = carrier.SetInterceptor( auto* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a)); 0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
a->SetPlace(place); a->SetPlace(place);
a->SetMicroBatchScope(scopes); a->SetMicroBatchScope(scopes);
...@@ -89,10 +90,10 @@ TEST(ComputeInterceptor, Compute) { ...@@ -89,10 +90,10 @@ TEST(ComputeInterceptor, Compute) {
msg.set_message_type(DATA_IS_READY); msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1); msg.set_src_id(-1);
msg.set_dst_id(0); msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg); carrier->EnqueueInterceptorMessage(msg);
carrier.Wait(); carrier->Wait();
carrier.Release(); carrier->Release();
} }
} // namespace distributed } // namespace distributed
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -47,11 +47,12 @@ class StartInterceptor : public Interceptor { ...@@ -47,11 +47,12 @@ class StartInterceptor : public Interceptor {
}; };
TEST(ComputeInterceptor, Compute) { TEST(ComputeInterceptor, Compute) {
Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}}); Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}, {0, 1, 2});
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
...@@ -65,9 +66,9 @@ TEST(ComputeInterceptor, Compute) { ...@@ -65,9 +66,9 @@ TEST(ComputeInterceptor, Compute) {
node_c->AddUpstreamTask(1); node_c->AddUpstreamTask(1);
Interceptor* a = Interceptor* a =
carrier.SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a)); carrier->SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY); msg.set_message_type(DATA_IS_READY);
...@@ -76,8 +77,8 @@ TEST(ComputeInterceptor, Compute) { ...@@ -76,8 +77,8 @@ TEST(ComputeInterceptor, Compute) {
a->Send(1, msg); a->Send(1, msg);
a->Send(1, msg); a->Send(1, msg);
carrier.Wait(); carrier->Wait();
carrier.Release(); carrier->Release();
} }
} // namespace distributed } // namespace distributed
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
class ParcelInterceptor : public Interceptor {
public:
ParcelInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle(
[this](const InterceptorMessage& msg) { PassParcel(msg); });
}
void PassParcel(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
if (count_ == 5 && interceptor_id_ == 0) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
Send(2, stop);
Send(3, stop);
StopCarrier();
return;
}
++count_;
InterceptorMessage new_msg;
if (msg.dst_id() == 3) {
Send(0, new_msg);
} else {
Send(msg.dst_id() + 1, new_msg);
}
}
private:
int count_{0};
};
REGISTER_INTERCEPTOR(Parcel, ParcelInterceptor);
TEST(InterceptorTest, PassTheParcel) {
auto msg_bus = std::make_shared<MessageBus>();
Carrier* carrier_0 = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier_0->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0});
carrier_0->SetMsgBus(msg_bus);
Carrier* carrier_1 = GlobalMap<int64_t, Carrier>::Create(1, 1);
carrier_1->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {1});
carrier_1->SetMsgBus(msg_bus);
Carrier* carrier_2 = GlobalMap<int64_t, Carrier>::Create(2, 2);
carrier_2->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {2});
carrier_2->SetMsgBus(msg_bus);
Carrier* carrier_3 = GlobalMap<int64_t, Carrier>::Create(3, 3);
carrier_3->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {3});
carrier_3->SetMsgBus(msg_bus);
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
Interceptor* a = carrier_0->SetInterceptor(
0, InterceptorFactory::Create("Parcel", 0, nullptr));
carrier_1->SetInterceptor(1,
InterceptorFactory::Create("Parcel", 1, nullptr));
carrier_2->SetInterceptor(2,
InterceptorFactory::Create("Parcel", 2, nullptr));
carrier_3->SetInterceptor(3,
InterceptorFactory::Create("Parcel", 3, nullptr));
InterceptorMessage msg;
a->Send(1, msg);
carrier_0->Wait();
}
} // namespace distributed
} // namespace paddle
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -59,20 +60,21 @@ class PingPongInterceptor : public Interceptor { ...@@ -59,20 +60,21 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor); REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
TEST(InterceptorTest, PingPong) { TEST(InterceptorTest, PingPong) {
Carrier carrier(0, {{0, 0}, {1, 0}}); Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1});
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
Interceptor* a = carrier.SetInterceptor( Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr)); 0, InterceptorFactory::Create("PingPong", 0, nullptr));
carrier.SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr)); carrier->SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr));
InterceptorMessage msg; InterceptorMessage msg;
a->Send(1, msg); a->Send(1, msg);
carrier.Wait(); carrier->Wait();
} }
} // namespace distributed } // namespace distributed
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -107,42 +107,31 @@ TEST(InterceptorTest, PingPong) { ...@@ -107,42 +107,31 @@ TEST(InterceptorTest, PingPong) {
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank = {{0, 0}, std::unordered_map<int64_t, int64_t> interceptor_id_to_rank = {{0, 0},
{1, 1}}; {1, 1}};
int exe_pid = fork();
if (exe_pid == 0) {
int pid = fork(); int pid = fork();
if (pid == 0) { if (pid == 0) {
Carrier* carrier = Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
FleetExecutor::CreateCarrier(0, interceptor_id_to_rank);
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
// NOTE: need Init msg_bus after carrier SetMsgBus // NOTE: need Init msg_bus after carrier SetMsgBus
carrier->Init(0, interceptor_id_to_rank, {0});
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
carrier->SetMsgBus(msg_bus);
Interceptor* a = carrier->SetInterceptor( Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr)); 0, InterceptorFactory::Create("PingPong", 0, nullptr));
carrier->Barrier(); msg_bus->Barrier();
InterceptorMessage msg; InterceptorMessage msg;
a->Send(1, msg); a->Send(1, msg);
carrier->Wait(); carrier->Wait();
} else { } else {
Carrier* carrier = Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
FleetExecutor::CreateCarrier(1, interceptor_id_to_rank);
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
carrier->Init(1, interceptor_id_to_rank, {1});
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetInterceptor( carrier->SetInterceptor(1,
1, InterceptorFactory::Create("PingPong", 1, nullptr)); InterceptorFactory::Create("PingPong", 1, nullptr));
carrier->Barrier(); msg_bus->Barrier();
carrier->Wait(); carrier->Wait();
int status;
int ret = waitpid(pid, &status, 0);
CHECK_EQ(ret, pid);
}
} else {
int status;
int ret = waitpid(exe_pid, &status, 0);
CHECK_EQ(ret, exe_pid);
} }
} }
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) { ...@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
} }
TEST(AmplifierInterceptor, Amplifier) { TEST(AmplifierInterceptor, Amplifier) {
Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}); Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}},
{0, 1, 2, 3, 4, 5});
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
carrier.SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
int64_t micro_steps = 3; int64_t micro_steps = 3;
...@@ -73,21 +76,23 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -73,21 +76,23 @@ TEST(AmplifierInterceptor, Amplifier) {
node_b->SetReplyUpPerSteps(micro_steps); node_b->SetReplyUpPerSteps(micro_steps);
node_e->SetSendDownPerSteps(micro_steps); node_e->SetSendDownPerSteps(micro_steps);
carrier.SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a)); carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Amplifier", 1, node_b)); carrier->SetInterceptor(1,
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); InterceptorFactory::Create("Amplifier", 1, node_b));
carrier.SetInterceptor(3, InterceptorFactory::Create("Compute", 3, node_d)); carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier.SetInterceptor(4, InterceptorFactory::Create("Amplifier", 4, node_e)); carrier->SetInterceptor(3, InterceptorFactory::Create("Compute", 3, node_d));
carrier.SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f)); carrier->SetInterceptor(4,
InterceptorFactory::Create("Amplifier", 4, node_e));
carrier->SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f));
// start // start
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY); msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1); msg.set_src_id(-1);
msg.set_dst_id(0); msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg); carrier->EnqueueInterceptorMessage(msg);
carrier.Wait(); carrier->Wait();
carrier.Release(); carrier->Release();
} }
} // namespace distributed } // namespace distributed
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -69,10 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes, ...@@ -69,10 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
} }
TEST(AmplifierInterceptor, Amplifier) { TEST(AmplifierInterceptor, Amplifier) {
Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}); Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0, 1, 2, 3});
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, ""}}, ""); msg_bus->Init(0, {{0, ""}}, "");
carrier.SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
int64_t micro_steps = 6; int64_t micro_steps = 6;
...@@ -91,19 +93,21 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -91,19 +93,21 @@ TEST(AmplifierInterceptor, Amplifier) {
node_d->SetRunPerSteps(micro_steps); node_d->SetRunPerSteps(micro_steps);
node_d->SetRunAtOffset(micro_steps - 1); node_d->SetRunAtOffset(micro_steps - 1);
carrier.SetInterceptor(0, InterceptorFactory::Create("Amplifier", 0, node_a)); carrier->SetInterceptor(0,
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); InterceptorFactory::Create("Amplifier", 0, node_a));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(3, InterceptorFactory::Create("Amplifier", 3, node_d)); carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(3,
InterceptorFactory::Create("Amplifier", 3, node_d));
// start // start
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY); msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1); msg.set_src_id(-1);
msg.set_dst_id(0); msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg); carrier->EnqueueInterceptorMessage(msg);
carrier.Wait(); carrier->Wait();
carrier.Release(); carrier->Release();
} }
} // namespace distributed } // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册