From b8f265d232454f41b9dc9e9162b4d11a04e08db5 Mon Sep 17 00:00:00 2001 From: wz1qqx <55830058+wz1qqx@users.noreply.github.com> Date: Fri, 7 Jul 2023 01:48:28 -0700 Subject: [PATCH] [XPU] Eliminate small ops (#55193) --- paddle/fluid/framework/ir/CMakeLists.txt | 4 +- ...x_fuse_pass.cc => reduce_ops_fuse_pass.cc} | 138 +++++++++++- .../framework/ir/xpu/reduce_ops_fuse_pass.h | 82 +++++++ .../redundant_onnx_ops_elimination_pass.cc | 209 ++++++++++++++++++ ... => redundant_onnx_ops_elimination_pass.h} | 60 +++-- .../inference/api/paddle_pass_builder.cc | 3 +- ...ss.py => test_xpu_reduce_ops_fuse_pass.py} | 4 +- 7 files changed, 466 insertions(+), 34 deletions(-) rename paddle/fluid/framework/ir/xpu/{reduce_max_fuse_pass.cc => reduce_ops_fuse_pass.cc} (60%) create mode 100644 paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc rename paddle/fluid/framework/ir/xpu/{reduce_max_fuse_pass.h => redundant_onnx_ops_elimination_pass.h} (55%) rename test/ir/inference/{test_xpu_reduce_max_fuse_pass.py => test_xpu_reduce_ops_fuse_pass.py} (97%) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 1c186373cdb..dc16744fbe2 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -238,6 +238,8 @@ if(WITH_XPU) set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils) pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(redundant_onnx_ops_elimination_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) pass_library(redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_transpose_xpu_fuse_pass inference DIR xpu DEPS @@ -271,7 +273,7 @@ if(WITH_XPU) ${XPU_PASS_DEPS}) pass_library(fold_two_squeeze2_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) - pass_library(reduce_max_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(reduce_ops_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(matmul_weight_trans_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) diff --git a/paddle/fluid/framework/ir/xpu/reduce_max_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc similarity index 60% rename from paddle/fluid/framework/ir/xpu/reduce_max_fuse_pass.cc rename to paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc index 31db4cdcbd3..501926d8b91 100644 --- a/paddle/fluid/framework/ir/xpu/reduce_max_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/xpu/reduce_max_fuse_pass.h" +#include "paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.h" #include #include "glog/logging.h" @@ -137,9 +137,88 @@ ReduceMaxFusePattern::ReduceMaxFusePattern(PDPattern* pattern, transpose2_2->LinksFrom({squeeze2_out}).LinksTo({transpose2_2_out}); } +struct ReduceMeanFusePattern : public PatternBase { + ReduceMeanFusePattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(unsqueeze2); + PATTERN_DECL_NODE(pool2d); + PATTERN_DECL_NODE(squeeze2); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(unsqueeze2_out); + PATTERN_DECL_NODE(pool2d_out); + PATTERN_DECL_NODE(squeeze2_out); +}; + +ReduceMeanFusePattern::ReduceMeanFusePattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* x = pattern->NewNode(x_repr()) + ->assert_is_op_input("unsqueeze2", "X") + ->assert_more([](Node* node) { + auto x_shape = node->Var()->GetShape(); + size_t x_rank = x_shape.size(); + return x_rank == 3; + }); + auto* unsqueeze2 = + pattern->NewNode(unsqueeze2_repr()) + ->assert_is_op("unsqueeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axes_array = + op_desc->GetAttrIfExists>("axes"); + return axes_array == std::vector{2}; + }); + auto* unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr()) + ->assert_is_op_output("unsqueeze2", "Out") + ->assert_is_op_input("pool2d", "X"); + + auto* pool2d = + pattern->NewNode(pool2d_repr()) + ->assert_is_op("pool2d") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto input_var = node->inputs[0]->Var(); + auto pool2d_x_shape = input_var->GetShape(); + std::vector HW = {static_cast(pool2d_x_shape[2]), + static_cast(pool2d_x_shape[3])}; + auto pool_type = + op_desc->GetAttrIfExists("pooling_type"); + auto ksize_array = + op_desc->GetAttrIfExists>("ksize"); + auto strides_array = + op_desc->GetAttrIfExists>("strides"); + auto paddings_array = + op_desc->GetAttrIfExists>("paddings"); + return pool_type == "avg" && ksize_array == HW && + strides_array == HW && + paddings_array == std::vector{0, 0}; + }); + auto* pool2d_out = pattern->NewNode(pool2d_out_repr()) + ->assert_is_op_output("pool2d", "Out") + ->assert_is_op_input("squeeze2", "X"); + + auto* squeeze2 = pattern->NewNode(squeeze2_repr()) + ->assert_is_op("squeeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axes_array = + op_desc->GetAttrIfExists>("axes"); + return axes_array == std::vector{2}; + }); + auto* squeeze2_out = pattern->NewNode(squeeze2_out_repr()) + ->assert_is_op_output("squeeze2", "Out") + ->assert_is_op_input("transpose2", "X"); + + unsqueeze2->LinksFrom({x}).LinksTo({unsqueeze2_out}); + pool2d->LinksFrom({unsqueeze2_out}).LinksTo({pool2d_out}); + squeeze2->LinksFrom({pool2d_out}).LinksTo({squeeze2_out}); +} + } // namespace patterns -void ReduceMaxFusePass::FuseReduceMax(ir::Graph* graph) const { +void ReduceOpsFusePass::FuseReduceMax(ir::Graph* graph) const { GraphPatternDetector gpd; patterns::ReduceMaxFusePattern pattern(gpd.mutable_pattern(), name_scope_); int found_subgraph_count = 0; @@ -193,21 +272,66 @@ void ReduceMaxFusePass::FuseReduceMax(ir::Graph* graph) const { AddStatis(found_subgraph_count); } -void ReduceMaxFusePass::ApplyImpl(ir::Graph* graph) const { +void ReduceOpsFusePass::FuseReduceMean(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::ReduceMeanFusePattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FuseReduceMean"; + // declare operator node's name + GET_IR_NODE(unsqueeze2); + GET_IR_NODE(pool2d); + GET_IR_NODE(squeeze2); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(unsqueeze2_out); + GET_IR_NODE(pool2d_out); + GET_IR_NODE(squeeze2_out); + + auto* block = pool2d->Op()->Block(); + // Generate reduce_mean op + framework::OpDesc reduce_op_desc(block); + reduce_op_desc.SetType("reduce_mean"); + reduce_op_desc.SetInput("X", {x->Name()}); + reduce_op_desc.SetAttr("dim", std::vector{-2}); + reduce_op_desc.SetAttr("reduce_all", false); + reduce_op_desc.SetAttr("keep_dim", true); + reduce_op_desc.SetOutput("Out", {squeeze2_out->Name()}); + + auto* reduce_op = graph->CreateOpNode(&reduce_op_desc); + + IR_NODE_LINK_TO(x, reduce_op); + IR_NODE_LINK_TO(reduce_op, squeeze2_out); + // delete useless node + std::unordered_set delete_nodes = { + unsqueeze2, unsqueeze2_out, pool2d, pool2d_out, squeeze2}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void ReduceOpsFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); FuseReduceMax(graph); + FuseReduceMean(graph); } } // namespace ir } // namespace framework } // namespace paddle -REGISTER_PASS(reduce_max_fuse_pass, paddle::framework::ir::ReduceMaxFusePass); +REGISTER_PASS(reduce_ops_fuse_pass, paddle::framework::ir::ReduceOpsFusePass); -REGISTER_PASS_CAPABILITY(reduce_max_fuse_pass) +REGISTER_PASS_CAPABILITY(reduce_ops_fuse_pass) .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination().EQ( - "reduce_max", 0)); + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("reduce_max", 0) + .EQ("reduce_mean", 0)); diff --git a/paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.h b/paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.h new file mode 100644 index 00000000000..abc560b70eb --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class ReduceOpsFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + /* + fuse series small ops to reduce_max op + For example: + graph: + x + | + transpose2 + | + unsqueeze2 + | + pool2d(pooling_type : max) + | + squeeze2 + | + transpose2 + | + ------------------------------------------------------ + After the pass is applied: + x + | + reduce_max + | + */ + void FuseReduceMax(ir::Graph* graph) const; + + /* + Origin subgraph: + unsqueeze2 + | + pool2d(avg) + | + squeeze2 + + Fused subgraph: + reduce_mean + */ + void FuseReduceMean(ir::Graph* graph) const; + + const std::string name_scope_{"reduce_ops_fuse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc b/paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc new file mode 100644 index 00000000000..f63c51e36db --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc @@ -0,0 +1,209 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.h" +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct FoldConv1dSqueeze2Pattern : public PatternBase { + FoldConv1dSqueeze2Pattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& act_type); + + // declare operator node's name + PATTERN_DECL_NODE(squeeze2); + PATTERN_DECL_NODE(bn); + PATTERN_DECL_NODE(act); + PATTERN_DECL_NODE(unsqueeze2); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(squeeze2_out); + PATTERN_DECL_NODE(bn_bias); + PATTERN_DECL_NODE(bn_mean); + PATTERN_DECL_NODE(bn_scale); + PATTERN_DECL_NODE(bn_var); + PATTERN_DECL_NODE(bn_out); + PATTERN_DECL_NODE(bn_mean_out); + PATTERN_DECL_NODE(bn_saved_mean); + PATTERN_DECL_NODE(bn_saved_var); + PATTERN_DECL_NODE(bn_var_out); + PATTERN_DECL_NODE(act_out); + PATTERN_DECL_NODE(unsqueeze2_out); + + private: + std::string act_type_; +}; + +FoldConv1dSqueeze2Pattern::FoldConv1dSqueeze2Pattern( + PDPattern* pattern, + const std::string& name_scope, + const std::string& act_type) + : PatternBase(pattern, name_scope, name_scope), act_type_(act_type) { + auto* x = pattern->NewNode(x_repr()) + ->assert_is_op_input("squeeze2", "X") + ->assert_more([](Node* node) { + auto x_shape = node->Var()->GetShape(); + size_t x_rank = x_shape.size(); + return x_rank == 4 && x_shape[2] == 1; + }); + auto* squeeze2 = pattern->NewNode(squeeze2_repr()) + ->assert_is_op("squeeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axes_array = + op_desc->GetAttrIfExists>("axes"); + return axes_array == std::vector{-2}; + }); + auto* squeeze2_out = pattern->NewNode(squeeze2_out_repr()) + ->assert_is_op_output("squeeze2", "Out") + ->assert_is_op_input("batch_norm", "X"); + squeeze2->LinksFrom({x}).LinksTo({squeeze2_out}); + + auto* bn_bias = pattern->NewNode(bn_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Bias") + ->assert_has_n_outputs(1); + auto* bn_mean = pattern->NewNode(bn_mean_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Mean") + ->assert_has_n_outputs(1); + auto* bn_scale = pattern->NewNode(bn_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Scale") + ->assert_has_n_outputs(1); + auto* bn_var = pattern->NewNode(bn_var_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Variance") + ->assert_has_n_outputs(1); + auto* bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm"); + auto* bn_out = pattern->NewNode(bn_out_repr()) + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input(act_type_, "X"); + auto* bn_mean_out = pattern->NewNode(bn_mean_out_repr()) + ->assert_is_op_output("batch_norm", "MeanOut"); + auto* bn_saved_mean = pattern->NewNode(bn_saved_mean_repr()) + ->assert_is_op_output("batch_norm", "SavedMean"); + auto* bn_var_out = pattern->NewNode(bn_var_out_repr()) + ->assert_is_op_output("batch_norm", "VarianceOut"); + auto* bn_saved_var = pattern->NewNode(bn_saved_var_repr()) + ->assert_is_op_output("batch_norm", "SavedVariance"); + bn->LinksFrom({squeeze2_out, bn_bias, bn_mean, bn_scale, bn_var}) + .LinksTo({bn_out, bn_mean_out, bn_var_out, bn_saved_mean, bn_saved_var}); + + auto act = pattern->NewNode(act_repr())->assert_is_op(act_type_); + auto act_out = pattern->NewNode(act_out_repr()) + ->assert_is_op_output(act_type_, "Out") + ->assert_is_op_input("unsqueeze2", "X"); + act->LinksFrom({bn_out}).LinksTo({act_out}); + + auto* unsqueeze2 = + pattern->NewNode(unsqueeze2_repr()) + ->assert_is_op("unsqueeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axes_array = + op_desc->GetAttrIfExists>("axes"); + return axes_array == std::vector{-2} || + axes_array == std::vector{2}; + }); + auto* unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr()) + ->assert_is_op_output("unsqueeze2", "Out"); + unsqueeze2->LinksFrom({act_out}).LinksTo({unsqueeze2_out}); +} + +} // namespace patterns + +void RedundantOnnxOpsEliminationPass::FoldConv1dSqueeze2Ops( + ir::Graph* graph, const std::string& act_type) const { + GraphPatternDetector gpd; + patterns::FoldConv1dSqueeze2Pattern pattern( + gpd.mutable_pattern(), name_scope_, act_type); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FoldConv1dSqueeze2Ops"; + // declare operator node's name + GET_IR_NODE(squeeze2); + GET_IR_NODE(bn); + GET_IR_NODE(act); + GET_IR_NODE(unsqueeze2); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(squeeze2_out); + GET_IR_NODE(bn_out); + GET_IR_NODE(act_out); + GET_IR_NODE(unsqueeze2_out); + + auto bn_op_desc = bn->Op(); + bn_op_desc->RenameInput(squeeze2_out->Var()->Name(), x->Var()->Name()); + bn_out->Var()->SetShape(x->Var()->GetShape()); + act_out->Var()->SetShape(x->Var()->GetShape()); + bn_op_desc->Flush(); + IR_NODE_LINK_TO(x, bn); + // behind unsqueeze op node + auto unsqueeze_out_link_nodes = unsqueeze2_out->outputs; + for (auto out_link_node : unsqueeze_out_link_nodes) { + auto op_desc = out_link_node->Op(); + op_desc->RenameInput(unsqueeze2_out->Var()->Name(), + act_out->Var()->Name()); + op_desc->Flush(); + IR_NODE_LINK_TO(act_out, out_link_node); + } + // delete useless node + std::unordered_set delete_nodes = { + squeeze2, squeeze2_out, unsqueeze2, unsqueeze2_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void RedundantOnnxOpsEliminationPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + for (auto act_type : {"leaky_relu", "elu"}) { + FoldConv1dSqueeze2Ops(graph, act_type); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(redundant_onnx_ops_elimination_pass, + paddle::framework::ir::RedundantOnnxOpsEliminationPass); + +REGISTER_PASS_CAPABILITY(redundant_onnx_ops_elimination_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "conv2d", 0)); diff --git a/paddle/fluid/framework/ir/xpu/reduce_max_fuse_pass.h b/paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.h similarity index 55% rename from paddle/fluid/framework/ir/xpu/reduce_max_fuse_pass.h rename to paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.h index 3c82f918e0c..ac7854761a9 100644 --- a/paddle/fluid/framework/ir/xpu/reduce_max_fuse_pass.h +++ b/paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.h @@ -30,38 +30,52 @@ class Scope; namespace paddle { namespace framework { namespace ir { -/* -fuse series small ops to reduce_max op -For example: -graph: - x + +class RedundantOnnxOpsEliminationPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + /* + Origin subgraph: + x filter + | | + unsqueeze2(axes={-2}) unsqueeze2(axes={-2}) + \ / + \ / + conv2d(conv1d) | - transpose2 + elementwise_add | - unsqueeze2 + squeeze2(axes={-2}) + | + batch_norm | - pool2d(pooling_type : max) + act | - squeeze2 + unsqueeze2 | - transpose2 + conv2d(conv1d) + Fused subgraph: + x filter + | | + unsqueeze2(axes={-2}) unsqueeze2(axes={-2}) + \ / + \ / + conv2d(conv1d) | ------------------------------------------------------- -After the pass is applied: - x + elementwise_add | - reduce_max + batch_norm | -*/ - -class ReduceMaxFusePass : public FusePassBase { - protected: - void ApplyImpl(ir::Graph* graph) const override; - - private: - void FuseReduceMax(ir::Graph* graph) const; + act + | + conv2d(conv1d) + */ + void FoldConv1dSqueeze2Ops(ir::Graph* graph, + const std::string& act_type) const; - const std::string name_scope_{"reduce_max_fuse_pass"}; + const std::string name_scope_{"redundant_onnx_ops_elimination_pass"}; }; } // namespace ir diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 3d516a2ee17..3542322bbb9 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -526,7 +526,8 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "one_beam_size_fuse_pass", "fold_interp_outsize_fuse_pass", "fold_two_squeeze2_fuse_pass", - "reduce_max_fuse_pass", + "redundant_onnx_ops_elimination_pass", + "reduce_ops_fuse_pass", "delete_cast_op_pass", "xpu_delete_cast_op_pass", "stack_fuse_pass", diff --git a/test/ir/inference/test_xpu_reduce_max_fuse_pass.py b/test/ir/inference/test_xpu_reduce_ops_fuse_pass.py similarity index 97% rename from test/ir/inference/test_xpu_reduce_max_fuse_pass.py rename to test/ir/inference/test_xpu_reduce_ops_fuse_pass.py index 83951980917..12d5cc92f01 100644 --- a/test/ir/inference/test_xpu_reduce_max_fuse_pass.py +++ b/test/ir/inference/test_xpu_reduce_ops_fuse_pass.py @@ -20,7 +20,7 @@ from auto_scan_test import PassAutoScanTest from program_config import OpConfig, ProgramConfig, TensorConfig -class TestFcFusePass(PassAutoScanTest): +class TestReduceMaxFusePass(PassAutoScanTest): def sample_predictor_configs(self, program_config): config = self.create_inference_config(use_xpu=True) yield config, ["reduce_max"], (1e-3, 1e-3) @@ -101,7 +101,7 @@ class TestFcFusePass(PassAutoScanTest): self.run_and_statis( quant=False, max_examples=25, - passes=["reduce_max_fuse_pass"], + passes=["reduce_ops_fuse_pass"], ) -- GitLab