From 12339fa0b9914da7abbd83a3f68a5792b9af4792 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 15 Nov 2021 11:13:52 +0800 Subject: [PATCH] Add distributed pass framework: including PassBase/PassTest/PassUtils (#36643) * add split_program * make ut faster * increase ut timeout * make result deterministic * add fuse_all_reduce pass * add ut framework, update * fix ut framework * remove useless code * add coverage support * update * fix CI * fix some bugs and fix ci coverage * fix conflict --- paddle/fluid/framework/ir/graph_helper.cc | 102 +++++ paddle/fluid/framework/ir/graph_helper.h | 3 + paddle/fluid/framework/ir/node.h | 1 + paddle/fluid/operators/coalesce_tensor_op.cc | 113 ++++-- paddle/fluid/pybind/ir.cc | 7 + paddle/fluid/pybind/protobuf.cc | 6 +- .../paddle/distributed/fleet/launch_utils.py | 18 +- .../meta_optimizers/raw_program_optimizer.py | 14 +- python/paddle/distributed/passes/__init__.py | 23 ++ python/paddle/distributed/passes/cpp_pass.py | 25 ++ .../distributed/passes/fuse_all_reduce.py | 360 ++++++++++++++++++ python/paddle/distributed/passes/pass_base.py | 273 +++++++++++++ .../paddle/distributed/passes/pass_utils.py | 134 +++++++ .../fluid/tests/unittests/CMakeLists.txt | 3 + .../distributed_passes/CMakeLists.txt | 8 + .../distributed_passes/dist_pass_test_base.py | 218 +++++++++++ .../unittests/distributed_passes/launch.py | 22 ++ .../distributed_passes/pass_run_main.py | 75 ++++ .../test_dist_fuse_all_reduce_pass.py | 76 ++++ .../tests/unittests/test_split_program.py | 149 ++++++++ python/setup.py.in | 1 + 21 files changed, 1592 insertions(+), 39 deletions(-) create mode 100644 python/paddle/distributed/passes/__init__.py create mode 100644 python/paddle/distributed/passes/cpp_pass.py create mode 100644 python/paddle/distributed/passes/fuse_all_reduce.py create mode 100644 python/paddle/distributed/passes/pass_base.py create mode 100644 python/paddle/distributed/passes/pass_utils.py create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/launch.py create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/pass_run_main.py create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/test_dist_fuse_all_reduce_pass.py create mode 100644 python/paddle/fluid/tests/unittests/test_split_program.py diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 5f7bfc61b4..b2ab6bed36 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -545,6 +545,108 @@ void GraphToProgram(const Graph &graph, ProgramDesc *program, program->CopyFrom(program_pb); } +static std::vector> GetOpDependencies( + const BlockDesc &block, const std::unordered_set &nodes) { + auto block_ops = block.AllOps(); + size_t op_num = block_ops.size(); + std::unordered_map> + preceding_ops(op_num); + std::unordered_map preceding_deps(op_num); + std::unordered_map> + pending_ops(op_num); + + std::queue ready_ops; + for (const auto *node : nodes) { + if (!node->IsOp()) continue; + + auto &tmp_preceding_ops = preceding_ops[node]; + for (const auto *in_var : node->inputs) { + for (const auto *in_op : in_var->inputs) { + tmp_preceding_ops.insert(in_op); + } + } + if (tmp_preceding_ops.empty()) { + ready_ops.push(node); + } + preceding_deps[node] = tmp_preceding_ops.size(); + + auto &tmp_pending_ops = pending_ops[node]; + for (const auto *out_var : node->outputs) { + for (const auto *out_op : out_var->outputs) { + tmp_pending_ops.insert(out_op); + } + } + } + + std::unordered_map> + all_preceding_ops; + while (!ready_ops.empty()) { + const auto *cur_op = ready_ops.front(); + ready_ops.pop(); + + auto &all_preceding_ops_of_cur_op = all_preceding_ops[cur_op]; + for (const auto *preceding_op : preceding_ops.at(cur_op)) { + all_preceding_ops_of_cur_op.insert(preceding_op); + auto &prev_preceding_ops = all_preceding_ops[preceding_op]; + all_preceding_ops_of_cur_op.insert(prev_preceding_ops.begin(), + prev_preceding_ops.end()); + } + + for (const auto *pending_op : pending_ops.at(cur_op)) { + if (--preceding_deps.at(pending_op) == 0) { + ready_ops.push(pending_op); + } + } + } + + std::unordered_map op_id_to_idx(op_num); + for (const auto *op_desc : block_ops) { + size_t op_idx = op_id_to_idx.size(); + PADDLE_ENFORCE_EQ( + op_id_to_idx.emplace(op_desc->Id(), op_idx).second, true, + platform::errors::InvalidArgument( + "There should not be duplicate op id: %d", op_desc->Id())); + } + + std::vector> dep_matrix(op_num); + for (size_t i = 0; i < op_num; ++i) { + dep_matrix[i].resize(op_num, ir::Node::Dep::kNoDep); + dep_matrix[i][i] = ir::Node::Dep::kSame; + } + + auto get_op_idx_by_id = [&op_id_to_idx](uint64_t op_id) { + auto iter = op_id_to_idx.find(op_id); + PADDLE_ENFORCE_NE(iter, op_id_to_idx.end(), + platform::errors::InvalidArgument( + "Cannot find OpDesc with id %d", op_id)); + return iter->second; + }; + + for (const auto &pair : all_preceding_ops) { + const auto *cur_op_node = pair.first; + size_t op_idx_1 = get_op_idx_by_id(cur_op_node->Op()->Id()); + for (const auto *preceding_op_node : pair.second) { + size_t op_idx_2 = get_op_idx_by_id(preceding_op_node->Op()->Id()); + dep_matrix[op_idx_1][op_idx_2] = ir::Node::Dep::kAfter; + dep_matrix[op_idx_2][op_idx_1] = ir::Node::Dep::kBefore; + } + } + return dep_matrix; +} + +std::vector>> GetOpDependencies( + const ProgramDesc &program) { + ir::Graph graph(program); + size_t block_num = program.Size(); + std::vector>> deps; + deps.reserve(block_num); + for (size_t i = 0; i < block_num; ++i) { + deps.emplace_back( + GetOpDependencies(program.Block(i), graph.GetSubGraph(i)->Nodes())); + } + return deps; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index f00e3ae37b..7affdab951 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -124,6 +124,9 @@ std::vector TopologySortGraphByDescOrder(const Graph &graph); void GraphToProgram(const Graph &graph, ProgramDesc *p_program, const SortKind *sort_kind = nullptr); +std::vector>> GetOpDependencies( + const ProgramDesc &program); + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 54bd4376c6..f4cca78b6d 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -63,6 +63,7 @@ class Node { } enum class Type { kOperation, kVariable }; + enum class Dep { kSame = 0, kBefore = 1, kAfter = 2, kNoDep = 3 }; #if !defined(_WIN32) // msvc not support constexpr correctly. static constexpr char kControlDepVarName[] = "__control_var"; #else diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index d2addb32bc..8d52cb9a08 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -87,8 +87,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &context) const override { auto in_var_names = context.InputNames("Input"); auto out_var_names = context.OutputNames("Output"); - auto &in_vars = context.MultiInputVar("Input"); - auto out_vars = context.MultiOutputVar("Output"); + const auto &in_tensors = context.MultiInput("Input"); + auto out_tensors = context.MultiOutput("Output"); PADDLE_ENFORCE_GT(in_var_names.size(), static_cast(0), platform::errors::InvalidArgument( @@ -101,30 +101,61 @@ class CoalesceTensorOpKernel : public framework::OpKernel { in_var_names.size(), out_var_names.size())); // Input & Output check: only support LoDTensor - for (size_t i = 0; i < in_var_names.size(); ++i) { + bool has_not_init_in_vars = false; + for (size_t i = 0; i < in_tensors.size(); ++i) { PADDLE_ENFORCE_NOT_NULL( - in_vars[i], - platform::errors::NotFound("The input variable %s of CoalesceTensor " - "operator does not exist.", - in_var_names[i])); + in_tensors[i], platform::errors::InvalidArgument( + "The %d-th input tensor cannot be nullptr.", i)); PADDLE_ENFORCE_NOT_NULL( - out_vars[i], - platform::errors::NotFound("The output variable %s of CoalesceTensor " - "operator does not exist.", - out_var_names[i])); - PADDLE_ENFORCE_EQ(in_vars[i]->IsType(), true, + out_tensors[i], platform::errors::InvalidArgument( + "The %d-th output tensor cannot be nullptr.", i)); + if (!in_tensors[i]->IsInitialized()) { + has_not_init_in_vars = true; + } + } + + if (has_not_init_in_vars) { + const auto &concated_shapes = + context.Attr>("concated_shapes"); + const auto &concated_ranks = + context.Attr>("concated_ranks"); + PADDLE_ENFORCE_EQ(concated_ranks.size(), out_tensors.size(), platform::errors::InvalidArgument( - "The input variable %s of CoalesceTensor operator " - "is not LoDTensor.", - in_var_names[i])); - PADDLE_ENFORCE_EQ(out_vars[i]->IsType(), true, + "The attribute(concated_ranks) length must be " + "equal to the output tensor number.")); + int64_t accumulated_ranks = 0; + for (size_t i = 0; i < in_tensors.size(); ++i) { + framework::DDim dims(concated_shapes.data() + accumulated_ranks, + concated_ranks[i]); + if (!in_tensors[i]->IsInitialized()) { + PADDLE_ENFORCE_EQ( + in_tensors[i], out_tensors[i], + platform::errors::InvalidArgument( + "The %d-th output tensor and %d-th input tensor when the " + "%d-th input tensor is not initialized.", + i, i, i)); + out_tensors[i]->Resize(dims); + } else { + PADDLE_ENFORCE_EQ( + in_tensors[i]->dims(), dims, + platform::errors::InvalidArgument( + "The %d-th input tensor shape does not match the " + "attribute(concated_shapes) and " + "attribute(concated_ranks).", + i)); + } + accumulated_ranks += concated_ranks[i]; + PADDLE_ENFORCE_LE(accumulated_ranks, concated_shapes.size(), + platform::errors::InvalidArgument( + "The attribute(concated_shapes) and " + "attribute(concated_ranks) do not match.")); + } + PADDLE_ENFORCE_EQ(accumulated_ranks, concated_shapes.size(), platform::errors::InvalidArgument( - "The output variable %s of CoalesceTensor operator " - "is not LoDTensor.", - out_var_names[i])); + "The attribute(concated_shapes) and " + "attribute(concated_ranks) do not match.")); } - auto in_tensors = context.MultiInput("Input"); bool use_align = context.Attr("use_align"); auto align_size = context.Attr("align_size"); auto size_of_dtype = context.Attr("user_defined_size_of_dtype"); @@ -141,8 +172,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel { } else { // Init the output as input for (size_t i = 0; i < in_tensors.size(); ++i) { - out_vars[i]->GetMutable()->Resize( - in_tensors[i]->dims()); + out_tensors[i]->Resize(in_tensors[i]->dims()); } } @@ -160,11 +190,13 @@ class CoalesceTensorOpKernel : public framework::OpKernel { // Alloc the continuous space auto fused_tensor = context.Output("FusedOutput"); - fused_tensor->Resize(framework::make_ddim({static_cast(numel)})) - .mutable_data(context.GetPlace(), dtype); + void *fused_tensor_ptr = + fused_tensor + ->Resize(framework::make_ddim({static_cast(numel)})) + .mutable_data(context.GetPlace(), dtype); + VLOG(10) << "Fused tensor addr " << fused_tensor_ptr; // Init the continuous space - auto out_tensors = context.MultiOutput("Output"); size_t offset = 0; if (context.Attr("copy_data")) { #ifdef PADDLE_WITH_ASCEND_CL @@ -257,10 +289,6 @@ class CoalesceTensorOpKernel : public framework::OpKernel { std::stringstream ss; ss << "alloc_space_for_vars: "; for (size_t i = 0; i < var_names.size(); ++i) { - PADDLE_ENFORCE_EQ(lod_tensors[i]->IsInitialized(), true, - platform::errors::InvalidArgument( - "Tensor `%s` is not initialized.", var_names[i])); - auto size = lod_tensors[i]->numel(); PADDLE_ENFORCE_GT( size, 0, @@ -272,11 +300,13 @@ class CoalesceTensorOpKernel : public framework::OpKernel { place, align_size) / size_of_dtype : static_cast(size); + const void *ptr = lod_tensors[i]->IsInitialized() + ? lod_tensors[i]->data() + : nullptr; VLOG(4) << size << " " << len; ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims() << ") " - << " addres:" << lod_tensors[i]->data() << " len: " << len - << ", "; + << " addres:" << ptr << " len: " << len << ", "; *numel += len; } VLOG(10) << ss.str(); @@ -328,6 +358,13 @@ class CoalesceTensorOp : public framework::OperatorWithKernel { } protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &context) const override { + auto dtype = static_cast( + context.Attr("dtype")); + return framework::OpKernelType(dtype, context.GetPlace()); + } + framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { @@ -386,6 +423,20 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker { "make sure the shape of these two vars are identical with " "each other, this attr is added.") .SetDefault(-1); + AddAttr>( + "concated_shapes", + "The concated shapes of each shape of the input tensors. " + "If any of the input tensors are not inited, this is used to " + "init the output tensor shape, together with " + "attribute(concated_ranks).") + .SetDefault({}); + AddAttr>( + "concated_ranks", + "The concated ranks of each rank of the input tensors. " + "If any of the input tensors are not inited, this is used to " + "init the output tensor shape, together with " + "attribute(concated_shapes).") + .SetDefault({}); AddComment(R"DOC( CoalesceTensor Operator. diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 050bfc967d..f2fb4671df 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -193,6 +193,13 @@ void BindNode(py::module *m) { .value("Operation", Node::Type::kOperation) .value("Variable", Node::Type::kVariable) .export_values(); + + py::enum_(node, "Dep") + .value("Same", Node::Dep::kSame) + .value("Before", Node::Dep::kBefore) + .value("After", Node::Dep::kAfter) + .value("NoDep", Node::Dep::kNoDep) + .export_values(); } class PYBIND11_HIDDEN PassAttrGetterSetterRegistry { diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 984f3d1a31..9e5e391920 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/process_mesh_desc.h" #include "paddle/fluid/framework/program_desc.h" @@ -81,7 +82,10 @@ void BindProgramDesc(pybind11::module *m) { }, pybind11::arg("version") = pd::kCurProgramVersion) .def("_version", - [](pd::ProgramDesc &self) -> int64_t { return self.Version(); }); + [](pd::ProgramDesc &self) -> int64_t { return self.Version(); }) + .def("get_op_deps", [](const framework::ProgramDesc &program) { + return framework::ir::GetOpDependencies(program); + }); } void BindProcessMeshDesc(pybind11::module *m) { diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index c44352303b..d1f4442ee6 100644 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -465,6 +465,18 @@ class TrainerProc(object): self.cmd = None +_run_with_coverage = False + + +def run_with_coverage(*args): + global _run_with_coverage + assert len(args) <= 1, "len(args) {} should <= 1".format(len(args)) + if len(args) == 1: + assert isinstance(args[0], bool) + _run_with_coverage = args[0] + return _run_with_coverage + + def start_local_trainers(cluster, pod, training_script, @@ -518,7 +530,11 @@ def start_local_trainers(cluster, current_env.update(proc_env) - cmd = [sys.executable, "-u", training_script] + training_script_args + coverage_args = [] + if run_with_coverage(): + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + cmd = [sys.executable, "-u"] + coverage_args + [training_script + ] + training_script_args logger.debug("start trainer proc{} env:{}".format(cmd, current_env)) diff --git a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py index c8eaa54f9c..d056d4e106 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py @@ -216,8 +216,8 @@ class RawProgramOptimizer(MetaOptimizerBase): gm_block._insert_op( first_optimize_op_idx + insert_op_num, type="c_sync_comm_stream", - inputs={'X': grad_vars[-1]}, - outputs={'Out': grad_vars[-1]}, + inputs={'X': grad_vars}, + outputs={'Out': grad_vars}, attrs={ 'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward, @@ -259,6 +259,7 @@ class RawProgramOptimizer(MetaOptimizerBase): block = self.main_program.global_block() ring_id = self.global_ring_id grad = None + grad_vars = [] for idx, op in reversed(list(enumerate(block.ops))): if is_backward_op(op) and \ OP_ROLE_VAR_KEY in op.attr_names: @@ -275,6 +276,7 @@ class RawProgramOptimizer(MetaOptimizerBase): if param.is_distributed: continue + grad_vars.append(grad) block._insert_op( idx + offset, type='c_sync_calc_stream', @@ -300,8 +302,8 @@ class RawProgramOptimizer(MetaOptimizerBase): block._insert_op( idx, type='c_sync_comm_stream', - inputs={'X': grad}, - outputs={'Out': grad}, + inputs={'X': grad_vars}, + outputs={'Out': grad_vars}, attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward}) break @@ -441,8 +443,8 @@ class RawProgramOptimizer(MetaOptimizerBase): block._insert_op_without_sync( idx, type='c_sync_comm_stream', - inputs={'X': grad_segment[0]}, - outputs={'Out': grad_segment[0]}, + inputs={'X': fused_vars}, + outputs={'Out': fused_vars}, attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward}) break diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py new file mode 100644 index 0000000000..55c90abf14 --- /dev/null +++ b/python/paddle/distributed/passes/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pass_base import new_pass, PassManager, PassContext +from .fuse_all_reduce import * +from .cpp_pass import * + +__all__ = [ + 'new_pass', + 'PassManager', + 'PassContext', +] diff --git a/python/paddle/distributed/passes/cpp_pass.py b/python/paddle/distributed/passes/cpp_pass.py new file mode 100644 index 0000000000..5dd50b9534 --- /dev/null +++ b/python/paddle/distributed/passes/cpp_pass.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pass_base import CPPPassWrapper, register_pass + + +@register_pass("fuse_elewise_add_act") +class FuseElementwiseAddActPass(CPPPassWrapper): + def __init__(self): + super(FuseElementwiseAddActPass, self).__init__() + + @property + def cpp_name(self): + return "fuse_elewise_add_act_pass" diff --git a/python/paddle/distributed/passes/fuse_all_reduce.py b/python/paddle/distributed/passes/fuse_all_reduce.py new file mode 100644 index 0000000000..101f0c3dc3 --- /dev/null +++ b/python/paddle/distributed/passes/fuse_all_reduce.py @@ -0,0 +1,360 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.framework import core +from paddle.fluid import unique_name +from .pass_base import CommOptPass, register_pass +from collections import OrderedDict +import numpy as np + + +def find_adjacent_match_sequences(iterable, + filter_func, + adjacent_filter_func=None): + n = len(iterable) + match_sequences = [] + if adjacent_filter_func is None: + adjacent_filter_func = lambda ref_op, new_op: True + i = 0 + while True: + while i < n and not filter_func(iterable[i]): + i += 1 + j = i + 1 + while j < n and filter_func(iterable[j]) and adjacent_filter_func( + iterable[i], iterable[j]): + j += 1 + if i < n and j <= n: + match_sequences.append((i, j)) + i = j + 1 + if i >= n: + break + return match_sequences + + +def insert_fuse_all_reduce_ops(block, reversed_op_indices, input_var_names, + output_var_names, dtype, attrs): + fused_var = block.create_var( + name=unique_name.generate("FusedOutput_{}".format(input_var_names[0])), + dtype=dtype) + + # FIXME(zengjinle): here we assume that we use + # c_sync_calc_stream/c_sync_comm_stream to do sync. + # But someone may use c_wait_compute/c_wait_comm instead. + if not attrs["use_calc_stream"]: + ring_id = attrs["ring_id"] + new_op_indices = list(reversed_op_indices) + + for i, op_idx in enumerate(reversed_op_indices): + prev_op_idx = op_idx - 1 + while prev_op_idx >= 0 and block.ops[ + prev_op_idx].type == "c_sync_calc_stream": + new_op_indices.append(prev_op_idx) + prev_op_idx -= 1 + + if i > 0: + next_op_idx = op_idx + 1 + n = len(block.ops) + while next_op_idx < n and block.ops[ + next_op_idx].type == "c_sync_comm_stream": + assert block.ops[next_op_idx].attr("ring_id") == ring_id + new_op_indices.append(next_op_idx) + + new_op_indices = list(set(new_op_indices)) + new_op_indices.sort(reverse=True) + reversed_op_indices = new_op_indices + + insert_idx = reversed_op_indices[0] + 1 + op_role_key = core.op_proto_and_checker_maker.kOpRoleAttrName() + + concated_shapes = [] + concated_ranks = [] + for var_name in output_var_names: + shape = block._find_var_recursive(var_name).shape + concated_shapes.extend(shape) + concated_ranks.append(len(shape)) + + coalesce_tensor_op_kwargs = { + "type": "coalesce_tensor", + "inputs": { + "Input": input_var_names, + }, + "outputs": { + "Output": output_var_names, + "FusedOutput": fused_var, + }, + "attrs": { + "use_align": True, + "dtype": dtype, + "concated_shapes": concated_shapes, + "concated_ranks": concated_ranks, + op_role_key: attrs[op_role_key], + }, + } + + if not attrs["use_calc_stream"]: + block._insert_op_without_sync( + insert_idx, + type="c_sync_calc_stream", + inputs={"X": fused_var}, + outputs={"Out": fused_var, + op_role_key: attrs[op_role_key]}) + insert_idx += 1 + + # c_allreduce_sum should insert + block._insert_op_without_sync( + insert_idx, + type="c_allreduce_sum", + inputs={"X": fused_var}, + outputs={"Out": fused_var}, + attrs=attrs) + + for op_idx in reversed_op_indices: + block._remove_op(op_idx) + + return coalesce_tensor_op_kwargs + + +def has_same_attrs(op1, op2, attr_names): + for attr_name in attr_names: + if op1.attr(attr_name) != op2.attr(attr_name): + return False + return True + + +def filter_all_collective_op_indices(block): + # NOTE: should add more collective ops + all_collective_ops = { + "c_allreduce_sum", + "c_allreduce_prod", + "c_allreduce_max", + "c_allreduce_min", + "c_allgather", + "c_broadcast", + } + + match_op_indices = [] + for i, op in enumerate(block.ops): + if op.type in all_collective_ops: + match_op_indices.append(i) + return match_op_indices + + +def find_all_fuse_all_reduce_groups(block): + collective_op_indices = filter_all_collective_op_indices(block) + collective_ops = [block.ops[i] for i in collective_op_indices] + + def is_valid_allreduce_op(op): + if op.type != "c_allreduce_sum" or op.attr("use_model_parallel"): + return False + in_var_name = op.input("X")[0] + out_var_name = op.output("Out")[0] + if in_var_name != out_var_name: + return False + in_var = block._find_var_recursive(in_var_name) + assert in_var is not None + if in_var.type != core.VarDesc.VarType.LOD_TENSOR: + return False + shape = in_var.shape + if any([s <= 0 for s in shape]): + return False + return True + + same_attr_names = [ + "ring_id", + "use_calc_stream", + core.op_proto_and_checker_maker.kOpRoleAttrName(), + core.op_proto_and_checker_maker.kOpDeviceAttrName(), + ] + + def is_same_adjacent_op(ref_op, new_op): + if not has_same_attrs(ref_op, new_op, same_attr_names): + return False + ref_op_in_var = block._find_var_recursive(ref_op.input("X")[0]) + new_op_in_var = block._find_var_recursive(new_op.input("X")[0]) + if ref_op_in_var.dtype != new_op_in_var.dtype: + return False + return True + + match_seqs = find_adjacent_match_sequences( + collective_ops, is_valid_allreduce_op, is_same_adjacent_op) + new_match_seqs = [] + for i, j in match_seqs: + new_match_seqs.append([collective_op_indices[k] for k in range(i, j)]) + return new_match_seqs + + +def split_fuse_all_reduce_groups_by_deps(block, groups, op_deps): + new_groups = [] + + def insert_new_group(op_indices, start_idx, end_idx): + if end_idx - start_idx > 1: + new_groups.append(op_indices[start_idx:end_idx]) + + for op_indices in groups: + n = len(op_indices) + assert n > 0 + if n == 1: + continue + + start_idx = 0 + k = start_idx + 1 + while k < n: + found_group = False + for prev_idx in range(start_idx, k): + dep = op_deps[op_indices[prev_idx]][op_indices[k]] + if dep == core.Node.Dep.NoDep: + continue + # [start_idx, k) is valid groups + insert_new_group(op_indices, start_idx, k) + start_idx = k + break + k += 1 + + insert_new_group(op_indices, start_idx, k) + + return new_groups + + +def insert_coalesce_tensor_ops(block, coalesce_ops_kwargs): + if not coalesce_ops_kwargs: + return + + var_infos = {} + for idx, op in enumerate(block.ops): + for var in op.input_arg_names: + if var not in var_infos: + var_infos[var] = [idx, True] + + for var in op.output_arg_names: + if var not in var_infos: + var_infos[var] = [idx, False] + + n = len(block.ops) + insert_idx_and_kwargs = [] + for group_idx, kwargs in enumerate(coalesce_ops_kwargs): + all_vars = kwargs["inputs"]["Input"] + kwargs["outputs"]["Output"] + min_op_idx = n + copy_data = False + for var in all_vars: + if var not in var_infos: + copy_data = True + min_idx = 0 + break + op_idx, is_input = var_infos[var] + if is_input: + copy_data = True + min_op_idx = min(min_op_idx, op_idx) + kwargs["attrs"]["copy_data"] = copy_data + insert_idx_and_kwargs.append((min_op_idx, kwargs)) + + insert_idx_and_kwargs.sort(key=lambda element: element[0], reverse=True) + for idx, kwargs in insert_idx_and_kwargs: + block._insert_op_without_sync(idx, **kwargs) + + +def insert_fuse_all_reduce_by_memory_size(block, groups, max_memory_size): + op_role_key = core.op_proto_and_checker_maker.kOpRoleAttrName() + op_role_var_key = core.op_proto_and_checker_maker.kOpRoleVarAttrName() + op_device_key = core.op_proto_and_checker_maker.kOpDeviceAttrName() + coalesce_ops_kwargs = [] + for group in reversed(groups): + first_op = block.ops[group[0]] + ring_id = first_op.attr("ring_id") + use_calc_stream = first_op.attr("use_calc_stream") + use_model_parallel = first_op.attr("use_model_parallel") + op_role = first_op.attr(op_role_key) + op_device = first_op.attr(op_device_key) + + attrs = { + "ring_id": ring_id, + "use_calc_stream": use_calc_stream, + "use_model_parallel": use_model_parallel, + op_role_key: op_role, + op_device_key: op_device, + } + dtype = block._find_var_recursive(first_op.input("X")[0]).dtype + sizeof = core.size_of_dtype(dtype) + + cur_mem_size = 0 + op_role_vars = [] + recorded_op_indices = [] + in_var_names = [] + out_var_names = [] + for op_idx in reversed(group): + op = block.ops[op_idx] + in_var_name = op.input("X")[0] + out_var_name = op.output("Out")[0] + in_var = block._find_var_recursive(in_var_name) + mem_size = int(np.prod(in_var.shape)) * sizeof + if cur_mem_size + mem_size > max_memory_size: + if len(recorded_op_indices) > 1: + attrs[op_role_var_key] = op_role_vars + coalesce_op_kwargs = insert_fuse_all_reduce_ops( + block, recorded_op_indices, in_var_names, out_var_names, + dtype, attrs) + coalesce_ops_kwargs.append(coalesce_op_kwargs) + + cur_mem_size = 0 + op_role_vars = [] + recorded_op_indices = [] + in_var_names = [] + out_var_names = [] + + cur_mem_size += mem_size + recorded_op_indices.append(op_idx) + in_var_names.append(in_var_name) + out_var_names.append(out_var_name) + if op.has_attr(op_role_var_key): + op_role_vars.extend(op.attr(op_role_var_key)) + + if len(recorded_op_indices) > 1: + attrs[op_role_var_key] = op_role_vars + coalesce_op_kwargs = insert_fuse_all_reduce_ops( + block, recorded_op_indices, in_var_names, out_var_names, dtype, + attrs) + coalesce_ops_kwargs.append(coalesce_op_kwargs) + block._sync_with_cpp() + insert_coalesce_tensor_ops(block, coalesce_ops_kwargs) + + +@register_pass("fuse_all_reduce") +class FuseAllReducePass(CommOptPass): + def __init__(self): + super(FuseAllReducePass, self).__init__() + self.set_attr("max_memory_size", -1) + + def _check_self(self): + max_memory_size = self.get_attr("max_memory_size") + return max_memory_size > 0 + + def _check_conflict(self, other_pass): + return True + + # NOTE: why FuseAllReducePass can override apply_single_impl instead of + # apply_impl? AllReduce is a collective operation, so the program of each + # rank inside the same communication group should have the same + # c_allreduce_sum operations. Therefore, FuseAllReducePass can override + # apply_single_impl directly. + def _apply_single_impl(self, main_program, startup_program, context): + max_memory_size = self.get_attr("max_memory_size") + op_deps = main_program.desc.get_op_deps() + num_blocks = main_program.num_blocks + for i in range(num_blocks): + block = main_program.block(i) + groups = find_all_fuse_all_reduce_groups(block) + groups = split_fuse_all_reduce_groups_by_deps(block, groups, + op_deps[i]) + insert_fuse_all_reduce_by_memory_size(block, groups, + max_memory_size) + main_program._sync_with_cpp() diff --git a/python/paddle/distributed/passes/pass_base.py b/python/paddle/distributed/passes/pass_base.py new file mode 100644 index 0000000000..4d4585ca6e --- /dev/null +++ b/python/paddle/distributed/passes/pass_base.py @@ -0,0 +1,273 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import sys +from abc import ABC, abstractmethod +from paddle.fluid.framework import program_guard, _apply_pass as _apply_cpp_pass + + +class PassContext: + def __init__(self): + self._applied_passes = [] + self._attrs = [] + + def set_attr(self, key, value): + self._attrs[key] = value + + def get_attr(self, key, default=None): + return self._attrs.get(key, default) + + @property + def passes(self): + return self._applied_passes + + def _add_pass(self, pass_obj): + self._applied_passes.append(pass_obj) + + def _pop_pass(self): + del self._applied_passes[-1] + + +class PassBase(ABC): + _REGISTERED_PASSES = {} + _COMMON_RULES = [] + + @staticmethod + def _register(pass_name, pass_class): + assert issubclass(pass_class, PassBase) + PassBase._REGISTERED_PASSES[pass_name] = pass_class + + def __init__(self): + self._attrs = {} + + def set_attr(self, key, value): + self._attrs[key] = value + return self + + def get_attr(self, key, default=None): + return self._attrs.get(key, default) + + @abstractmethod + def _check_self(self): + pass + + @abstractmethod + def _check_conflict(self, other_pass): + pass + + def _check_conflict_including_common_rules(self, other_pass): + return self._check_conflict(other_pass) and all( + [r(other_pass, self) for r in PassBase._COMMON_RULES]) + + def apply(self, main_programs, startup_programs, context=None): + if context is None: + context = PassContext() + + if not self._check_self(): + return context + + if not all([ + self._check_conflict_including_common_rules(p) + for p in context.passes + ]): + return context + + assert isinstance(main_programs, list) + assert isinstance(startup_programs, list) + assert len(main_programs) == len(startup_programs) + self._apply_impl(main_programs, startup_programs, context) + context._add_pass(self) + return context + + def _apply_impl(self, main_programs, startup_programs, context): + for main_program, startup_program in zip(main_programs, + startup_programs): + self._apply_single_impl(main_program, startup_program, context) + + @abstractmethod + def _apply_single_impl(self, main_program, startup_program, context): + pass + + +def register_pass(name): + def impl(cls): + PassBase._register(name, cls) + cls.name = name + return cls + + return impl + + +def new_pass(name, pass_attrs={}): + pass_class = PassBase._REGISTERED_PASSES.get(name) + assert pass_class is not None, "Pass {} is not registered".format(name) + pass_obj = pass_class() + for k, v in pass_attrs.items(): + pass_obj.set_attr(k, v) + return pass_obj + + +class CPPPassWrapper(PassBase): + def __init__(self): + super(CPPPassWrapper, self).__init__() + + @property + def cpp_name(self): + raise NotImplementedError() + + @property + def cpp_attr_types(self): + return {} + + def _check_self(self): + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + _apply_cpp_pass(main_program, startup_program, self.cpp_name, + self._attrs, self.cpp_attr_types) + + +# Like AutoParallel/HybridParallel, etc. +class ParallelOptPass(PassBase): + def __init__(self): + super(ParallelOptPass, self).__init__() + + +# Like AMP, Recompute, etc. +class CalcOptPass(PassBase): + def __init__(self): + super(CalcOptPass, self).__init__() + + +# Like FuseAllReduce, FuseGradientMerge, etc. +class CommOptPass(PassBase): + def __init__(self): + super(CommOptPass, self).__init__() + + +def _make_pass_order_rule(pass_class_before, pass_class_after): + def impl(pass_obj_before, pass_obj_after): + if isinstance(pass_obj_before, pass_class_after) \ + and isinstance(pass_obj_after, pass_class_before): + return False + return True + + return impl + + +PassBase._COMMON_RULES = [ + _make_pass_order_rule(CalcOptPass, CommOptPass), + _make_pass_order_rule(ParallelOptPass, CPPPassWrapper), + _make_pass_order_rule(CalcOptPass, CPPPassWrapper), + _make_pass_order_rule(CommOptPass, CPPPassWrapper), + lambda pass_before, pass_after: type(pass_before) != type(pass_after), +] + + +def _find_longest_path(edges): + n = len(edges) + paths = [None] * n + dists = [None] * n + + min_path = [] + min_dist = 0 + for i in range(n): + paths[i] = [None] * n + dists[i] = [None] * n + for j in range(n): + assert isinstance(edges[i][j], bool) + if not edges[i][j]: + dists[i][j] = n # inf + paths[i][j] = [] + else: + assert edges[i][j] is True + dists[i][j] = -1 + paths[i][j] = [i, j] + if dists[i][j] < min_dist: + min_dist = -1 + min_path = paths[i][j] + + for k in range(n): + for i in range(n): + for j in range(n): + if dists[i][j] > dists[i][k] + dists[k][j]: + dists[i][j] = dists[i][k] + dists[k][j] + paths[i][j] = paths[i][k] + paths[k][j] + if dists[i][j] < min_dist: + min_dist = dists[i][j] + min_path = paths[i][j] + + return min_path if min_path else [0] + + +def _solve_pass_conflict(passes, context): + passes = [p for p in passes if p._check_self()] + if not passes: + return [] + + old_passes = passes + passes = [] + for p in old_passes: + if all([ + p._check_conflict_including_common_rules(applied_p) + for applied_p in context.passes + ]): + passes.append(p) + + if not passes: + return [] + + n = len(passes) + adjacent_matrix = [] + for _ in range(n): + adjacent_matrix.append([None] * n) + + for i in range(n): + for j in range(n): + adjacent_matrix[i][j] = passes[ + j]._check_conflict_including_common_rules(passes[i]) + + longest_path = _find_longest_path(adjacent_matrix) + return [passes[idx] for idx in longest_path] + + +class PassManager: + def __init__(self, passes, context=None, auto_solve_conflict=True): + if context is None: + context = PassContext() + self._context = context + + if auto_solve_conflict: + self._passes = _solve_pass_conflict(passes, context) + else: + self._passes = list(passes) + + def apply(self, main_programs, startup_programs): + context = self._context + for p in self._passes: + context = p.apply(main_programs, startup_programs, context) + self._context = context + return context + + @property + def context(self): + return self._context + + @property + def names(self): + return [p.name for p in self._passes] diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py new file mode 100644 index 0000000000..bd1eddce3b --- /dev/null +++ b/python/paddle/distributed/passes/pass_utils.py @@ -0,0 +1,134 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + + +def list_to_ordered_dict(list_obj, ordered_dict=None): + if ordered_dict is None: + ordered_dict = OrderedDict() + else: + assert isinstance(ordered_dict, OrderedDict) + for obj in list_obj: + if obj not in ordered_dict: + ordered_dict[obj] = True + return ordered_dict + + +# The inputs of a program are the variables +# that first occur as the input of the op. +def get_inputs_of_program(program): + visited_vars = set() + input_vars = [] + for op in program.global_block().ops: + for in_var_name in op.input_arg_names: + if in_var_name not in visited_vars: + input_vars.append(in_var_name) + visited_vars.add(in_var_name) + + for out_var_name in op.output_arg_names: + visited_vars.add(out_var_name) + return input_vars + + +def get_outputs_of_program(program): + output_vars = OrderedDict() + for op in program.global_block().ops: + list_to_ordered_dict(op.output_arg_names, output_vars) + return list(output_vars.keys()) + + +def prune_program(program, start_op_idx, end_op_idx): + op_num = len(program.global_block().ops) + if start_op_idx < 0: + start_op_idx += op_num + assert start_op_idx >= 0 and start_op_idx < op_num + if end_op_idx < 0: + end_op_idx += op_num + assert end_op_idx >= 0 and end_op_idx <= op_num, end_op_idx + assert start_op_idx < end_op_idx + + program = program.clone() + for idx in range(op_num - 1, end_op_idx - 1, -1): + program.global_block()._remove_op(idx, sync=False) + for idx in range(start_op_idx - 1, -1, -1): + program.global_block()._remove_op(idx, sync=False) + program._sync_with_cpp() + + valid_vars = set() + for op in program.global_block().ops: + for in_var_name in op.input_arg_names: + valid_vars.add(in_var_name) + for out_var_name in op.output_arg_names: + valid_vars.add(out_var_name) + + vars_to_remove = [] + for var in program.global_block().vars: + if var not in valid_vars: + vars_to_remove.append(var) + + for var in vars_to_remove: + program.global_block()._remove_var(var, sync=False) + program._sync_with_cpp() + return program + + +def split_program(program, op_indices): + """ + Split the program by op_indices. + + For examples, a program has 100 ops, and op_indices = [25, 60]. + Then the program is splitted into 3 parts, containing 25, 35 and 40 + ops respectively. + + The return values are a tuple with 3 elements: the splitted program + list, the input var names of each splitted program, and the output + var names of each splitted program. + """ + assert op_indices, "op_indices cannot be empty" + op_num = len(program.global_block().ops) + assert op_num > 0, "program cannot be empty" + + op_indices = [idx if idx >= 0 else idx + op_num for idx in op_indices] + + if op_indices[0] != 0: + op_indices = [0] + op_indices + if op_indices[-1] != op_num: + op_indices.append(op_num) + + for idx in range(len(op_indices) - 1): + assert op_indices[idx] < op_indices[ + idx + 1], "op_indices must be strictly sorted" + + splitted_programs = [] + for idx in range(len(op_indices) - 1): + new_split = prune_program(program, op_indices[idx], op_indices[idx + 1]) + splitted_programs.append(new_split) + + num_split = len(splitted_programs) + input_vars = [get_inputs_of_program(p) for p in splitted_programs] + output_vars = [ + list_to_ordered_dict(get_outputs_of_program(p)) + for p in splitted_programs + ] + valid_output_vars = [OrderedDict() for _ in range(num_split)] + valid_output_vars[-1] = output_vars[-1] + for i in range(1, num_split): + for in_var_name in input_vars[i]: + for j in reversed(range(i)): + if in_var_name in output_vars[j]: + valid_output_vars[j][in_var_name] = True + break + valid_output_vars = [list(item.keys()) for item in valid_output_vars] + return splitted_programs, input_vars, valid_output_vars diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index deabdc6c14..4ec4299513 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -581,6 +581,8 @@ set_tests_properties(test_conv_nn_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_norm_nn_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_nn_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") if(WITH_DISTRIBUTE) + add_subdirectory(distributed_passes) + # FIXME(typhoonzero): add these tests back list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transpiler") @@ -1023,6 +1025,7 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120) set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_split_program PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt b/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt new file mode 100644 index 0000000000..a286953153 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt @@ -0,0 +1,8 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) + list(APPEND DIST_TEST_OPS ${TEST_OP}) + set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 120) +endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py new file mode 100644 index 0000000000..a5b1cdff0f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py @@ -0,0 +1,218 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os +import random +import sys +import pickle +import shlex +import shutil +import inspect +import numpy as np +from collections import OrderedDict +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +def prepare_python_path_and_return_module(path): + dirname, filename = os.path.split(path) + py_suffix = ".py" + assert filename.endswith(py_suffix), filename + + env_name = 'PYTHONPATH' + python_path = env_name + if python_path: + paths = [p for p in python_path.split(":") if p] + if dirname not in paths: + paths.append(dirname) + python_path = ":".join(paths) + else: + python_path = path + os.environ[env_name] = python_path + return filename[:-len(py_suffix)] + + +def remove_path_if_exists(path): + if not os.path.exists(path): + return + + if os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + + +# NOTE: only support GPU now +class DistPassTestBase(unittest.TestCase): + def setUp(self): + paddle.enable_static() + seed = int(os.environ.get('SEED', -1)) + if seed <= 0: + seed = np.random.randint(low=1, high=1000000, size=[1])[0] + os.environ['SEED'] = str(seed) + self.seed = seed + paddle.seed(self.seed) + + self.rtol = 1e-5 + self.atol = 1e-8 + self.equal_nan = False + + self.init() + + def init(self): + pass + + def get_model(self, place, **kwargs): + raise NotImplementedError() + + def apply_passes(self, main_prog, startup_prog): + raise NotImplementedError() + + def check_main(self, gpus=None, **kwargs): + no_pass_rets = self._distributed_launch( + apply_pass=False, gpus=gpus, **kwargs) + pass_rets = self._distributed_launch( + apply_pass=True, gpus=gpus, **kwargs) + self.check_results(no_pass_rets, pass_rets) + + def check_results(self, no_pass_rets, pass_rets): + self.assertEqual(len(no_pass_rets), len(pass_rets)) + for no_pass_ret, pass_ret in zip(no_pass_rets, pass_rets): + self.assertEqual(len(no_pass_ret), len(pass_ret)) + for i, (out_var_no_pass, + out_var_pass) in enumerate(zip(no_pass_ret, pass_ret)): + if out_var_no_pass is None: + self.assertTrue(out_var_pass is None) + else: + self.assertTrue( + np.allclose( + out_var_no_pass, + out_var_pass, + rtol=self.rtol, + atol=self.atol, + equal_nan=self.equal_nan)) + + @classmethod + def _to_var_names(cls, program, names_or_vars): + if not isinstance(names_or_vars, (list, tuple)): + names_or_vars = [names_or_vars] + ret_var_names = [] + for name_or_var in names_or_vars: + if isinstance(name_or_var, str): + ret_var_names.append(name_or_var) + else: + ret_var_names.append(name_or_var.name) + return ret_var_names + + def _run_gpu_main(self, apply_pass, dump_file, **kwargs): + gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) + place = paddle.CUDAPlace(gpu_id) + scope = paddle.static.Scope() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + with paddle.static.scope_guard(scope): + with paddle.fluid.unique_name.guard(): + main_prog, startup_prog, inputs, outputs, reader = self.get_model( + place, **kwargs) + inputs = self._to_var_names(main_prog, inputs) + outputs = self._to_var_names(main_prog, outputs) + if apply_pass: + self.apply_passes(main_prog, startup_prog) + + all_fetch_values = [] + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(scope): + exe.run(startup_prog) + for batch_id, input_data in enumerate(reader()): + assert len(input_data) == len(inputs), "{} vs {}".format( + len(input_data), len(inputs)) + feed = dict(zip(inputs, input_data)) + fetch_values = exe.run(main_prog, feed=feed, fetch_list=outputs) + if paddle.distributed.get_rank() == 0: + output_dict = OrderedDict(zip(outputs, fetch_values)) + print('batch {}, outputs {}'.format(batch_id, output_dict)) + all_fetch_values.append(fetch_values) + with open(dump_file, "wb") as f: + pickle.dump(all_fetch_values, f) + + def _distributed_launch(self, apply_pass, gpus=None, **kwargs): + if gpus is None: + num_gpus = paddle.device.cuda.device_count() + gpus = list(range(num_gpus)) + else: + num_gpus = len(gpus) + + gpus = ','.join([str(gpu_id) for gpu_id in gpus]) + + pid = os.getpid() + if apply_pass: + output_dir = "test_with_pass_{}".format(pid) + else: + output_dir = "test_without_pass_{}".format(pid) + remove_path_if_exists(output_dir) + os.makedirs(output_dir, mode=777) + + input_dump_file = os.path.join(output_dir, 'inputs') + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + run_with_coverage(True) + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + file_dir = os.path.dirname(os.path.abspath(__file__)) + + try: + with open(input_dump_file, 'wb') as f: + pickle.dump(kwargs, f) + + cmd = [ + sys.executable, + "-u", + ] + coverage_args + [ + "-m", + "launch", + "--log_dir", + output_dir, + "--gpus", + gpus, + os.path.join(file_dir, "pass_run_main.py"), + "--file_path", + inspect.getfile(type(self)), + "--class_name", + type(self).__name__, + "--input_file", + input_dump_file, + "--output_dir", + output_dir, + ] + (["--apply_pass"] if apply_pass else []) + cmd = [shlex.quote(c) for c in cmd] + prepare_python_path_and_return_module(__file__) + exitcode = os.system(' '.join(cmd)) + self.assertEqual( + exitcode, 0, + "Pass failed with apply_pass = {}".format(apply_pass)) + + results = [] + for i in range(num_gpus): + dump_file = '{0}/{1}.bin'.format(output_dir, i) + self.assertTrue( + os.path.exists(dump_file), + "Pass failed with apply_pass = {}".format(apply_pass)) + with open(dump_file, "rb") as f: + results.append(pickle.load(f)) + return results + finally: + if int(os.environ.get("DEBUG", 0)) == 0: + remove_path_if_exists(output_dir) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/launch.py b/python/paddle/fluid/tests/unittests/distributed_passes/launch.py new file mode 100644 index 0000000000..c225fe85cd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/launch.py @@ -0,0 +1,22 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from paddle.distributed.fleet import launch +from paddle.distributed.fleet.launch_utils import run_with_coverage + +if __name__ == "__main__": + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + run_with_coverage(True) + launch.launch() diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/pass_run_main.py b/python/paddle/fluid/tests/unittests/distributed_passes/pass_run_main.py new file mode 100644 index 0000000000..6ff24ec176 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/pass_run_main.py @@ -0,0 +1,75 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import paddle +import pickle +import importlib +import os +import sys +from paddle.distributed.fleet.launch_utils import run_with_coverage +from dist_pass_test_base import prepare_python_path_and_return_module, DistPassTestBase + + +def parse_args(): + parser = argparse.ArgumentParser( + description='arguments for distributed pass tests') + parser.add_argument('--file_path', type=str, help='The test file path.') + parser.add_argument( + '--class_name', + type=str, + help='The test class name. It is the class name that inherits the DistPassTestBase class.' + ) + parser.add_argument( + '--apply_pass', + default=False, + action="store_true", + help='Whether to apply distributed passes.') + parser.add_argument( + '--input_file', + type=str, + help='The input file which contains the dumped input arguments.') + parser.add_argument( + '--output_dir', + type=str, + help='The output directory to save the logs and output results.') + return parser.parse_args() + + +def run_main(args): + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + run_with_coverage(True) + module_name = prepare_python_path_and_return_module(args.file_path) + test_module = importlib.import_module(module_name) + test_class = getattr(test_module, args.class_name) + assert issubclass(test_class, DistPassTestBase) + test_obj = test_class() + rank = paddle.distributed.get_rank() + with open(args.input_file, "rb") as f: + kwargs = pickle.load(f) + + output_file = "{}/{}.bin".format(args.output_dir, rank) + + try: + test_obj.setUpClass() + test_obj.setUp() + test_obj._run_gpu_main(args.apply_pass, output_file, **kwargs) + finally: + test_obj.tearDown() + test_obj.tearDownClass() + + +if __name__ == "__main__": + args = parse_args() + run_main(args) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_fuse_all_reduce_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_fuse_all_reduce_pass.py new file mode 100644 index 0000000000..1a55bd5ecc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_fuse_all_reduce_pass.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.distributed.passes import new_pass, PassManager +import paddle.distributed.fleet as fleet +from paddle.vision.models import resnet50 as resnet +import unittest +from dist_pass_test_base import DistPassTestBase +import paddle.nn as nn +import numpy as np + + +class TestFuseAllReducePass(DistPassTestBase): + def init(self): + if paddle.is_compiled_with_cuda(): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + self.atol = 0.0 + self.rtol = 0.0 + + def apply_passes(self, main_prog, startup_prog): + pass_manager = PassManager([ + new_pass("fuse_elewise_add_act"), + new_pass("fuse_all_reduce", {"max_memory_size": 1024 * 1024}) + ]) + pass_manager.apply([main_prog], [startup_prog]) + + def test_bs_32(self): + self.check_main(batch_size=32) + + def get_model(self, place, batch_size): + image = paddle.static.data( + shape=[batch_size, 3, 224, 224], dtype='float32', name='image') + label = paddle.static.data( + shape=[batch_size, 1], dtype='int64', name='label') + model = resnet(pretrained=False) + loss_fn = nn.loss.CrossEntropyLoss() + pred_out = model(image) + loss = loss_fn(pred_out, label) + optimizer = paddle.optimizer.Adam(learning_rate=1e-3) + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.fuse_all_reduce_ops = False + dist_strategy.without_graph_optimization = True + fleet.init(is_collective=True, strategy=dist_strategy) + optimizer = fleet.distributed_optimizer(optimizer) + optimizer.minimize(loss) + + rank = paddle.distributed.get_rank() + + def reader(): + np.random.seed(self.seed + rank) + for _ in range(10): + image_np = np.random.random(size=image.shape).astype('float32') + label_np = np.random.randint( + low=0, high=1000, size=label.shape).astype('int64') + yield image_np, label_np + + main_program = paddle.static.default_main_program() + startup_program = paddle.static.default_startup_program() + return main_program, startup_program, [image, label], [loss], reader + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_split_program.py b/python/paddle/fluid/tests/unittests/test_split_program.py new file mode 100644 index 0000000000..3245e8d997 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_split_program.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.passes.pass_utils import split_program +from paddle.vision.models import resnet18 as resnet +import paddle +import paddle.nn as nn +import unittest +import json +import numpy as np + + +class TestSplitProgram(unittest.TestCase): + def setUp(self): + paddle.enable_static() + if paddle.is_compiled_with_cuda(): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + + def get_model(self, batch_size): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + image = paddle.static.data( + shape=[batch_size, 3, 224, 224], dtype='float32', name='image') + label = paddle.static.data( + shape=[batch_size, 1], dtype='int64', name='label') + + model = resnet(pretrained=False) + loss_fn = nn.loss.CrossEntropyLoss() + + pred_out = model(image) + loss = loss_fn(pred_out, label) + + optimizer = paddle.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + return main, startup, image, label + + def find_startup_vars(self, main_prog, startup_prog): + self.assertEqual(startup_prog.num_blocks, 1) + startup_vars = [] + for op in startup_prog.global_block().ops: + for var_name in op.output_arg_names: + var = main_prog.global_block().var(var_name) + if var.persistable: + startup_vars.append(var_name) + return startup_vars + + def test_split_program(self): + for p in self.get_places(): + vars_expected = self.check_split_program(p, use_split=False) + vars_actual = self.check_split_program(p, use_split=True) + self.assertEqual(len(vars_actual), len(vars_expected)) + for actual, expected in zip(vars_actual, vars_expected): + self.assertEqual(actual.shape, expected.shape) + self.assertTrue( + np.array_equal(actual, expected), + '{}\n{}\n'.format(actual, expected)) + + def get_places(self): + places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + return places + + def get_var_values(self, scope, var_names): + values = [] + for var_name in var_names: + values.append(np.array(scope.find_var(var_name).get_tensor())) + return values + + def check_split_program(self, place, use_split=True, seed=100, batch_num=5): + batch_size = 2 + + np.random.seed(seed) + paddle.seed(seed) + + main_prog, startup_prog, image, label = self.get_model(batch_size) + startup_vars = self.find_startup_vars(main_prog, startup_prog) + exe = paddle.static.Executor(place) + + image_np = np.random.random(size=image.shape).astype('float32') + label_np = np.random.randint( + low=0, high=1000, dtype='int64', size=label.shape) + + scope = paddle.static.Scope() + if not use_split: + with paddle.static.scope_guard(scope): + exe.run(startup_prog) + for _ in range(batch_num): + exe.run(main_prog, + feed={image.name: image_np, + label.name: label_np}) + return self.get_var_values(scope, startup_vars) + + op_num = len(main_prog.global_block().ops) + split_op_indices = [int(op_num / 3.0), int(op_num * 3 / 4.0)] + programs, input_vars, output_vars = split_program(main_prog, + split_op_indices) + op_nums = [0] + split_op_indices + [op_num] + op_nums = [op_nums[i + 1] - op_nums[i] for i in range(len(op_nums) - 1)] + num_split = len(split_op_indices) + 1 + self.assertEqual(len(programs), num_split) + self.assertEqual(len(input_vars), num_split) + self.assertEqual(len(output_vars), num_split) + self.assertEqual(len(programs), len(op_nums)) + for p, n in zip(programs, op_nums): + self.assertEqual(len(p.global_block().ops), n) + + with paddle.static.scope_guard(scope): + exe.run(startup_prog) + for _ in range(batch_num): + tmp_vars = {image.name: image_np, label.name: label_np} + for i, program in enumerate(programs): + feed_dict = {} + for in_name in input_vars[i]: + if in_name in startup_vars: + continue + self.assertTrue(in_name in tmp_vars) + if tmp_vars[in_name] is not None: + feed_dict[in_name] = tmp_vars[in_name] + + output_var_values = exe.run(program, + feed=feed_dict, + fetch_list=output_vars[i], + return_numpy=False) + for out_name, out_value in zip(output_vars[i], + output_var_values): + if not out_value._is_initialized(): + tmp_vars[out_name] = np.ndarray(out_value._get_dims( + )).astype('float32') + else: + tmp_vars[out_name] = np.array(out_value) + + return self.get_var_values(scope, startup_vars) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 6a252a5723..b1643e6a6d 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -293,6 +293,7 @@ packages=['paddle', 'paddle.distributed.fleet.meta_parallel.parallel_layers', 'paddle.distributed.auto_parallel', 'paddle.distributed.auto_parallel.operators', + 'paddle.distributed.passes', 'paddle.framework', 'paddle.jit', 'paddle.jit.dy2static', -- GitLab