diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 8641b36a1be8ea51dc4ad911214c2cebe6121e20..4a2dfcb554ad33439d26517b660cec1e399f93e2 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -13,7 +13,7 @@ endif() cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog) cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc dist_model.cc interceptor.cc - compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper.cc + compute_interceptor.cc amplifier_interceptor.cc source_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper.cc DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper op_registry executor_gc_helper gflags glog ${BRPC_DEPS}) @@ -25,6 +25,7 @@ if(WITH_DISTRIBUTE) set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 0d5d328fd32cc2e12d4f4e94c94dae51f0c040bc..358393d97f0710ee6762c86c6a1f56b3ceb1fa2f 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -28,6 +28,7 @@ namespace paddle { namespace distributed { +USE_INTERCEPTOR(Source); USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Amplifier); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index d934ab1948e7e584ba1be5e844abc7ad4a6059dc..f49c84e6e5edc0f7e0c97041cd82d196ce0c44bb 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -164,7 +164,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { if (up_id == -1) return; InterceptorMessage reply_msg; - reply_msg.set_message_type(DATE_IS_USELESS); + reply_msg.set_message_type(DATA_IS_USELESS); Send(up_id, reply_msg); } } @@ -247,7 +247,7 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { if (msg.message_type() == DATA_IS_READY) { IncreaseReady(msg.src_id()); Run(); - } else if (msg.message_type() == DATE_IS_USELESS) { + } else if (msg.message_type() == DATA_IS_USELESS) { DecreaseBuff(msg.src_id()); Run(); } else if (msg.message_type() == STOP) { diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto index ed38894641c3a65b7041ea51b736e91d8ecd7a7c..7cf99e8741943774c935e41d8e152c39fffbcd75 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto @@ -20,9 +20,10 @@ option cc_enable_arenas = true; enum MessageType { STOP = 1; // STOP an Interceptor DATA_IS_READY = 2; // upstream data is ready - DATE_IS_USELESS = 3; // downstream has used the data + DATA_IS_USELESS = 3; // downstream has used the data ERR = 4; // current Interceptor encounters error RESET = 5; // reset the status + START = 6; } message InterceptorMessage { @@ -30,6 +31,7 @@ message InterceptorMessage { optional int64 dst_id = 2 [ default = 0 ]; optional MessageType message_type = 3 [ default = RESET ]; optional bool ctrl_message = 4 [ default = false ]; + optional int64 scope_idx = 5 [ default = 0 ]; } message InterceptorResponse { optional bool rst = 1 [ default = false ]; } diff --git a/paddle/fluid/distributed/fleet_executor/source_interceptor.cc b/paddle/fluid/distributed/fleet_executor/source_interceptor.cc new file mode 100644 index 0000000000000000000000000000000000000000..78b2bed66dd99d60d1e39782e501d116bd6c1ec7 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/source_interceptor.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/source_interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" + +namespace paddle { +namespace distributed { + +SourceInterceptor::SourceInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node), max_run_times_(node->max_run_times()) { + // prepare the downstream running status + for (const auto& down : node->downstream()) { + downstream_step_.emplace(down.first, 0); + } + RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); }); +} + +void SourceInterceptor::SendDataReadyToDownStream(int64_t downstream_id) { + int64_t micro_step = downstream_step_.at(downstream_id); + if (micro_step >= max_run_times_) { + return; + } + int64_t scope_idx = micro_step % max_run_times_; + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_IS_READY); + ready_msg.set_scope_idx(scope_idx); + Send(downstream_id, ready_msg); + downstream_step_.at(downstream_id) = micro_step + 1; +} + +void SourceInterceptor::Run(const InterceptorMessage& msg) { + if (msg.message_type() == START) { + // start run in a new step, reset the previous running status + for (const auto& down : downstream_step_) { + downstream_step_.at(down.first) = 0; + SendDataReadyToDownStream(down.first); + } + } else if (msg.message_type() == DATA_IS_USELESS) { + SendDataReadyToDownStream(msg.src_id()); + } +} + +REGISTER_INTERCEPTOR(Source, SourceInterceptor); +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/source_interceptor.h b/paddle/fluid/distributed/fleet_executor/source_interceptor.h new file mode 100644 index 0000000000000000000000000000000000000000..f8b18fb1848645c44c75db90a7d123ba48aeae21 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/source_interceptor.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" + +namespace paddle { +namespace distributed { + +/* + * Source interceptor + * There is only one source in the runtime graph + * Take charge of: + * 1. receive `start` message from carrier + * 2. send num_of_steps `data_is_ready` message to downstream + */ +class SourceInterceptor : public Interceptor { + public: + SourceInterceptor(int64_t interceptor_id, TaskNode* node); + + private: + void SendDataReadyToDownStream(int64_t down_id); + void Run(const InterceptorMessage& msg); + int64_t max_run_times_; + // downstream_id->cur_step + std::map downstream_step_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt index d4587b90c87f3deadba686230728ed084b2a18ad..33c08acd4498df9d422ec518cc458194501b72d1 100644 --- a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt @@ -4,6 +4,9 @@ cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) +set_source_files_properties(source_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(source_interceptor_test SRCS source_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) + set_source_files_properties(interceptor_pipeline_short_path_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(interceptor_pipeline_short_path_test SRCS interceptor_pipeline_short_path_test.cc DEPS fleet_executor ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf49e97474af0373497dc04c3d254dae820b2caf --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/global.h" +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" + +namespace paddle { +namespace distributed { + +class FakeInterceptor : public Interceptor { + public: + FakeInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + step_ = 0; + RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); }); + } + + void NOP(const InterceptorMessage& msg) { + if (msg.message_type() == DATA_IS_READY) { + std::cout << "FakeInterceptor run in scope " << msg.scope_idx() + << std::endl; + InterceptorMessage reply; + reply.set_message_type(DATA_IS_USELESS); + Send(-1, reply); + step_++; + if (step_ == node_->max_run_times()) { + carrier_->WakeUp(); + } + } + } + + private: + int64_t step_; +}; + +TEST(SourceInterceptor, Source) { + std::string carrier_id = "0"; + Carrier* carrier = + GlobalMap::Create(carrier_id, carrier_id); + carrier->Init(0, {{-1, 0}, {0, 0}}); + + MessageBus* msg_bus = GlobalVal::Create(); + msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); + + // NOTE: don't delete, otherwise interceptor will use undefined node + TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id + TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id + + source->AddDownstreamTask(0, 1); + node_a->AddUpstreamTask(-1, 1); + carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source)); + carrier->SetInterceptor(0, std::make_unique(0, node_a)); + + // start + InterceptorMessage msg; + msg.set_message_type(START); + msg.set_dst_id(-1); + carrier->EnqueueInterceptorMessage(msg); + + carrier->Wait(); + carrier->Release(); +} + +} // namespace distributed +} // namespace paddle