未验证 提交 8a4460f5 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] interceptor run from python interface (#37693)

上级 82b55961
...@@ -92,19 +92,22 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { ...@@ -92,19 +92,22 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
} }
void Carrier::Start() { void Carrier::Start() {
// TODO(fleet_executor dev): this start is a faked one, need replace MessageBus& msg_bus = MessageBus::Instance();
for (const auto& pair : interceptor_idx_to_interceptor_) { PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true,
VLOG(3) << "Fake run is sending start to interceptor " << pair.first << "."; platform::errors::PreconditionNotMet(
InterceptorMessage tmp_msg; "Message bus has not been initialized."));
tmp_msg.set_src_id(pair.first);
tmp_msg.set_dst_id(pair.first); for (int64_t id : source_interceptor_ids_) {
tmp_msg.set_message_type(DATA_IS_READY); VLOG(3) << "Carrier Start is sending start to source interceptor " << id
MessageBus& message_bus_instance = MessageBus::Instance(); << ".";
PADDLE_ENFORCE_EQ(message_bus_instance.IsInit(), true, InterceptorMessage start_msg;
platform::errors::PreconditionNotMet( // source node data_is_ready is send by carrier, so set src_id=-1
"Message bus has not been initialized.")); start_msg.set_src_id(-1);
message_bus_instance.Send(tmp_msg); start_msg.set_dst_id(id);
start_msg.set_message_type(DATA_IS_READY);
msg_bus.Send(start_msg);
} }
std::unique_lock<std::mutex> lock(running_mutex_); std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock); cond_var_.wait(lock);
dev_ctx_->Wait(); dev_ctx_->Wait();
...@@ -164,16 +167,26 @@ void Carrier::CreateInterceptors() { ...@@ -164,16 +167,26 @@ void Carrier::CreateInterceptors() {
int64_t interceptor_id = item.first; int64_t interceptor_id = item.first;
TaskNode* task_node = item.second; TaskNode* task_node = item.second;
// TODO(wangxi): use node_type to select different Interceptor std::unique_ptr<Interceptor> interceptor;
auto interceptor = if (task_node->type().empty()) {
std::make_unique<Interceptor>(interceptor_id, task_node); // TODO(wangxi): delete this in future
interceptor.reset(new Interceptor(interceptor_id, task_node));
} else {
interceptor = InterceptorFactory::Create(task_node->type(),
interceptor_id, task_node);
}
interceptor->SetPlace(place_); interceptor->SetPlace(place_);
interceptor->SetMiniBatchScope(minibatch_scope_); interceptor->SetMiniBatchScope(minibatch_scope_);
interceptor->SetMicroBatchScope(microbatch_scopes_); interceptor->SetMicroBatchScope(microbatch_scopes_);
interceptor->SetRootScope(root_scope_); interceptor->SetRootScope(root_scope_);
SetInterceptor(interceptor_id, std::move(interceptor)); SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< "."; << ".";
if (task_node->upstream().empty()) {
source_interceptor_ids_.emplace_back(interceptor_id);
}
} }
// The carrier will be always waiting for outside initializer // The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init // since there is no interceptor has been created during auto init
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <condition_variable> #include <condition_variable>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -90,6 +91,8 @@ class Carrier final { ...@@ -90,6 +91,8 @@ class Carrier final {
std::unordered_map<int64_t, std::unique_ptr<Interceptor>> std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_; interceptor_idx_to_interceptor_;
std::vector<int64_t> source_interceptor_ids_;
std::vector<InterceptorMessage> message_tmp_{}; std::vector<InterceptorMessage> message_tmp_{};
std::mutex tmp_message_mutex_; std::mutex tmp_message_mutex_;
bool creating_interceptors_{true}; bool creating_interceptors_{true};
......
...@@ -154,18 +154,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -154,18 +154,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
} }
void ComputeInterceptor::Run() { void ComputeInterceptor::Run() {
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) return;
}
step_ = 0; // reset
return;
}
while (IsInputReady() && CanWriteOutput() && !ShouldReset()) { while (IsInputReady() && CanWriteOutput() && !ShouldReset()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
...@@ -181,6 +169,18 @@ void ComputeInterceptor::Run() { ...@@ -181,6 +169,18 @@ void ComputeInterceptor::Run() {
// reply to upstream and decrease ready data // reply to upstream and decrease ready data
ReplyCompletedToUpStream(); ReplyCompletedToUpStream();
} }
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run max_run_times.
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) return;
}
step_ = 0; // reset
return;
}
} }
void ComputeInterceptor::ReceivedStop(int64_t up_id) { void ComputeInterceptor::ReceivedStop(int64_t up_id) {
......
...@@ -46,11 +46,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) { ...@@ -46,11 +46,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
VLOG(3) << "Interceptor is using default message handler. This handler is " VLOG(3) << "Interceptor is using default message handler. This handler is "
"only used for test purpose. Check whether you init interceptor " "only used for test purpose. Check whether you init interceptor "
"in the proper way."; "in the proper way.";
if (msg.message_type() == DATA_IS_READY) { if (msg.message_type() == DATA_IS_READY) {
if (node_->role() != 2) {
VLOG(3) << "Fake handler is sending DATA_IS_READY message to: "
<< interceptor_id_ + 1 << ".";
InterceptorMessage data_is_ready_msg;
data_is_ready_msg.set_message_type(DATA_IS_READY);
Send(interceptor_id_ + 1, data_is_ready_msg);
}
VLOG(3) << "Fake handler is sending stop message to it self."; VLOG(3) << "Fake handler is sending stop message to it self.";
InterceptorMessage msg; InterceptorMessage stop_msg;
msg.set_message_type(STOP); stop_msg.set_message_type(STOP);
Send(interceptor_id_, msg); Send(interceptor_id_, stop_msg);
} else if (msg.message_type() == STOP) { } else if (msg.message_type() == STOP) {
stop_ = true; stop_ = true;
StopCarrier(); StopCarrier();
......
...@@ -136,6 +136,9 @@ void MessageBus::ListenPort() { ...@@ -136,6 +136,9 @@ void MessageBus::ListenPort() {
} }
bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) { bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) {
// -1 is sent by carrier to source interceptor
if (src_id == -1) src_id = dst_id;
// check whether the dst is the same rank or different rank with src // check whether the dst is the same rank or different rank with src
const auto& src_rank = interceptor_id_to_rank_.find(src_id); const auto& src_rank = interceptor_id_to_rank_.find(src_id);
const auto& dst_rank = interceptor_id_to_rank_.find(dst_id); const auto& dst_rank = interceptor_id_to_rank_.find(dst_id);
......
...@@ -112,6 +112,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) { ...@@ -112,6 +112,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
for (const auto& op_desc : program.Block(0).AllOps()) { for (const auto& op_desc : program.Block(0).AllOps()) {
ops_.emplace_back(OpRegistry::CreateOp(*op_desc)); ops_.emplace_back(OpRegistry::CreateOp(*op_desc));
} }
std::unordered_map<int32_t, std::vector<OperatorBase*>> role_to_ops; std::unordered_map<int32_t, std::vector<OperatorBase*>> role_to_ops;
for (const auto& op : ops_) { for (const auto& op : ops_) {
int32_t op_role = op->Attr<int32_t>("op_role"); int32_t op_role = op->Attr<int32_t>("op_role");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册