From 964e20e0a5d4188367bf9058bcfc4328b9ee1a2b Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 22 Nov 2021 11:41:27 +0800 Subject: [PATCH] [fleet_executor] Add compute interceptor (#37376) --- .../distributed/fleet_executor/CMakeLists.txt | 3 +- .../distributed/fleet_executor/carrier.cc | 2 + .../fleet_executor/compute_interceptor.cc | 61 +++++++++++++++ .../fleet_executor/compute_interceptor.h | 37 +++++++++ .../distributed/fleet_executor/interceptor.cc | 2 +- .../distributed/fleet_executor/interceptor.h | 26 +++++-- .../distributed/fleet_executor/task_node.h | 8 ++ .../fleet_executor/test/CMakeLists.txt | 2 + .../test/compute_interceptor_test.cc | 75 +++++++++++++++++++ 9 files changed, 208 insertions(+), 8 deletions(-) create mode 100644 paddle/fluid/distributed/fleet_executor/compute_interceptor.cc create mode 100644 paddle/fluid/distributed/fleet_executor/compute_interceptor.h create mode 100644 paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 2d54e60265..162cbd8a7b 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -11,12 +11,13 @@ else() endif() cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc - interceptor.cc interceptor_message_service.cc message_bus.cc + interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto ${BRPC_DEPS}) if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 84548d7fd6..8a42533f59 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -21,6 +21,8 @@ namespace paddle { namespace distributed { +USE_INTERCEPTOR(Compute); + void Carrier::Init( const std::unordered_map& interceptor_id_to_node) { PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc new file mode 100644 index 0000000000..4307665f30 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" + +#include "paddle/fluid/distributed/fleet_executor/task_node.h" + +namespace paddle { +namespace distributed { + +ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + PrepareDeps(); + RegisterMsgHandle([this](const InterceptorMessage& msg) { Compute(msg); }); +} + +void ComputeInterceptor::PrepareDeps() { + auto& upstream = GetTaskNode()->upstream(); + upstream_deps_.insert(upstream.begin(), upstream.end()); +} + +void ComputeInterceptor::SendDataReadyToDownStream() { + auto& downstream = GetTaskNode()->downstream(); + for (auto dst_id : downstream) { + InterceptorMessage dst_msg; + dst_msg.set_message_type(DATA_IS_READY); + VLOG(3) << "ComputeInterceptor Send msg to " << dst_id; + Send(dst_id, dst_msg); + } +} + +void ComputeInterceptor::Compute(const InterceptorMessage& msg) { + if (msg.message_type() == DATA_IS_READY) { + auto src_id = msg.src_id(); + upstream_deps_.erase(src_id); + + // all input is ready + if (upstream_deps_.empty()) { + // TODO(wangxi): op run + VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; + SendDataReadyToDownStream(); + PrepareDeps(); + } + } +} + +REGISTER_INTERCEPTOR(Compute, ComputeInterceptor); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h new file mode 100644 index 0000000000..9b49910b9e --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -0,0 +1,37 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" + +namespace paddle { +namespace distributed { + +class ComputeInterceptor : public Interceptor { + public: + ComputeInterceptor(int64_t interceptor_id, TaskNode* node); + + void PrepareDeps(); + + void SendDataReadyToDownStream(); + + void Compute(const InterceptorMessage& msg); + + private: + std::unordered_set upstream_deps_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 3e9d6e4841..a342d4431a 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -76,7 +76,7 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { void Interceptor::PoolTheMailbox() { // pool the local mailbox, parse the Message - while (true) { + for (;;) { if (local_mailbox_.empty()) { // local mailbox is empty, fetch the remote mailbox VLOG(3) << interceptor_id_ << "'s local mailbox is empty. " diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 0b94318e18..9ea392ea5f 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -62,6 +62,9 @@ class Interceptor { DISABLE_COPY_AND_ASSIGN(Interceptor); + protected: + TaskNode* GetTaskNode() const { return node_; } + private: // pool the local mailbox, parse the Message void PoolTheMailbox(); @@ -114,19 +117,30 @@ class InterceptorFactory { int64_t id, TaskNode* node); }; +template +std::unique_ptr CreatorInterceptor(int64_t id, TaskNode* node) { + return std::make_unique(id, node); +} + #define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \ - std::unique_ptr CreatorInterceptor_##interceptor_type( \ - int64_t id, TaskNode* node) { \ - return std::make_unique(id, node); \ - } \ class __RegisterInterceptor_##interceptor_type { \ public: \ __RegisterInterceptor_##interceptor_type() { \ InterceptorFactory::Register(#interceptor_type, \ - CreatorInterceptor_##interceptor_type); \ + CreatorInterceptor); \ } \ + void Touch() {} \ }; \ - __RegisterInterceptor_##interceptor_type g_register_##interceptor_type; + __RegisterInterceptor_##interceptor_type g_register_##interceptor_type; \ + int TouchRegisterInterceptor_##interceptor_type() { \ + g_register_##interceptor_type.Touch(); \ + return 0; \ + } + +#define USE_INTERCEPTOR(interceptor_type) \ + extern int TouchRegisterInterceptor_##interceptor_type(); \ + UNUSED static int use_interceptor_##interceptor_type = \ + TouchRegisterInterceptor_##interceptor_type(); } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index ea7f43eb82..f438e491da 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -15,8 +15,10 @@ #pragma once #include #include +#include #include #include + #include "paddle/fluid/platform/macros.h" namespace paddle { @@ -33,6 +35,7 @@ class TaskNode final { TaskNode(int32_t role, const std::vector& ops, int64_t rank, int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); ~TaskNode() = default; + int64_t rank() const { return rank_; } int64_t task_id() const { return task_id_; } int32_t role() const { return role_; } @@ -40,9 +43,12 @@ class TaskNode final { int64_t max_slot_nums() const { return max_slot_nums_; } const std::unordered_set& upstream() const { return upstream_; } const std::unordered_set& downstream() const { return downstream_; } + const std::string& type() const { return type_; } + void AddUpstreamTask(int64_t task_id); void AddDownstreamTask(int64_t task_id); std::string DebugString() const; + static std::unique_ptr CreateEmptyTaskNode(int32_t role, int64_t rank, int64_t task_id, @@ -63,6 +69,8 @@ class TaskNode final { int64_t task_id_; int64_t max_run_times_; int64_t max_slot_nums_; + + std::string type_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt index 524aebe3b9..1d034d510a 100644 --- a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt @@ -1,2 +1,4 @@ set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS}) +cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc new file mode 100644 index 0000000000..658ff25672 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" + +namespace paddle { +namespace distributed { + +class StopInterceptor : public Interceptor { + public: + StopInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + RegisterMsgHandle([this](const InterceptorMessage& msg) { Stop(msg); }); + } + + void Stop(const InterceptorMessage& msg) { + std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() + << std::endl; + InterceptorMessage stop; + stop.set_message_type(STOP); + Send(0, stop); + Send(1, stop); + Send(2, stop); + } +}; + +TEST(ComputeInterceptor, Compute) { + MessageBus& msg_bus = MessageBus::Instance(); + msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + + Carrier& carrier = Carrier::Instance(); + + // NOTE: don't delete, otherwise interceptor will use undefined node + TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id + TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); + TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0); + + // a->b->c + node_a->AddDownstreamTask(1); + node_b->AddUpstreamTask(0); + node_b->AddDownstreamTask(2); + + Interceptor* a = carrier.SetInterceptor( + 0, InterceptorFactory::Create("Compute", 0, node_a)); + carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); + carrier.SetInterceptor(2, std::make_unique(2, node_c)); + + carrier.SetCreatingFlag(false); + + InterceptorMessage msg; + msg.set_message_type(DATA_IS_READY); + a->Send(1, msg); +} + +} // namespace distributed +} // namespace paddle -- GitLab