未验证 提交 b3e79731 编写于 作者: L LiYuRio 提交者: GitHub

[fleet executor] Add sink interceptor and test (#41497)

上级 330582e2
......@@ -13,7 +13,7 @@ endif()
cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog)
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc dist_model.cc interceptor.cc
compute_interceptor.cc amplifier_interceptor.cc source_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper.cc
compute_interceptor.cc amplifier_interceptor.cc source_interceptor.cc sink_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog ${BRPC_DEPS})
......@@ -26,6 +26,7 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sink_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
......@@ -31,6 +31,7 @@ namespace distributed {
USE_INTERCEPTOR(Source);
USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);
USE_INTERCEPTOR(Sink);
void Carrier::Init(
int64_t rank,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/sink_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node), max_run_times_(node->max_run_times()) {
// prepare the upstream running status
for (const auto& up : node->upstream()) {
upstream_step_.emplace(up.first, 0);
}
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void SinkInterceptor::StopCarrierIfComplete() {
bool flag = true;
for (const auto& up : upstream_step_) {
flag = flag & (up.second == max_run_times_);
}
if (flag) {
VLOG(3) << "Sink Interceptor is stopping carrier";
StopCarrier();
for (const auto& up : upstream_step_) {
upstream_step_.at(up.first) = 0;
}
}
}
void SinkInterceptor::ReplyCompletedToUpStream(int64_t upstream_id) {
int64_t micro_step = upstream_step_.at(upstream_id);
int64_t scope_idx = micro_step % max_run_times_;
InterceptorMessage msg;
msg.set_message_type(DATA_IS_USELESS);
msg.set_scope_idx(scope_idx);
Send(upstream_id, msg);
upstream_step_.at(upstream_id) = micro_step + 1;
if (micro_step == max_run_times_ - 1) {
StopCarrierIfComplete();
}
}
void SinkInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
ReplyCompletedToUpStream(msg.src_id());
}
}
REGISTER_INTERCEPTOR(Sink, SinkInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/*
* Sink interceptor
* There is only one sink in the runtime graph
* Take charge of:
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
*/
class SinkInterceptor : public Interceptor {
public:
SinkInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void ReplyCompletedToUpStream(int64_t up_id);
void Run(const InterceptorMessage& msg);
void StopCarrierIfComplete();
int64_t max_run_times_;
// upstream_id->cur_step
std::map<int64_t, int64_t> upstream_step_;
};
} // namespace distributed
} // namespace paddle
......@@ -7,6 +7,9 @@ cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_exe
set_source_files_properties(source_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(source_interceptor_test SRCS source_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(sink_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(sink_interceptor_test SRCS sink_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(interceptor_pipeline_short_path_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pipeline_short_path_test SRCS interceptor_pipeline_short_path_test.cc DEPS fleet_executor ${BRPC_DEPS})
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
class FakeInterceptor : public Interceptor {
public:
FakeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
std::cout << "FakeInterceptor run in scope " << msg.scope_idx()
<< std::endl;
InterceptorMessage reply;
reply.set_message_type(DATA_IS_USELESS);
Send(-1, reply);
InterceptorMessage ready;
ready.set_message_type(DATA_IS_READY);
Send(-2, ready);
} else if (msg.message_type() == DATA_IS_USELESS) {
std::cout << "FakeInterceptor remove result in scope " << msg.scope_idx()
<< std::endl;
}
}
private:
int64_t step_;
};
TEST(SourceInterceptor, Source) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{-1, 0}, {0, 0}, {-2, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* sink = new TaskNode(0, -2, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(-1, 1);
node_a->AddDownstreamTask(-2, 1);
sink->AddUpstreamTask(0, 1);
carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source));
carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a));
carrier->SetInterceptor(-2, InterceptorFactory::Create("Sink", -2, sink));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(-1);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册