From 3508bd28e55d8ce6ad79231fb7dde047c4296cd1 Mon Sep 17 00:00:00 2001 From: dyning Date: Thu, 8 Jul 2021 21:20:26 +0800 Subject: [PATCH] Add the op def for elementwise_mul and enhance layer_norm_fuse_pass (#33560) --- .../framework/ir/layer_norm_fuse_pass.cc | 126 ++++++++++++++++++ .../fluid/framework/ir/layer_norm_fuse_pass.h | 1 + .../ir/layer_norm_fuse_pass_tester.cc | 49 ++++--- 3 files changed, 158 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc b/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc index 18d2e9817eb..95d55834f82 100644 --- a/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc @@ -99,6 +99,122 @@ void addIntermediateOut(Node* op_node, const std::string& out_name, } // namespace +LayerNormFusePass::LayerNormFusePass() { + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Variance") + .IsTensor() + .IsOptional() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + AddOpCompat(OpCompat("reduce_mean")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("dim") + .IsType>() + .End() + .AddAttr("keep_dim") + .IsBoolEQ(true) + .End(); + AddOpCompat(OpCompat("sqrt")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + AddOpCompat(OpCompat("elementwise_sub")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("elementwise_pow")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("elementwise_div")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("elementwise_mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); +} + void LayerNormFusePass::ApplyImpl(Graph* graph) const { PADDLE_ENFORCE_NOT_NULL(graph, platform::errors::InvalidArgument( @@ -117,6 +233,10 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const { int found_layer_norm_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "Fuse LayerNorm from subgraph."; GET_IR_NODE_FROM_SUBGRAPH(x, x, layer_norm_pattern); GET_IR_NODE_FROM_SUBGRAPH(x_mean, x_mean, layer_norm_pattern); @@ -205,6 +325,12 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const { ln_op_desc.SetAttr("begin_norm_axis", static_cast(x_shape.size() - 1)); ln_op_desc.SetAttr("epsilon", *(eps_tensor->data())); ln_op_desc.SetAttr("is_test", true); + + if (!IsCompat(ln_op_desc)) { + LOG(WARNING) << "layer norm pass in out layer_norm op compat failed."; + return; + } + Node* ln_op = g->CreateOpNode(&ln_op_desc); addIntermediateOut(ln_op, "Mean", scope_name_, g); diff --git a/paddle/fluid/framework/ir/layer_norm_fuse_pass.h b/paddle/fluid/framework/ir/layer_norm_fuse_pass.h index 29a6f127065..a9d49ea012d 100644 --- a/paddle/fluid/framework/ir/layer_norm_fuse_pass.h +++ b/paddle/fluid/framework/ir/layer_norm_fuse_pass.h @@ -70,6 +70,7 @@ namespace ir { */ class LayerNormFusePass : public FusePassBase { public: + LayerNormFusePass(); virtual ~LayerNormFusePass() {} protected: diff --git a/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc b/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc index 5fe71fbc214..accfe8920a8 100644 --- a/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc @@ -66,12 +66,16 @@ class LayerNormFuseTest { x_mean->SetAttr("keep_dim", true); x_mean->SetAttr("reduce_all", false); - test::CreateOp(&m_prog, "elementwise_sub", - {{"X", "x"}, {"Y", "x_mean_out"}}, - {{"Out", "x_sub_mean_out"}}, false); - test::CreateOp(&m_prog, "elementwise_pow", - {{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}}, - {{"Out", "x_sub_mean_sqr_out"}}, false); + auto* x_sub = test::CreateOp(&m_prog, "elementwise_sub", + {{"X", "x"}, {"Y", "x_mean_out"}}, + {{"Out", "x_sub_mean_out"}}, false); + x_sub->SetAttr("axis", 1); + + auto* x_pow = test::CreateOp(&m_prog, "elementwise_pow", + {{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}}, + {{"Out", "x_sub_mean_sqr_out"}}, false); + x_pow->SetAttr("axis", 1); + auto* std_dev = test::CreateOp(&m_prog, "reduce_mean", {{"X", "x_sub_mean_sqr_out"}}, {{"Out", "std_dev_out"}}, false); @@ -79,20 +83,29 @@ class LayerNormFuseTest { std_dev->SetAttr("keep_dim", true); std_dev->SetAttr("reduce_all", false); - test::CreateOp(&m_prog, "elementwise_add", - {{"X", "std_dev_out"}, {"Y", "eps"}}, - {{"Out", "std_dev_eps_out"}}, false); + auto* x_add = test::CreateOp(&m_prog, "elementwise_add", + {{"X", "std_dev_out"}, {"Y", "eps"}}, + {{"Out", "std_dev_eps_out"}}, false); + x_add->SetAttr("axis", 1); + test::CreateOp(&m_prog, "sqrt", {{"X", "std_dev_eps_out"}}, {{"Out", "std_dev_eps_sqrt_out"}}, false); - test::CreateOp(&m_prog, "elementwise_div", - {{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}}, - {{"Out", "division_out"}}, false); - test::CreateOp(&m_prog, "elementwise_mul", - {{"X", "division_out"}, {"Y", "gamma"}}, - {{"Out", "scale_out"}}, false); - test::CreateOp(&m_prog, "elementwise_add", - {{"X", "scale_out"}, {"Y", "beta"}}, {{"Out", "shift_out"}}, - false); + + auto* x_div = + test::CreateOp(&m_prog, "elementwise_div", + {{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}}, + {{"Out", "division_out"}}, false); + x_div->SetAttr("axis", 1); + + auto* x_mul = test::CreateOp(&m_prog, "elementwise_mul", + {{"X", "division_out"}, {"Y", "gamma"}}, + {{"Out", "scale_out"}}, false); + x_mul->SetAttr("axis", 1); + + auto* x_add_v1 = test::CreateOp(&m_prog, "elementwise_add", + {{"X", "scale_out"}, {"Y", "beta"}}, + {{"Out", "shift_out"}}, false); + x_add_v1->SetAttr("axis", 1); } template -- GitLab