From 5f369881244fbd424cd75ddba6bdbbe86cbd3f07 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 8 Sep 2021 11:58:41 +0800 Subject: [PATCH] Enable program passes on Fleet APIs (#34955) * add fleet api for program pass * turn on apply pass for CI test * fix disable fuse_all_optimizer bug * try to test ci * fix CI * fill unspecified op role * fix fuse_allreduce * add ut to improve coverage * remove useless change * improve c++ coverage * follow some comments * test ir pass pipeline * update doc * reduce ut time again --- paddle/fluid/framework/block_desc.cc | 22 +++- .../memory_optimize_pass/memory_reuse_pass.h | 2 + paddle/fluid/framework/ir/pass.cc | 110 ++++++++++++++---- paddle/fluid/framework/ir/pass.h | 12 +- paddle/fluid/framework/program_desc_test.cc | 34 ++++++ paddle/fluid/platform/flags.cc | 13 +++ .../pybind/global_value_getter_setter.cc | 4 +- paddle/fluid/pybind/ir.cc | 64 +++++++--- .../distributed/fleet/base/fleet_base.py | 32 ++++- .../graph_execution_optimizer.py | 2 +- .../meta_optimizers/raw_program_optimizer.py | 100 ++++++++++++++++ python/paddle/fluid/__init__.py | 1 + python/paddle/fluid/ir.py | 40 ++++++- python/paddle/fluid/optimizer.py | 5 + .../fluid/tests/unittests/CMakeLists.txt | 2 + ...dist_mnist_gradient_merge_raw_optimizer.py | 79 +++++++++++++ .../test_dist_mnist_gradient_merge.py | 20 +++- .../tests/unittests/test_ir_pass_pipeline.py | 25 ++++ .../fluid/tests/unittests/test_pipeline.py | 16 ++- 19 files changed, 523 insertions(+), 60 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dist_mnist_gradient_merge_raw_optimizer.py create mode 100644 python/paddle/fluid/tests/unittests/test_ir_pass_pipeline.py diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index c225d4090a..71d439999f 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -263,7 +263,27 @@ void BlockDesc::MoveFrom(BlockDesc *block) { } ops_.clear(); for (const auto &src_op : block->ops_) { - AppendOp()->CopyFrom(*src_op); + auto *dst_op = AppendOp(); + dst_op->CopyFrom(*src_op); + for (const auto &pair : src_op->GetAttrMap()) { + const auto &attr_name = pair.first; + const auto &attr_value = pair.second; + auto attr_type = static_cast(attr_value.which() - 1); + if (attr_type == proto::AttrType::BLOCK) { + auto block_id = BOOST_GET_CONST(BlockDesc *, attr_value)->ID(); + dst_op->SetBlockAttr(attr_name, prog_->MutableBlock(block_id)); + VLOG(10) << "Set block attr " << attr_name << " id " << block_id; + } else if (attr_type == proto::AttrType::BLOCKS) { + auto old_blocks = BOOST_GET_CONST(std::vector, attr_value); + std::vector new_blocks; + new_blocks.reserve(old_blocks.size()); + for (auto *b : old_blocks) { + VLOG(10) << "Set block attr " << attr_name << " id " << b->ID(); + new_blocks.push_back(prog_->MutableBlock(b->ID())); + } + dst_op->SetBlocksAttr(attr_name, new_blocks); + } + } } need_update_ = true; Flush(); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h b/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h index d908a37a2a..eb6a43e6a6 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h +++ b/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h @@ -113,6 +113,8 @@ class MemoryReusePass : public Pass { details::VarHandle *in_var, details::VarHandle *out_var) const; + bool SupportApplyProgramViaGraph() const override { return false; } + private: VarDesc *GetVarDesc(const details::VarHandle &var) const; diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 350f00ae2a..1199d251d2 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_proto_maker.h" namespace paddle { namespace framework { @@ -72,19 +73,6 @@ Graph *Pass::Apply(Graph *graph) const { return graph; } -void Pass::Apply(ProgramDesc *main_program, - ProgramDesc *startup_program) const { - VLOG(10) << "apply pass " << Type() << " to program"; - PADDLE_ENFORCE_NOT_NULL(main_program, platform::errors::InvalidArgument( - "main program must be provided")); - PADDLE_ENFORCE_NOT_NULL( - startup_program, - platform::errors::InvalidArgument("startup program must be provided")); - - ApplyImpl(main_program, startup_program); - VLOG(10) << "finish to apply pass " << Type() << " to program"; -} - template static void VisitAllElements(Container &&container, Visitor &&visitor, bool reverse) { @@ -95,8 +83,8 @@ static void VisitAllElements(Container &&container, Visitor &&visitor, } } -void Pass::MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs, - bool append) { +static void MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs, + bool append) { PADDLE_ENFORCE_NOT_NULL( dst, platform::errors::InvalidArgument("Dst program must be provided.")); bool reverse = !append; @@ -137,27 +125,105 @@ void Pass::MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs, VisitAllElements(srcs, create_op_visitor, reverse); } +static void FillNotSpecifiedOpRole(const ProgramDesc &main_program) { + for (size_t block_idx = 0; block_idx < main_program.Size(); ++block_idx) { + auto ops = main_program.Block(block_idx).AllOps(); + size_t n = ops.size(); + std::vector roles; + roles.reserve(n); + auto op_role_attr = OpProtoAndCheckerMaker::OpRoleAttrName(); + for (auto *op : ops) { + OpRole role; + if (op->HasAttr(op_role_attr)) { + role = static_cast(op->GetAttrIfExists(op_role_attr)); + } else { + role = OpRole::kNotSpecified; + } + roles.emplace_back(role); + } + + // NOTE: The following codes may be wrong in some cases. + // But how can we get the right OpRole? The right way + // is that all passes should deal with unspecified OpRole. + auto prev_role = OpRole::kForward; + for (size_t i = 0; i < n; ++i) { + if (roles[i] == OpRole::kNotSpecified) { + VLOG(10) << "Fill op role of " << ops[i]->Type() << " as " + << static_cast(prev_role); + ops[i]->SetAttr(op_role_attr, static_cast(prev_role)); + } else { + prev_role = roles[i]; + } + } + } +} + +void Pass::ApplyPassesToProgram(const std::vector &passes, + ProgramDesc *main_program, + ProgramDesc *startup_program) { + VLOG(10) << "ApplyPassesToProgram is called"; + PADDLE_ENFORCE_NOT_NULL( + main_program, + platform::errors::InvalidArgument("The main program must be provided.")); + + PADDLE_ENFORCE_NOT_NULL(startup_program, + platform::errors::InvalidArgument( + "The startup program must be provided.")); + + for (auto *p : passes) { + PADDLE_ENFORCE_NOT_NULL(p, platform::errors::InvalidArgument( + "The provided pass cannot be nullptr.")); + VLOG(10) << "Pass " << p->Type(); + if (passes.size() > 1) { + PADDLE_ENFORCE_EQ(p->SupportApplyProgramViaGraph(), true, + platform::errors::PermissionDenied( + "Each pass must support to be applied via Graph if " + "multi-passes are applied.")); + } + } + + if (passes.size() == 1 && !passes[0]->SupportApplyProgramViaGraph()) { + VLOG(10) << "apply pass " << passes[0]->Type() << " to program"; + passes[0]->ApplyImpl(main_program, startup_program); + FillNotSpecifiedOpRole(*main_program); + VLOG(10) << "finish to apply pass " << passes[0]->Type() << " to program"; + return; + } + + Graph graph(*main_program); + for (auto *p : passes) { + p->Apply(&graph); + } + ConvertToPrograms(&graph, main_program, startup_program); + FillNotSpecifiedOpRole(*main_program); +} + void Pass::ApplyImpl(ProgramDesc *main_program, ProgramDesc *startup_program) const { - Graph graph(*main_program); - Apply(&graph); + PADDLE_THROW(platform::errors::Unimplemented( + "The pass %s does not support to apply ProgramDesc directly", Type())); +} +void Pass::ConvertToPrograms(Graph *graph, ProgramDesc *main_program, + ProgramDesc *startup_program) { ProgramDesc new_main_program; - GraphToProgram(graph, &new_main_program); + GraphToProgram(*graph, &new_main_program); main_program->CopyFrom(*new_main_program.Proto()); - if (graph.Has(details::kStartupProgramDescs)) { + if (graph->Has(details::kStartupProgramDescs)) { const auto &startups = - graph.Get(details::kStartupProgramDescs); + graph->Get(details::kStartupProgramDescs); VLOG(10) << "Merge startup programs"; MergePrograms(startup_program, startups, /*append=*/true); + graph->Erase(details::kStartupProgramDescs); } - if (graph.Has(details::kProgramDescs)) { + if (graph->Has(details::kProgramDescs)) { const auto &mains = - graph.Get(details::kProgramDescs); + graph->Get(details::kProgramDescs); VLOG(10) << "Merge main programs"; MergePrograms(main_program, mains, /*append=*/false); + graph->Erase(details::kProgramDescs); } startup_program->Flush(); diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 1d1ebcb17e..016d0fd4a6 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -65,8 +65,6 @@ class Pass { Graph *Apply(Graph *graph) const; - void Apply(ProgramDesc *main_program, ProgramDesc *startup_program) const; - // Get a reference to the attributed previously set. template AttrType &Get(const std::string &attr_name) const { @@ -142,6 +140,12 @@ class Pass { attrs_[attr_name] = attr; } + static void ApplyPassesToProgram(const std::vector &passes, + ProgramDesc *main_program, + ProgramDesc *startup_program); + + virtual bool SupportApplyProgramViaGraph() const { return true; } + protected: virtual void ApplyImpl(Graph *graph) const { PADDLE_THROW(platform::errors::Unimplemented( @@ -151,8 +155,8 @@ class Pass { virtual void ApplyImpl(ProgramDesc *main_program, ProgramDesc *startup_program) const; - static void MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs, - bool append); + static void ConvertToPrograms(ir::Graph *graph, ProgramDesc *main_program, + ProgramDesc *startup_program); // Some Pass must be placed before this Pass, and some // Pass must be placed after this Pass. diff --git a/paddle/fluid/framework/program_desc_test.cc b/paddle/fluid/framework/program_desc_test.cc index 7d5d61c4c5..e57b883d91 100644 --- a/paddle/fluid/framework/program_desc_test.cc +++ b/paddle/fluid/framework/program_desc_test.cc @@ -23,6 +23,40 @@ namespace paddle { namespace framework { class VarDesc; +TEST(ProgramDesc, block_desc_move) { + auto program = std::make_unique(); + auto* global_block = program->MutableBlock(0); + + auto* op = global_block->AppendOp(); + op->SetType("op_with_subblock"); + op->SetAttr("sub_block", program->AppendBlock(*global_block)); + + std::vector sub_blocks; + sub_blocks.push_back(program->AppendBlock(*global_block)); + sub_blocks.push_back(program->AppendBlock(*global_block)); + op->SetAttr("sub_blocks", sub_blocks); + + program->Flush(); + + ProgramDesc program_move; + for (size_t i = 1; i < program->Size(); ++i) { + program_move.AppendBlock(program_move.Block(0)); + } + for (size_t i = 0; i < program->Size(); ++i) { + program_move.MutableBlock(i)->MoveFrom(program->MutableBlock(i)); + } + program = nullptr; + EXPECT_EQ(program_move.Size(), static_cast(4)); + op = program_move.Block(0).Op(0); + auto sub_block = op->GetAttrIfExists("sub_block"); + EXPECT_EQ(sub_block, program_move.MutableBlock(1)); + + sub_blocks = op->GetAttrIfExists>("sub_blocks"); + EXPECT_EQ(sub_blocks.size(), static_cast(2)); + EXPECT_EQ(sub_blocks[0], program_move.MutableBlock(2)); + EXPECT_EQ(sub_blocks[1], program_move.MutableBlock(3)); +} + TEST(ProgramDesc, copy_ctor) { ProgramDesc program; auto* global_block = program.MutableBlock(0); diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 135cf4e399..ed465c9ea2 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -624,3 +624,16 @@ DEFINE_bool(conv2d_disable_cudnn, false, "Disable cudnn in conv2d"); DEFINE_int32(get_host_by_name_time, 120, "The maximum time for get host by name time"); #endif + +/** + * Distributed related FLAG + * Name: FLAGS_apply_pass_to_program + * Since Version: 2.2.0 + * Value Range: bool, default=false + * Example: FLAGS_apply_pass_to_program=true would apply IR Pass to + * program when using Fleet APIs. + * Note: Apply IR pass to program. Be only useful when using Fleet APIs. + */ +DEFINE_bool( + apply_pass_to_program, false, + "It controls whether to apply IR pass to program when using Fleet APIs"); diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index dd45443a04..59c7628447 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -67,6 +67,7 @@ DECLARE_bool(benchmark); DECLARE_int32(inner_op_parallelism); DECLARE_int32(max_inplace_grad_add); DECLARE_string(tracer_profile_fname); +DECLARE_bool(apply_pass_to_program); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // cudnn DECLARE_uint64(conv_workspace_size_limit); @@ -367,7 +368,8 @@ static void RegisterGlobalVarGetterSetter() { FLAGS_memory_fraction_of_eager_deletion, FLAGS_use_pinned_memory, FLAGS_benchmark, FLAGS_inner_op_parallelism, FLAGS_tracer_profile_fname, FLAGS_paddle_num_threads, FLAGS_use_mkldnn, FLAGS_max_inplace_grad_add, - FLAGS_tracer_mkldnn_ops_on, FLAGS_tracer_mkldnn_ops_off); + FLAGS_tracer_mkldnn_ops_on, FLAGS_tracer_mkldnn_ops_off, + FLAGS_apply_pass_to_program); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) REGISTER_PUBLIC_GLOBAL_VAR( diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 788d8d15ff..e27e3674ee 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -293,6 +293,19 @@ static void SetAttrsToPass( } } +static std::vector GetPassNames(const py::object &names) { + try { + return {py::cast(names)}; + } catch (py::cast_error &) { + try { + return py::cast>(names); + } catch (py::cast_error &) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Pass names must be either str or list[str]")); + } + } +} + void BindPass(py::module *m) { // NOTE: pass_attr_types is a dict to indicate the type of each attribute. // Python has only one integral type "int", but C++ has many integral types. @@ -312,25 +325,38 @@ void BindPass(py::module *m) { REGISTER_PASS_ATTR_GETTER_SETTER("str", std::string); REGISTER_PASS_ATTR_GETTER_SETTER("list[str]", std::vector); - m->def( - "apply_pass", - [](framework::ProgramDesc *main_program, - framework::ProgramDesc *startup_program, const std::string &pass_name, - const std::unordered_map &pass_attrs, - std::unordered_map pass_attr_types) { - auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); - SetAttrsToPass(pass_attrs, &pass_attr_types, pass.get()); - pass->Apply(main_program, startup_program); - std::unordered_map result_attrs; - for (const auto &name_and_value : pass_attrs) { - const auto &attr_name = name_and_value.first; - const auto &attr_type = pass_attr_types.at(attr_name); - result_attrs[attr_name] = - PassAttrGetterSetterRegistry::Instance().Get(*pass, attr_name, - attr_type); - } - return result_attrs; - }); + m->def("apply_pass", + [](framework::ProgramDesc *main_program, + framework::ProgramDesc *startup_program, + const py::object &py_pass_names, + const std::unordered_map &pass_attrs, + std::unordered_map pass_attr_types) { + auto pass_names = GetPassNames(py_pass_names); + std::vector> passes; + std::vector passes_not_owned; + passes.reserve(pass_names.size()); + passes_not_owned.reserve(pass_names.size()); + for (const auto &name : pass_names) { + auto pass = framework::ir::PassRegistry::Instance().Get(name); + SetAttrsToPass(pass_attrs, &pass_attr_types, pass.get()); + passes.push_back(std::move(pass)); + passes_not_owned.push_back(passes.back().get()); + } + + framework::ir::Pass::ApplyPassesToProgram( + passes_not_owned, main_program, startup_program); + std::unordered_map result_attrs; + for (const auto &pass : passes) { + for (const auto &name_and_value : pass_attrs) { + const auto &attr_name = name_and_value.first; + const auto &attr_type = pass_attr_types.at(attr_name); + result_attrs[attr_name] = + PassAttrGetterSetterRegistry::Instance().Get( + *pass, attr_name, attr_type); + } + } + return result_attrs; + }); } } // namespace pybind diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 1ed84c146f..d1f6802919 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -18,7 +18,7 @@ import warnings import paddle import os import numpy as np -from paddle.fluid.framework import dygraph_only +from paddle.fluid.framework import dygraph_only, _global_flags from paddle.fluid import compiler from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase from .strategy_compiler import StrategyCompiler @@ -27,6 +27,7 @@ from .meta_optimizer_factory import MetaOptimizerFactory from .runtime_factory import RuntimeFactory from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.dygraph import parallel_helper +from paddle.fluid.ir import apply_build_strategy from . import topology as tp from .topology import ParallelMode from ..meta_parallel import TensorParallel, model_parallel_random_seed @@ -37,6 +38,33 @@ from ..meta_optimizers import HybridParallelGradScaler __all__ = [] +def apply_ir_passes(main_program, startup_program, config): + build_strategy = config._user_defined_strategy.build_strategy._copy() + if not _global_flags()['FLAGS_apply_pass_to_program']: + return build_strategy + + pipeline_opt = getattr(main_program, "_pipeline_opt", {}) + if pipeline_opt: + main_program = pipeline_opt["section_program"] + startup_program = startup_program._pipeline_opt["startup_program"] + + pass_attrs = {"use_cuda": config._is_collective} + fuse_all_reduce = config._user_defined_strategy.fuse_all_reduce_ops + if fuse_all_reduce and build_strategy.fuse_all_optimizer_ops: + # FIXME(zjl): currently, fuse_all_optimizer_ops + # have conflict with fuse_all_reduce_ops because + # RawProgramOptimizer also inserts coalesce_tensor + # into program. These two procedures may conflict + # in which vars are to be fused. + warnings.warn( + 'Currently, the fuse_all_optimizer_ops pass has conflict with fuse_all_reduce_ops pass. Disable the fuse_all_optimizer_ops pass temporarily.' + ) + build_strategy.fuse_all_optimizer_ops = False + + return apply_build_strategy(main_program, startup_program, build_strategy, + pass_attrs) + + def _inited_runtime_handler_(func): def __impl__(*args, **kwargs): cls = args[0] @@ -1475,6 +1503,8 @@ class Fleet(object): # i.e. users can not modify current computation graph anymore context["graph_optimize_ops"] = optimize_ops context["graph_optimize_grads"] = params_grads + else: + apply_ir_passes(loss.block.program, startup_program, self) program = paddle.static.default_main_program() opt_info = {} diff --git a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py index 5827f6bb3a..0fd7db56de 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py @@ -42,7 +42,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase): # update me. currently, if parameter server is used # graph execution optimizer can not be applied return False - return True + return not self.user_defined_strategy.without_graph_optimization def backward(self, loss, diff --git a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py index c923624651..754272f7fc 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py @@ -32,6 +32,11 @@ class RawProgramOptimizer(MetaOptimizerBase): self.meta_optimizers_white_list = [ "RecomputeOptimizer", "AMPOptimizer", + "GradientMergeOptimizer", + "LambOptimizer", + "LarsOptimizer", + "DGCOptimizer", + "LocalSGDOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.global_ring_id = 0 @@ -129,8 +134,103 @@ class RawProgramOptimizer(MetaOptimizerBase): self._transpile_main_program(loss) return optimize_ops, params_grads + def _find_gradient_merge_block(self): + GRAD_MERGE_COND_NAME = "grad_merge_cond_name" + gm_cond_var_name = None + for op in self.main_program.global_block().ops: + if GRAD_MERGE_COND_NAME not in op.attr_names: + continue + if gm_cond_var_name is None: + gm_cond_var_name = op.attr(GRAD_MERGE_COND_NAME) + else: + assert gm_cond_var_name == op.attr( + GRAD_MERGE_COND_NAME + ), "multiple gradient merge condition found" + if gm_cond_var_name is None: + return None + + cond_op = None # false_fn of gm is None, so we should only find one block + for op in self.main_program.global_block().ops: + if op.type != 'conditional_block' or 'Cond' not in op.input_names: + continue + cond_vars = op.input('Cond') + if not cond_vars or cond_vars[0] != gm_cond_var_name: + continue + assert cond_op is None, "multiple gradient merge block found" + cond_op = op + assert cond_op is not None, "cannot find gradient merge block" + return cond_op._block_attr("sub_block") + + def _insert_allreduce_ops_for_gm(self, gm_block): + block = self.main_program.global_block() + + last_backward_op_idx = None + for i, op in enumerate(reversed(gm_block.ops)): + if is_backward_op(op) and last_backward_op_idx is None: + last_backward_idx = i + break + if last_backward_op_idx is None: + last_backward_op_idx = 0 + + param_vars = [] + grad_vars = [] + for op in block.ops: + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.attr(OP_ROLE_VAR_KEY) + assert len(op_role_var) % 2 == 0 + for i in range(0, len(op_role_var), 2): + param = block.var(op_role_var[i]) + grad = block.var(op_role_var[i + 1]) + if param.is_distributed: + continue + param_vars.append(param) + grad_vars.append(grad) + + if not grad_vars: + return + + gm_block._insert_op( + last_backward_op_idx, + type="c_sync_calc_stream", + inputs={'X': grad_vars[0]}, + outputs={'Out': grad_vars[0]}, + attrs={OP_ROLE_KEY: OpRole.Backward}) + + insert_op_num = 1 + ring_id = self.global_ring_id + + # NOTE: can perform fuse allreduce inside the loop in the future + for i, (p, g) in enumerate(zip(param_vars, grad_vars)): + gm_block._insert_op( + last_backward_op_idx + insert_op_num, + type="c_allreduce_sum", + inputs={'X': g}, + outputs={'Out': g}, + attrs={ + 'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward, + }) + insert_op_num += 1 + + gm_block._insert_op( + last_backward_op_idx + insert_op_num, + type="c_sync_comm_stream", + inputs={'X': grad_vars[-1]}, + outputs={'Out': grad_vars[-1]}, + attrs={ + 'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward, + }) + def _transpile_main_program(self, loss): self._insert_loss_grad_ops(loss) + gm_block = self._find_gradient_merge_block() + if gm_block is not None: + # TODO(zjl): support fuse allreduce + self._insert_allreduce_ops_for_gm(gm_block) + return + if self.fuse_all_reduce_ops and self.fuse_grad_size_in_num > 1: self._allreduce_fusion_program() else: diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 3fe7f90a5b..cffbc29466 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -208,6 +208,7 @@ def __bootstrap__(): 'call_stack_level', 'sort_sum_gradient', 'max_inplace_grad_add', + 'apply_pass_to_program', ] if 'Darwin' not in sysstr: read_env_flags.append('use_pinned_memory') diff --git a/python/paddle/fluid/ir.py b/python/paddle/fluid/ir.py index 765272f9dc..69775dbdaf 100644 --- a/python/paddle/fluid/ir.py +++ b/python/paddle/fluid/ir.py @@ -14,6 +14,7 @@ import os import copy +from . import core from .framework import _apply_pass @@ -25,6 +26,35 @@ def get_data_vars(program): return data_vars +def _update_grad_persistable(main_program): + grad_merge_attr_name = "grad_merge_cond_name" + op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName() + has_grad_merge = False + has_persistable_grad_var = False + grad_vars = [] + for block_id in range(main_program.num_blocks): + block = main_program.block(block_id) + for op in block.ops: + if grad_merge_attr_name in op.attr_names: + has_grad_merge = True + + if op_role_var_attr_name not in op.attr_names: + continue + + p_g = op.attr(op_role_var_attr_name) + for g in p_g[1::2]: + g_var = block._find_var_recursive(g) + if g_var is None: + continue + grad_vars.append(g_var) + if g_var.persistable: + has_persistable_grad_var = True + + if has_grad_merge and has_persistable_grad_var: + for g_var in grad_vars: + g_var.persistable = True + + def apply_build_strategy(main_program, startup_program, build_strategy, pass_attrs): def update_attr(attrs, attr_types, name, value, typ=None): @@ -43,6 +73,7 @@ def apply_build_strategy(main_program, startup_program, build_strategy, get_data_vars(main_program), "list[str]") _apply_pass(main_program, startup_program, name, attrs, attr_types) + _update_grad_persistable(main_program) use_cuda = pass_attrs.get("use_cuda", False) build_strategy = build_strategy._copy() if build_strategy.sync_batch_norm: @@ -64,9 +95,12 @@ def apply_build_strategy(main_program, startup_program, build_strategy, apply_pass("fuse_elewise_add_act_pass") build_strategy.fuse_elewise_add_act_ops = False if build_strategy.fuse_all_optimizer_ops: - apply_pass("fuse_adam_op_pass") - apply_pass("fuse_sgd_op_pass") - apply_pass("fuse_momentum_op_pass") + apply_pass([ + "coalesce_grad_tensor_pass", + "fuse_adam_op_pass", + "fuse_sgd_op_pass", + "fuse_momentum_op_pass", + ]) build_strategy.fuse_all_optimizer_ops = False # TODO(zjl): support fuse all reduce ops if build_strategy.cache_runtime_context: diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 8b2af328f5..676642a9c9 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -6974,6 +6974,7 @@ class GradientMergeOptimizer(object): # cur_block's forward_block & backward_block is itself cur_block._set_forward_block_idx(cur_block_idx) + op_maker = core.op_proto_and_checker_maker if self.avg: for param, new_grad in new_params_grads: @@ -6987,6 +6988,8 @@ class GradientMergeOptimizer(object): 'bias': 0.0, 'bias_after_scale': False }) + new_grad.op._set_attr(op_maker.kOpRoleAttrName(), + op_maker.OpRole.Backward) for param, new_grad in new_params_grads: # NOTE. regularization will append ops to grad.block, @@ -7005,6 +7008,8 @@ class GradientMergeOptimizer(object): dtype=new_grad.dtype, value=0.0, out=new_grad) + new_grad.op._set_attr(op_maker.kOpRoleAttrName(), + op_maker.OpRole.Optimize) # step3. apply gradient layers.cond(cond, true_fn=true_apply_gradient, false_fn=None) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 00f2d2aa0b..44fd4c04e2 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -12,6 +12,7 @@ endif() string(REPLACE ".py" "" DIST_TEST_OPS "${DIST_TEST_OPS}") list(APPEND DIST_TEST_OPS test_parallel_dygraph_mnist) list(APPEND DIST_TEST_OPS test_pipeline) +list(APPEND DIST_TEST_OPS test_ir_pass_pipeline) list(APPEND DIST_TEST_OPS test_static_model_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding) @@ -968,6 +969,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) if(WITH_DISTRIBUTE) set_tests_properties(test_new_group_api PROPERTIES TIMEOUT 120) set_tests_properties(test_pipeline PROPERTIES TIMEOUT 120) + set_tests_properties(test_ir_pass_pipeline PROPERTIES TIMEOUT 120) set_tests_properties(test_static_model_parallel PROPERTIES TIMEOUT 240) set_tests_properties(test_collective_split_embedding test_collective_split_embedding_none_divisible diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_gradient_merge_raw_optimizer.py b/python/paddle/fluid/tests/unittests/dist_mnist_gradient_merge_raw_optimizer.py new file mode 100644 index 0000000000..ce8b6dec3c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_mnist_gradient_merge_raw_optimizer.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +import numpy as np +from test_dist_base import TestDistRunnerBase, runtime_main +from dist_mnist import cnn_model + + +class TestDistMnistGradientMergeRawOptimizer(TestDistRunnerBase): + def get_model(self, batch_size=2, single_device=False): + paddle.enable_static() + paddle.seed(1) + np.random.seed(1) + + assert fluid.core.globals()['FLAGS_apply_pass_to_program'] + strategy = fleet.DistributedStrategy() + build_strategy = paddle.static.BuildStrategy() + settings = { + "fuse_relu_depthwise_conv": True, + "fuse_bn_act_ops": True, + "fuse_bn_add_act_ops": True, + "fuse_elewise_add_act_ops": True, + "fuse_all_optimizer_ops": True, + "enable_addto": True, + "enable_inplace": True, + } + for k, v in settings.items(): + setattr(build_strategy, k, v) + strategy.build_strategy = build_strategy + + strategy.gradient_merge = True + strategy.gradient_merge_configs = { + "k_steps": 2, + "avg": False, + } + strategy.without_graph_optimization = True + + fleet.init(is_collective=True, strategy=strategy) + image = paddle.static.data( + name='image', shape=[None, 1, 28, 28], dtype="float32") + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + predict = cnn_model(image) + acc = paddle.metric.accuracy(predict, label) + loss_fn = nn.CrossEntropyLoss(use_softmax=False) + cost = loss_fn(predict, label) + test_program = paddle.static.default_main_program().clone(for_test=True) + optimizer = paddle.optimizer.Adam(learning_rate=1e-3) + if single_device: + optimizer = fluid.optimizer.GradientMergeOptimizer( + optimizer, + k_steps=strategy.gradient_merge_configs["k_steps"], + avg=strategy.gradient_merge_configs["avg"]) + else: + optimizer = fleet.distributed_optimizer(optimizer) + optimizer.minimize(cost) + train_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + return test_program, cost, train_reader, test_reader, acc, predict + + +if __name__ == "__main__": + runtime_main(TestDistMnistGradientMergeRawOptimizer) diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_gradient_merge.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_gradient_merge.py index a5610caa52..e10e4fd09e 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist_gradient_merge.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_gradient_merge.py @@ -16,6 +16,7 @@ from __future__ import print_function import os import unittest from test_dist_base import TestDistBase +import paddle.fluid as fluid flag_name = os.path.splitext(__file__)[0] @@ -27,7 +28,6 @@ class TestDistMnistGradMerge(TestDistBase): self._nccl2_mode = True def test_dist_train(self): - import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): self.check_with_place( "dist_mnist_gradient_merge.py", @@ -44,7 +44,6 @@ class TestDistMnistGradMergeNoFuse(TestDistBase): self._fuse_all_reduce = False def test_dist_train(self): - import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): self.check_with_place( "dist_mnist_gradient_merge.py", @@ -53,5 +52,22 @@ class TestDistMnistGradMergeNoFuse(TestDistBase): log_name=flag_name + "_no_fuse") +class TestDistMnistGradMergeRawOptimizer(TestDistBase): + def _setup_config(self): + self._use_reader_alloc = False + self._nccl2_mode = True + self._use_fleet_api = True + self._use_fleet_api_20 = True + + def test_dist_train(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "dist_mnist_gradient_merge_raw_optimizer.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name + "_raw_optimizer", + need_envs={'FLAGS_apply_pass_to_program': '1'}) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_ir_pass_pipeline.py b/python/paddle/fluid/tests/unittests/test_ir_pass_pipeline.py new file mode 100644 index 0000000000..7d11c03a1f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_ir_pass_pipeline.py @@ -0,0 +1,25 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import test_pipeline + + +class TestPipelineWithIRPass(test_pipeline.TestPipeline): + def need_envs(self): + return {'FLAGS_apply_pass_to_program': '1'} + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pipeline.py b/python/paddle/fluid/tests/unittests/test_pipeline.py index 1be10113a5..8f46119d55 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline.py @@ -18,6 +18,7 @@ from test_dist_base import TestDistBase import os import paddle +import paddle.fluid as fluid paddle.enable_static() flag_name = os.path.splitext(__file__)[0] @@ -31,8 +32,10 @@ class TestPipeline(TestDistBase): self._pipeline_mode = True self._nccl_comm_num = 1 + def need_envs(self): + return {} + def test_dist_train(self): - import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): # TODO (sandyhouse) fix the delta value. # Now pipeline only gets the loss value of the last @@ -42,24 +45,25 @@ class TestPipeline(TestDistBase): "pipeline_mnist.py", delta=1e0, check_error_log=True, - log_name=flag_name) + log_name=flag_name, + need_envs=self.need_envs()) def test_dist_train_multi_device(self): - import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): self.check_with_place( "pipeline_mnist_multi_device.py", check_error_log=True, delta=1e0, - log_name=flag_name) + log_name=flag_name, + need_envs=self.need_envs()) def test_dist_train_one_device(self): - import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): self.check_with_place( "pipeline_mnist_one_device.py", check_error_log=True, - log_name=flag_name) + log_name=flag_name, + need_envs=self.need_envs()) if __name__ == '__main__': -- GitLab