From 0daa69d45467bea8874bd14f2d6a6355ecc0ea32 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Wed, 17 Nov 2021 12:09:35 +0800 Subject: [PATCH] [Fleet Executor] Construct runtime graph (#37158) --- .../distributed/fleet_executor/CMakeLists.txt | 2 +- .../fleet_executor/fleet_executor.cc | 4 +- .../fleet_executor/fleet_executor_desc.proto | 3 + .../fleet_executor/runtime_graph.cc | 221 ++++++++++++++++++ .../fleet_executor/runtime_graph.h | 31 ++- .../distributed/fleet_executor/task_node.cc | 49 ++++ .../distributed/fleet_executor/task_node.h | 36 ++- python/paddle/fluid/executor.py | 9 + .../tests/unittests/test_fleet_executor.py | 2 +- .../test_fleet_executor_multi_devices.py | 19 +- 10 files changed, 364 insertions(+), 12 deletions(-) create mode 100644 paddle/fluid/distributed/fleet_executor/runtime_graph.cc create mode 100644 paddle/fluid/distributed/fleet_executor/task_node.cc diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index bb17e018c35..2d54e602653 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -10,7 +10,7 @@ else() set(BRPC_DEPS "") endif() -cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc +cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc interceptor.cc interceptor_message_service.cc message_bus.cc DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 8990cae1e27..05e78e77cb7 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" namespace paddle { @@ -31,7 +33,7 @@ FleetExecutor::~FleetExecutor() { } void FleetExecutor::Init(const paddle::framework::ProgramDesc& program_desc) { - // Compile and Initialize + runtime_graph_ = std::make_unique(program_desc, exe_desc_); InitMessageBus(); } diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto b/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto index c817f743227..766463eceae 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto @@ -24,4 +24,7 @@ message FleetExecutorDesc { optional string grain = 1 [ default = "coarse" ]; optional int64 cur_rank = 2 [ default = 0 ]; // Rank id of current processor repeated RankInfo cluster_info = 3; + optional int32 dp_degree = 4 [ default = 1 ]; + optional int32 mp_degree = 5 [ default = 1 ]; + optional int32 pp_degree = 6 [ default = 1 ]; } diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.cc b/paddle/fluid/distributed/fleet_executor/runtime_graph.cc new file mode 100644 index 00000000000..e0fbecf2ca9 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.cc @@ -0,0 +1,221 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace distributed { +namespace { + +using OperatorBase = RuntimeGraph::OperatorBase; +using OpRole = paddle::framework::OpRole; +using OpRegistry = paddle::framework::OpRegistry; +using ProgramDesc = paddle::framework::ProgramDesc; + +bool IsForward(int64_t op_role) { + return (op_role == static_cast(OpRole::kForward)) || + (op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss))); +} + +bool IsLRSched(int64_t op_role) { + return op_role == static_cast(OpRole::kLRSched); +} + +bool IsBackward(int64_t op_role) { + return (op_role == static_cast(OpRole::kBackward)) || + (op_role == (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss))); +} + +bool IsOptimize(int64_t op_role) { + return op_role == static_cast(OpRole::kOptimize); +} + +struct DistCoord { + int32_t dp_idx; + int32_t pp_idx; + int32_t mp_idx; +}; + +class DistCoordSys final { + public: + DistCoordSys(int32_t dp_degree, int32_t pp_degree, int32_t mp_degree) + : dp_degree_(dp_degree), pp_degree_(pp_degree), mp_degree_(mp_degree) {} + DistCoord RankToCoord(int64_t rank) const; + int64_t CoordToRank(const DistCoord& coord) const; + + private: + DISABLE_COPY_AND_ASSIGN(DistCoordSys); + bool InvalidCoord(const DistCoord& coord) const; + int32_t dp_degree_; + int32_t pp_degree_; + int32_t mp_degree_; +}; + +DistCoord DistCoordSys::RankToCoord(int64_t rank) const { + DistCoord coord; + coord.mp_idx = rank % mp_degree_; + rank /= mp_degree_; + coord.pp_idx = rank % pp_degree_; + rank /= pp_degree_; + coord.dp_idx = rank % dp_degree_; + return coord; +} + +int64_t DistCoordSys::CoordToRank(const DistCoord& coord) const { + if (InvalidCoord(coord)) { + return -1; + } + return coord.dp_idx * pp_degree_ * mp_degree_ + coord.pp_idx * mp_degree_ + + coord.mp_idx; +} + +bool DistCoordSys::InvalidCoord(const DistCoord& coord) const { + return coord.mp_idx < 0 || coord.mp_idx >= mp_degree_ || coord.pp_idx < 0 || + coord.pp_idx >= pp_degree_ || coord.dp_idx < 0 || + coord.dp_idx >= dp_degree_; +} + +} // namespace + +std::vector RuntimeGraph::functionality_order = { + OpRole::kLRSched, OpRole::kForward, OpRole::kBackward, OpRole::kOptimize}; + +RuntimeGraph::RuntimeGraph(const ProgramDesc& program, + const FleetExecutorDesc& exe_desc) + : exe_desc_(exe_desc) { + if (exe_desc.grain() == "coarse") { + SplitProgramBasedFunctionality(program); + AssignTaskToIntercepter(); + FakeDependence(); + FakeRuntimeInfo(); + } +} + +void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) { + for (const auto& op_desc : program.Block(0).AllOps()) { + ops_.emplace_back(OpRegistry::CreateOp(*op_desc)); + } + std::unordered_map> role_to_ops; + for (const auto& op : ops_) { + int64_t op_role = op->Attr("op_role"); + OpRole new_op_role; + if (IsLRSched(op_role)) { + new_op_role = OpRole::kLRSched; + } else if (IsForward(op_role)) { + new_op_role = OpRole::kForward; + } else if (IsBackward(op_role)) { + new_op_role = OpRole::kBackward; + } else if (IsOptimize(op_role)) { + new_op_role = OpRole::kOptimize; + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "The op %s is None of LRSched, Forward, Backward or Optimize.", + op->Type())); + } + int64_t new_op_role_id = static_cast(new_op_role); + if (role_to_ops.find(new_op_role_id) == role_to_ops.end()) { + role_to_ops.insert({new_op_role_id, {}}); + } + role_to_ops.at(new_op_role_id).emplace_back(op.get()); + } + int64_t cur_rank = exe_desc_.cur_rank(); + int64_t task_id = cur_rank * functionality_order.size(); + for (std::size_t i = 0; i < functionality_order.size(); ++i) { + OpRole role = functionality_order[i]; + int64_t role_id = static_cast(role); + if (role_to_ops.find(role_id) == role_to_ops.end()) { + task_nodes_.emplace_back( + TaskNode::CreateEmptyTaskNode(role_id, cur_rank, task_id)); + } else { + task_nodes_.emplace_back(TaskNode::CreateTaskNode( + role_id, role_to_ops.at(role_id), cur_rank, task_id)); + } + ++task_id; + } +} + +void RuntimeGraph::FakeDependence() { + int64_t cur_rank = exe_desc_.cur_rank(); + DistCoordSys coord_sys(exe_desc_.dp_degree(), exe_desc_.pp_degree(), + exe_desc_.mp_degree()); + const auto& coord = coord_sys.RankToCoord(cur_rank); + DistCoord upstream_coord = coord, downstream_coord = coord; + upstream_coord.pp_idx -= 1; + downstream_coord.pp_idx += 1; + int64_t pp_upstream = coord_sys.CoordToRank(upstream_coord); + int64_t pp_downstream = coord_sys.CoordToRank(downstream_coord); + int32_t num_of_functionality = functionality_order.size(); + // lr -> forward -> backward -> optimize + // | | + // lr -> forward -> backward -> optimize + for (std::size_t i = 0; i < task_nodes_.size(); ++i) { + if (i != 0) { + task_nodes_[i]->AddUpstreamTask(cur_rank * num_of_functionality + i - 1); + } + if (i != task_nodes_.size() - 1) { + task_nodes_[i]->AddDownstreamTask(cur_rank * num_of_functionality + i + + 1); + } + if (IsForward(task_nodes_[i]->role())) { + if (pp_upstream != -1) { + task_nodes_[i]->AddUpstreamTask(pp_upstream * num_of_functionality + i); + } + if (pp_downstream != -1) { + task_nodes_[i]->AddDownstreamTask(pp_downstream * num_of_functionality + + i); + } + } else if (IsBackward(task_nodes_[i]->role())) { + if (pp_downstream != -1) { + task_nodes_[i]->AddUpstreamTask(pp_downstream * num_of_functionality + + i); + } + if (pp_upstream != -1) { + task_nodes_[i]->AddDownstreamTask(pp_upstream * num_of_functionality + + i); + } + } + } +} + +void RuntimeGraph::AssignTaskToIntercepter() { + for (const auto& task : task_nodes_) { + int64_t intercepter_id = task->task_id(); + if (intercepter_id_to_node_.find(intercepter_id) != + intercepter_id_to_node_.end()) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Repeated intercepter id: %d", intercepter_id)); + } + intercepter_id_to_node_.insert({intercepter_id, task.get()}); + } +} + +void RuntimeGraph::FakeRuntimeInfo() { + int64_t nrank = exe_desc_.cluster_info().size(); + int32_t num_of_functionality = functionality_order.size(); + for (int64_t i = 0; i < nrank; ++i) { + for (int64_t j = 0; j < num_of_functionality; ++j) { + int64_t intercepter_id = i * num_of_functionality + j; + intercepter_id_to_rank_.insert({intercepter_id, i}); + } + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.h b/paddle/fluid/distributed/fleet_executor/runtime_graph.h index 7ae573039e6..b25a93080da 100644 --- a/paddle/fluid/distributed/fleet_executor/runtime_graph.h +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.h @@ -13,21 +13,50 @@ // limitations under the License. #pragma once +#include +#include +#include +#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/platform/macros.h" namespace paddle { namespace framework { class ProgramDesc; +class OperatorBase; } namespace distributed { +class TaskNode; class RuntimeGraph final { public: + using ProgramDesc = paddle::framework::ProgramDesc; + using OperatorBase = paddle::framework::OperatorBase; RuntimeGraph() = default; - explicit RuntimeGraph(const paddle::framework::ProgramDesc &program) {} + explicit RuntimeGraph(const ProgramDesc& program, + const FleetExecutorDesc& exe_desc); ~RuntimeGraph() = default; + const std::unordered_map& intercepter_id_to_node() const { + return intercepter_id_to_node_; + } + const std::unordered_map& intercepter_id_to_rank() const { + return intercepter_id_to_rank_; + } + private: DISABLE_COPY_AND_ASSIGN(RuntimeGraph); + void SplitProgramBasedFunctionality(const ProgramDesc& program); + void FakeDependence(); + void AssignTaskToIntercepter(); + void FakeRuntimeInfo(); + // LRSched, Forward, Backward, Optimize + static std::vector functionality_order; + std::vector> task_nodes_; + std::vector> ops_; + std::unordered_map intercepter_id_to_node_; + std::unordered_map intercepter_id_to_rank_; + FleetExecutorDesc exe_desc_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/task_node.cc b/paddle/fluid/distributed/fleet_executor/task_node.cc new file mode 100644 index 00000000000..de85871af51 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_node.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace distributed { +namespace { +using OperatorBase = TaskNode::OperatorBase; +} + +TaskNode::TaskNode(int64_t role, const std::vector& ops, + int64_t rank, int64_t task_id) + : ops_(ops), role_(role), rank_(rank), task_id_(task_id) {} + +TaskNode::TaskNode(int64_t role, int64_t rank, int64_t task_id) + : role_(role), rank_(rank), task_id_(task_id) {} + +std::unique_ptr TaskNode::CreateEmptyTaskNode(int64_t role, + int64_t rank, + int64_t task_id) { + return std::make_unique(role, rank, task_id); +} + +std::unique_ptr TaskNode::CreateTaskNode( + int64_t role, const std::vector& ops, int64_t rank, + int64_t task_id) { + return std::make_unique(role, ops, rank, task_id); +} + +void TaskNode::AddUpstreamTask(int64_t task_id) { upstream_.insert(task_id); } + +void TaskNode::AddDownstreamTask(int64_t task_id) { + downstream_.insert(task_id); +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index 62fb9dfb011..e341f525071 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -13,14 +13,48 @@ // limitations under the License. #pragma once +#include +#include +#include +#include +#include "paddle/fluid/platform/macros.h" namespace paddle { +namespace framework { +class OperatorBase; +} namespace distributed { class TaskNode final { public: - TaskNode() = default; + using OperatorBase = paddle::framework::OperatorBase; + TaskNode(int64_t role, int64_t rank, int64_t task_id); + TaskNode(int64_t role, const std::vector& ops, int64_t rank, + int64_t task_id); ~TaskNode() = default; + int64_t rank() const { return rank_; } + int64_t task_id() const { return task_id_; } + int64_t role() const { return role_; } + const std::unordered_set& upstream() const { return upstream_; } + const std::unordered_set& downstream() const { return downstream_; } + void AddUpstreamTask(int64_t task_id); + void AddDownstreamTask(int64_t task_id); + static std::unique_ptr CreateEmptyTaskNode(int64_t role, + int64_t rank, + int64_t task_id); + static std::unique_ptr CreateTaskNode( + int64_t role, const std::vector& ops, int64_t rank, + int64_t task_id); + + private: + DISABLE_COPY_AND_ASSIGN(TaskNode); + TaskNode() = default; + std::vector ops_; + std::unordered_set upstream_; + std::unordered_set downstream_; + int64_t role_; + int64_t rank_; + int64_t task_id_; }; } // namespace distributed diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 393232a2061..377994252d7 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1890,6 +1890,7 @@ class Executor(object): cur_rank = os.getenv("PADDLE_TRAINER_ID") trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS") fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc() + nrank = 1 if cur_rank and trainer_endpoints_str: fleet_exe_desc.cur_rank = int(cur_rank) trainer_endpoints = trainer_endpoints_str.split(',') @@ -1898,8 +1899,16 @@ class Executor(object): rank_info.rank = rank rank_info.ip_port = endpoint fleet_exe_desc.cluster_info.append(rank_info) + nrank = len(trainer_endpoints) else: logging.warning("Fleet Executor will run on single device only.") + fleet_opt = program._pipeline_opt["fleet_opt"] + if "dist_strategy" in fleet_opt: + fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"]["dp_degree"] + fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"]["mp_degree"] + fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"]["pp_degree"] + num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree + assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu." fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) fleet_exe.init(program._pipeline_opt["section_program"].desc) fleet_exe.run() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor.py b/python/paddle/fluid/tests/unittests/test_fleet_executor.py index 48952cab3db..6ba9e2d9e21 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor.py @@ -26,7 +26,7 @@ class TestFleetExecutor(unittest.TestCase): with fluid.program_guard(empty_program, empty_program): x = fluid.layers.data(name='x', shape=[1], dtype=paddle.float32) empty_program._pipeline_opt = { - "fleet_opt": True, + "fleet_opt": {}, "section_program": empty_program } exe.run(empty_program, feed={'x': [1]}) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor_multi_devices.py b/python/paddle/fluid/tests/unittests/test_fleet_executor_multi_devices.py index 1afe4b94753..473d49fea48 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor_multi_devices.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor_multi_devices.py @@ -16,18 +16,19 @@ import unittest import os import paddle import paddle.fluid as fluid +import paddle.distributed.fleet as fleet paddle.enable_static() class TestFleetExecutor(unittest.TestCase): - def run_fleet_executor(self, place): + def run_fleet_executor(self, place, fleet_opt=dict()): exe = paddle.static.Executor(place) empty_program = paddle.static.Program() with fluid.program_guard(empty_program, empty_program): x = fluid.layers.data(name='x', shape=[1], dtype=paddle.float32) empty_program._pipeline_opt = { - "fleet_opt": True, + "fleet_opt": fleet_opt, "section_program": empty_program } exe.run(empty_program, feed={'x': [1]}) @@ -35,12 +36,16 @@ class TestFleetExecutor(unittest.TestCase): def test_dist_executor_on_multi_devices(self): os.environ["PADDLE_TRAINER_ID"] = "0" os.environ[ - "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:7000,127.0.0.1:7001,127.0.0.1:7002" - places = [fluid.CPUPlace()] + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:7000,127.0.0.1:7001,127.0.0.1:7002,127.0.0.1:7003,127.0.0.1:7004,127.0.0.1:7005,127.0.0.1:7006,127.0.0.1:7007" + strategy = fleet.DistributedStrategy() + strategy.sharding_configs = { + "dp_degree": 2, + "mp_degree": 2, + "pp_degree": 2 + } + fleet_opt = {"dist_strategy": strategy.sharding_configs} if fluid.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for place in places: - self.run_fleet_executor(place) + self.run_fleet_executor(fluid.CUDAPlace(0), fleet_opt) if __name__ == "__main__": -- GitLab