diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 4a2dfcb554ad33439d26517b660cec1e399f93e2..977a125627ba547122f483fab4690baadc704bb0 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -13,7 +13,7 @@ endif() cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog) cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc dist_model.cc interceptor.cc - compute_interceptor.cc amplifier_interceptor.cc source_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper.cc + compute_interceptor.cc amplifier_interceptor.cc source_interceptor.cc sink_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper.cc DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper op_registry executor_gc_helper gflags glog ${BRPC_DEPS}) @@ -26,6 +26,7 @@ if(WITH_DISTRIBUTE) set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(sink_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 358393d97f0710ee6762c86c6a1f56b3ceb1fa2f..2d2a3b688fefeda4deef1f0a24a270788b380cfe 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -31,6 +31,7 @@ namespace distributed { USE_INTERCEPTOR(Source); USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Amplifier); +USE_INTERCEPTOR(Sink); void Carrier::Init( int64_t rank, diff --git a/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc b/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc new file mode 100644 index 0000000000000000000000000000000000000000..af707c28acd9e9576ebd5491bd4517ea0cbae32e --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/sink_interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" + +namespace paddle { +namespace distributed { + +SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node), max_run_times_(node->max_run_times()) { + // prepare the upstream running status + for (const auto& up : node->upstream()) { + upstream_step_.emplace(up.first, 0); + } + RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); }); +} + +void SinkInterceptor::StopCarrierIfComplete() { + bool flag = true; + for (const auto& up : upstream_step_) { + flag = flag & (up.second == max_run_times_); + } + if (flag) { + VLOG(3) << "Sink Interceptor is stopping carrier"; + StopCarrier(); + for (const auto& up : upstream_step_) { + upstream_step_.at(up.first) = 0; + } + } +} + +void SinkInterceptor::ReplyCompletedToUpStream(int64_t upstream_id) { + int64_t micro_step = upstream_step_.at(upstream_id); + int64_t scope_idx = micro_step % max_run_times_; + InterceptorMessage msg; + msg.set_message_type(DATA_IS_USELESS); + msg.set_scope_idx(scope_idx); + Send(upstream_id, msg); + upstream_step_.at(upstream_id) = micro_step + 1; + if (micro_step == max_run_times_ - 1) { + StopCarrierIfComplete(); + } +} + +void SinkInterceptor::Run(const InterceptorMessage& msg) { + if (msg.message_type() == DATA_IS_READY) { + ReplyCompletedToUpStream(msg.src_id()); + } +} + +REGISTER_INTERCEPTOR(Sink, SinkInterceptor); +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/sink_interceptor.h b/paddle/fluid/distributed/fleet_executor/sink_interceptor.h new file mode 100644 index 0000000000000000000000000000000000000000..cb1d698a78526fdde61586304e588e8009340584 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/sink_interceptor.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" + +namespace paddle { +namespace distributed { + +/* + * Sink interceptor + * There is only one sink in the runtime graph + * Take charge of: + * 1. record the num of micro-step + * 2. check whether to notify carrier the current step is finished + */ +class SinkInterceptor : public Interceptor { + public: + SinkInterceptor(int64_t interceptor_id, TaskNode* node); + + private: + void ReplyCompletedToUpStream(int64_t up_id); + void Run(const InterceptorMessage& msg); + void StopCarrierIfComplete(); + int64_t max_run_times_; + // upstream_id->cur_step + std::map upstream_step_; +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt index 33c08acd4498df9d422ec518cc458194501b72d1..e0db8a261b58594f7d3e8db8a535395bdb4a2a80 100644 --- a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt @@ -7,6 +7,9 @@ cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_exe set_source_files_properties(source_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(source_interceptor_test SRCS source_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) +set_source_files_properties(sink_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(sink_interceptor_test SRCS sink_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) + set_source_files_properties(interceptor_pipeline_short_path_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(interceptor_pipeline_short_path_test SRCS interceptor_pipeline_short_path_test.cc DEPS fleet_executor ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b1a555e987a380da6ca4db8ceeb5b9965150ff8 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" + +namespace paddle { +namespace distributed { + +class FakeInterceptor : public Interceptor { + public: + FakeInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); }); + } + + void NOP(const InterceptorMessage& msg) { + if (msg.message_type() == DATA_IS_READY) { + std::cout << "FakeInterceptor run in scope " << msg.scope_idx() + << std::endl; + InterceptorMessage reply; + reply.set_message_type(DATA_IS_USELESS); + Send(-1, reply); + InterceptorMessage ready; + ready.set_message_type(DATA_IS_READY); + Send(-2, ready); + } else if (msg.message_type() == DATA_IS_USELESS) { + std::cout << "FakeInterceptor remove result in scope " << msg.scope_idx() + << std::endl; + } + } + + private: + int64_t step_; +}; + +TEST(SourceInterceptor, Source) { + std::string carrier_id = "0"; + Carrier* carrier = + GlobalMap::Create(carrier_id, carrier_id); + carrier->Init(0, {{-1, 0}, {0, 0}, {-2, 0}}); + + MessageBus* msg_bus = GlobalVal::Create(); + msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); + + // NOTE: don't delete, otherwise interceptor will use undefined node + TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id + TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id + TaskNode* sink = new TaskNode(0, -2, 0, 3, 0); // role, rank, task_id + + source->AddDownstreamTask(0, 1); + node_a->AddUpstreamTask(-1, 1); + node_a->AddDownstreamTask(-2, 1); + sink->AddUpstreamTask(0, 1); + carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source)); + carrier->SetInterceptor(0, std::make_unique(0, node_a)); + carrier->SetInterceptor(-2, InterceptorFactory::Create("Sink", -2, sink)); + + // start + InterceptorMessage msg; + msg.set_message_type(START); + msg.set_dst_id(-1); + carrier->EnqueueInterceptorMessage(msg); + + carrier->Wait(); + carrier->Release(); +} + +} // namespace distributed +} // namespace paddle