From 160b3477745141b49a41c616af5b0c272c886a12 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 29 Dec 2020 18:55:41 +0800 Subject: [PATCH] [cherry-pick] map matmul/squeeze2+matmul/reshape2+matmul to mul #29911 (#29980) --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/graph_pattern_detector.cc | 59 +++++ .../framework/ir/graph_pattern_detector.h | 46 +++- .../framework/ir/map_matmul_to_mul_pass.cc | 249 ++++++++++++++++++ .../framework/ir/map_matmul_to_mul_pass.h | 106 ++++++++ .../framework/ir/mkldnn/cpu_quantize_pass.cc | 2 +- .../inference/api/paddle_pass_builder.cc | 19 +- 7 files changed, 474 insertions(+), 8 deletions(-) create mode 100644 paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc create mode 100644 paddle/fluid/framework/ir/map_matmul_to_mul_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 760e237bcc1..548a00d67df 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -60,6 +60,7 @@ pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) pass_library(lock_free_optimize_pass base) pass_library(fc_fuse_pass inference) +pass_library(map_matmul_to_mul_pass inference) pass_library(attention_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 2a72642b17d..22f6388597c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1555,6 +1555,65 @@ PDNode *patterns::Reshape::operator()() { } PDNode *patterns::Matmul::operator()() { + auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + + auto matmul_in_x = pattern->NewNode(matmul_in_x_repr()) + ->AsInput() + ->assert_is_op_input("matmul", "X"); + auto matmul_in_y = pattern->NewNode(matmul_in_y_repr()) + ->AsInput() + ->assert_is_op_input("matmul", "Y"); + auto matmul_out = pattern->NewNode(matmul_out_repr()) + ->AsOutput() + ->assert_is_op_output("matmul", "Out"); + + matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out}); + return matmul_out; +} + +PDNode *patterns::Squeeze2Matmul::operator()() { + auto squeeze2_in_x = pattern->NewNode(squeeze2_in_x_repr()) + ->assert_is_op_input("squeeze2", "X") + ->AsInput(); + auto squeeze2_op = + pattern->NewNode(squeeze2_op_repr())->assert_is_op("squeeze2"); + auto matmul_in_x = pattern->NewNode(matmul_in_x_repr()) + ->assert_is_op_output("squeeze2", "Out") + ->assert_is_op_input("matmul", "X"); + auto matmul_in_y = + pattern->NewNode(matmul_in_y_repr())->assert_is_op_input("matmul", "Y"); + auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + auto matmul_out = pattern->NewNode(matmul_out_repr()) + ->AsOutput() + ->assert_is_op_output("matmul", "Out"); + + squeeze2_op->LinksFrom({squeeze2_in_x}).LinksTo({matmul_in_x}); + matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out}); + return matmul_out; +} + +PDNode *patterns::Reshape2Matmul::operator()() { + auto reshape2_in_x = pattern->NewNode(reshape2_in_x_repr()) + ->assert_is_op_input("reshape2", "X") + ->AsInput(); + auto reshape2_op = + pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2"); + auto matmul_in_x = pattern->NewNode(matmul_in_x_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("matmul", "X"); + auto matmul_in_y = + pattern->NewNode(matmul_in_y_repr())->assert_is_op_input("matmul", "Y"); + auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + auto matmul_out = pattern->NewNode(matmul_out_repr()) + ->AsOutput() + ->assert_is_op_output("matmul", "Out"); + + reshape2_op->LinksFrom({reshape2_in_x}).LinksTo({matmul_in_x}); + matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out}); + return matmul_out; +} + +PDNode *patterns::MatmulWithInputOps::operator()() { auto prev_op_x = pattern->NewNode(prev_op_x_repr())->assert_is_op(); auto prev_op_y = pattern->NewNode(prev_op_y_repr())->assert_is_op(); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index a1e7435523c..83feaa3a4bf 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -940,10 +940,52 @@ struct Reshape : public PatternBase { // Matmul op // Forward pass for matmul. -// matmul_out is a result of the operator. struct Matmul : public PatternBase { Matmul(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "reshape2") {} + : PatternBase(pattern, name_scope, "matmul") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(matmul_in_x); + PATTERN_DECL_NODE(matmul_in_y); + PATTERN_DECL_NODE(matmul_op); + PATTERN_DECL_NODE(matmul_out); +}; + +// Squeeze2 + Matmul +// Forward pass. +struct Squeeze2Matmul : public PatternBase { + Squeeze2Matmul(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "squeeze2_matmul") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(squeeze2_in_x); + PATTERN_DECL_NODE(squeeze2_op); + PATTERN_DECL_NODE(matmul_in_x); + PATTERN_DECL_NODE(matmul_in_y); + PATTERN_DECL_NODE(matmul_op); + PATTERN_DECL_NODE(matmul_out); +}; + +// Reshape2 + Matmul +// Forward pass. +struct Reshape2Matmul : public PatternBase { + Reshape2Matmul(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "reshape2_matmul") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(reshape2_in_x); + PATTERN_DECL_NODE(reshape2_op); + PATTERN_DECL_NODE(matmul_in_x); + PATTERN_DECL_NODE(matmul_in_y); + PATTERN_DECL_NODE(matmul_op); + PATTERN_DECL_NODE(matmul_out); +}; + +// Forward pass for two input ops and matmul op. +// matmul_out is a result of the operator. +struct MatmulWithInputOps : public PatternBase { + MatmulWithInputOps(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "matmul_with_input_ops") {} PDNode* operator()(); PATTERN_DECL_NODE(prev_op_x); diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc new file mode 100644 index 00000000000..76148a90074 --- /dev/null +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -0,0 +1,249 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/map_matmul_to_mul_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "map_matmul_to_mul_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::Matmul matmul_pattern(gpd.mutable_pattern(), name_scope); + matmul_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "map matmul to mul"; + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); + bool flag = true; + + bool transpose_X = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X")); + bool transpose_Y = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y")); + float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha")); + flag = flag && !transpose_X && !transpose_Y && std::abs(alpha - 1.0) < 1e-5; + + std::vector x_shape = matmul_in_x->Var()->GetShape(); + std::vector y_shape = matmul_in_y->Var()->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + flag = flag && x_rank == 2 && y_rank == 2; + + std::vector& next_ops = matmul_out->outputs; + flag = flag && next_ops.size() == 1 && + next_ops[0]->Name() == "elementwise_add"; + + if (flag) { + OpDesc desc; + desc.SetType("mul"); + desc.SetInput("X", {matmul_in_x->Name()}); + desc.SetInput("Y", {matmul_in_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("x_num_col_dims", 1); + desc.SetAttr("y_num_col_dims", 1); + + auto mul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(matmul_in_x, mul_node); + IR_NODE_LINK_TO(matmul_in_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_out); + GraphSafeRemoveNodes(graph, {matmul_op}); + ++found_count; + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "squeeze2_matmul_fuse_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::Squeeze2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope); + fuse_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "fuse squeeze2+matmul to mul"; + GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern); + bool flag = true; + + size_t squeeze2_in_x_rank = (squeeze2_in_x->Var()->GetShape()).size(); + std::vector squeeze2_op_axes = + BOOST_GET_CONST(std::vector, squeeze2_op->Op()->GetAttr("axes")); + flag = flag && squeeze2_in_x_rank == 4 && + squeeze2_op_axes == std::vector{2, 3} && + (matmul_in_x->outputs).size() == 1; + + bool transpose_X = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X")); + bool transpose_Y = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y")); + float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha")); + size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size(); + size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size(); + flag = flag && !transpose_X && !transpose_Y && + std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 && + matmul_in_y_rank == 2; + + std::vector& next_ops = matmul_out->outputs; + flag = flag && next_ops.size() == 1 && + next_ops[0]->Name() == "elementwise_add"; + + if (flag) { + OpDesc desc; + desc.SetType("mul"); + desc.SetInput("X", {squeeze2_in_x->Name()}); + desc.SetInput("Y", {matmul_in_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("x_num_col_dims", 1); + desc.SetAttr("y_num_col_dims", 1); + + auto mul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(squeeze2_in_x, mul_node); + IR_NODE_LINK_TO(matmul_in_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_out); + GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op}); + ++found_count; + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "reshape2_matmul_fuse_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::Reshape2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope); + fuse_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "fuse reshape2+matmul to mul"; + GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern); + bool flag = true; + + size_t reshape2_in_nums = reshape2_op->inputs.size(); + auto reshape2_in_x_shape = reshape2_in_x->Var()->GetShape(); + size_t reshape2_in_x_rank = reshape2_in_x_shape.size(); + std::vector reshape2_op_shape = + BOOST_GET_CONST(std::vector, reshape2_op->Op()->GetAttr("shape")); + flag = flag && reshape2_in_nums == 1 && reshape2_in_x_rank == 4 && + reshape2_in_x_shape[2] == 1 && reshape2_in_x_shape[3] == 1 && + reshape2_op_shape.size() == 2 && (matmul_in_x->outputs).size() == 1; + + bool transpose_X = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X")); + bool transpose_Y = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y")); + float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha")); + size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size(); + size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size(); + flag = flag && !transpose_X && !transpose_Y && + std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 && + matmul_in_y_rank == 2; + + std::vector& next_ops = matmul_out->outputs; + flag = flag && next_ops.size() == 1 && + next_ops[0]->Name() == "elementwise_add"; + + if (flag) { + OpDesc desc; + desc.SetType("mul"); + desc.SetInput("X", {reshape2_in_x->Name()}); + desc.SetInput("Y", {matmul_in_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("x_num_col_dims", 1); + desc.SetAttr("y_num_col_dims", 1); + + auto mul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(reshape2_in_x, mul_node); + IR_NODE_LINK_TO(matmul_in_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_out); + GraphSafeRemoveNodes(graph, {reshape2_op, matmul_in_x, matmul_op}); + ++found_count; + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass); +REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul", 0) + .EQ("mul", 0)); + +REGISTER_PASS(squeeze2_matmul_fuse_pass, + paddle::framework::ir::Squeeze2MatmulFusePass); +REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul", 0) + .EQ("squeeze2", 0) + .EQ("mul", 0)); + +REGISTER_PASS(reshape2_matmul_fuse_pass, + paddle::framework::ir::Reshape2MatmulFusePass); +REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul", 0) + .EQ("reshape2", 0) + .EQ("mul", 0)); diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h new file mode 100644 index 00000000000..1c89c97f96e --- /dev/null +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h @@ -0,0 +1,106 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Map matmul to mul, so the optimization can use fc_fuse_pass. + * The mul op must satisfy the following conditions: + * 1. the transpose_X and transpose_Y attrs are false + * 2. the alpha attr is 1.0 + * 3. the rank of input X and Y is 2 + * 4. the next op of matmul is only elementwise_add + * + * Notice: + * the rank of input activation is obtained from var_desc, + * it maybe change in runtime. + */ +class Graph; + +class MapMatmul2MulPass : public FusePassBase { + public: + virtual ~MapMatmul2MulPass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +/* + * Fuse squeeze2+matmul to mul, so the optimization can use fc_fuse_pass. + * The squeeze2 op must satisfy the following conditions: + * 1. the rank of input X is 4 + * 2. the axis attr is [2, 3] + * 3. the next op is only matmul + * + * The matmul op must satisfy the following conditions: + * 1. the transpose_X and transpose_Y attrs are false + * 2. the alpha attr is 1.0 + * 3. the rank of input X and Y is 2 + * 4. the next op of matmul is only elementwise_add + * + * Notice: + * the rank of input activation is obtained from var_desc, + * it maybe change in runtime. Therefore, the pass considers + * the above passes to reduce the impact on other models. + */ + +class Squeeze2MatmulFusePass : public FusePassBase { + public: + virtual ~Squeeze2MatmulFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +/* + * Fuse reshape2+matmul to mul, so the optimization can use fc_fuse_pass. + * The reshape2 op must satisfy the following conditions: + * 1. reshape2 has one input node, which means it don't + * have Shape or ShapeTensor input + * 2. the rank of input X is 4 and the last two dims of input X is 1 + * 3. the rank of shape attr is 2 + * 4. the next op is only matmul + * + * The matmul op must satisfy the following conditions: + * 1. the transpose_X and transpose_Y attrs are false + * 2. the alpha attr is 1.0 + * 3. the rank of input X and Y is 2 + * 4. the next op of matmul is only elementwise_add + * + * Notice: + * the shape and rank of input activation is obtained from var_desc, + * they maybe change in runtime. Therefore, the pass considers + * the above passes to reduce the impact on other models. + */ + +class Reshape2MatmulFusePass : public FusePassBase { + public: + virtual ~Reshape2MatmulFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index c7c4a1cf238..3c06c9ee41d 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -679,7 +679,7 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const { void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); - patterns::Matmul matmul_pattern{pattern, name_scope_}; + patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_}; matmul_pattern(); int quantize_matmul_count = 0; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 6c255b67199..95303e7e850 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -82,11 +82,14 @@ const std::vector kTRTSubgraphPasses({ "embedding_eltwise_layernorm_fuse_pass", // "multihead_matmul_fuse_pass_v2", // "skip_layernorm_fuse_pass", // - "unsqueeze2_eltwise_fuse_pass", - "conv_bn_fuse_pass", // - "fc_fuse_pass", // - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + "unsqueeze2_eltwise_fuse_pass", // + "conv_bn_fuse_pass", // + "squeeze2_matmul_fuse_pass", // + "reshape2_matmul_fuse_pass", // + "map_matmul_to_mul_pass", // + "fc_fuse_pass", // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 "conv_elementwise_add_act_fuse_pass", // @@ -113,6 +116,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_eltwiseadd_bn_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", // "multihead_matmul_fuse_pass_v2", // + "squeeze2_matmul_fuse_pass", // + "reshape2_matmul_fuse_pass", // + "map_matmul_to_mul_pass", // "fc_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be @@ -164,6 +170,9 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "fc_gru_fuse_pass", // "mul_gru_fuse_pass", // "seq_concat_fc_fuse_pass", // + "squeeze2_matmul_fuse_pass", // + "reshape2_matmul_fuse_pass", // + "map_matmul_to_mul_pass", // "fc_fuse_pass", // "repeated_fc_relu_fuse_pass", // "squared_mat_sub_fuse_pass", // -- GitLab