未验证 提交 f85bd5c9 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Parse runtime graph to start carrier (#37282)

上级 38141036
......@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
......@@ -22,8 +23,11 @@ namespace distributed {
void Carrier::Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
interceptor_id_to_node_ = interceptor_id_to_node;
CreateInterceptors();
is_init_ = true;
}
bool Carrier::EnqueueInterceptorMessage(
......@@ -63,6 +67,24 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
return iter->second.get();
}
void Carrier::Start() {
// TODO(fleet_executor dev): this start is a faked one, need replace
for (const auto& pair : interceptor_idx_to_interceptor_) {
VLOG(3) << "Fake run is sending start to interceptor " << pair.first << ".";
InterceptorMessage tmp_msg;
tmp_msg.set_src_id(pair.first);
tmp_msg.set_dst_id(pair.first);
tmp_msg.set_message_type(DATA_IS_READY);
MessageBus& message_bus_instance = MessageBus::Instance();
PADDLE_ENFORCE_EQ(message_bus_instance.IsInit(), true,
platform::errors::PreconditionNotMet(
"Message bus has not been initialized."));
message_bus_instance.Send(tmp_msg);
}
}
bool Carrier::IsInit() const { return is_init_; }
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor> interceptor) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
......
......@@ -56,6 +56,10 @@ class Carrier final {
void SetCreatingFlag(bool flag);
void Start();
bool IsInit() const;
DISABLE_COPY_AND_ASSIGN(Carrier);
private:
......@@ -75,6 +79,7 @@ class Carrier final {
std::vector<InterceptorMessage> message_tmp_{};
bool creating_interceptors_{true};
bool is_init_{false};
};
} // namespace distributed
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......@@ -34,14 +35,21 @@ FleetExecutor::~FleetExecutor() {
void FleetExecutor::Init(const paddle::framework::ProgramDesc& program_desc) {
runtime_graph_ = std::make_unique<RuntimeGraph>(program_desc, exe_desc_);
InitCarrier();
InitMessageBus();
}
void FleetExecutor::InitCarrier() {
Carrier& carrier_instance = Carrier::Instance();
if (!carrier_instance.IsInit()) {
carrier_instance.Init(runtime_graph_->intercepter_id_to_node());
}
}
void FleetExecutor::InitMessageBus() {
std::stringstream ss;
ss << "\nThe DNS table of the message bus is: \n";
int64_t cur_rank = exe_desc_.cur_rank();
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank;
std::unordered_map<int64_t, std::string> rank_to_addr;
std::string addr;
for (const auto& rank_info : exe_desc_.cluster_info()) {
......@@ -49,8 +57,6 @@ void FleetExecutor::InitMessageBus() {
int64_t rank = rank_info.rank();
std::string ip_port = rank_info.ip_port();
ss << rank << "\t->\t" << ip_port << "\n";
// TODO(Yuang): init interceptor_id_to_rank out of this loop
interceptor_id_to_rank.insert(std::make_pair(rank, rank));
rank_to_addr.insert(std::make_pair(rank, ip_port));
if (rank == cur_rank) {
addr = ip_port;
......@@ -58,7 +64,7 @@ void FleetExecutor::InitMessageBus() {
}
if (addr == "") {
PADDLE_ENFORCE_EQ(
rank_to_addr.size(), 0,
rank_to_addr.size(), 1,
platform::errors::NotFound("Empty address is not valid for "
"paddle.distributed.launch method."));
PADDLE_ENFORCE_EQ(
......@@ -72,12 +78,22 @@ void FleetExecutor::InitMessageBus() {
VLOG(5) << ss.str();
MessageBus& message_bus_instance = MessageBus::Instance();
if (!message_bus_instance.IsInit()) {
message_bus_instance.Init(interceptor_id_to_rank, rank_to_addr, addr);
message_bus_instance.Init(runtime_graph_->intercepter_id_to_rank(),
rank_to_addr, addr);
}
}
void FleetExecutor::Run() {
// Run
Carrier& carrier_instance = Carrier::Instance();
MessageBus& message_bus_instance = MessageBus::Instance();
PADDLE_ENFORCE_EQ(
carrier_instance.IsInit(), true,
platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ(
message_bus_instance.IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
carrier_instance.Start();
}
void FleetExecutor::Release() {
......
......@@ -43,6 +43,7 @@ class FleetExecutor final {
FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_;
void InitMessageBus();
void InitCarrier();
};
} // namespace distributed
......
......@@ -33,6 +33,16 @@ void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
void Interceptor::Handle(const InterceptorMessage& msg) {
if (handle_) {
handle_(msg);
} else {
VLOG(3) << "Interceptor is using default message handler. This handler is "
"only used for test purpose. Check whether you init interceptor "
"in the proper way.";
if (msg.message_type() == DATA_IS_READY) {
VLOG(3) << "Fake handler is sending stop message to it self.";
InterceptorMessage msg;
msg.set_message_type(STOP);
Send(interceptor_id_, msg);
}
}
}
......
......@@ -1958,6 +1958,11 @@ class Executor(object):
fleet_exe_desc.cluster_info.append(rank_info)
nrank = len(trainer_endpoints)
else:
fleet_exe_desc.cur_rank = 0
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = 0
rank_info.ip_port = ''
fleet_exe_desc.cluster_info.append(rank_info)
logging.warning("Fleet Executor will run on single device only.")
fleet_opt = program._pipeline_opt["fleet_opt"]
if "dist_strategy" in fleet_opt:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册