未验证 提交 ddc15a18 编写于 作者: L LiYuRio 提交者: GitHub

[fleet_executor] Move IntraSend to Carrier. Using blocking queue (#38322)

上级 142ea171
...@@ -5,7 +5,7 @@ endif() ...@@ -5,7 +5,7 @@ endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto) proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog) set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog)
else() else()
set(BRPC_DEPS "") set(BRPC_DEPS "")
endif() endif()
...@@ -13,7 +13,7 @@ endif() ...@@ -13,7 +13,7 @@ endif()
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
executor_gc_helper ${BRPC_DEPS}) executor_gc_helper gflags glog ${BRPC_DEPS})
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......
...@@ -27,14 +27,16 @@ namespace distributed { ...@@ -27,14 +27,16 @@ namespace distributed {
USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier); USE_INTERCEPTOR(Amplifier);
void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph, void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* root_scope,
framework::Scope* minibatch_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes, const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) { const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init.")); "Carrier is already init."));
rank_ = rank;
runtime_graph_ = runtime_graph; runtime_graph_ = runtime_graph;
interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank();
minibatch_scope_ = minibatch_scope; minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes; microbatch_scopes_ = microbatch_scopes;
place_ = place; place_ = place;
...@@ -48,12 +50,6 @@ void Carrier::Release() { ...@@ -48,12 +50,6 @@ void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct, // NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete. // otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
for (int64_t id : source_interceptor_ids_) { for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< "."; << ".";
...@@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; } ...@@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
bool Carrier::EnqueueInterceptorMessage( bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) { const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor
if (interceptor_message.ctrl_message()) { if (interceptor_message.ctrl_message()) {
// handle control message VLOG(3) << "Receiving control message from rank "
return true; << interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
} else { } else {
{ {
std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_); std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_);
...@@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage( ...@@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage(
} }
int64_t dst_id = interceptor_message.dst_id(); int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id); Interceptor* dst_interceptor = GetInterceptor(dst_id);
bool rst = dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
if (rst) {
std::condition_variable& interceptor_cond_var =
dst_interceptor->GetCondVar();
interceptor_cond_var.notify_all();
}
return rst;
} }
return true;
} }
Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
...@@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; } ...@@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool Carrier::IsInit() const { return is_init_; } bool Carrier::IsInit() const { return is_init_; }
// TODO(liyurui): Move SendIntra into carrier int64_t Carrier::GetRank(int64_t interceptor_id) const {
bool Carrier::Send(const InterceptorMessage& msg) const { PADDLE_ENFORCE_NE(
return msg_bus_->Send(msg); interceptor_id_to_rank_.find(interceptor_id),
interceptor_id_to_rank_.end(),
platform::errors::NotFound("Cannot find rank for interceptor id %lld.",
interceptor_id));
return interceptor_id_to_rank_.at(interceptor_id);
}
bool Carrier::Send(const InterceptorMessage& msg) {
int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id();
int64_t dst_id = msg.dst_id();
int64_t src_rank = GetRank(src_id);
int64_t dst_rank = GetRank(dst_id);
PADDLE_ENFORCE_EQ(
src_rank, rank_,
platform::errors::Fatal("The source rank id %lld, which is not equal to "
"the carrier rank id %lld.",
src_rank, rank_));
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
platform::errors::Unavailable("Message bus is released accidently"));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return msg_bus_->Send(dst_rank, msg);
}
} }
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
...@@ -222,13 +247,13 @@ static std::shared_ptr<framework::GarbageCollector> GetGC( ...@@ -222,13 +247,13 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
} }
void Carrier::CreateInterceptors() { void Carrier::CreateInterceptors() {
if (runtime_graph_->intercepter_id_to_node().empty()) return; if (runtime_graph_->interceptor_id_to_node().empty()) return;
auto gc = GetGC(place_); auto gc = GetGC(place_);
// create each Interceptor // create each Interceptor
// no auto init since there is no config // no auto init since there is no config
for (const auto& item : runtime_graph_->intercepter_id_to_node()) { for (const auto& item : runtime_graph_->interceptor_id_to_node()) {
int64_t interceptor_id = item.first; int64_t interceptor_id = item.first;
TaskNode* task_node = item.second; TaskNode* task_node = item.second;
......
...@@ -45,8 +45,11 @@ class MessageBus; ...@@ -45,8 +45,11 @@ class MessageBus;
class Carrier final { class Carrier final {
public: public:
Carrier() = default; Carrier() = default;
Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {}
~Carrier(); ~Carrier();
void Init(std::shared_ptr<RuntimeGraph> runtime_graph, void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope, framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes, const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place); const platform::Place& place);
...@@ -75,7 +78,7 @@ class Carrier final { ...@@ -75,7 +78,7 @@ class Carrier final {
bool IsInit() const; bool IsInit() const;
bool Send(const InterceptorMessage& msg) const; bool Send(const InterceptorMessage& msg);
// NOTE: This mutex will be used in interceptor's RunOps function. // NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run // This mutex is used for avoiding forward ops and backward ops run
...@@ -90,6 +93,8 @@ class Carrier final { ...@@ -90,6 +93,8 @@ class Carrier final {
void HandleTmpMessages(); void HandleTmpMessages();
int64_t GetRank(int64_t interceptor_id) const;
// interceptor logic id to actually interceptor // interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>> std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_; interceptor_idx_to_interceptor_;
...@@ -111,6 +116,7 @@ class Carrier final { ...@@ -111,6 +116,7 @@ class Carrier final {
paddle::platform::DeviceContext* dev_ctx_{nullptr}; paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_; std::shared_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<MessageBus> msg_bus_; std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_; std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
}; };
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -28,6 +27,8 @@ ...@@ -28,6 +27,8 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
std::unique_ptr<Carrier> FleetExecutor::carrier_;
FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str); bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet( PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
...@@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { ...@@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor::~FleetExecutor() { FleetExecutor::~FleetExecutor() {
root_scope_->DropKids(); root_scope_->DropKids();
GetCarrier().Release(); GetCarrier()->Release();
} }
Carrier& FleetExecutor::GetCarrier() { Carrier* FleetExecutor::GetCarrier() {
static Carrier carrier; PADDLE_ENFORCE_NOT_NULL(carrier_.get(), platform::errors::NotFound(
return carrier; "Carrier has not been created."));
return carrier_.get();
} }
void FleetExecutor::Init( void FleetExecutor::Init(
...@@ -84,16 +86,16 @@ void FleetExecutor::Init( ...@@ -84,16 +86,16 @@ void FleetExecutor::Init(
} }
VLOG(5) << runtime_graph_->DebugString(); VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>(); msg_bus_ = std::make_shared<MessageBus>();
CreateCarrier();
InitCarrier(); InitCarrier();
InitMessageBus(); InitMessageBus();
} }
void FleetExecutor::InitCarrier() { void FleetExecutor::InitCarrier() {
Carrier& carrier = GetCarrier(); if (!GetCarrier()->IsInit()) {
if (!carrier.IsInit()) { GetCarrier()->SetMsgBus(msg_bus_);
carrier.SetMsgBus(msg_bus_); GetCarrier()->Init(exe_desc_.cur_rank(), runtime_graph_, root_scope_,
carrier.Init(runtime_graph_, root_scope_, minibatch_scope_, minibatch_scope_, microbatch_scopes_, place_);
microbatch_scopes_, place_);
} }
} }
...@@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() { ...@@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() {
<< (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << "."; << (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << ".";
VLOG(5) << ss.str(); VLOG(5) << ss.str();
if (!msg_bus_->IsInit()) { if (!msg_bus_->IsInit()) {
msg_bus_->Init(runtime_graph_->intercepter_id_to_rank(), rank_to_addr, msg_bus_->Init(cur_rank, rank_to_addr, addr);
addr);
} }
} }
void FleetExecutor::Run() { void FleetExecutor::Run() {
// Run // Run
Carrier& carrier = GetCarrier();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
carrier.IsInit(), true, GetCarrier()->IsInit(), true,
platform::errors::Unavailable("Carrier has not been init yet.")); platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true, msg_bus_->IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet.")); platform::errors::Unavailable("MessageBus has not been init yet."));
carrier.Start(); GetCarrier()->Start();
for (auto* micro_scop : microbatch_scopes_) { for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because // By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op. // some operators may create local scope when running, such as while_op.
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -30,7 +31,6 @@ namespace distributed { ...@@ -30,7 +31,6 @@ namespace distributed {
class RuntimeGraph; class RuntimeGraph;
class MessageBus; class MessageBus;
class TaskNode; class TaskNode;
class Carrier;
class FleetExecutor final { class FleetExecutor final {
public: public:
...@@ -43,7 +43,15 @@ class FleetExecutor final { ...@@ -43,7 +43,15 @@ class FleetExecutor final {
const std::unordered_map<int64_t, int64_t>& task_id_to_rank); const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run(); void Run();
// TODO(liyurui): Change to use registry table for multi-carrier. // TODO(liyurui): Change to use registry table for multi-carrier.
static Carrier& GetCarrier(); static Carrier* GetCarrier();
template <typename... Args>
static Carrier* CreateCarrier(Args&&... args) {
PADDLE_ENFORCE_EQ(
carrier_.get(), nullptr,
platform::errors::AlreadyExists("Carrier has been created already."));
carrier_ = std::make_unique<Carrier>(std::forward<Args>(args)...);
return carrier_.get();
}
private: private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor); DISABLE_COPY_AND_ASSIGN(FleetExecutor);
...@@ -59,6 +67,7 @@ class FleetExecutor final { ...@@ -59,6 +67,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus, // The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race. // using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_; std::shared_ptr<MessageBus> msg_bus_;
static std::unique_ptr<Carrier> carrier_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -52,24 +52,17 @@ void Interceptor::StopCarrier() { ...@@ -52,24 +52,17 @@ void Interceptor::StopCarrier() {
cond_var.notify_all(); cond_var.notify_all();
} }
std::condition_variable& Interceptor::GetCondVar() {
// get the conditional var
return cond_var_;
}
int64_t Interceptor::GetInterceptorId() const { int64_t Interceptor::GetInterceptorId() const {
// return the interceptor id // return the interceptor id
return interceptor_id_; return interceptor_id_;
} }
bool Interceptor::EnqueueRemoteInterceptorMessage( void Interceptor::EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message) { const InterceptorMessage& interceptor_message) {
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox // Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG(3) << "Enqueue message: " << interceptor_message.message_type() VLOG(3) << "Enqueue message: " << interceptor_message.message_type()
<< " into " << interceptor_id_ << "'s remote mailbox."; << " into " << interceptor_id_ << "'s remote mailbox.";
std::unique_lock<std::mutex> lock(remote_mailbox_mutex_); remote_mailbox_.Push(interceptor_message);
remote_mailbox_.push(interceptor_message);
return true;
} }
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
...@@ -92,7 +85,7 @@ void Interceptor::PoolTheMailbox() { ...@@ -92,7 +85,7 @@ void Interceptor::PoolTheMailbox() {
"Error encountered when fetch remote mailbox.")); "Error encountered when fetch remote mailbox."));
} }
const InterceptorMessage interceptor_message = local_mailbox_.front(); const InterceptorMessage interceptor_message = local_mailbox_.front();
local_mailbox_.pop(); local_mailbox_.pop_front();
const MessageType message_type = interceptor_message.message_type(); const MessageType message_type = interceptor_message.message_type();
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message" VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << interceptor_message.src_id() << " from interceptor " << interceptor_message.src_id()
...@@ -109,19 +102,8 @@ void Interceptor::PoolTheMailbox() { ...@@ -109,19 +102,8 @@ void Interceptor::PoolTheMailbox() {
} }
bool Interceptor::FetchRemoteMailbox() { bool Interceptor::FetchRemoteMailbox() {
// fetch all Message from remote mailbox to local mailbox remote_mailbox_.PopAll(&local_mailbox_);
// return true if remote mailbox not empty, otherwise return false return !local_mailbox_.empty();
std::unique_lock<std::mutex> lock(remote_mailbox_mutex_);
cond_var_.wait(lock, [this]() { return !remote_mailbox_.empty(); });
if (remote_mailbox_.empty()) {
// the thread has been unblocked accidentally
return false;
}
while (!remote_mailbox_.empty()) {
local_mailbox_.push(std::move(remote_mailbox_.front()));
remote_mailbox_.pop();
}
return true;
} }
static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() { static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() {
......
...@@ -15,14 +15,15 @@ ...@@ -15,14 +15,15 @@
#pragma once #pragma once
#include <condition_variable> #include <condition_variable>
#include <deque>
#include <functional> #include <functional>
#include <map> #include <map>
#include <memory> #include <memory>
#include <queue>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -59,11 +60,8 @@ class Interceptor { ...@@ -59,11 +60,8 @@ class Interceptor {
// return the interceptor id // return the interceptor id
int64_t GetInterceptorId() const; int64_t GetInterceptorId() const;
// return the conditional var
std::condition_variable& GetCondVar();
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox // Called by Carrier, enqueue an InterceptorMessage to remote mailbox
bool EnqueueRemoteInterceptorMessage( void EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message); const InterceptorMessage& interceptor_message);
bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT
...@@ -115,23 +113,16 @@ class Interceptor { ...@@ -115,23 +113,16 @@ class Interceptor {
// interceptor handle which process message // interceptor handle which process message
MsgHandle handle_{nullptr}; MsgHandle handle_{nullptr};
// mutex to control read/write conflict for remote mailbox
std::mutex remote_mailbox_mutex_;
// interceptor runs PoolTheMailbox() function to poll local mailbox // interceptor runs PoolTheMailbox() function to poll local mailbox
std::thread interceptor_thread_; std::thread interceptor_thread_;
// conditional variable for blocking the thread when
// fetch an empty remote mailbox
std::condition_variable cond_var_;
// remote mailbox, written by EnqueueRemoteMessage() // remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox() // read by FetchRemoteMailbox()
std::queue<InterceptorMessage> remote_mailbox_; framework::BlockingQueue<InterceptorMessage> remote_mailbox_;
// local mailbox, written by FetchRemoteMailbox() // local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox() // read by PoolTheMailbox()
std::queue<InterceptorMessage> local_mailbox_; std::deque<InterceptorMessage> local_mailbox_;
int64_t already_run_times_{0}; int64_t already_run_times_{0};
int64_t used_slot_nums_{0}; int64_t used_slot_nums_{0};
......
...@@ -29,8 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( ...@@ -29,8 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor " VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id() << request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type(); << ", with the message: " << request->message_type();
FleetExecutor::GetCarrier().EnqueueInterceptorMessage(*request); bool flag = FleetExecutor::GetCarrier()->EnqueueInterceptorMessage(*request);
response->set_rst(true); response->set_rst(flag);
} }
} // namespace distributed } // namespace distributed
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
#include <set> #include <set>
#include <thread> #include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/gen_comm_id_helper.h"
...@@ -26,16 +24,25 @@ namespace paddle { ...@@ -26,16 +24,25 @@ namespace paddle {
namespace distributed { namespace distributed {
void MessageBus::Init( void MessageBus::Init(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, int64_t rank, const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr) { const std::string& addr) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"MessageBus is already init.")); "MessageBus is already init."));
rank_ = rank;
is_init_ = true; is_init_ = true;
interceptor_id_to_rank_ = interceptor_id_to_rank;
rank_to_addr_ = rank_to_addr; rank_to_addr_ = rank_to_addr;
addr_ = addr; addr_ = addr;
if (addr_ != "") {
const auto& addr = GetAddr(rank_);
PADDLE_ENFORCE_EQ(addr, addr_,
platform::errors::Fatal(
"The current rank's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error.",
addr, addr_));
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
// NOTE: To make the brpc is compatible with collective, // NOTE: To make the brpc is compatible with collective,
...@@ -65,43 +72,57 @@ MessageBus::~MessageBus() { ...@@ -65,43 +72,57 @@ MessageBus::~MessageBus() {
#endif #endif
} }
bool MessageBus::Send(const InterceptorMessage& interceptor_message) { const std::string& MessageBus::GetAddr(int64_t rank) const {
// called by Interceptor, send InterceptorMessage to dst PADDLE_ENFORCE_NE(
int64_t src_id = interceptor_message.src_id(); rank_to_addr_.find(rank), rank_to_addr_.end(),
int64_t dst_id = interceptor_message.dst_id(); platform::errors::NotFound("Cannot find addr rank id %lld.", rank));
if (IsSameRank(src_id, dst_id)) { return rank_to_addr_.at(rank);
VLOG(3) << "Send a message from interceptor " << src_id }
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return SendIntraRank(interceptor_message); bool MessageBus::Send(int64_t dst_rank,
} else { const InterceptorMessage& interceptor_message) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
int retry_time = 0; // message bus will retry sending for 10 times int retry_time = 0; // message bus will retry sending for 10 times
while (retry_time < 10) { while (retry_time < 10) {
++retry_time; ++retry_time;
if (SendInterRank(interceptor_message)) { if (SendInterRank(dst_rank, interceptor_message)) {
VLOG(3) << "Message bus sends inter rank successfully with " VLOG(3) << "Message bus sends inter rank successfully with " << retry_time
<< retry_time << " times retries."; << " times retries.";
return true; return true;
}
VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
} }
VLOG(3) << "Message bus sends inter rank fail after 10 times retries."; VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
return false; std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
return false;
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"Fleet executor does not support sending message between different " "Fleet executor does not support sending message between different "
"ranks when Paddle is compiled with npu or " "ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now.")); "isn't compiled with distributed for now."));
#endif #endif
}
return true; return true;
} }
void MessageBus::TestConnection() {
InterceptorMessage ctrl_msg;
ctrl_msg.set_ctrl_message(true);
ctrl_msg.set_src_id(rank_);
for (const auto& dst_rank_pair : rank_to_addr_) {
int64_t dst_rank = dst_rank_pair.first;
if (dst_rank != rank_) {
ctrl_msg.set_dst_id(dst_rank);
VLOG(3) << "Send control message bus from rank " << rank_ << " to rank "
<< dst_rank;
while (!Send(dst_rank, ctrl_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus has connected to rank: " << dst_rank << ".";
}
}
}
void MessageBus::ListenPort() { void MessageBus::ListenPort() {
if (addr_ == "") { if (addr_ == "") {
LOG(INFO) << "No need listen to port since training on single card."; LOG(INFO) << "No need listen to port since training on single card.";
...@@ -130,30 +151,7 @@ void MessageBus::ListenPort() { ...@@ -130,30 +151,7 @@ void MessageBus::ListenPort() {
interval += 500; interval += 500;
} }
LOG(INFO) << "Message bus's listen port thread starts successful."; LOG(INFO) << "Message bus's listen port thread starts successful.";
TestConnection();
std::set<int64_t> visit;
InterceptorMessage tmp_msg;
tmp_msg.set_ctrl_message(true);
for (auto pair : interceptor_id_to_rank_) {
if (rank_to_addr_.at(pair.second) == addr_) {
tmp_msg.set_src_id(pair.first);
}
}
for (auto pair : interceptor_id_to_rank_) {
int64_t rank = pair.second;
if (rank_to_addr_.at(rank) == addr_) {
continue;
}
tmp_msg.set_dst_id(pair.first);
if (visit.find(rank) == visit.end()) {
VLOG(3) << "Message bus is testing connection for rank: " << rank << ".";
visit.insert(rank);
while (!Send(tmp_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus has connected to rank: " << rank << ".";
}
}
#else #else
LOG(WARNING) LOG(WARNING)
<< "Fleet executor's ListenPort() is a fake function when Paddle is " << "Fleet executor's ListenPort() is a fake function when Paddle is "
...@@ -162,53 +160,13 @@ void MessageBus::ListenPort() { ...@@ -162,53 +160,13 @@ void MessageBus::ListenPort() {
#endif #endif
} }
bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) {
// -1 is sent by carrier to source interceptor
if (src_id == -1) src_id = dst_id;
// check whether the dst is the same rank or different rank with src
const auto& src_rank = interceptor_id_to_rank_.find(src_id);
const auto& dst_rank = interceptor_id_to_rank_.find(dst_id);
PADDLE_ENFORCE_NE(
src_rank, interceptor_id_to_rank_.end(),
platform::errors::NotFound(
"Cannot find rank for src interceptor id %lld. Init error.", src_id));
PADDLE_ENFORCE_NE(
dst_rank, interceptor_id_to_rank_.end(),
platform::errors::NotFound(
"Cannot find rank for dst interceptor id %lld. Init error.", dst_id));
if (addr_ == "") {
// single card training, must be same rank
return true;
}
const auto& src_ip = rank_to_addr_.find(src_rank->second);
PADDLE_ENFORCE_NE(src_ip, rank_to_addr_.end(),
platform::errors::NotFound(
"Cannot find addr for src rank id %lld. Init error.",
src_rank->second));
PADDLE_ENFORCE_EQ(
src_ip->second, addr_,
platform::errors::Fatal("The src interceptor's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error.",
src_ip->second, addr_));
return src_rank->second == dst_rank->second;
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { bool MessageBus::SendInterRank(int64_t dst_rank,
// send the message inter rank (dst is different rank with src) const InterceptorMessage& interceptor_message) {
int64_t dst_id = interceptor_message.dst_id(); const auto& dst_addr = GetAddr(dst_rank);
int64_t dst_rank = interceptor_id_to_rank_[dst_id]; VLOG(3) << "Message bus sending to addr: " << dst_addr;
auto dst_ip = rank_to_addr_.find(dst_rank); const char* dst_addr_for_brpc = dst_addr.c_str();
PADDLE_ENFORCE_NE(dst_ip, rank_to_addr_.end(),
platform::errors::InvalidArgument(
"Cannot find rank for dst interceptor id %lld. "
"Init error.",
dst_id));
VLOG(3) << "Message bus sending to addr: " << dst_ip->second;
const char* dst_ip_for_brpc = dst_ip->second.c_str();
brpc::Channel channel; brpc::Channel channel;
brpc::ChannelOptions options; brpc::ChannelOptions options;
options.protocol = "baidu_std"; options.protocol = "baidu_std";
...@@ -216,7 +174,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { ...@@ -216,7 +174,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
options.timeout_ms = 1000; options.timeout_ms = 1000;
options.max_retry = 5; options.max_retry = 5;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
channel.Init(dst_ip_for_brpc, &options), 0, channel.Init(dst_addr_for_brpc, &options), 0,
platform::errors::Unavailable("Message bus: init brpc channel error.")); platform::errors::Unavailable("Message bus: init brpc channel error."));
TheInterceptorMessageService_Stub stub(&channel); TheInterceptorMessageService_Stub stub(&channel);
InterceptorResponse response; InterceptorResponse response;
...@@ -239,11 +197,5 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { ...@@ -239,11 +197,5 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
} }
#endif #endif
bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) {
// send the message intra rank (dst is the same rank with src)
return FleetExecutor::GetCarrier().EnqueueInterceptorMessage(
interceptor_message);
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -42,14 +42,14 @@ class MessageBus final { ...@@ -42,14 +42,14 @@ class MessageBus final {
MessageBus() = default; MessageBus() = default;
~MessageBus(); ~MessageBus();
void Init(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, void Init(int64_t rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr, const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr); const std::string& addr);
bool IsInit() const; bool IsInit() const;
// called by Interceptor, send InterceptorMessage to dst // called by Interceptor, send InterceptorMessage to dst
bool Send(const InterceptorMessage& interceptor_message); bool Send(int64_t dst_rank, const InterceptorMessage& interceptor_message);
private: private:
DISABLE_COPY_AND_ASSIGN(MessageBus); DISABLE_COPY_AND_ASSIGN(MessageBus);
...@@ -57,22 +57,20 @@ class MessageBus final { ...@@ -57,22 +57,20 @@ class MessageBus final {
// function keep listen the port and handle the message // function keep listen the port and handle the message
void ListenPort(); void ListenPort();
// check whether the dst is the same rank or different rank with src void TestConnection();
bool IsSameRank(int64_t src_id, int64_t dst_id);
const std::string& GetAddr(int64_t rank) const;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
// send the message inter rank (dst is different rank with src) // send the message inter rank (dst is different rank with src)
bool SendInterRank(const InterceptorMessage& interceptor_message); bool SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message);
#endif #endif
bool is_init_{false}; bool is_init_{false};
// send the message intra rank (dst is the same rank with src) int64_t rank_;
bool SendIntraRank(const InterceptorMessage& interceptor_message);
// handed by above layer, save the info mapping interceptor id to rank id
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
// handed by above layer, save the info mapping rank id to addr // handed by above layer, save the info mapping rank id to addr
std::unordered_map<int64_t, std::string> rank_to_addr_; std::unordered_map<int64_t, std::string> rank_to_addr_;
......
...@@ -21,7 +21,7 @@ namespace distributed { ...@@ -21,7 +21,7 @@ namespace distributed {
std::string RuntimeGraph::DebugString() const { std::string RuntimeGraph::DebugString() const {
std::ostringstream os; std::ostringstream os;
os << "\nRuntime Graph Debug: \n"; os << "\nRuntime Graph Debug: \n";
for (const auto& pair : intercepter_id_to_node_) { for (const auto& pair : interceptor_id_to_node_) {
os << pair.second->DebugString(); os << pair.second->DebugString();
os << "\n"; os << "\n";
} }
......
...@@ -29,26 +29,26 @@ class RuntimeGraph final { ...@@ -29,26 +29,26 @@ class RuntimeGraph final {
public: public:
RuntimeGraph() = default; RuntimeGraph() = default;
~RuntimeGraph() = default; ~RuntimeGraph() = default;
const std::unordered_map<int64_t, TaskNode*>& intercepter_id_to_node() const { const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node() const {
return intercepter_id_to_node_; return interceptor_id_to_node_;
} }
const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank() const { const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const {
return intercepter_id_to_rank_; return interceptor_id_to_rank_;
} }
void SetInterceptorIdToRank( void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank) { const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
intercepter_id_to_rank_ = intercepter_id_to_rank; interceptor_id_to_rank_ = interceptor_id_to_rank;
} }
void SetInterceptorIdToNode( void SetInterceptorIdToNode(
const std::unordered_map<int64_t, TaskNode*>& intercepter_id_to_node) { const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
intercepter_id_to_node_ = intercepter_id_to_node; interceptor_id_to_node_ = interceptor_id_to_node;
} }
std::string DebugString() const; std::string DebugString() const;
private: private:
DISABLE_COPY_AND_ASSIGN(RuntimeGraph); DISABLE_COPY_AND_ASSIGN(RuntimeGraph);
std::unordered_map<int64_t, TaskNode*> intercepter_id_to_node_; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> intercepter_id_to_rank_; std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -62,11 +62,10 @@ TEST(ComputeInterceptor, Compute) { ...@@ -62,11 +62,10 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope}; std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace(); platform::Place place = platform::CPUPlace();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier Carrier carrier(0, {{0, 0}, {1, 0}});
Carrier& carrier = FleetExecutor::GetCarrier();
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus); carrier.SetMsgBus(msg_bus);
// FIXME: don't delete, otherwise interceptor will use undefined node // FIXME: don't delete, otherwise interceptor will use undefined node
......
...@@ -47,11 +47,10 @@ class StartInterceptor : public Interceptor { ...@@ -47,11 +47,10 @@ class StartInterceptor : public Interceptor {
}; };
TEST(ComputeInterceptor, Compute) { TEST(ComputeInterceptor, Compute) {
// TODO(liyurui): Remove singleton when move SendIntra into Carrier Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}});
Carrier& carrier = FleetExecutor::GetCarrier();
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus); carrier.SetMsgBus(msg_bus);
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -60,11 +59,9 @@ class PingPongInterceptor : public Interceptor { ...@@ -60,11 +59,9 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor); REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
TEST(InterceptorTest, PingPong) { TEST(InterceptorTest, PingPong) {
// TODO(liyurui): Remove singleton when move SendIntra into Carrier Carrier carrier(0, {{0, 0}, {1, 0}});
Carrier& carrier = FleetExecutor::GetCarrier();
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus); carrier.SetMsgBus(msg_bus);
Interceptor* a = carrier.SetInterceptor( Interceptor* a = carrier.SetInterceptor(
......
...@@ -104,35 +104,42 @@ TEST(InterceptorTest, PingPong) { ...@@ -104,35 +104,42 @@ TEST(InterceptorTest, PingPong) {
std::string ip1 = "127.0.0.1:" + std::to_string(port1); std::string ip1 = "127.0.0.1:" + std::to_string(port1);
std::cout << "ip0: " << ip0 << std::endl; std::cout << "ip0: " << ip0 << std::endl;
std::cout << "ip1: " << ip1 << std::endl; std::cout << "ip1: " << ip1 << std::endl;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank = {{0, 0},
int pid = fork(); {1, 1}};
if (pid == 0) {
auto msg_bus = std::make_shared<MessageBus>(); int exe_pid = fork();
msg_bus->Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip0); if (exe_pid == 0) {
int pid = fork();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier if (pid == 0) {
Carrier& carrier = FleetExecutor::GetCarrier(); Carrier* carrier =
carrier.SetMsgBus(msg_bus); FleetExecutor::CreateCarrier(0, interceptor_id_to_rank);
carrier->SetCreatingFlag(false);
Interceptor* a = carrier.SetInterceptor( auto msg_bus = std::make_shared<MessageBus>();
0, InterceptorFactory::Create("PingPong", 0, nullptr)); msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
carrier.SetCreatingFlag(false); carrier->SetMsgBus(msg_bus);
Interceptor* a = carrier->SetInterceptor(
InterceptorMessage msg; 0, InterceptorFactory::Create("PingPong", 0, nullptr));
a->Send(1, msg); InterceptorMessage msg;
carrier.Wait(); a->Send(1, msg);
carrier->Wait();
} else {
Carrier* carrier =
FleetExecutor::CreateCarrier(1, interceptor_id_to_rank);
carrier->SetCreatingFlag(false);
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetMsgBus(msg_bus);
carrier->SetInterceptor(
1, InterceptorFactory::Create("PingPong", 1, nullptr));
carrier->Wait();
int status;
int ret = waitpid(pid, &status, 0);
CHECK_EQ(ret, pid);
}
} else { } else {
auto msg_bus = std::make_shared<MessageBus>(); int status;
msg_bus->Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip1); int ret = waitpid(exe_pid, &status, 0);
CHECK_EQ(ret, exe_pid);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
carrier.SetMsgBus(msg_bus);
carrier.SetInterceptor(1,
InterceptorFactory::Create("PingPong", 1, nullptr));
carrier.SetCreatingFlag(false);
carrier.Wait();
} }
} }
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -52,11 +51,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) { ...@@ -52,11 +51,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
} }
TEST(AmplifierInterceptor, Amplifier) { TEST(AmplifierInterceptor, Amplifier) {
// TODO(liyurui): Remove singleton when move SendIntra into Carrier Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}});
Carrier& carrier = FleetExecutor::GetCarrier();
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}, msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
{{0, "127.0.0.0:0"}}, "127.0.0.0:0");
carrier.SetMsgBus(msg_bus); carrier.SetMsgBus(msg_bus);
int64_t micro_steps = 3; int64_t micro_steps = 3;
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -70,10 +69,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes, ...@@ -70,10 +69,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
} }
TEST(AmplifierInterceptor, Amplifier) { TEST(AmplifierInterceptor, Amplifier) {
// TODO(liyurui): Remove singleton when move SendIntra into Carrier Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}});
Carrier& carrier = FleetExecutor::GetCarrier();
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, ""}}, ""); msg_bus->Init(0, {{0, ""}}, "");
carrier.SetMsgBus(msg_bus); carrier.SetMsgBus(msg_bus);
int64_t micro_steps = 6; int64_t micro_steps = 6;
......
...@@ -75,6 +75,12 @@ class BlockingQueue { ...@@ -75,6 +75,12 @@ class BlockingQueue {
return ret; return ret;
} }
void PopAll(std::deque<T> *empty_queue) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return !q_.empty(); });
std::swap(*empty_queue, q_);
}
T Pop() { T Pop() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); }); cv_.wait(lock, [=] { return !q_.empty(); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册