diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 2d54e602653fdfd209c572b191b5610f91b9d92b..162cbd8a7b5204761e521485b26470a708bc71ba 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -11,12 +11,13 @@ else() endif() cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc - interceptor.cc interceptor_message_service.cc message_bus.cc + interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto ${BRPC_DEPS}) if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 84548d7fd69c056c2cfe8de3d1f9092e978a44ab..8a42533f59e3d5b1e4a8925a9a37b81e504d076f 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -21,6 +21,8 @@ namespace paddle { namespace distributed { +USE_INTERCEPTOR(Compute); + void Carrier::Init( const std::unordered_map& interceptor_id_to_node) { PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc new file mode 100644 index 0000000000000000000000000000000000000000..4307665f30e5389265d54926f43940ec1cc215e6 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" + +#include "paddle/fluid/distributed/fleet_executor/task_node.h" + +namespace paddle { +namespace distributed { + +ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + PrepareDeps(); + RegisterMsgHandle([this](const InterceptorMessage& msg) { Compute(msg); }); +} + +void ComputeInterceptor::PrepareDeps() { + auto& upstream = GetTaskNode()->upstream(); + upstream_deps_.insert(upstream.begin(), upstream.end()); +} + +void ComputeInterceptor::SendDataReadyToDownStream() { + auto& downstream = GetTaskNode()->downstream(); + for (auto dst_id : downstream) { + InterceptorMessage dst_msg; + dst_msg.set_message_type(DATA_IS_READY); + VLOG(3) << "ComputeInterceptor Send msg to " << dst_id; + Send(dst_id, dst_msg); + } +} + +void ComputeInterceptor::Compute(const InterceptorMessage& msg) { + if (msg.message_type() == DATA_IS_READY) { + auto src_id = msg.src_id(); + upstream_deps_.erase(src_id); + + // all input is ready + if (upstream_deps_.empty()) { + // TODO(wangxi): op run + VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; + SendDataReadyToDownStream(); + PrepareDeps(); + } + } +} + +REGISTER_INTERCEPTOR(Compute, ComputeInterceptor); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h new file mode 100644 index 0000000000000000000000000000000000000000..9b49910b9eb78312faffa03228e4f7b7615f0d64 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -0,0 +1,37 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" + +namespace paddle { +namespace distributed { + +class ComputeInterceptor : public Interceptor { + public: + ComputeInterceptor(int64_t interceptor_id, TaskNode* node); + + void PrepareDeps(); + + void SendDataReadyToDownStream(); + + void Compute(const InterceptorMessage& msg); + + private: + std::unordered_set upstream_deps_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 3e9d6e484125b92d8d71ad4d55c9a30be959bae0..a342d4431a1aa3070cf335322a6219ce1d18cec1 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -76,7 +76,7 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { void Interceptor::PoolTheMailbox() { // pool the local mailbox, parse the Message - while (true) { + for (;;) { if (local_mailbox_.empty()) { // local mailbox is empty, fetch the remote mailbox VLOG(3) << interceptor_id_ << "'s local mailbox is empty. " diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 0b94318e18d070a2e1bd3a899660c10ac7a17e00..9ea392ea5f8cab6ef1d201d892776ad813f60a2e 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -62,6 +62,9 @@ class Interceptor { DISABLE_COPY_AND_ASSIGN(Interceptor); + protected: + TaskNode* GetTaskNode() const { return node_; } + private: // pool the local mailbox, parse the Message void PoolTheMailbox(); @@ -114,19 +117,30 @@ class InterceptorFactory { int64_t id, TaskNode* node); }; +template +std::unique_ptr CreatorInterceptor(int64_t id, TaskNode* node) { + return std::make_unique(id, node); +} + #define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \ - std::unique_ptr CreatorInterceptor_##interceptor_type( \ - int64_t id, TaskNode* node) { \ - return std::make_unique(id, node); \ - } \ class __RegisterInterceptor_##interceptor_type { \ public: \ __RegisterInterceptor_##interceptor_type() { \ InterceptorFactory::Register(#interceptor_type, \ - CreatorInterceptor_##interceptor_type); \ + CreatorInterceptor); \ } \ + void Touch() {} \ }; \ - __RegisterInterceptor_##interceptor_type g_register_##interceptor_type; + __RegisterInterceptor_##interceptor_type g_register_##interceptor_type; \ + int TouchRegisterInterceptor_##interceptor_type() { \ + g_register_##interceptor_type.Touch(); \ + return 0; \ + } + +#define USE_INTERCEPTOR(interceptor_type) \ + extern int TouchRegisterInterceptor_##interceptor_type(); \ + UNUSED static int use_interceptor_##interceptor_type = \ + TouchRegisterInterceptor_##interceptor_type(); } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index ea7f43eb82bb9726404f3d34312d0d41cb78c230..f438e491daa53c51548489be8fc7ca3bd3cc7817 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -15,8 +15,10 @@ #pragma once #include #include +#include #include #include + #include "paddle/fluid/platform/macros.h" namespace paddle { @@ -33,6 +35,7 @@ class TaskNode final { TaskNode(int32_t role, const std::vector& ops, int64_t rank, int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); ~TaskNode() = default; + int64_t rank() const { return rank_; } int64_t task_id() const { return task_id_; } int32_t role() const { return role_; } @@ -40,9 +43,12 @@ class TaskNode final { int64_t max_slot_nums() const { return max_slot_nums_; } const std::unordered_set& upstream() const { return upstream_; } const std::unordered_set& downstream() const { return downstream_; } + const std::string& type() const { return type_; } + void AddUpstreamTask(int64_t task_id); void AddDownstreamTask(int64_t task_id); std::string DebugString() const; + static std::unique_ptr CreateEmptyTaskNode(int32_t role, int64_t rank, int64_t task_id, @@ -63,6 +69,8 @@ class TaskNode final { int64_t task_id_; int64_t max_run_times_; int64_t max_slot_nums_; + + std::string type_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt index 524aebe3b959f5f5833376765291e4694145ca39..1d034d510a9cc7c2dd2ce7e76ae8665f600da09f 100644 --- a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt @@ -1,2 +1,4 @@ set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS}) +cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..658ff25672df4d5156a89a71bbe43cf57da84af7 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" + +namespace paddle { +namespace distributed { + +class StopInterceptor : public Interceptor { + public: + StopInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + RegisterMsgHandle([this](const InterceptorMessage& msg) { Stop(msg); }); + } + + void Stop(const InterceptorMessage& msg) { + std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() + << std::endl; + InterceptorMessage stop; + stop.set_message_type(STOP); + Send(0, stop); + Send(1, stop); + Send(2, stop); + } +}; + +TEST(ComputeInterceptor, Compute) { + MessageBus& msg_bus = MessageBus::Instance(); + msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + + Carrier& carrier = Carrier::Instance(); + + // NOTE: don't delete, otherwise interceptor will use undefined node + TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id + TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); + TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0); + + // a->b->c + node_a->AddDownstreamTask(1); + node_b->AddUpstreamTask(0); + node_b->AddDownstreamTask(2); + + Interceptor* a = carrier.SetInterceptor( + 0, InterceptorFactory::Create("Compute", 0, node_a)); + carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); + carrier.SetInterceptor(2, std::make_unique(2, node_c)); + + carrier.SetCreatingFlag(false); + + InterceptorMessage msg; + msg.set_message_type(DATA_IS_READY); + a->Send(1, msg); +} + +} // namespace distributed +} // namespace paddle