From ca552933503f9c4a7f9c36099504a975f071832d Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 30 Nov 2022 11:27:40 +0800 Subject: [PATCH] Add fuse_act_add_grad_pass (#48346) * add fuse act add grad pass * polish code * refine code * add test * refine code --- .../framework/ir/fuse_elewise_add_act_pass.cc | 129 +++++++++++++++++- .../framework/ir/fuse_elewise_add_act_pass.h | 8 ++ .../framework/ir/graph_pattern_detector.cc | 29 +++- .../framework/ir/graph_pattern_detector.h | 21 +++ .../new_executor/interpreter/data_transfer.cc | 1 - .../fused/fused_elemwise_activation_op.h | 25 +++- .../test_fuse_elewise_add_act_pass.py | 71 +++++++++- 7 files changed, 272 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index 67aa5a822e..b6faf76f11 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -31,6 +31,7 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const { { std::unordered_set in_place_act_types = {"relu_grad"}; graph = FuseElewiseAddActInplaceGrad(graph, in_place_act_types); + graph = FuseActElewiseAddInplaceGrad(graph, in_place_act_types); } // Remove the removable intermediate_out. @@ -110,7 +111,7 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd( auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { - VLOG(4) << "handle FuseElewiseAddAct fuse"; + VLOG(4) << "handle FuseActElewiseAdd fuse"; GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, act_elewise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(ele_x, ele_x, act_elewise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH( @@ -220,6 +221,86 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( return graph; } +// the backward of act(ele_add(x,y)) +// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"] +// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"] +ir::Graph *FuseElewiseAddActPass::FuseActElewiseAddInplaceGrad( + ir::Graph *graph, const std::unordered_set &act_types) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init("act_elewise_add_grad", graph); + GraphPatternDetector gpd; + auto *d_out_var = + gpd.mutable_pattern() + ->NewNode("act_elewise_add_grad_inplace/d_out_var") + ->AsInput() + ->assert_is_ops_input({"elementwise_add_grad"}, GradVarName("Out")); + patterns::ActElewiseAddInplaceGrad act_elewise_add_grad_pattern( + gpd.mutable_pattern(), "act_elewise_add_grad_inplace"); + act_elewise_add_grad_pattern(d_out_var, act_types); + + int found_elewise_add_act_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle ActFuseElewiseAddGrad1 fuse"; + + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_grad_op, ele_add_grad_op, act_elewise_add_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + act_grad_op, act_grad_op, act_elewise_add_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + intermediate_var, intermediate_var, act_elewise_add_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + d_intermediate_var, d_intermediate_var, act_elewise_add_grad_pattern); + + std::string d_out_var_n = subgraph.at(d_out_var)->Name(); + std::string intermediate_var_n = intermediate_var->Name(); + std::string d_intermediate_var_n = d_intermediate_var->Name(); + + OpDesc desc; + desc.SetType("fused_elemwise_add_activation_grad"); + desc.SetInput("IntermediateOut", + std::vector({intermediate_var_n})); + desc.SetInput("X", {}); + desc.SetInput("Y", ele_add_grad_op->Op()->Input("X")); + desc.SetInput("Out", {}); + desc.SetInput(GradVarName("Out"), std::vector({d_out_var_n})); + desc.SetOutput(GradVarName("X"), + act_grad_op->Op()->Output(GradVarName("X"))); + desc.SetOutput(GradVarName("Y"), + ele_add_grad_op->Op()->Output(GradVarName("X"))); + desc.SetOutput(GradVarName("IntermediateOut"), + std::vector({d_intermediate_var_n})); + + desc.SetAttr("save_intermediate_out", false); + desc.SetAttr("functor_list", + std::vector({ele_add_grad_op->Op()->Type(), + act_grad_op->Op()->Type()})); + + for (auto &n : {ele_add_grad_op->Op(), act_grad_op->Op()}) { + for (auto &m_ele : n->GetAttrMap()) { + desc.SetAttr(m_ele.first, m_ele.second); + } + } + + auto fused_node = g->CreateOpNode(&desc); + + VLOG(4) << "\n\t " << d_out_var_n << " -> " << ele_add_grad_op->Name() + << " -> " << d_intermediate_var_n << "\n\t " << intermediate_var_n + << " and " << d_intermediate_var_n << " -> " << act_grad_op->Name(); + + ReLinkNodes2( + g, d_intermediate_var, ele_add_grad_op, act_grad_op, fused_node); + found_elewise_add_act_count++; + }; + + gpd(graph, handler); + + AddStatis(found_elewise_add_act_count); + return graph; +} + Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode( Graph *g, const Node *op_1, @@ -364,6 +445,52 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph, GraphSafeRemoveNodes(graph, nodes2delete); } +void FuseElewiseAddActPass::ReLinkNodes2(Graph *graph, + const Node *intermediate_out, + Node *op_1, + Node *op_2, + Node *fused_op) const { // delete act + for (auto &in : op_1->inputs) { + fused_op->inputs.emplace_back(in); + in->outputs = this->ReplaceNode(op_1, fused_op, in->outputs); + } + + std::unordered_set nodes2delete; + for (auto &out : op_1->outputs) { + if (out->IsCtrlVar()) { + auto result_iter = std::find_if( + op_2->inputs.begin(), + op_2->inputs.end(), + [&out](const Node *node) -> bool { return node == out; }); + + if (result_iter == op_2->inputs.end()) { + IR_OP_VAR_LINK(fused_op, out); + } else { + nodes2delete.emplace(out); + } + } else { + IR_OP_VAR_LINK(fused_op, out); + } + } + + for (auto &in : op_2->inputs) { + if (in == intermediate_out || nodes2delete.count(in)) { + continue; + } + fused_op->inputs.emplace_back(in); + in->outputs = this->ReplaceNode(op_2, fused_op, in->outputs); + } + + for (auto &out : op_2->outputs) { + IR_OP_VAR_LINK(fused_op, out); + } + + nodes2delete.insert(std::move(op_1)); + nodes2delete.insert(std::move(op_2)); + + GraphSafeRemoveNodes(graph, nodes2delete); +} + std::vector FuseElewiseAddActPass::ReplaceNode( Node *cur_node, Node *new_node, const std::vector &nodes) const { std::vector new_list(nodes.size()); diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h index d9b0ec928a..c608bb5845 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h @@ -49,6 +49,9 @@ class FuseElewiseAddActPass : public FusePassBase { ir::Graph *FuseElewiseAddActInplaceGrad( ir::Graph *graph, const std::unordered_set &act_types) const; + ir::Graph *FuseActElewiseAddInplaceGrad( + ir::Graph *graph, const std::unordered_set &act_types) const; + /** * Remove the removable intermediate_out. * - If the intermediate_out is only used by the backward op, but the @@ -69,6 +72,11 @@ class FuseElewiseAddActPass : public FusePassBase { Node *op_1, Node *op_2, Node *fused_op) const; + void ReLinkNodes2(Graph *graph, + const Node *intermediate_out, + Node *op_1, + Node *op_2, + Node *fused_op) const; Node *CreateFuseElewiseAddActNode(Graph *g, const Node *op_1, const Node *op_2, diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 753c169f8f..acbaef67a6 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -91,7 +91,6 @@ void GraphPatternDetector::operator()(Graph *graph, if (!MarkPDNodesInGraph(*graph)) { return; } - auto subgraphs = DetectPatterns(); UniquePatterns(&subgraphs); SortSubgraphs(&subgraphs); @@ -99,7 +98,6 @@ void GraphPatternDetector::operator()(Graph *graph, ValidateByNodeRole(&subgraphs); if (subgraphs.empty()) return; - int id = 0; for (auto &g : subgraphs) { VLOG(3) << "optimizing #" << id++ << " subgraph"; @@ -1613,6 +1611,33 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( return ele_add_grad; } +PDNode *patterns::ActElewiseAddInplaceGrad::operator()( + paddle::framework::ir::PDNode *d_out_var, + std::unordered_set act_types) { + VLOG(4) << "ActElewiseAddInplaceGrad::operator"; + + auto *ele_add_grad_op = pattern->NewNode(ele_add_grad_op_repr()) + ->assert_is_op("elementwise_add_grad"); + auto *act_grad_op = + pattern->NewNode(act_grad_op_repr())->assert_is_ops(act_types); + + auto *d_intermediate_out_var = + pattern->NewNode(d_intermediate_var_repr()) + ->assert_is_op_output("elementwise_add_grad", GradVarName("Y")) + ->assert_is_ops_input(act_types, GradVarName("Out")); + auto *intermediate_out_var = + pattern->NewNode(intermediate_var_repr()) + ->assert_is_op_input("elementwise_add_grad", "Y") + ->assert_is_ops_input(act_types, "Out"); + + ele_add_grad_op->LinksFrom({d_out_var}); + d_intermediate_out_var->LinksFrom({ele_add_grad_op}).LinksTo({act_grad_op}); + intermediate_out_var->LinksTo({ele_add_grad_op}); + intermediate_out_var->LinksTo({act_grad_op}); + + return act_grad_op; +} + PDNode *patterns::ElewiseAddAct::operator()( paddle::framework::ir::PDNode *ele_x_var, std::unordered_set act_types) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index cb1b9266b1..da479c1bf7 100755 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -928,6 +928,27 @@ struct ElewiseAddActInplaceGrad : public PatternBase { PATTERN_DECL_NODE(ele_y); }; +// the backward of ele_add(act(x), y) +// the act is inplace. +// op: elementwise_add_grad + act_grad +// named nodes: elementwise_add_grad, act_grad +// ele_y, d_ele_y, d_intermeiate_out, intermediate_out, d_x +struct ActElewiseAddInplaceGrad : public PatternBase { + ActElewiseAddInplaceGrad(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "act_elewise_add_grad1") {} + + // ele_add_grad: in["Y", "Out@GRAD"], out["IntermediateOut@GRAD", "Y@GRAD"] + // act_grad: in["IntermediateOut", "IntermediateOut@GRAD"], out["X@GRAD"] + PDNode* operator()(PDNode* d_out_var, std::unordered_set acts); + + // declare operator node's name + PATTERN_DECL_NODE(ele_add_grad_op); + PATTERN_DECL_NODE(act_grad_op); + // // declare variable node's name + PATTERN_DECL_NODE(intermediate_var); + PATTERN_DECL_NODE(d_intermediate_var); +}; + // The following patterns are used to fuse linear and act (ReLu or GeLU) // formula: act(F.linear(x)) // op: matmul_v2 + elementwise_add + act diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc index 8f9209f6a9..f2882eaf59 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc @@ -462,7 +462,6 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, for (auto& var_name_item : *ins_map_temp) { bool should_skip_input = no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0; - for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto var = var_name_item.second[i]; auto var_name = new_ins[var_name_item.first].at(i); diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.h b/paddle/fluid/operators/fused/fused_elemwise_activation_op.h index 0d6a5e3b40..50d2057dbd 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.h +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.h @@ -664,11 +664,9 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel { in_y, nullptr, platform::errors::InvalidArgument("Input(Y) should not be nullptr.")); - auto in_out = ctx.Input("Out"); - PADDLE_ENFORCE_NE( - in_out, - nullptr, - platform::errors::InvalidArgument("Input(Out) should not be nullptr.")); + phi::DenseTensor *in_out = + const_cast(ctx.Input("Out")); + auto in_out_grad = ctx.Input(framework::GradVarName("Out")); PADDLE_ENFORCE_NE(in_out_grad, @@ -726,6 +724,23 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel { in_x = const_cast(in_out_grad); } + // Get in_Out + if (ctx.HasInput("Out")) { + PADDLE_ENFORCE_NE( + in_out, + nullptr, + platform::errors::InvalidArgument("Input(X) should not be null.")); + } else { + // If functor_list contains elementwise_add, the backward doesn't use + // in_x, in_y and in_out. + PADDLE_ENFORCE_EQ(InputXCanBeAbsent(functor_list), + true, + platform::errors::InvalidArgument( + "Only when the compoundfunctor contains " + "elementwise_add_grad, the 'X' could be absent.")); + in_out = const_cast(in_out_grad); + } + bool has_in_place = HasInPlaceUnary(functor_list); if (has_in_place) { RunGradFunctors(ctx, diff --git a/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py b/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py index 9a33552f31..6f3bc21e4b 100644 --- a/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py +++ b/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py @@ -14,10 +14,11 @@ import os import unittest +import numpy from parallel_executor_test_base import DeviceType, TestParallelExecutorBase from simple_nets import fc_with_batchnorm, init_data, simple_fc_net - +import paddle import paddle.fluid as fluid import paddle.fluid.core as core @@ -89,8 +90,72 @@ class TestMNIST(TestParallelExecutorBase): ) -if __name__ == '__main__': - import paddle +class TestFuseActElewiseAddInplaceGradPass(unittest.TestCase): + def build_program(self, main_program, startup_program): + with paddle.static.program_guard(main_program, startup_program): + X = fluid.data(name="X", shape=[3, 3], dtype='float32') + Y = fluid.data(name="Y", shape=[3, 3], dtype='float32') + Out1 = X * 5 + Out2 = fluid.layers.relu(Out1) + prediction = fluid.layers.elementwise_add(Y, Out2, axis=1) + loss = paddle.mean(prediction) + sgd = fluid.optimizer.SGD(learning_rate=0.001) + sgd.minimize(loss) + return X, Y, loss + + def check(self, place): + paddle.seed(1) + numpy.random.seed(1) + paddle.framework.random._manual_program_seed(1) + main_program = fluid.Program() + startup_program = fluid.Program() + X, Y, loss = self.build_program(main_program, startup_program) + exe = fluid.Executor(place) + + x = numpy.random.random(size=(3, 3)).astype('float32') + y = numpy.random.random(size=(3, 3)).astype('float32') + label = numpy.random.random(size=(3, 3)).astype('float32') + # open fused_pass + build_strategy = fluid.BuildStrategy() + build_strategy.fuse_elewise_add_act_ops = True + compiled_prog_fused = paddle.static.CompiledProgram( + main_program, build_strategy=build_strategy + ) + scope = fluid.Scope() + with fluid.scope_guard(scope): + exe.run(startup_program) + loss_data_fused = exe.run( + compiled_prog_fused, + feed={"X": x, "Y": y}, + fetch_list=[loss.name], + ) + + # close fused_pass + build_strategy = fluid.BuildStrategy() + build_strategy.fuse_elewise_add_act_ops = False + compiled_prog = paddle.static.CompiledProgram( + main_program, build_strategy=build_strategy + ) + scope = fluid.Scope() + with fluid.scope_guard(scope): + exe.run(startup_program) + loss_data = exe.run( + compiled_prog, feed={"X": x, "Y": y}, fetch_list=[loss.name] + ) + + self.assertEqual(loss_data_fused, loss_data) + + def test_fuse_act_add_grad_pass_cpu(self): + place = fluid.CPUPlace() + self.check(place) + + def test_fuse_act_add_grad_pass_cuda(self): + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self.check(place) + + +if __name__ == '__main__': paddle.enable_static() unittest.main() -- GitLab