未验证 提交 3658405c 编写于 作者: L LiYuRio 提交者: GitHub

[Fleet Executor] Support multi carrier (#38535)

上级 2421a25a
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
......@@ -27,16 +28,32 @@ namespace distributed {
USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);
void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope,
framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids) {
rank_ = rank;
runtime_graph_ = runtime_graph;
interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank();
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;
// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
}
void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;
interceptor_id_to_node_ = interceptor_id_to_node;
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
......@@ -72,8 +89,6 @@ bool Carrier::EnqueueInterceptorMessage(
return true;
}
void Carrier::Barrier() { msg_bus_->Barrier(); }
Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(),
......@@ -100,7 +115,8 @@ void Carrier::Start() {
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
PADDLE_ENFORCE_EQ(is_init_, true, platform::errors::PreconditionNotMet(
"Using carrier before initialized."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Start is sending start to source interceptor " << id
<< ".";
......@@ -140,7 +156,9 @@ bool Carrier::Send(const InterceptorMessage& msg) {
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg);
int64_t carrier_id = *GlobalMap<int64_t, int64_t>::Get(dst_id);
return GlobalMap<int64_t, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
......@@ -174,6 +192,9 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
loop, platform::errors::Fatal("thread task loop must not null"));
interceptor->RegisterTaskLoop(loop);
// TODO(liyurui): Using struct InterceptorID replace int64_t
GlobalMap<int64_t, int64_t>::Create(interceptor_id, carrier_id_);
auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
......@@ -199,15 +220,19 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}
void Carrier::CreateInterceptors() {
if (runtime_graph_->interceptor_id_to_node().empty()) return;
if (interceptor_ids_.empty()) return;
auto gc = GetGC(place_);
// create each Interceptor
// no auto init since there is no config
for (const auto& item : runtime_graph_->interceptor_id_to_node()) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
for (int64_t interceptor_id : interceptor_ids_) {
const auto& task_node_iter = interceptor_id_to_node_.find(interceptor_id);
PADDLE_ENFORCE_NE(
task_node_iter, interceptor_id_to_node_.end(),
platform::errors::NotFound("Can not find task node for interceptor %ld",
interceptor_id));
TaskNode* task_node = task_node_iter->second;
PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(),
......
......@@ -45,19 +45,19 @@ class MessageBus;
class Carrier final {
public:
Carrier() = default;
Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
}
explicit Carrier(int64_t carrier_id) : carrier_id_(carrier_id) {}
~Carrier();
void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
void Init(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids);
void Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
void Release();
void Wait();
......@@ -83,10 +83,9 @@ class Carrier final {
bool Send(const InterceptorMessage& msg);
void Barrier();
private:
DISABLE_COPY_AND_ASSIGN(Carrier);
Carrier() = delete;
// create each Interceptor
void CreateInterceptors();
......@@ -108,13 +107,14 @@ class Carrier final {
framework::Scope* minibatch_scope_;
paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
int64_t carrier_id_;
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
int thread_num_;
TaskLoopThreadPool thread_pool_;
std::unordered_set<int64_t> interceptor_ids_;
};
} // namespace distributed
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -27,8 +28,6 @@
namespace paddle {
namespace distributed {
std::unique_ptr<Carrier> FleetExecutor::carrier_;
FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
......@@ -37,13 +36,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
GetCarrier()->Release();
}
Carrier* FleetExecutor::GetCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_.get(), platform::errors::NotFound(
"Carrier has not been created."));
return carrier_.get();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Get(item.first)->Release();
}
}
void FleetExecutor::Init(
......@@ -63,13 +58,19 @@ void FleetExecutor::Init(
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids;
std::unordered_set<int64_t> interceptor_ids;
for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars);
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
interceptor_ids.insert(interceptor_id);
}
carrier_id_to_interceptor_ids.emplace(0, interceptor_ids);
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
runtime_graph_->SetCarrierIdToInterceptorIds(carrier_id_to_interceptor_ids);
for (auto& unique_op : ops) {
unique_op.release();
}
......@@ -86,21 +87,26 @@ void FleetExecutor::Init(
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
CreateCarrier();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Create(item.first, item.first);
}
InitCarrier();
InitMessageBus();
// refine this? wait all carrier ready
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier()->Barrier();
// Wait for all message bus connected.
msg_bus_->Barrier();
}
void FleetExecutor::InitCarrier() {
if (!GetCarrier()->IsInit()) {
GetCarrier()->SetMsgBus(msg_bus_);
GetCarrier()->Init(exe_desc_.cur_rank(), runtime_graph_, root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Get(item.first);
PADDLE_ENFORCE_NOT_NULL(carrier, platform::errors::InvalidArgument(
"Carrier has not been created."));
carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(),
runtime_graph_->interceptor_id_to_rank(), item.second,
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
}
......@@ -140,14 +146,9 @@ void FleetExecutor::InitMessageBus() {
}
void FleetExecutor::Run() {
// Run
PADDLE_ENFORCE_EQ(
GetCarrier()->IsInit(), true,
platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
GetCarrier()->Start();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Get(item.first)->Start();
}
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
......
......@@ -42,16 +42,6 @@ class FleetExecutor final {
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
// TODO(liyurui): Change to use registry table for multi-carrier.
static Carrier* GetCarrier();
template <typename... Args>
static Carrier* CreateCarrier(Args&&... args) {
PADDLE_ENFORCE_EQ(
carrier_.get(), nullptr,
platform::errors::AlreadyExists("Carrier has been created already."));
carrier_ = std::make_unique<Carrier>(std::forward<Args>(args)...);
return carrier_.get();
}
private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
......@@ -67,7 +57,6 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
static std::unique_ptr<Carrier> carrier_;
};
} // namespace distributed
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace distributed {
template <typename KeyT, typename ValueT>
class GlobalMap final {
public:
static ValueT* Get(KeyT id) {
ValueT* item = GetPPtr(id)->get();
PADDLE_ENFORCE_NOT_NULL(
item, platform::errors::NotFound("This value is not in global map."));
return item;
}
template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
platform::errors::AlreadyExists(
"This value has already in global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...);
ptr->reset(item);
return item;
}
private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::mutex mutex;
static std::unordered_map<KeyT, std::unique_ptr<ValueT>> id_to_ptr;
std::unique_lock<std::mutex> lock(mutex);
return &id_to_ptr[id];
}
};
} // namespace distributed
} // namespace paddle
......@@ -16,7 +16,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
namespace paddle {
namespace distributed {
......@@ -29,7 +29,15 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
bool flag = FleetExecutor::GetCarrier()->EnqueueInterceptorMessage(*request);
// TODO(liyurui): Remove this hard code.
int64_t carrier_id;
if (request->ctrl_message()) {
carrier_id = 0;
} else {
carrier_id = *GlobalMap<int64_t, int64_t>::Get(request->dst_id());
}
bool flag = GlobalMap<int64_t, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(*request);
response->set_rst(flag);
}
......
......@@ -35,6 +35,10 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const {
return interceptor_id_to_rank_;
}
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids() const {
return carrier_id_to_interceptor_ids_;
}
void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
interceptor_id_to_rank_ = interceptor_id_to_rank;
......@@ -43,12 +47,19 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
interceptor_id_to_node_ = interceptor_id_to_node;
}
void SetCarrierIdToInterceptorIds(
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids) {
carrier_id_to_interceptor_ids_ = carrier_id_to_interceptor_ids;
}
std::string DebugString() const;
private:
DISABLE_COPY_AND_ASSIGN(RuntimeGraph);
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids_;
};
} // namespace distributed
......
......@@ -13,6 +13,9 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_
set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context)
set_source_files_properties(interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor ${BRPC_DEPS})
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS})
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -62,11 +62,12 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace();
Carrier carrier(0, {{0, 0}, {1, 0}});
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus);
carrier->SetMsgBus(msg_bus);
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a =
......@@ -77,9 +78,9 @@ TEST(ComputeInterceptor, Compute) {
node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0);
auto* a = carrier.SetInterceptor(
auto* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
a->SetPlace(place);
a->SetMicroBatchScope(scopes);
......@@ -89,10 +90,10 @@ TEST(ComputeInterceptor, Compute) {
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg);
carrier->EnqueueInterceptorMessage(msg);
carrier.Wait();
carrier.Release();
carrier->Wait();
carrier->Release();
}
} // namespace distributed
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -47,11 +47,12 @@ class StartInterceptor : public Interceptor {
};
TEST(ComputeInterceptor, Compute) {
Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}});
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}, {0, 1, 2});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus);
carrier->SetMsgBus(msg_bus);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
......@@ -65,9 +66,9 @@ TEST(ComputeInterceptor, Compute) {
node_c->AddUpstreamTask(1);
Interceptor* a =
carrier.SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
......@@ -76,8 +77,8 @@ TEST(ComputeInterceptor, Compute) {
a->Send(1, msg);
a->Send(1, msg);
carrier.Wait();
carrier.Release();
carrier->Wait();
carrier->Release();
}
} // namespace distributed
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
class ParcelInterceptor : public Interceptor {
public:
ParcelInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle(
[this](const InterceptorMessage& msg) { PassParcel(msg); });
}
void PassParcel(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
if (count_ == 5 && interceptor_id_ == 0) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
Send(2, stop);
Send(3, stop);
StopCarrier();
return;
}
++count_;
InterceptorMessage new_msg;
if (msg.dst_id() == 3) {
Send(0, new_msg);
} else {
Send(msg.dst_id() + 1, new_msg);
}
}
private:
int count_{0};
};
REGISTER_INTERCEPTOR(Parcel, ParcelInterceptor);
TEST(InterceptorTest, PassTheParcel) {
auto msg_bus = std::make_shared<MessageBus>();
Carrier* carrier_0 = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier_0->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0});
carrier_0->SetMsgBus(msg_bus);
Carrier* carrier_1 = GlobalMap<int64_t, Carrier>::Create(1, 1);
carrier_1->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {1});
carrier_1->SetMsgBus(msg_bus);
Carrier* carrier_2 = GlobalMap<int64_t, Carrier>::Create(2, 2);
carrier_2->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {2});
carrier_2->SetMsgBus(msg_bus);
Carrier* carrier_3 = GlobalMap<int64_t, Carrier>::Create(3, 3);
carrier_3->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {3});
carrier_3->SetMsgBus(msg_bus);
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
Interceptor* a = carrier_0->SetInterceptor(
0, InterceptorFactory::Create("Parcel", 0, nullptr));
carrier_1->SetInterceptor(1,
InterceptorFactory::Create("Parcel", 1, nullptr));
carrier_2->SetInterceptor(2,
InterceptorFactory::Create("Parcel", 2, nullptr));
carrier_3->SetInterceptor(3,
InterceptorFactory::Create("Parcel", 3, nullptr));
InterceptorMessage msg;
a->Send(1, msg);
carrier_0->Wait();
}
} // namespace distributed
} // namespace paddle
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
......@@ -59,20 +60,21 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
TEST(InterceptorTest, PingPong) {
Carrier carrier(0, {{0, 0}, {1, 0}});
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus);
carrier->SetMsgBus(msg_bus);
Interceptor* a = carrier.SetInterceptor(
Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr));
carrier.SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr));
carrier->SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr));
InterceptorMessage msg;
a->Send(1, msg);
carrier.Wait();
carrier->Wait();
}
} // namespace distributed
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
......@@ -107,42 +107,31 @@ TEST(InterceptorTest, PingPong) {
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank = {{0, 0},
{1, 1}};
int exe_pid = fork();
if (exe_pid == 0) {
int pid = fork();
if (pid == 0) {
Carrier* carrier =
FleetExecutor::CreateCarrier(0, interceptor_id_to_rank);
auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus);
// NOTE: need Init msg_bus after carrier SetMsgBus
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr));
carrier->Barrier();
InterceptorMessage msg;
a->Send(1, msg);
carrier->Wait();
} else {
Carrier* carrier =
FleetExecutor::CreateCarrier(1, interceptor_id_to_rank);
auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus);
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetInterceptor(
1, InterceptorFactory::Create("PingPong", 1, nullptr));
carrier->Barrier();
carrier->Wait();
int status;
int ret = waitpid(pid, &status, 0);
CHECK_EQ(ret, pid);
}
int pid = fork();
if (pid == 0) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus);
// NOTE: need Init msg_bus after carrier SetMsgBus
carrier->Init(0, interceptor_id_to_rank, {0});
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
carrier->SetMsgBus(msg_bus);
Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr));
msg_bus->Barrier();
InterceptorMessage msg;
a->Send(1, msg);
carrier->Wait();
} else {
int status;
int ret = waitpid(exe_pid, &status, 0);
CHECK_EQ(ret, exe_pid);
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus);
carrier->Init(1, interceptor_id_to_rank, {1});
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetInterceptor(1,
InterceptorFactory::Create("PingPong", 1, nullptr));
msg_bus->Barrier();
carrier->Wait();
}
}
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
TEST(AmplifierInterceptor, Amplifier) {
Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}});
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}},
{0, 1, 2, 3, 4, 5});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
carrier.SetMsgBus(msg_bus);
carrier->SetMsgBus(msg_bus);
int64_t micro_steps = 3;
......@@ -73,21 +76,23 @@ TEST(AmplifierInterceptor, Amplifier) {
node_b->SetReplyUpPerSteps(micro_steps);
node_e->SetSendDownPerSteps(micro_steps);
carrier.SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Amplifier", 1, node_b));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier.SetInterceptor(3, InterceptorFactory::Create("Compute", 3, node_d));
carrier.SetInterceptor(4, InterceptorFactory::Create("Amplifier", 4, node_e));
carrier.SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f));
carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1,
InterceptorFactory::Create("Amplifier", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(3, InterceptorFactory::Create("Compute", 3, node_d));
carrier->SetInterceptor(4,
InterceptorFactory::Create("Amplifier", 4, node_e));
carrier->SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f));
// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg);
carrier.Wait();
carrier.Release();
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -69,10 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
TEST(AmplifierInterceptor, Amplifier) {
Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}});
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0, 1, 2, 3});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, ""}}, "");
carrier.SetMsgBus(msg_bus);
carrier->SetMsgBus(msg_bus);
int64_t micro_steps = 6;
......@@ -91,19 +93,21 @@ TEST(AmplifierInterceptor, Amplifier) {
node_d->SetRunPerSteps(micro_steps);
node_d->SetRunAtOffset(micro_steps - 1);
carrier.SetInterceptor(0, InterceptorFactory::Create("Amplifier", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier.SetInterceptor(3, InterceptorFactory::Create("Amplifier", 3, node_d));
carrier->SetInterceptor(0,
InterceptorFactory::Create("Amplifier", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(3,
InterceptorFactory::Create("Amplifier", 3, node_d));
// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg);
carrier.Wait();
carrier.Release();
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册