未验证 提交 8a4460f5 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] interceptor run from python interface (#37693)

上级 82b55961
......@@ -92,19 +92,22 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
}
void Carrier::Start() {
// TODO(fleet_executor dev): this start is a faked one, need replace
for (const auto& pair : interceptor_idx_to_interceptor_) {
VLOG(3) << "Fake run is sending start to interceptor " << pair.first << ".";
InterceptorMessage tmp_msg;
tmp_msg.set_src_id(pair.first);
tmp_msg.set_dst_id(pair.first);
tmp_msg.set_message_type(DATA_IS_READY);
MessageBus& message_bus_instance = MessageBus::Instance();
PADDLE_ENFORCE_EQ(message_bus_instance.IsInit(), true,
platform::errors::PreconditionNotMet(
"Message bus has not been initialized."));
message_bus_instance.Send(tmp_msg);
MessageBus& msg_bus = MessageBus::Instance();
PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true,
platform::errors::PreconditionNotMet(
"Message bus has not been initialized."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Start is sending start to source interceptor " << id
<< ".";
InterceptorMessage start_msg;
// source node data_is_ready is send by carrier, so set src_id=-1
start_msg.set_src_id(-1);
start_msg.set_dst_id(id);
start_msg.set_message_type(DATA_IS_READY);
msg_bus.Send(start_msg);
}
std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
dev_ctx_->Wait();
......@@ -164,16 +167,26 @@ void Carrier::CreateInterceptors() {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
// TODO(wangxi): use node_type to select different Interceptor
auto interceptor =
std::make_unique<Interceptor>(interceptor_id, task_node);
std::unique_ptr<Interceptor> interceptor;
if (task_node->type().empty()) {
// TODO(wangxi): delete this in future
interceptor.reset(new Interceptor(interceptor_id, task_node));
} else {
interceptor = InterceptorFactory::Create(task_node->type(),
interceptor_id, task_node);
}
interceptor->SetPlace(place_);
interceptor->SetMiniBatchScope(minibatch_scope_);
interceptor->SetMicroBatchScope(microbatch_scopes_);
interceptor->SetRootScope(root_scope_);
SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< ".";
if (task_node->upstream().empty()) {
source_interceptor_ids_.emplace_back(interceptor_id);
}
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
......
......@@ -17,6 +17,7 @@
#include <condition_variable>
#include <memory>
#include <mutex>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -90,6 +91,8 @@ class Carrier final {
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
std::vector<int64_t> source_interceptor_ids_;
std::vector<InterceptorMessage> message_tmp_{};
std::mutex tmp_message_mutex_;
bool creating_interceptors_{true};
......
......@@ -154,18 +154,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
void ComputeInterceptor::Run() {
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) return;
}
step_ = 0; // reset
return;
}
while (IsInputReady() && CanWriteOutput() && !ShouldReset()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
......@@ -181,6 +169,18 @@ void ComputeInterceptor::Run() {
// reply to upstream and decrease ready data
ReplyCompletedToUpStream();
}
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run max_run_times.
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) return;
}
step_ = 0; // reset
return;
}
}
void ComputeInterceptor::ReceivedStop(int64_t up_id) {
......
......@@ -46,11 +46,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
VLOG(3) << "Interceptor is using default message handler. This handler is "
"only used for test purpose. Check whether you init interceptor "
"in the proper way.";
if (msg.message_type() == DATA_IS_READY) {
if (node_->role() != 2) {
VLOG(3) << "Fake handler is sending DATA_IS_READY message to: "
<< interceptor_id_ + 1 << ".";
InterceptorMessage data_is_ready_msg;
data_is_ready_msg.set_message_type(DATA_IS_READY);
Send(interceptor_id_ + 1, data_is_ready_msg);
}
VLOG(3) << "Fake handler is sending stop message to it self.";
InterceptorMessage msg;
msg.set_message_type(STOP);
Send(interceptor_id_, msg);
InterceptorMessage stop_msg;
stop_msg.set_message_type(STOP);
Send(interceptor_id_, stop_msg);
} else if (msg.message_type() == STOP) {
stop_ = true;
StopCarrier();
......
......@@ -136,6 +136,9 @@ void MessageBus::ListenPort() {
}
bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) {
// -1 is sent by carrier to source interceptor
if (src_id == -1) src_id = dst_id;
// check whether the dst is the same rank or different rank with src
const auto& src_rank = interceptor_id_to_rank_.find(src_id);
const auto& dst_rank = interceptor_id_to_rank_.find(dst_id);
......
......@@ -112,6 +112,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
for (const auto& op_desc : program.Block(0).AllOps()) {
ops_.emplace_back(OpRegistry::CreateOp(*op_desc));
}
std::unordered_map<int32_t, std::vector<OperatorBase*>> role_to_ops;
for (const auto& op : ops_) {
int32_t op_role = op->Attr<int32_t>("op_role");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册