From dc4b48f629404d31785816d90c34e471b5bce8ed Mon Sep 17 00:00:00 2001 From: wz1qqx <55830058+wz1qqx@users.noreply.github.com> Date: Thu, 3 Aug 2023 00:59:38 -0700 Subject: [PATCH] eliminate small pattern (#55843) --- paddle/fluid/framework/ir/CMakeLists.txt | 4 +- .../ir/xpu/add_layernorm_xpu_fuse_pass.cc | 49 ++- .../framework/ir/xpu/reduce_ops_fuse_pass.cc | 12 +- .../redundant_onnx_ops_elimination_pass.cc | 209 ----------- ...dant_unsqueeze_squeeze_elimination_pass.cc | 330 ++++++++++++++++++ ...dant_unsqueeze_squeeze_elimination_pass.h} | 50 +-- .../inference/api/paddle_pass_builder.cc | 2 +- paddle/phi/api/yaml/fused_ops.yaml | 2 +- paddle/phi/backends/xpu/xpu2_op_list.cc | 3 +- paddle/phi/infermeta/fusion.cc | 11 +- paddle/phi/infermeta/fusion.h | 3 +- .../fusion/xpu/add_layernorm_xpu_kernel.cc | 64 +++- 12 files changed, 455 insertions(+), 284 deletions(-) delete mode 100644 paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc create mode 100644 paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc rename paddle/fluid/framework/ir/xpu/{redundant_onnx_ops_elimination_pass.h => redundant_unsqueeze_squeeze_elimination_pass.h} (60%) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index bbc045271ca..39e676dac85 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -240,8 +240,8 @@ if(WITH_XPU) pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) - pass_library(redundant_onnx_ops_elimination_pass inference DIR xpu DEPS - ${XPU_PASS_DEPS}) + pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu + DEPS ${XPU_PASS_DEPS}) pass_library(redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_transpose_xpu_fuse_pass inference DIR xpu DEPS diff --git a/paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc index 698c0b6c033..5e50b762e8c 100644 --- a/paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc @@ -43,17 +43,17 @@ namespace patterns { fuse ele_add + activation block in to xpu_ele_fusion op For example: graph: - ele_x + add_x | - elementwise_add -----ele_y + elementwise_add -----add_y | layernorm | output ------------------------------------------------------ After the pass is applied: - ele_x - | ele_y + add_x + | add_y | / | / scale---- add_layernorm_fusion ---- bias @@ -68,8 +68,8 @@ struct AddLayernormXPUPattern : public PatternBase { PATTERN_DECL_NODE(ele_add); PATTERN_DECL_NODE(l_norm); // declare variable node's name - PATTERN_DECL_NODE(ele_x); - PATTERN_DECL_NODE(ele_y); + PATTERN_DECL_NODE(add_x); + PATTERN_DECL_NODE(add_y); PATTERN_DECL_NODE(ele_out); PATTERN_DECL_NODE(norm_bias); PATTERN_DECL_NODE(norm_scale); @@ -83,17 +83,16 @@ AddLayernormXPUPattern::AddLayernormXPUPattern(PDPattern* pattern, : PatternBase(pattern, name_scope, name_scope) { auto ele_add = pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add"); - auto ele_x = pattern->NewNode(ele_x_repr()) + auto add_x = pattern->NewNode(add_x_repr()) ->assert_is_op_input("elementwise_add", "X") ->AsInput(); - auto ele_y = pattern->NewNode(ele_y_repr()) + auto add_y = pattern->NewNode(add_y_repr()) ->assert_is_op_input("elementwise_add", "Y") ->AsInput(); auto ele_out = pattern->NewNode(ele_out_repr()) ->assert_is_op_output("elementwise_add", "Out") - ->assert_is_op_input("layer_norm", "X") - ->assert_has_n_outputs(1); - ele_add->LinksFrom({ele_x, ele_y}).LinksTo({ele_out}); + ->assert_is_op_input("layer_norm", "X"); + ele_add->LinksFrom({add_x, add_y}).LinksTo({ele_out}); auto l_norm = pattern->NewNode(l_norm_repr())->assert_is_op("layer_norm"); auto norm_bias = pattern->NewNode(norm_bias_repr()) ->AsInput() @@ -169,8 +168,8 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const { GET_IR_NODE(ele_add); GET_IR_NODE(l_norm); // declare variable node's name - GET_IR_NODE(ele_x); - GET_IR_NODE(ele_y); + GET_IR_NODE(add_x); + GET_IR_NODE(add_y); GET_IR_NODE(ele_out); GET_IR_NODE(norm_bias); GET_IR_NODE(norm_scale); @@ -178,21 +177,21 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const { GET_IR_NODE(norm_variance); GET_IR_NODE(norm_out); - auto* block = ele_add->Op()->Block(); + auto* block = l_norm->Op()->Block(); auto* scope = param_scope(); PADDLE_ENFORCE_NOT_NULL( scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + auto x_shape = add_x->Var()->GetShape(); + auto x_rank = x_shape.size(); + auto y_shape = add_y->Var()->GetShape(); + auto y_rank = y_shape.size(); + if (x_rank != y_rank) return; // delete useless node std::unordered_set delete_nodes; float eps = PADDLE_GET_CONST(float, l_norm->Op()->GetAttr("epsilon")); int begin_norm_axis = PADDLE_GET_CONST(int, l_norm->Op()->GetAttr("begin_norm_axis")); - auto layer_norm_x_dims = ele_out->Var()->GetShape(); - auto layer_norm_x_mat_dims = - phi::flatten_to_2d(phi::make_ddim(layer_norm_x_dims), begin_norm_axis); - int64_t m = layer_norm_x_mat_dims[0]; - int64_t n = layer_norm_x_mat_dims[1]; std::string fused_op_out_name; fused_op_out_name = norm_out->Name(); @@ -200,28 +199,26 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const { framework::OpDesc fused_op_desc(block); fused_op_desc.SetType("add_layernorm_xpu"); // set attrs for fused op - fused_op_desc.SetInput("x", {ele_x->Name()}); - fused_op_desc.SetInput("y", {ele_y->Name()}); + fused_op_desc.SetInput("x", {add_x->Name()}); + fused_op_desc.SetInput("y", {add_y->Name()}); fused_op_desc.SetInput("scale", {norm_scale->Name()}); fused_op_desc.SetInput("bias", {norm_bias->Name()}); - fused_op_desc.SetAttr("m", m); - fused_op_desc.SetAttr("n", n); fused_op_desc.SetAttr("epsilon", eps); + fused_op_desc.SetAttr("begin_norm_axis", begin_norm_axis); fused_op_desc.SetOutput("out", {fused_op_out_name}); setIntermediateOut(&fused_op_desc, "mean", name_scope_); setIntermediateOut(&fused_op_desc, "variance", name_scope_); setIntermediateOut(&fused_op_desc, "z_add", name_scope_); // relink fused op auto* fused_op = graph->CreateOpNode(&fused_op_desc); - IR_NODE_LINK_TO(ele_x, fused_op); - IR_NODE_LINK_TO(ele_y, fused_op); + IR_NODE_LINK_TO(add_x, fused_op); + IR_NODE_LINK_TO(add_y, fused_op); IR_NODE_LINK_TO(norm_scale, fused_op); IR_NODE_LINK_TO(norm_bias, fused_op); IR_NODE_LINK_TO(fused_op, norm_out); addIntermediateOut(fused_op, "mean", name_scope_, graph); addIntermediateOut(fused_op, "variance", name_scope_, graph); addIntermediateOut(fused_op, "z_add", name_scope_, graph); - delete_nodes.insert({ele_add, l_norm, ele_out, norm_mean, norm_variance}); GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; diff --git a/paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc index eae3d5d9c17..1738d39f155 100644 --- a/paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc @@ -88,7 +88,7 @@ ReduceMaxFusePattern::ReduceMaxFusePattern(PDPattern* pattern, auto* op_desc = node->Op(); auto input_var = node->inputs[0]->Var(); auto pool2d_x_shape = input_var->GetShape(); - std::vector HW = {static_cast(pool2d_x_shape[2]), + std::vector hw = {static_cast(pool2d_x_shape[2]), static_cast(pool2d_x_shape[3])}; auto pool_type = op_desc->GetAttrIfExists("pooling_type"); @@ -98,8 +98,8 @@ ReduceMaxFusePattern::ReduceMaxFusePattern(PDPattern* pattern, op_desc->GetAttrIfExists>("strides"); auto paddings_array = op_desc->GetAttrIfExists>("paddings"); - return pool_type == "max" && ksize_array == HW && - strides_array == HW && + return pool_type == "max" && ksize_array == hw && + strides_array == hw && paddings_array == std::vector{0, 0}; }); auto* pool2d_out = pattern->NewNode(pool2d_out_repr()) @@ -181,7 +181,7 @@ ReduceMeanFusePattern::ReduceMeanFusePattern(PDPattern* pattern, auto* op_desc = node->Op(); auto input_var = node->inputs[0]->Var(); auto pool2d_x_shape = input_var->GetShape(); - std::vector HW = {static_cast(pool2d_x_shape[2]), + std::vector hw = {static_cast(pool2d_x_shape[2]), static_cast(pool2d_x_shape[3])}; auto pool_type = op_desc->GetAttrIfExists("pooling_type"); @@ -191,8 +191,8 @@ ReduceMeanFusePattern::ReduceMeanFusePattern(PDPattern* pattern, op_desc->GetAttrIfExists>("strides"); auto paddings_array = op_desc->GetAttrIfExists>("paddings"); - return pool_type == "avg" && ksize_array == HW && - strides_array == HW && + return pool_type == "avg" && ksize_array == hw && + strides_array == hw && paddings_array == std::vector{0, 0}; }); auto* pool2d_out = pattern->NewNode(pool2d_out_repr()) diff --git a/paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc b/paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc deleted file mode 100644 index f63c51e36db..00000000000 --- a/paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.h" -#include - -#include "glog/logging.h" - -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/xpu/pass_utils.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace framework { -namespace ir { -namespace patterns { - -struct FoldConv1dSqueeze2Pattern : public PatternBase { - FoldConv1dSqueeze2Pattern(PDPattern* pattern, - const std::string& name_scope, - const std::string& act_type); - - // declare operator node's name - PATTERN_DECL_NODE(squeeze2); - PATTERN_DECL_NODE(bn); - PATTERN_DECL_NODE(act); - PATTERN_DECL_NODE(unsqueeze2); - // declare variable node's name - PATTERN_DECL_NODE(x); - PATTERN_DECL_NODE(squeeze2_out); - PATTERN_DECL_NODE(bn_bias); - PATTERN_DECL_NODE(bn_mean); - PATTERN_DECL_NODE(bn_scale); - PATTERN_DECL_NODE(bn_var); - PATTERN_DECL_NODE(bn_out); - PATTERN_DECL_NODE(bn_mean_out); - PATTERN_DECL_NODE(bn_saved_mean); - PATTERN_DECL_NODE(bn_saved_var); - PATTERN_DECL_NODE(bn_var_out); - PATTERN_DECL_NODE(act_out); - PATTERN_DECL_NODE(unsqueeze2_out); - - private: - std::string act_type_; -}; - -FoldConv1dSqueeze2Pattern::FoldConv1dSqueeze2Pattern( - PDPattern* pattern, - const std::string& name_scope, - const std::string& act_type) - : PatternBase(pattern, name_scope, name_scope), act_type_(act_type) { - auto* x = pattern->NewNode(x_repr()) - ->assert_is_op_input("squeeze2", "X") - ->assert_more([](Node* node) { - auto x_shape = node->Var()->GetShape(); - size_t x_rank = x_shape.size(); - return x_rank == 4 && x_shape[2] == 1; - }); - auto* squeeze2 = pattern->NewNode(squeeze2_repr()) - ->assert_is_op("squeeze2") - ->assert_more([](Node* node) { - auto* op_desc = node->Op(); - auto axes_array = - op_desc->GetAttrIfExists>("axes"); - return axes_array == std::vector{-2}; - }); - auto* squeeze2_out = pattern->NewNode(squeeze2_out_repr()) - ->assert_is_op_output("squeeze2", "Out") - ->assert_is_op_input("batch_norm", "X"); - squeeze2->LinksFrom({x}).LinksTo({squeeze2_out}); - - auto* bn_bias = pattern->NewNode(bn_bias_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("batch_norm", "Bias") - ->assert_has_n_outputs(1); - auto* bn_mean = pattern->NewNode(bn_mean_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("batch_norm", "Mean") - ->assert_has_n_outputs(1); - auto* bn_scale = pattern->NewNode(bn_scale_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("batch_norm", "Scale") - ->assert_has_n_outputs(1); - auto* bn_var = pattern->NewNode(bn_var_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("batch_norm", "Variance") - ->assert_has_n_outputs(1); - auto* bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm"); - auto* bn_out = pattern->NewNode(bn_out_repr()) - ->assert_is_op_output("batch_norm", "Y") - ->assert_is_op_input(act_type_, "X"); - auto* bn_mean_out = pattern->NewNode(bn_mean_out_repr()) - ->assert_is_op_output("batch_norm", "MeanOut"); - auto* bn_saved_mean = pattern->NewNode(bn_saved_mean_repr()) - ->assert_is_op_output("batch_norm", "SavedMean"); - auto* bn_var_out = pattern->NewNode(bn_var_out_repr()) - ->assert_is_op_output("batch_norm", "VarianceOut"); - auto* bn_saved_var = pattern->NewNode(bn_saved_var_repr()) - ->assert_is_op_output("batch_norm", "SavedVariance"); - bn->LinksFrom({squeeze2_out, bn_bias, bn_mean, bn_scale, bn_var}) - .LinksTo({bn_out, bn_mean_out, bn_var_out, bn_saved_mean, bn_saved_var}); - - auto act = pattern->NewNode(act_repr())->assert_is_op(act_type_); - auto act_out = pattern->NewNode(act_out_repr()) - ->assert_is_op_output(act_type_, "Out") - ->assert_is_op_input("unsqueeze2", "X"); - act->LinksFrom({bn_out}).LinksTo({act_out}); - - auto* unsqueeze2 = - pattern->NewNode(unsqueeze2_repr()) - ->assert_is_op("unsqueeze2") - ->assert_more([](Node* node) { - auto* op_desc = node->Op(); - auto axes_array = - op_desc->GetAttrIfExists>("axes"); - return axes_array == std::vector{-2} || - axes_array == std::vector{2}; - }); - auto* unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr()) - ->assert_is_op_output("unsqueeze2", "Out"); - unsqueeze2->LinksFrom({act_out}).LinksTo({unsqueeze2_out}); -} - -} // namespace patterns - -void RedundantOnnxOpsEliminationPass::FoldConv1dSqueeze2Ops( - ir::Graph* graph, const std::string& act_type) const { - GraphPatternDetector gpd; - patterns::FoldConv1dSqueeze2Pattern pattern( - gpd.mutable_pattern(), name_scope_, act_type); - int found_subgraph_count = 0; - - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* graph) { - VLOG(4) << "handle FoldConv1dSqueeze2Ops"; - // declare operator node's name - GET_IR_NODE(squeeze2); - GET_IR_NODE(bn); - GET_IR_NODE(act); - GET_IR_NODE(unsqueeze2); - // declare variable node's name - GET_IR_NODE(x); - GET_IR_NODE(squeeze2_out); - GET_IR_NODE(bn_out); - GET_IR_NODE(act_out); - GET_IR_NODE(unsqueeze2_out); - - auto bn_op_desc = bn->Op(); - bn_op_desc->RenameInput(squeeze2_out->Var()->Name(), x->Var()->Name()); - bn_out->Var()->SetShape(x->Var()->GetShape()); - act_out->Var()->SetShape(x->Var()->GetShape()); - bn_op_desc->Flush(); - IR_NODE_LINK_TO(x, bn); - // behind unsqueeze op node - auto unsqueeze_out_link_nodes = unsqueeze2_out->outputs; - for (auto out_link_node : unsqueeze_out_link_nodes) { - auto op_desc = out_link_node->Op(); - op_desc->RenameInput(unsqueeze2_out->Var()->Name(), - act_out->Var()->Name()); - op_desc->Flush(); - IR_NODE_LINK_TO(act_out, out_link_node); - } - // delete useless node - std::unordered_set delete_nodes = { - squeeze2, squeeze2_out, unsqueeze2, unsqueeze2_out}; - GraphSafeRemoveNodes(graph, delete_nodes); - found_subgraph_count++; - }; - - gpd(graph, handler); - AddStatis(found_subgraph_count); -} - -void RedundantOnnxOpsEliminationPass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE_NOT_NULL( - graph, platform::errors::PreconditionNotMet("graph should not be null.")); - Init(name_scope_, graph); - for (auto act_type : {"leaky_relu", "elu"}) { - FoldConv1dSqueeze2Ops(graph, act_type); - } -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(redundant_onnx_ops_elimination_pass, - paddle::framework::ir::RedundantOnnxOpsEliminationPass); - -REGISTER_PASS_CAPABILITY(redundant_onnx_ops_elimination_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination().EQ( - "conv2d", 0)); diff --git a/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc b/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc new file mode 100644 index 00000000000..710fa94b4e0 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc @@ -0,0 +1,330 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h" +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct FoldTranspose2OpsPattern : public PatternBase { + FoldTranspose2OpsPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& act_type); + + // declare operator node's name + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(unsqueeze2); + PATTERN_DECL_NODE(reduce_sum); + PATTERN_DECL_NODE(act); + PATTERN_DECL_NODE(transpose2_2); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(unsqueeze2_out); + PATTERN_DECL_NODE(sum_out); + PATTERN_DECL_NODE(act_out); + PATTERN_DECL_NODE(transpose2_2_out); + + private: + std::string act_type_; +}; + +FoldTranspose2OpsPattern::FoldTranspose2OpsPattern( + PDPattern* pattern, + const std::string& name_scope, + const std::string& act_type) + : PatternBase(pattern, name_scope, name_scope), act_type_(act_type) { + auto* x = pattern->NewNode(x_repr()) + ->assert_is_op_input("transpose2", "X") + ->assert_more([](Node* node) { + auto x_shape = node->Var()->GetShape(); + size_t x_rank = x_shape.size(); + return x_rank == 3; + }); + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr()) + ->assert_is_op("transpose2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axis_array = + op_desc->GetAttrIfExists>("axis"); + return axis_array == std::vector{0, 2, 1}; + }); + auto* transpose2_1_out = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("unsqueeze2", "X"); + transpose2_1->LinksFrom({x}).LinksTo({transpose2_1_out}); + + auto* unsqueeze2 = + pattern->NewNode(unsqueeze2_repr()) + ->assert_is_op("unsqueeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axes_array = + op_desc->GetAttrIfExists>("axes"); + return axes_array == std::vector{-2}; + }); + auto* unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr()) + ->assert_is_op_output("unsqueeze2", "Out") + ->assert_is_op_input("reduce_sum", "X"); + unsqueeze2->LinksFrom({transpose2_1_out}).LinksTo({unsqueeze2_out}); + + auto* reduce_sum = + pattern->NewNode(reduce_sum_repr()) + ->assert_is_op("reduce_sum") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto keep_dim = op_desc->GetAttrIfExists("keep_dim"); + auto dim_array = op_desc->GetAttrIfExists>("dim"); + return dim_array == std::vector{-2} && !keep_dim; + }); + auto* sum_out = pattern->NewNode(sum_out_repr()) + ->assert_is_op_output("reduce_sum", "Out") + ->assert_is_op_input(act_type_, "X"); + reduce_sum->LinksFrom({unsqueeze2_out}).LinksTo({sum_out}); + + auto* act = pattern->NewNode(act_repr())->assert_is_op(act_type_); + auto* act_out = pattern->NewNode(act_out_repr()) + ->assert_is_op_output(act_type_, "Out") + ->assert_is_op_input("transpose2", "X"); + act->LinksFrom({sum_out}).LinksTo({act_out}); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr()) + ->assert_is_op("transpose2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axis_array = + op_desc->GetAttrIfExists>("axis"); + return axis_array == std::vector{0, 2, 1}; + }); + auto* transpose2_2_out = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2", "Out"); + transpose2_2->LinksFrom({act_out}).LinksTo({transpose2_2_out}); +} + +struct FoldGatherSqueeze2Pattern : public PatternBase { + FoldGatherSqueeze2Pattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(unsqueeze2_op); + PATTERN_DECL_NODE(gather_op); + PATTERN_DECL_NODE(squeeze2_op); + + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(unsqueeze2_op_out); + PATTERN_DECL_NODE(gather_i); + PATTERN_DECL_NODE(gather_op_out); + PATTERN_DECL_NODE(squeeze2_op_out); +}; + +FoldGatherSqueeze2Pattern::FoldGatherSqueeze2Pattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* x = pattern->NewNode(x_repr())->assert_is_op_input("unsqueeze2", "X"); + auto* unsqueeze2_op = + pattern->NewNode(unsqueeze2_op_repr()) + ->assert_is_op("unsqueeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axes_array = + op_desc->GetAttrIfExists>("axes"); + return axes_array.size() == 1; + }); + auto* unsqueeze2_op_out = pattern->NewNode(unsqueeze2_op_out_repr()) + ->assert_is_op_output("unsqueeze2", "Out") + ->assert_is_op_input("gather", "X"); + unsqueeze2_op->LinksFrom({x}).LinksTo({unsqueeze2_op_out}); + auto* gather_op = pattern->NewNode(gather_op_repr())->assert_is_op("gather"); + auto* gather_i = pattern->NewNode(gather_i_repr()) + ->assert_is_op_input("gather", "Index") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + auto i_shape = node->Var()->GetShape(); + size_t i_rank = i_shape.size(); + return i_rank == 1; + }); + auto* gather_op_out = pattern->NewNode(gather_op_out_repr()) + ->assert_is_op_output("gather", "Out") + ->assert_is_op_input("squeeze2", "X"); + gather_op->LinksFrom({unsqueeze2_op_out, gather_i}).LinksTo({gather_op_out}); + auto* squeeze2_op = + pattern->NewNode(squeeze2_op_repr()) + ->assert_is_op("squeeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axes_array = + op_desc->GetAttrIfExists>("axes"); + return axes_array.size() == 1; + }); + auto* squeeze2_op_out = pattern->NewNode(squeeze2_op_out_repr()) + ->assert_is_op_output("squeeze2", "Out"); + squeeze2_op->LinksFrom({gather_op_out}).LinksTo({squeeze2_op_out}); +} + +} // namespace patterns + +void RedundantUnsqueeze2EliminationPass::FoldTranspose2Ops( + ir::Graph* graph, const std::string& act_type) const { + GraphPatternDetector gpd; + patterns::FoldTranspose2OpsPattern pattern( + gpd.mutable_pattern(), name_scope_, act_type); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FoldTranspose2Ops"; + // declare operator node's name + GET_IR_NODE(transpose2_1); + GET_IR_NODE(unsqueeze2); + GET_IR_NODE(reduce_sum); + GET_IR_NODE(act); + GET_IR_NODE(transpose2_2); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(transpose2_1_out); + GET_IR_NODE(unsqueeze2_out); + GET_IR_NODE(sum_out); + GET_IR_NODE(act_out); + GET_IR_NODE(transpose2_2_out); + + auto act_op_desc = act->Op(); + act_op_desc->RenameInput(sum_out->Var()->Name(), x->Var()->Name()); + act_out->Var()->SetShape(x->Var()->GetShape()); + act_op_desc->Flush(); + IR_NODE_LINK_TO(x, act); + // behind unsqueeze op node + auto final_out_link_nodes = transpose2_2_out->outputs; + for (auto out_link_node : final_out_link_nodes) { + auto op_desc = out_link_node->Op(); + op_desc->RenameInput(transpose2_2_out->Var()->Name(), + act_out->Var()->Name()); + op_desc->Flush(); + IR_NODE_LINK_TO(act_out, out_link_node); + } + // delete useless node + std::unordered_set delete_nodes = {transpose2_1, + transpose2_1_out, + unsqueeze2, + unsqueeze2_out, + reduce_sum, + sum_out, + transpose2_2, + transpose2_2_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void RedundantUnsqueeze2EliminationPass::FoldGatherSqueeze2Ops( + ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::FoldGatherSqueeze2Pattern pattern(gpd.mutable_pattern(), + name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FoldGatherSqueeze2Ops"; + // declare operator node's name + GET_IR_NODE(unsqueeze2_op); + GET_IR_NODE(gather_op); + GET_IR_NODE(squeeze2_op); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(unsqueeze2_op_out); + GET_IR_NODE(gather_i); + GET_IR_NODE(gather_op_out); + GET_IR_NODE(squeeze2_op_out); + + bool flag = true; + auto x_shape = x->Var()->GetShape(); + auto x_rank = static_cast(x_shape.size()); + std::vector unsqueeze_axes_attr = PADDLE_GET_CONST( + std::vector, unsqueeze2_op->Op()->GetAttr("axes")); + auto unsqueeze_axes = unsqueeze_axes_attr.front(); + unsqueeze_axes = + unsqueeze_axes < 0 ? unsqueeze_axes + x_rank : unsqueeze_axes; + auto gather_axis = PADDLE_GET_CONST(int, gather_op->Op()->GetAttr("axis")); + gather_axis = gather_axis < 0 ? gather_axis + x_rank + 1 : gather_axis; + std::vector squeeze_axes_attr = + PADDLE_GET_CONST(std::vector, squeeze2_op->Op()->GetAttr("axes")); + auto squeeze_axes = squeeze_axes_attr.front(); + squeeze_axes = squeeze_axes < 0 ? squeeze_axes + x_rank + 1 : squeeze_axes; + flag &= (unsqueeze_axes >= 0 && unsqueeze_axes < x_rank); + flag &= + ((gather_axis == unsqueeze_axes + 1) && (squeeze_axes == gather_axis)); + if (!flag) return; + // x->gather->squeeze2_op_out + auto gather_op_desc = gather_op->Op(); + gather_op_desc->RenameInput(unsqueeze2_op_out->Var()->Name(), + x->Var()->Name()); + gather_op_desc->SetAttr("axis", gather_axis - 1); + gather_op_out->Var()->SetShape(squeeze2_op_out->Var()->GetShape()); + gather_op_desc->Flush(); + IR_NODE_LINK_TO(x, gather_op); + // behind squeeze op node + auto squeeze_out_link_nodes = squeeze2_op_out->outputs; + for (auto out_link_node : squeeze_out_link_nodes) { + auto op_desc = out_link_node->Op(); + op_desc->RenameInput(squeeze2_op_out->Var()->Name(), + gather_op_out->Var()->Name()); + op_desc->Flush(); + IR_NODE_LINK_TO(gather_op_out, out_link_node); + } + std::unordered_set delete_nodes{ + squeeze2_op, squeeze2_op_out, unsqueeze2_op, unsqueeze2_op_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void RedundantUnsqueeze2EliminationPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + for (auto act_type : {"relu"}) { + FoldTranspose2Ops(graph, act_type); + } + FoldGatherSqueeze2Ops(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(redundant_unsqueeze_squeeze_elimination_pass, + paddle::framework::ir::RedundantUnsqueeze2EliminationPass); + +REGISTER_PASS_CAPABILITY(redundant_unsqueeze_squeeze_elimination_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "conv2d", 0)); diff --git a/paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.h b/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h similarity index 60% rename from paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.h rename to paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h index ac7854761a9..04ed41e2b6d 100644 --- a/paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.h +++ b/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h @@ -31,51 +31,51 @@ namespace paddle { namespace framework { namespace ir { -class RedundantOnnxOpsEliminationPass : public FusePassBase { +class RedundantUnsqueeze2EliminationPass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; private: /* Origin subgraph: - x filter - | | - unsqueeze2(axes={-2}) unsqueeze2(axes={-2}) - \ / - \ / - conv2d(conv1d) + x | - elementwise_add + transpose2 | - squeeze2(axes={-2}) + unsqueeze2(axes={-2}) | - batch_norm + reduce_sum | act | - unsqueeze2 + transpose2 | - conv2d(conv1d) Fused subgraph: - x filter - | | - unsqueeze2(axes={-2}) unsqueeze2(axes={-2}) - \ / - \ / - conv2d(conv1d) + x | - elementwise_add + act | - batch_norm + */ + void FoldTranspose2Ops(ir::Graph* graph, const std::string& act_type) const; + /* + Origin subgraph: + x | - act + unsqueeze2(axes={-2}) + | + gather + | + squeeze2 + | + Fused subgraph: + x + | + gather | - conv2d(conv1d) */ - void FoldConv1dSqueeze2Ops(ir::Graph* graph, - const std::string& act_type) const; + void FoldGatherSqueeze2Ops(ir::Graph* graph) const; - const std::string name_scope_{"redundant_onnx_ops_elimination_pass"}; + const std::string name_scope_{"redundant_unsqueeze_squeeze_elimination_pass"}; }; } // namespace ir diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index c426d8e8759..cb94cf4a5a5 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -527,7 +527,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "fold_interp_outsize_fuse_pass", "fold_two_squeeze2_fuse_pass", "conv1d_xpu_fuse_pass", - "redundant_onnx_ops_elimination_pass", + "redundant_unsqueeze_squeeze_elimination_pass", "reduce_ops_fuse_pass", "delete_cast_op_pass", "xpu_delete_cast_op_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 09bd0ec82c2..211949bd2c0 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -15,7 +15,7 @@ optional : x_max, y_max - op : add_layernorm_xpu - args : (Tensor x, Tensor y, Tensor scale, Tensor bias, int64_t m, int64_t n, float epsilon) + args : (Tensor x, Tensor y, Tensor scale, Tensor bias, int begin_norm_axis, float epsilon) output : Tensor(out), Tensor(mean), Tensor(variance), Tensor(z_add) infer_meta : func : AddLayernormXPUInferMeta diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 71c78d0830b..5a36873049d 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -24,8 +24,7 @@ XPUOpMap& get_kl2_ops() { static XPUOpMap s_xpu2_kernels{ {"add_act_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"add_layernorm_xpu", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"add_layernorm_xpu", XPUKernelSet({phi::DataType::FLOAT32})}, {"abs", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"abs_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 11a9faec848..80a4c47511e 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -96,8 +96,7 @@ void AddLayernormXPUInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& scale, const MetaTensor& bias, - int64_t m, - int64_t n, + int begin_norm_axis, float epsilon, MetaTensor* out, MetaTensor* mean, @@ -106,12 +105,16 @@ void AddLayernormXPUInferMeta(const MetaTensor& x, int axis = -1; auto x_dims = x.dims(); auto y_dims = y.dims(); + auto out_dims = x_dims; if (x_dims != y_dims) { - auto out_dims = BroadCastInferShape(x_dims, y_dims, axis); + out_dims = BroadCastInferShape(x_dims, y_dims, axis); out->set_dims(out_dims); } else { - out->set_dims(x_dims); + out->set_dims(out_dims); } + auto layer_norm_x_mat_dims = phi::flatten_to_2d(out_dims, begin_norm_axis); + int64_t m = layer_norm_x_mat_dims[0]; + int64_t n = layer_norm_x_mat_dims[1]; out->set_dtype(x.dtype()); out->set_layout(x.layout()); out->share_lod(x); diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index be21fc80e73..605fb3dcaff 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -34,8 +34,7 @@ void AddLayernormXPUInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& scale, const MetaTensor& bias, - int64_t m, - int64_t n, + int begin_norm_axis, float epsilon, MetaTensor* out, MetaTensor* mean, diff --git a/paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc index 66220d11873..616e81c138c 100644 --- a/paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc @@ -13,19 +13,65 @@ // limitations under the License. #include "paddle/phi/backends/xpu/enforce_xpu.h" + +#include "glog/logging.h" + #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { namespace fusion { +static phi::DDim BroadCastInferShape(const DDim x_dims, + const DDim y_dims, + int axis) { + std::vector out_dims_array(x_dims.size(), -1); + if (x_dims != y_dims) { + int max_dim = std::max(x_dims.size(), y_dims.size()); + if (x_dims.size() == y_dims.size()) { + PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0), + true, + phi::errors::InvalidArgument( + "axis should be -1 or 0 while the dimension of " + "tensor X (%s) is equal to the dimension of " + "tensor Y (%s), but received axis: %s", + x_dims.size(), + y_dims.size(), + axis)); + } + PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim), + true, + phi::errors::InvalidArgument( + "The axis range must be [%s, %s), but axis is %s. " + "Please set the axis again.", + -1 * max_dim, + max_dim, + axis)); + axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) + : axis); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + out_dims_array.resize(max_dim); + phi::funcs::GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + + return phi::make_ddim(out_dims_array); + } + return x_dims; +} + template void AddLayernormXPUKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& scale, const DenseTensor& bias, - int64_t m, - int64_t n, + int begin_norm_axis, float epsilon, DenseTensor* out, DenseTensor* mean, @@ -37,12 +83,19 @@ void AddLayernormXPUKernel(const Context& ctx, auto* y_data = reinterpret_cast(y.data()); const float* scale_data = scale.data(); const float* bias_data = bias.data(); - - auto* out_data = reinterpret_cast(ctx.template Alloc(out)); float* mean_data = ctx.template Alloc(mean); float* variance_data = ctx.template Alloc(variance); auto* z_add_data = reinterpret_cast(ctx.template Alloc(z_add)); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto out_dims = BroadCastInferShape(x_dims, y_dims, -1); + auto layer_norm_x_mat_dims = phi::flatten_to_2d(out_dims, begin_norm_axis); + int64_t m = layer_norm_x_mat_dims[0]; + int64_t n = layer_norm_x_mat_dims[1]; + + auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + int r = xpu::add_layer_norm_fusion( // T /* baidu::xpu::api::Context* ctx */ ctx.x_context(), /* const T* x */ x_data, @@ -66,5 +119,4 @@ PD_REGISTER_KERNEL(add_layernorm_xpu, XPU, ALL_LAYOUT, phi::fusion::AddLayernormXPUKernel, - float, - phi::dtype::float16) {} + float) {} -- GitLab