未验证 提交 ddc15a18 编写于 作者: L LiYuRio 提交者: GitHub

[fleet_executor] Move IntraSend to Carrier. Using blocking queue (#38322)

上级 142ea171
......@@ -5,7 +5,7 @@ endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog)
else()
set(BRPC_DEPS "")
endif()
......@@ -13,7 +13,7 @@ endif()
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
executor_gc_helper ${BRPC_DEPS})
executor_gc_helper gflags glog ${BRPC_DEPS})
if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......
......@@ -27,14 +27,16 @@ namespace distributed {
USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);
void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph,
void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope,
framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
rank_ = rank;
runtime_graph_ = runtime_graph;
interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank();
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
......@@ -48,12 +50,6 @@ void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
......@@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor
if (interceptor_message.ctrl_message()) {
// handle control message
return true;
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
} else {
{
std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_);
......@@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage(
}
int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id);
bool rst =
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
if (rst) {
std::condition_variable& interceptor_cond_var =
dst_interceptor->GetCondVar();
interceptor_cond_var.notify_all();
}
return rst;
}
return true;
}
Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
......@@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool Carrier::IsInit() const { return is_init_; }
// TODO(liyurui): Move SendIntra into carrier
bool Carrier::Send(const InterceptorMessage& msg) const {
return msg_bus_->Send(msg);
int64_t Carrier::GetRank(int64_t interceptor_id) const {
PADDLE_ENFORCE_NE(
interceptor_id_to_rank_.find(interceptor_id),
interceptor_id_to_rank_.end(),
platform::errors::NotFound("Cannot find rank for interceptor id %lld.",
interceptor_id));
return interceptor_id_to_rank_.at(interceptor_id);
}
bool Carrier::Send(const InterceptorMessage& msg) {
int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id();
int64_t dst_id = msg.dst_id();
int64_t src_rank = GetRank(src_id);
int64_t dst_rank = GetRank(dst_id);
PADDLE_ENFORCE_EQ(
src_rank, rank_,
platform::errors::Fatal("The source rank id %lld, which is not equal to "
"the carrier rank id %lld.",
src_rank, rank_));
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
platform::errors::Unavailable("Message bus is released accidently"));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return msg_bus_->Send(dst_rank, msg);
}
}
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
......@@ -222,13 +247,13 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}
void Carrier::CreateInterceptors() {
if (runtime_graph_->intercepter_id_to_node().empty()) return;
if (runtime_graph_->interceptor_id_to_node().empty()) return;
auto gc = GetGC(place_);
// create each Interceptor
// no auto init since there is no config
for (const auto& item : runtime_graph_->intercepter_id_to_node()) {
for (const auto& item : runtime_graph_->interceptor_id_to_node()) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
......
......@@ -45,8 +45,11 @@ class MessageBus;
class Carrier final {
public:
Carrier() = default;
Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {}
~Carrier();
void Init(std::shared_ptr<RuntimeGraph> runtime_graph,
void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
......@@ -75,7 +78,7 @@ class Carrier final {
bool IsInit() const;
bool Send(const InterceptorMessage& msg) const;
bool Send(const InterceptorMessage& msg);
// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
......@@ -90,6 +93,8 @@ class Carrier final {
void HandleTmpMessages();
int64_t GetRank(int64_t interceptor_id) const;
// interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
......@@ -111,6 +116,7 @@ class Carrier final {
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
};
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -28,6 +27,8 @@
namespace paddle {
namespace distributed {
std::unique_ptr<Carrier> FleetExecutor::carrier_;
FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
......@@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
GetCarrier().Release();
GetCarrier()->Release();
}
Carrier& FleetExecutor::GetCarrier() {
static Carrier carrier;
return carrier;
Carrier* FleetExecutor::GetCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_.get(), platform::errors::NotFound(
"Carrier has not been created."));
return carrier_.get();
}
void FleetExecutor::Init(
......@@ -84,16 +86,16 @@ void FleetExecutor::Init(
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
CreateCarrier();
InitCarrier();
InitMessageBus();
}
void FleetExecutor::InitCarrier() {
Carrier& carrier = GetCarrier();
if (!carrier.IsInit()) {
carrier.SetMsgBus(msg_bus_);
carrier.Init(runtime_graph_, root_scope_, minibatch_scope_,
microbatch_scopes_, place_);
if (!GetCarrier()->IsInit()) {
GetCarrier()->SetMsgBus(msg_bus_);
GetCarrier()->Init(exe_desc_.cur_rank(), runtime_graph_, root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
}
......@@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() {
<< (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << ".";
VLOG(5) << ss.str();
if (!msg_bus_->IsInit()) {
msg_bus_->Init(runtime_graph_->intercepter_id_to_rank(), rank_to_addr,
addr);
msg_bus_->Init(cur_rank, rank_to_addr, addr);
}
}
void FleetExecutor::Run() {
// Run
Carrier& carrier = GetCarrier();
PADDLE_ENFORCE_EQ(
carrier.IsInit(), true,
GetCarrier()->IsInit(), true,
platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
carrier.Start();
GetCarrier()->Start();
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
......
......@@ -16,6 +16,7 @@
#include <memory>
#include <string>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
......@@ -30,7 +31,6 @@ namespace distributed {
class RuntimeGraph;
class MessageBus;
class TaskNode;
class Carrier;
class FleetExecutor final {
public:
......@@ -43,7 +43,15 @@ class FleetExecutor final {
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
// TODO(liyurui): Change to use registry table for multi-carrier.
static Carrier& GetCarrier();
static Carrier* GetCarrier();
template <typename... Args>
static Carrier* CreateCarrier(Args&&... args) {
PADDLE_ENFORCE_EQ(
carrier_.get(), nullptr,
platform::errors::AlreadyExists("Carrier has been created already."));
carrier_ = std::make_unique<Carrier>(std::forward<Args>(args)...);
return carrier_.get();
}
private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
......@@ -59,6 +67,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
static std::unique_ptr<Carrier> carrier_;
};
} // namespace distributed
......
......@@ -52,24 +52,17 @@ void Interceptor::StopCarrier() {
cond_var.notify_all();
}
std::condition_variable& Interceptor::GetCondVar() {
// get the conditional var
return cond_var_;
}
int64_t Interceptor::GetInterceptorId() const {
// return the interceptor id
return interceptor_id_;
}
bool Interceptor::EnqueueRemoteInterceptorMessage(
void Interceptor::EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG(3) << "Enqueue message: " << interceptor_message.message_type()
<< " into " << interceptor_id_ << "'s remote mailbox.";
std::unique_lock<std::mutex> lock(remote_mailbox_mutex_);
remote_mailbox_.push(interceptor_message);
return true;
remote_mailbox_.Push(interceptor_message);
}
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
......@@ -92,7 +85,7 @@ void Interceptor::PoolTheMailbox() {
"Error encountered when fetch remote mailbox."));
}
const InterceptorMessage interceptor_message = local_mailbox_.front();
local_mailbox_.pop();
local_mailbox_.pop_front();
const MessageType message_type = interceptor_message.message_type();
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << interceptor_message.src_id()
......@@ -109,19 +102,8 @@ void Interceptor::PoolTheMailbox() {
}
bool Interceptor::FetchRemoteMailbox() {
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
std::unique_lock<std::mutex> lock(remote_mailbox_mutex_);
cond_var_.wait(lock, [this]() { return !remote_mailbox_.empty(); });
if (remote_mailbox_.empty()) {
// the thread has been unblocked accidentally
return false;
}
while (!remote_mailbox_.empty()) {
local_mailbox_.push(std::move(remote_mailbox_.front()));
remote_mailbox_.pop();
}
return true;
remote_mailbox_.PopAll(&local_mailbox_);
return !local_mailbox_.empty();
}
static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() {
......
......@@ -15,14 +15,15 @@
#pragma once
#include <condition_variable>
#include <deque>
#include <functional>
#include <map>
#include <memory>
#include <queue>
#include <thread>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
......@@ -59,11 +60,8 @@ class Interceptor {
// return the interceptor id
int64_t GetInterceptorId() const;
// return the conditional var
std::condition_variable& GetCondVar();
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
bool EnqueueRemoteInterceptorMessage(
void EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message);
bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT
......@@ -115,23 +113,16 @@ class Interceptor {
// interceptor handle which process message
MsgHandle handle_{nullptr};
// mutex to control read/write conflict for remote mailbox
std::mutex remote_mailbox_mutex_;
// interceptor runs PoolTheMailbox() function to poll local mailbox
std::thread interceptor_thread_;
// conditional variable for blocking the thread when
// fetch an empty remote mailbox
std::condition_variable cond_var_;
// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
std::queue<InterceptorMessage> remote_mailbox_;
framework::BlockingQueue<InterceptorMessage> remote_mailbox_;
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std::queue<InterceptorMessage> local_mailbox_;
std::deque<InterceptorMessage> local_mailbox_;
int64_t already_run_times_{0};
int64_t used_slot_nums_{0};
......
......@@ -29,8 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
FleetExecutor::GetCarrier().EnqueueInterceptorMessage(*request);
response->set_rst(true);
bool flag = FleetExecutor::GetCarrier()->EnqueueInterceptorMessage(*request);
response->set_rst(flag);
}
} // namespace distributed
......
......@@ -17,8 +17,6 @@
#include <set>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
......@@ -26,16 +24,25 @@ namespace paddle {
namespace distributed {
void MessageBus::Init(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
int64_t rank, const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"MessageBus is already init."));
rank_ = rank;
is_init_ = true;
interceptor_id_to_rank_ = interceptor_id_to_rank;
rank_to_addr_ = rank_to_addr;
addr_ = addr;
if (addr_ != "") {
const auto& addr = GetAddr(rank_);
PADDLE_ENFORCE_EQ(addr, addr_,
platform::errors::Fatal(
"The current rank's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error.",
addr, addr_));
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
// NOTE: To make the brpc is compatible with collective,
......@@ -65,26 +72,23 @@ MessageBus::~MessageBus() {
#endif
}
bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
// called by Interceptor, send InterceptorMessage to dst
int64_t src_id = interceptor_message.src_id();
int64_t dst_id = interceptor_message.dst_id();
if (IsSameRank(src_id, dst_id)) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return SendIntraRank(interceptor_message);
} else {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
const std::string& MessageBus::GetAddr(int64_t rank) const {
PADDLE_ENFORCE_NE(
rank_to_addr_.find(rank), rank_to_addr_.end(),
platform::errors::NotFound("Cannot find addr rank id %lld.", rank));
return rank_to_addr_.at(rank);
}
bool MessageBus::Send(int64_t dst_rank,
const InterceptorMessage& interceptor_message) {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
int retry_time = 0; // message bus will retry sending for 10 times
while (retry_time < 10) {
++retry_time;
if (SendInterRank(interceptor_message)) {
VLOG(3) << "Message bus sends inter rank successfully with "
<< retry_time << " times retries.";
if (SendInterRank(dst_rank, interceptor_message)) {
VLOG(3) << "Message bus sends inter rank successfully with " << retry_time
<< " times retries.";
return true;
}
VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
......@@ -98,10 +102,27 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
"ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now."));
#endif
}
return true;
}
void MessageBus::TestConnection() {
InterceptorMessage ctrl_msg;
ctrl_msg.set_ctrl_message(true);
ctrl_msg.set_src_id(rank_);
for (const auto& dst_rank_pair : rank_to_addr_) {
int64_t dst_rank = dst_rank_pair.first;
if (dst_rank != rank_) {
ctrl_msg.set_dst_id(dst_rank);
VLOG(3) << "Send control message bus from rank " << rank_ << " to rank "
<< dst_rank;
while (!Send(dst_rank, ctrl_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus has connected to rank: " << dst_rank << ".";
}
}
}
void MessageBus::ListenPort() {
if (addr_ == "") {
LOG(INFO) << "No need listen to port since training on single card.";
......@@ -130,30 +151,7 @@ void MessageBus::ListenPort() {
interval += 500;
}
LOG(INFO) << "Message bus's listen port thread starts successful.";
std::set<int64_t> visit;
InterceptorMessage tmp_msg;
tmp_msg.set_ctrl_message(true);
for (auto pair : interceptor_id_to_rank_) {
if (rank_to_addr_.at(pair.second) == addr_) {
tmp_msg.set_src_id(pair.first);
}
}
for (auto pair : interceptor_id_to_rank_) {
int64_t rank = pair.second;
if (rank_to_addr_.at(rank) == addr_) {
continue;
}
tmp_msg.set_dst_id(pair.first);
if (visit.find(rank) == visit.end()) {
VLOG(3) << "Message bus is testing connection for rank: " << rank << ".";
visit.insert(rank);
while (!Send(tmp_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus has connected to rank: " << rank << ".";
}
}
TestConnection();
#else
LOG(WARNING)
<< "Fleet executor's ListenPort() is a fake function when Paddle is "
......@@ -162,53 +160,13 @@ void MessageBus::ListenPort() {
#endif
}
bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) {
// -1 is sent by carrier to source interceptor
if (src_id == -1) src_id = dst_id;
// check whether the dst is the same rank or different rank with src
const auto& src_rank = interceptor_id_to_rank_.find(src_id);
const auto& dst_rank = interceptor_id_to_rank_.find(dst_id);
PADDLE_ENFORCE_NE(
src_rank, interceptor_id_to_rank_.end(),
platform::errors::NotFound(
"Cannot find rank for src interceptor id %lld. Init error.", src_id));
PADDLE_ENFORCE_NE(
dst_rank, interceptor_id_to_rank_.end(),
platform::errors::NotFound(
"Cannot find rank for dst interceptor id %lld. Init error.", dst_id));
if (addr_ == "") {
// single card training, must be same rank
return true;
}
const auto& src_ip = rank_to_addr_.find(src_rank->second);
PADDLE_ENFORCE_NE(src_ip, rank_to_addr_.end(),
platform::errors::NotFound(
"Cannot find addr for src rank id %lld. Init error.",
src_rank->second));
PADDLE_ENFORCE_EQ(
src_ip->second, addr_,
platform::errors::Fatal("The src interceptor's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error.",
src_ip->second, addr_));
return src_rank->second == dst_rank->second;
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
// send the message inter rank (dst is different rank with src)
int64_t dst_id = interceptor_message.dst_id();
int64_t dst_rank = interceptor_id_to_rank_[dst_id];
auto dst_ip = rank_to_addr_.find(dst_rank);
PADDLE_ENFORCE_NE(dst_ip, rank_to_addr_.end(),
platform::errors::InvalidArgument(
"Cannot find rank for dst interceptor id %lld. "
"Init error.",
dst_id));
VLOG(3) << "Message bus sending to addr: " << dst_ip->second;
const char* dst_ip_for_brpc = dst_ip->second.c_str();
bool MessageBus::SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message) {
const auto& dst_addr = GetAddr(dst_rank);
VLOG(3) << "Message bus sending to addr: " << dst_addr;
const char* dst_addr_for_brpc = dst_addr.c_str();
brpc::Channel channel;
brpc::ChannelOptions options;
options.protocol = "baidu_std";
......@@ -216,7 +174,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
options.timeout_ms = 1000;
options.max_retry = 5;
PADDLE_ENFORCE_EQ(
channel.Init(dst_ip_for_brpc, &options), 0,
channel.Init(dst_addr_for_brpc, &options), 0,
platform::errors::Unavailable("Message bus: init brpc channel error."));
TheInterceptorMessageService_Stub stub(&channel);
InterceptorResponse response;
......@@ -239,11 +197,5 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
}
#endif
bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) {
// send the message intra rank (dst is the same rank with src)
return FleetExecutor::GetCarrier().EnqueueInterceptorMessage(
interceptor_message);
}
} // namespace distributed
} // namespace paddle
......@@ -42,14 +42,14 @@ class MessageBus final {
MessageBus() = default;
~MessageBus();
void Init(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
void Init(int64_t rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr);
bool IsInit() const;
// called by Interceptor, send InterceptorMessage to dst
bool Send(const InterceptorMessage& interceptor_message);
bool Send(int64_t dst_rank, const InterceptorMessage& interceptor_message);
private:
DISABLE_COPY_AND_ASSIGN(MessageBus);
......@@ -57,22 +57,20 @@ class MessageBus final {
// function keep listen the port and handle the message
void ListenPort();
// check whether the dst is the same rank or different rank with src
bool IsSameRank(int64_t src_id, int64_t dst_id);
void TestConnection();
const std::string& GetAddr(int64_t rank) const;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
// send the message inter rank (dst is different rank with src)
bool SendInterRank(const InterceptorMessage& interceptor_message);
bool SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message);
#endif
bool is_init_{false};
// send the message intra rank (dst is the same rank with src)
bool SendIntraRank(const InterceptorMessage& interceptor_message);
// handed by above layer, save the info mapping interceptor id to rank id
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
int64_t rank_;
// handed by above layer, save the info mapping rank id to addr
std::unordered_map<int64_t, std::string> rank_to_addr_;
......
......@@ -21,7 +21,7 @@ namespace distributed {
std::string RuntimeGraph::DebugString() const {
std::ostringstream os;
os << "\nRuntime Graph Debug: \n";
for (const auto& pair : intercepter_id_to_node_) {
for (const auto& pair : interceptor_id_to_node_) {
os << pair.second->DebugString();
os << "\n";
}
......
......@@ -29,26 +29,26 @@ class RuntimeGraph final {
public:
RuntimeGraph() = default;
~RuntimeGraph() = default;
const std::unordered_map<int64_t, TaskNode*>& intercepter_id_to_node() const {
return intercepter_id_to_node_;
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node() const {
return interceptor_id_to_node_;
}
const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank() const {
return intercepter_id_to_rank_;
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const {
return interceptor_id_to_rank_;
}
void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank) {
intercepter_id_to_rank_ = intercepter_id_to_rank;
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
interceptor_id_to_rank_ = interceptor_id_to_rank;
}
void SetInterceptorIdToNode(
const std::unordered_map<int64_t, TaskNode*>& intercepter_id_to_node) {
intercepter_id_to_node_ = intercepter_id_to_node;
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
interceptor_id_to_node_ = interceptor_id_to_node;
}
std::string DebugString() const;
private:
DISABLE_COPY_AND_ASSIGN(RuntimeGraph);
std::unordered_map<int64_t, TaskNode*> intercepter_id_to_node_;
std::unordered_map<int64_t, int64_t> intercepter_id_to_rank_;
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
};
} // namespace distributed
......
......@@ -62,11 +62,10 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
Carrier carrier(0, {{0, 0}, {1, 0}});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "");
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus);
// FIXME: don't delete, otherwise interceptor will use undefined node
......
......@@ -47,11 +47,10 @@ class StartInterceptor : public Interceptor {
};
TEST(ComputeInterceptor, Compute) {
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "");
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus);
// NOTE: don't delete, otherwise interceptor will use undefined node
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
......@@ -60,11 +59,9 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
TEST(InterceptorTest, PingPong) {
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
Carrier carrier(0, {{0, 0}, {1, 0}});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "");
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus);
Interceptor* a = carrier.SetInterceptor(
......
......@@ -104,35 +104,42 @@ TEST(InterceptorTest, PingPong) {
std::string ip1 = "127.0.0.1:" + std::to_string(port1);
std::cout << "ip0: " << ip0 << std::endl;
std::cout << "ip1: " << ip1 << std::endl;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank = {{0, 0},
{1, 1}};
int exe_pid = fork();
if (exe_pid == 0) {
int pid = fork();
if (pid == 0) {
Carrier* carrier =
FleetExecutor::CreateCarrier(0, interceptor_id_to_rank);
carrier->SetCreatingFlag(false);
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip0);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
carrier.SetMsgBus(msg_bus);
Interceptor* a = carrier.SetInterceptor(
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
carrier->SetMsgBus(msg_bus);
Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr));
carrier.SetCreatingFlag(false);
InterceptorMessage msg;
a->Send(1, msg);
carrier.Wait();
carrier->Wait();
} else {
Carrier* carrier =
FleetExecutor::CreateCarrier(1, interceptor_id_to_rank);
carrier->SetCreatingFlag(false);
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip1);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
carrier.SetMsgBus(msg_bus);
carrier.SetInterceptor(1,
InterceptorFactory::Create("PingPong", 1, nullptr));
carrier.SetCreatingFlag(false);
carrier.Wait();
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetMsgBus(msg_bus);
carrier->SetInterceptor(
1, InterceptorFactory::Create("PingPong", 1, nullptr));
carrier->Wait();
int status;
int ret = waitpid(pid, &status, 0);
CHECK_EQ(ret, pid);
}
} else {
int status;
int ret = waitpid(exe_pid, &status, 0);
CHECK_EQ(ret, exe_pid);
}
}
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -52,11 +51,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
TEST(AmplifierInterceptor, Amplifier) {
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}},
{{0, "127.0.0.0:0"}}, "127.0.0.0:0");
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
carrier.SetMsgBus(msg_bus);
int64_t micro_steps = 3;
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -70,10 +69,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
TEST(AmplifierInterceptor, Amplifier) {
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();
Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, ""}}, "");
msg_bus->Init(0, {{0, ""}}, "");
carrier.SetMsgBus(msg_bus);
int64_t micro_steps = 6;
......
......@@ -75,6 +75,12 @@ class BlockingQueue {
return ret;
}
void PopAll(std::deque<T> *empty_queue) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return !q_.empty(); });
std::swap(*empty_queue, q_);
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册