From fc701369095706ce7f6d7282b60de37b4759cba5 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Fri, 17 Dec 2021 10:36:40 +0800 Subject: [PATCH] [fleet_executor] run time graph on python side (#38164) --- .../fleet_executor/fleet_executor.cc | 17 ++ .../distributed/fleet_executor/task_node.cc | 24 ++- .../distributed/fleet_executor/task_node.h | 9 +- paddle/fluid/pybind/bind_fleet_executor.cc | 9 +- .../distributed/fleet/fleet_executor_utils.py | 203 ++++++++++++++++++ python/paddle/fluid/executor.py | 22 +- .../tests/unittests/test_fleet_executor.py | 3 +- 7 files changed, 279 insertions(+), 8 deletions(-) create mode 100644 python/paddle/distributed/fleet/fleet_executor_utils.py diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 9bba2624bf2..0369c442734 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -17,6 +17,9 @@ #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -38,16 +41,30 @@ void FleetExecutor::Init( const platform::Place& place, const std::vector& task_nodes, const std::unordered_map& task_id_to_rank) { if (task_nodes.size() == 0) { + LOG(INFO) << "fleet executor will use c++ side scheduler construction."; runtime_graph_ = std::make_shared(program_desc, exe_desc_); } else { + LOG(INFO) << "fleet executor has been set dependency on python side."; + // TODO(fleet_exe devs): the unused_vars should be got from run time graph + std::vector> ops; + for (auto task_node : task_nodes) { + for (auto op : task_node->ops()) { + ops.emplace_back(std::unique_ptr(op)); + } + } + auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); runtime_graph_ = std::make_shared(); std::unordered_map interceptor_id_to_task; for (auto task_node : task_nodes) { + task_node->SetUnusedVars(unused_vars); int64_t interceptor_id = task_node->task_id(); interceptor_id_to_task.emplace(interceptor_id, task_node); } runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); + for (auto& unique_op : ops) { + unique_op.release(); + } } root_scope_ = scope; place_ = place; diff --git a/paddle/fluid/distributed/fleet_executor/task_node.cc b/paddle/fluid/distributed/fleet_executor/task_node.cc index e92ab09d481..f03ee0acb47 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.cc +++ b/paddle/fluid/distributed/fleet_executor/task_node.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" @@ -39,7 +40,28 @@ TaskNode::TaskNode(const framework::ProgramDesc& program, int64_t rank, } } -TaskNode::TaskNode(int32_t role, const std::vector& ops, +TaskNode::TaskNode(int32_t role, + const std::vector& op_descs, + int64_t rank, int64_t task_id, int64_t max_run_times, + int64_t max_slot_nums) + : role_(role), + rank_(rank), + task_id_(task_id), + max_run_times_(max_run_times), + max_slot_nums_(max_slot_nums) { + if (op_descs.empty()) { + return; + } + for (const auto& desc : op_descs) { + ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*desc)); + } + for (const auto& op : ops_vec_) { + ops_.emplace_back(op.get()); + } +} + +TaskNode::TaskNode(int32_t role, + const std::vector& ops, int64_t rank, int64_t task_id, int64_t max_run_times, int64_t max_slot_nums) : ops_(ops), diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index 37105bdd230..d43cd99926d 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -25,6 +25,7 @@ namespace paddle { namespace framework { class OperatorBase; +class OpDesc; } namespace distributed { @@ -33,8 +34,12 @@ class TaskNode final { using OperatorBase = paddle::framework::OperatorBase; TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); - TaskNode(int32_t role, const std::vector& ops, int64_t rank, - int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); + TaskNode(int32_t role, const std::vector& op_descs, + int64_t rank, int64_t task_id, int64_t max_run_times, + int64_t max_slot_nums); + TaskNode(int32_t role, const std::vector& ops, + int64_t rank, int64_t task_id, int64_t max_run_times, + int64_t max_slot_nums); TaskNode(const paddle::framework::ProgramDesc& program, int64_t rank, int64_t max_run_times, int64_t max_slot_nums); ~TaskNode() = default; diff --git a/paddle/fluid/pybind/bind_fleet_executor.cc b/paddle/fluid/pybind/bind_fleet_executor.cc index 6fc9b2a494f..cd7e559aa16 100644 --- a/paddle/fluid/pybind/bind_fleet_executor.cc +++ b/paddle/fluid/pybind/bind_fleet_executor.cc @@ -28,6 +28,7 @@ namespace pybind { using paddle::distributed::FleetExecutor; using paddle::distributed::TaskNode; +using paddle::framework::OpDesc; void BindFleetExecutor(py::module* m) { py::class_(*m, "FleetExecutor") @@ -38,9 +39,15 @@ void BindFleetExecutor(py::module* m) { py::class_(*m, "TaskNode") .def(py::init()) + .def(py::init&, int64_t, + int64_t, int64_t, int64_t>()) .def("task_id", &TaskNode::task_id) .def("add_upstream_task", &TaskNode::AddUpstreamTask) - .def("add_downstream_task", &TaskNode::AddDownstreamTask); + .def("add_downstream_task", &TaskNode::AddDownstreamTask) + .def("set_run_pre_steps", &TaskNode::SetRunPerSteps) + .def("set_run_at_offset", &TaskNode::SetRunAtOffset) + .def("set_type", &TaskNode::SetType) + .def("role", &TaskNode::role); } } // namespace pybind } // namespace paddle diff --git a/python/paddle/distributed/fleet/fleet_executor_utils.py b/python/paddle/distributed/fleet/fleet_executor_utils.py new file mode 100644 index 00000000000..9422774bb64 --- /dev/null +++ b/python/paddle/distributed/fleet/fleet_executor_utils.py @@ -0,0 +1,203 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY +from paddle.fluid import core + + +class CoordSys: + """ + This class is used to mapping rank to (mp rank, sharding rank, pp rank, dp rank). + """ + + def __init__(self, dist_opt): + self.dp_degree = dist_opt.get('dp_degree', 1) + self.pp_degree = dist_opt.get('pp_degree', 1) + self.sharding_degree = dist_opt.get('sharding_degree', 1) + self.mp_degree = dist_opt.get('mp_degree', 1) + + def _invalide_coord(self, coord): + """ + Test the input coord is valid or not. + :param coord: The coord to be tested + :return: False if valid, True if invalid. + """ + return coord['mp_idx'] < 0 or coord['mp_idx'] >= self.mp_degree or \ + coord['sharding_idx'] < 0 or coord['sharding_idx'] >= self.sharding_degree or \ + coord['pp_idx'] < 0 or coord['pp_idx'] >= self.pp_degree or \ + coord['dp_idx'] < 0 or coord['dp_idx'] >= self.dp_degree + + def coord_to_rank(self, coord): + """ + Map the input coord to it's corresponding rank. + :param coord: The coord to be converted + :return: The rank corresponding with the coord + """ + if self._invalide_coord(coord): + return -1 + return int(coord['dp_idx'] * self.pp_degree * self.sharding_degree * self.mp_degree + \ + coord['pp_idx'] * self.sharding_degree * self.mp_degree + \ + coord['sharding_idx'] * self.mp_degree + coord['mp_idx']) + + def rank_to_coord(self, rank): + """ + Map the input rank to it's corresponding coord + :param rank: The rank to be converted + :return: The coord corresponding with the rank + """ + mp_idx = rank % self.mp_degree + rank //= self.mp_degree + sharding_idx = rank % self.sharding_degree + rank //= self.sharding_degree + pp_idx = rank % self.pp_degree + rank //= self.pp_degree + dp_idx = rank % self.dp_degree + return { + 'mp_idx': int(mp_idx), + 'sharding_idx': int(sharding_idx), + 'pp_idx': int(pp_idx), + 'dp_idx': int(dp_idx) + } + + +def is_optimizer_op(op_role): + return op_role == int(OpRole.Optimize) + + +def is_lr_sched_op(op_role): + return op_role == int(OpRole.Optimize.LRSched) + + +def is_forward_op(op_role): + return (op_role == int(OpRole.Forward)) or \ + (op_role == (int(OpRole.Forward) ^ int(OpRole.Loss))) + + +def is_backward_op(op_role): + return (op_role == int(OpRole.Backward)) or \ + (op_role == (int(OpRole.Backward) ^ int(OpRole.Loss))) + + +def one_f_one_b(program, cur_rank, max_run_times, dist_opt, nrank): + """ + Split the program to support 1f1b pipeline scheduler. + This funct will split the program based on the op_role. + The program will be split into four parts: lr_sched, fwd, bwd, opt. + And will create task nodes based on the four parts of the program. + :param program: The origin program. + :param cur_rank: Current rank (can be got from fleet.worker_index()). + :param max_run_times: Max run times for a micro batch. AKA number of micro steps. + :param dist_opt: The fleet_opt configured by user. + :param nrank: Number of workers (can be got from fleet.worker_num()). + :return: + task_nodes (list): four task nodes for current rank + task_id_to_rank (dict): task nodes' ids to it's corresponding rank + """ + print("fleet executor will use python side 1f1b scheduler.") + coord_sys = CoordSys(dist_opt) + coord = coord_sys.rank_to_coord(cur_rank) + max_slot_times = int(max_run_times - coord['pp_idx']) + num_of_functionality = 4 + + def create_task_node(role, ops, offset, node_type): + task_id = int(cur_rank * num_of_functionality + offset) + print("Creating task node with role:", role, "and with id:", task_id) + node = core.TaskNode(role, ops, cur_rank, task_id, max_run_times, + max_slot_times) + node.set_type(node_type) + return node + + lr_ops, fwd_ops, bwd_ops, opt_ops = [], [], [], [] + for op in program.block(0).ops: + # split the program based on the op_role + op_role = int(op.all_attrs()[OP_ROLE_KEY]) + if is_lr_sched_op(op_role): + lr_ops.append(op.desc) + elif is_optimizer_op(op_role): + opt_ops.append(op.desc) + elif is_forward_op(op_role): + fwd_ops.append(op.desc) + elif is_backward_op(op_role): + bwd_ops.append(op.desc) + else: + raise "The op role: " + str( + op_role + ) + " isn't one of LRSched, Forward, Backward or Optimizer." + + # Create task nodes. + # The lr_sched and opt should be 'amplifier interceptor. + # The fwd and bwd should be 'compute interceptor'. + lr_task_node = create_task_node( + int(OpRole.Optimize.LRSched), lr_ops, 0, "Amplifier") + lr_task_node.set_run_pre_steps(max_run_times) + fwd_task_node = create_task_node(int(OpRole.Forward), fwd_ops, 1, "Compute") + bwd_task_node = create_task_node( + int(OpRole.Backward), bwd_ops, 2, "Compute") + opt_task_node = create_task_node( + int(OpRole.Optimize), opt_ops, 3, "Amplifier") + opt_task_node.set_run_pre_steps(max_run_times) + opt_task_node.set_run_at_offset(max_run_times - 1) + task_nodes = [lr_task_node, fwd_task_node, bwd_task_node, opt_task_node] + + # Generated the dependency based on this graph: + # lr(1:m) -> forward -> backward -> (m:1)optimize + # ↑ ↓ + # lr(1:m) -> forward -> backward -> (m:1)optimize + # ↑ ↓ + # lr(1:m) -> forward -> backward -> (m:1)optimize + upstream_coord, downstream_coord = coord.copy(), coord.copy() + upstream_coord['pp_idx'] = upstream_coord['pp_idx'] - 1 + downstream_coord['pp_idx'] = downstream_coord['pp_idx'] + 1 + pp_upstream = coord_sys.coord_to_rank(upstream_coord) + pp_downstream = coord_sys.coord_to_rank(downstream_coord) + first_stage = (pp_upstream == -1) + last_stage = (pp_downstream == -1) + for i in range(num_of_functionality): + task_node = task_nodes[i] + task_role = task_node.role() + cur_id = int(cur_rank * num_of_functionality + i) + prev_id = cur_id - 1 + next_id = cur_id + 1 + upstream_id = int(pp_upstream * num_of_functionality + i) + downstream_id = int(pp_downstream * num_of_functionality + i) + pp_buff_size = int(dist_opt['pp_degree'] - coord['pp_idx']) + ups = [] + downs = [] + if not is_lr_sched_op(task_role): + buf_size = pp_buff_size if is_backward_op(task_role) else 2 + ups.append((prev_id, buf_size)) + if not is_optimizer_op(task_role): + buf_size = pp_buff_size if is_forward_op(task_role) else 2 + downs.append((next_id, buf_size)) + if is_forward_op(task_role): + if not first_stage: + ups.append((upstream_id, 2)) + if not last_stage: + downs.append((downstream_id, 2)) + elif is_backward_op(task_role): + if not last_stage: + ups.append((downstream_id, 2)) + if not first_stage: + downs.append((upstream_id, 2)) + for up in ups: + print("Task:", cur_id, "'s upstream includes:", up[0]) + task_node.add_upstream_task(up[0], up[1]) + for down in downs: + print("Task:", cur_id, "'s downstream includes:", down[0]) + task_node.add_downstream_task(down[0], down[1]) + task_id_to_rank = {} + for i in range(nrank): + for j in range(num_of_functionality): + task_id_to_rank[int(i * num_of_functionality + j)] = i + return task_nodes, task_id_to_rank diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index eab00707f01..b00449a7475 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1964,7 +1964,8 @@ class Executor(object): trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "") trainer_endpoints = trainer_endpoints_str.split(',') fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc() - fleet_exe_desc.cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0)) + cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0)) + fleet_exe_desc.cur_rank = cur_rank nrank = len(trainer_endpoints) for rank, endpoint in enumerate(trainer_endpoints): rank_info = fleet_executor_desc_pb2.RankInfo() @@ -1979,8 +1980,23 @@ class Executor(object): fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"] num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu." - task_id_to_rank = fleet_opt.get("task_id_to_rank", {}) - tasks = fleet_opt.get("tasks", []) + if 'python_side' in fleet_opt: + strategy = fleet_opt['python_side'] + if strategy == '1F1B': + from paddle.distributed.fleet.fleet_executor_utils import one_f_one_b + tasks, task_id_to_rank = one_f_one_b( + program, cur_rank, + fleet_opt.get('num_micro_batches', 1), + fleet_opt.get('dist_strategy', {}), nrank) + # NOTE: have to hold these vars, otherwise will be destructed + fleet_opt['tasks'] = tasks + fleet_opt['task_id_to_rank'] = task_id_to_rank + else: + raise "Fleet_executor only supports 1F1B scheduler if you choose python side split, " \ + "but received " + str(strategy) + "." + else: + task_id_to_rank = fleet_opt.get("task_id_to_rank", {}) + tasks = fleet_opt.get("tasks", []) fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) place = core.Place() place.set_place(self.place) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor.py b/python/paddle/fluid/tests/unittests/test_fleet_executor.py index fbc5db341e5..30ba376dd4d 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor.py @@ -33,7 +33,8 @@ class TestFleetExecutor(unittest.TestCase): strategy.pipeline_configs = {"accumulate_steps": 1} fleet_opt = { "dist_strategy": strategy.sharding_configs, - "num_micro_batches": strategy.pipeline_configs["accumulate_steps"] + "num_micro_batches": strategy.pipeline_configs["accumulate_steps"], + "python_side": "1F1B" } return fleet_opt -- GitLab