From d0cf9d9dc5d99f758cacab11c96f624d05bfd374 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Fri, 5 Aug 2022 11:57:41 +0200 Subject: [PATCH] Merge matmul_v1 and matmul_v2 fuse passes (#44870) * remove v2_transpose_reshape * matmul_transpose_reshape * reshape_transpose_matmul * restore ut * adjust old ut * restore parallel UT ruels * feedback from review --- paddle/fluid/framework/ir/CMakeLists.txt | 13 +- ...mul_transpose_reshape_mkldnn_fuse_pass.cc} | 183 ++++++++++-------- ...tmul_transpose_reshape_mkldnn_fuse_pass.h} | 11 +- ...nspose_reshape_mkldnn_fuse_pass_tester.cc} | 7 +- .../matmul_v2_transpose_reshape_fuse_pass.cc | 94 --------- .../matmul_v2_transpose_reshape_fuse_pass.h | 35 ---- ...shape_transpose_matmul_mkldnn_fuse_pass.cc | 183 ++++++++++-------- ...eshape_transpose_matmul_mkldnn_fuse_pass.h | 14 +- ...ranspose_matmul_mkldnn_fuse_pass_tester.cc | 6 +- ...pe_transpose_matmul_v2_mkldnn_fuse_pass.cc | 93 --------- ...ape_transpose_matmul_v2_mkldnn_fuse_pass.h | 39 ---- .../inference/api/paddle_pass_builder.cc | 18 +- .../quantization/quant2_int8_mkldnn_pass.py | 6 +- ...ldnn_matmul_transpose_reshape_fuse_pass.py | 4 +- ...n_matmul_v2_transpose_reshape_fuse_pass.py | 4 +- ...n_reshape_transpose_matmul_v2_fuse_pass.py | 2 +- 16 files changed, 235 insertions(+), 477 deletions(-) rename paddle/fluid/framework/ir/mkldnn/{matmul_transpose_reshape_fuse_pass.cc => matmul_transpose_reshape_mkldnn_fuse_pass.cc} (70%) rename paddle/fluid/framework/ir/mkldnn/{matmul_transpose_reshape_fuse_pass.h => matmul_transpose_reshape_mkldnn_fuse_pass.h} (80%) rename paddle/fluid/framework/ir/mkldnn/{matmul_transpose_reshape_fuse_pass_tester.cc => matmul_transpose_reshape_mkldnn_fuse_pass_tester.cc} (92%) delete mode 100644 paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.cc delete mode 100644 paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h delete mode 100644 paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.cc delete mode 100644 paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 0b5af21ca5c..f17adca19e5 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -218,10 +218,7 @@ if(WITH_MKLDNN) pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) - pass_library(reshape_transpose_matmul_v2_mkldnn_fuse_pass inference DIR - mkldnn) - pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) - pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn) + pass_library(matmul_transpose_reshape_mkldnn_fuse_pass inference DIR mkldnn) pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) pass_library(multi_gru_fuse_pass inference DIR mkldnn) pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn) @@ -497,13 +494,11 @@ if(WITH_MKLDNN) cc_test( test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc - DEPS reshape_transpose_matmul_mkldnn_fuse_pass - reshape_transpose_matmul_v2_mkldnn_fuse_pass) + DEPS reshape_transpose_matmul_mkldnn_fuse_pass) cc_test( test_matmul_transpose_reshape_fuse_pass - SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc - DEPS matmul_transpose_reshape_fuse_pass - matmul_v2_transpose_reshape_fuse_pass) + SRCS mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass_tester.cc + DEPS matmul_transpose_reshape_mkldnn_fuse_pass) cc_test( test_shuffle_channel_mkldnn_detect_pass SRCS mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc similarity index 70% rename from paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc rename to paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc index 09bf9c57c47..ce892aa8683 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,12 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h" - +#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h" #include - -#include - #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -25,76 +21,28 @@ namespace paddle { namespace framework { namespace ir { -MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { - op_name_ = "matmul"; +using string::PrettyLogDetail; - AddOpCompat(OpCompat(op_name_)) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("alpha") // unconstrained. can be any float value. - .IsType() - .End() - .AddAttr("transpose_X") // unconstrained. can be any bool value. - .IsType() - .End() - .AddAttr("transpose_Y") // unconstrained. can be any bool value. - .IsType() - .End(); +void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(Graph *graph) const { + auto matmul_types = {"matmul", "matmul_v2"}; - AddOpCompat(OpCompat("transpose2")) - .AddInput("X") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddOutput("XShape") - .IsTensor() - .End() - .AddAttr("axis") // ints - .IsType>() - .End(); - - AddOpCompat(OpCompat("reshape2")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Shape") - .IsTensor() - .IsOptional() - .End() - .AddInput("ShapeTensor") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddOutput("XShape") - .IsTensor() - .End() - .AddAttr("shape") // ints - .IsType>() - .End(); + for (const auto &matmul_type : matmul_types) { + Fuse(graph, matmul_type); + } } -void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { + +void MatmulTransposeReshapeMKLDNNPass::Fuse( + Graph *graph, const std::string &matmul_type) const { PADDLE_ENFORCE_NOT_NULL(graph, platform::errors::InvalidArgument( "Pointer to graph argument should not be NULL.")); - FusePassBase::Init(name_scope_, graph); - + FusePassBase::Init(matmul_type + "_transpose_reshape_mkldnn_fuse_pass", + graph); GraphPatternDetector gpd; - patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(), - name_scope_); - - mtrp(op_name_); + patterns::MatmulTransposeReshapePattern mtrp( + gpd.mutable_pattern(), + matmul_type + "_transpose_reshape_mkldnn_fuse_pass"); + mtrp(matmul_type); int found_matmul_transpose_reshape_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, @@ -103,7 +51,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { LOG(WARNING) << "Pass in op compat failed."; return; } - VLOG(4) << "handle " + op_name_ + "_transpose_reshape fuse"; + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp); GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp); @@ -112,6 +60,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, mtrp); GET_IR_NODE_FROM_SUBGRAPH(reshape_out_xshape, reshape_out_xshape, mtrp); + auto reshape_shape = PADDLE_GET_CONST(std::vector, reshape_op->Op()->GetAttr("shape")); auto transpose_axis = @@ -123,17 +72,17 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { const bool supported_transpose_axis = std::equal( transpose_axis.begin(), transpose_axis.end(), supported_axis.begin()); if (transpose_out_size != 4) { - VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: " + VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: " << "supported rank is 4, received " << transpose_out_size; return; } if (!supported_transpose_axis) { - VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: " + VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: " << "supported transpose axis for the fuse are {0, 2, 1, 3}"; return; } if (reshape_out_size != 3) { - VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: " + VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: " << "reshape_out supported rank is 3, received " << reshape_out_size; return; @@ -158,23 +107,93 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { gpd(graph, handler); AddStatis(found_matmul_transpose_reshape_count); - if (!Has("disable_logs") || !Get("disable_logs")) { - std::stringstream msg_ss; - msg_ss << "--- Fused " << found_matmul_transpose_reshape_count - << " MatmulTransposeReshape patterns for " + op_name_ + " Op"; - paddle::string::PrettyLogDetail(msg_ss.str().c_str()); + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_matmul_transpose_reshape_count > 0) { + PrettyLogDetail("--- fused %d %s + transpose + reshape patterns", + found_matmul_transpose_reshape_count, + matmul_type); } } + +MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + // The reshape2 op for this pass should not have "Shape" and "ShapeTensor" + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); +} + } // namespace ir } // namespace framework } // namespace paddle -REGISTER_PASS(matmul_transpose_reshape_fuse_pass, +REGISTER_PASS(matmul_transpose_reshape_mkldnn_fuse_pass, paddle::framework::ir::MatmulTransposeReshapeMKLDNNPass); -REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_fuse_pass) +REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("matmul", 1) - .EQ("transpose", 0) - .EQ("reshape", 0)); + .EQ("matmul_v2", 0) + .EQ("transpose2", 0) + .EQ("reshape2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h similarity index 80% rename from paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h rename to paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h index e03746e6e80..36bc97876ce 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,14 +14,11 @@ #pragma once -#include - #include "paddle/fluid/framework/ir/fuse_pass_base.h" namespace paddle { namespace framework { namespace ir { -class Graph; class MatmulTransposeReshapeMKLDNNPass : public FusePassBase { public: @@ -29,10 +26,10 @@ class MatmulTransposeReshapeMKLDNNPass : public FusePassBase { virtual ~MatmulTransposeReshapeMKLDNNPass() {} protected: - void ApplyImpl(Graph* graph) const override; - const std::string name_scope_{"matmul_transpose_reshape_fuse"}; - std::string op_name_; + void ApplyImpl(Graph *graph) const override; + void Fuse(Graph *graph, const std::string &matmul_type) const; }; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass_tester.cc similarity index 92% rename from paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc rename to paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass_tester.cc index 75cc3e12c2e..4149bb23473 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass_tester.cc @@ -14,7 +14,7 @@ #include -#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h" namespace paddle { namespace framework { @@ -74,7 +74,7 @@ void MainTest(const ProgramDesc &prog, const std::string &op_name) { int original_nodes_num = graph->Nodes().size(); auto pass = - PassRegistry::Instance().Get(op_name + "_transpose_reshape_fuse_pass"); + PassRegistry::Instance().Get("matmul_transpose_reshape_mkldnn_fuse_pass"); graph.reset(pass->Apply(graph.release())); int current_nodes_num = graph->Nodes().size(); @@ -106,5 +106,4 @@ TEST(MatmulTransposeReshapeFusePass, matmul_v2_fuse_pass) { } // namespace framework } // namespace paddle -USE_PASS(matmul_transpose_reshape_fuse_pass); -USE_PASS(matmul_v2_transpose_reshape_fuse_pass); +USE_PASS(matmul_transpose_reshape_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.cc deleted file mode 100644 index 6e106fa9dae..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.cc +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h" - -#include - -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace framework { -namespace ir { - -MatmulV2TransposeReshapeMKLDNNPass::MatmulV2TransposeReshapeMKLDNNPass() { - op_name_ = "matmul_v2"; - - AddOpCompat(OpCompat(op_name_)) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("trans_x") - .IsType() - .End() - .AddAttr("trans_y") - .IsType() - .End(); - - AddOpCompat(OpCompat("transpose2")) - .AddInput("X") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddOutput("XShape") - .IsTensor() - .End() - .AddAttr("axis") - .IsType>() - .End(); - - AddOpCompat(OpCompat("reshape2")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Shape") - .IsTensor() - .IsOptional() - .End() - .AddInput("ShapeTensor") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddOutput("XShape") - .IsTensor() - .End() - .AddAttr("shape") - .IsType>() - .End(); -} -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(matmul_v2_transpose_reshape_fuse_pass, - paddle::framework::ir::MatmulV2TransposeReshapeMKLDNNPass); - -REGISTER_PASS_CAPABILITY(matmul_v2_transpose_reshape_fuse_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul_v2", 0) - .EQ("transpose2", 0) - .EQ("reshape2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h deleted file mode 100644 index 60b7e981456..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h" - -namespace paddle { -namespace framework { -namespace ir { -class MatmulV2TransposeReshapeMKLDNNPass - : public MatmulTransposeReshapeMKLDNNPass { - public: - MatmulV2TransposeReshapeMKLDNNPass(); - virtual ~MatmulV2TransposeReshapeMKLDNNPass() {} - - protected: - const std::string name_scope_{"matmul_v2_transpose_reshape_fuse"}; -}; -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc index 20bfe5726f6..29e013c55a4 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,11 +13,6 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" - -#include -#include -#include - #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/pretty_log.h" @@ -26,78 +21,46 @@ namespace paddle { namespace framework { namespace ir { -ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { - op_name_ = "matmul"; - - AddOpCompat(OpCompat("reshape2")) - .AddInput("X") - .IsTensor() - .End() - // The reshape2 op for this pass should not have "Shape" and "ShapeTensor" - .AddOutput("Out") - .IsTensor() - .End() - .AddOutput("XShape") - .IsOptional() - .IsTensor() - .End() - .AddAttr("shape") - .IsType>() - .End(); - - AddOpCompat(OpCompat("transpose2")) - .AddInput("X") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddOutput("XShape") - .IsOptional() - .IsTensor() - .End() - .AddAttr("axis") - .IsType>() - .End(); +void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(Graph *graph) const { + auto matmul_types = {"matmul", "matmul_v2"}; + bool with_reshape_xshape = true; + bool with_transpose_xshape = true; - AddOpCompat(OpCompat(op_name_)) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("alpha") - .IsType() - .End() - .AddAttr("transpose_X") - .IsType() - .End() - .AddAttr("transpose_Y") - .IsType() - .End(); + for (const auto &matmul_type : matmul_types) { + Fuse(graph, matmul_type, with_reshape_xshape, with_transpose_xshape); + Fuse(graph, matmul_type, with_reshape_xshape, !with_transpose_xshape); + Fuse(graph, matmul_type, !with_reshape_xshape, with_transpose_xshape); + Fuse(graph, matmul_type, !with_reshape_xshape, !with_transpose_xshape); + } } void ReshapeTransposeMatmulMkldnnFusePass::Fuse( - Graph *graph, bool with_reshape_xshape, bool with_transpose_xshape) const { + Graph *graph, + const std::string &matmul_type, + bool with_reshape_xshape, + bool with_transpose_xshape) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument should not be NULL.")); + FusePassBase::Init("reshape_transpose_" + matmul_type + "_mkldnn_fuse_pass", + graph); + GraphPatternDetector gpd; - patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(), - name_scope_); + patterns::ReshapeTransposeMatmulPattern rtm_pattern( + gpd.mutable_pattern(), + "reshape_transpose_" + matmul_type + "_mkldnn_fuse_pass"); - rtm_pattern(op_name_, with_reshape_xshape, with_transpose_xshape); + rtm_pattern(matmul_type, with_reshape_xshape, with_transpose_xshape); int found_reshape_transpose_matmul_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Op compatible check in reshape_transpose_" << op_name_ + LOG(WARNING) << "Op compatible check in reshape_transpose_" << matmul_type << "_mkldnn_fuse_pass failed."; return; } - VLOG(4) << "handle reshape_transpose_" << op_name_ << " fuse"; + GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern); @@ -137,7 +100,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( UpdateMatmul("Y"); } else { throw platform::errors::InvalidArgument("Unexpected input to " + - op_name_ + " encountered."); + matmul_type + " encountered."); } std::unordered_set nodes_to_remove{ @@ -153,26 +116,85 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( gpd(graph, handler); AddStatis(found_reshape_transpose_matmul_count); - if (!Has("disable_logs") || !Get("disable_logs")) { + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_reshape_transpose_matmul_count > 0) { std::stringstream msg_ss; - msg_ss << "--- Fused " << found_reshape_transpose_matmul_count - << " ReshapeTransposeMatmul patterns for " << op_name_ << " Op"; + msg_ss << "--- fused " << found_reshape_transpose_matmul_count + << " reshape + transpose + " << matmul_type; if (with_reshape_xshape) msg_ss << " with reshape's xshape"; if (with_transpose_xshape) msg_ss << " with transpose's xshape"; string::PrettyLogDetail(msg_ss.str().c_str()); } } -void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(ir::Graph *graph) const { - PADDLE_ENFORCE_NOT_NULL(graph, - platform::errors::InvalidArgument( - "Pointer to graph argument should not be NULL.")); - FusePassBase::Init(name_scope_, graph); +ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + // The reshape2 op for this pass should not have "Shape" and "ShapeTensor" + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); - Fuse(graph, false, false); - Fuse(graph, false, true); - Fuse(graph, true, false); - Fuse(graph, true, true); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); } } // namespace ir @@ -184,5 +206,8 @@ REGISTER_PASS(reshape_transpose_matmul_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_mkldnn_fuse_pass) .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination().EQ( - "matmul", 1)); + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("matmul", 1) + .EQ("matmul_v2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h index 187bfe0650a..4b595837b23 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,17 +13,11 @@ // limitations under the License. #pragma once - -#include - #include "paddle/fluid/framework/ir/fuse_pass_base.h" namespace paddle { namespace framework { namespace ir { -/* - * Fuse Reshape->Transpose->MatMul when MatMul uses mkldnn. - */ class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { public: @@ -31,13 +25,11 @@ class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { virtual ~ReshapeTransposeMatmulMkldnnFusePass() {} protected: - void ApplyImpl(ir::Graph* graph) const override; - const std::string name_scope_{"reshape_transpose_matmul_fuse"}; - + void ApplyImpl(Graph* graph) const override; void Fuse(Graph* graph, + const std::string& matmul_type, bool with_reshape_xshape, bool with_transpose_xshape) const; - std::string op_name_; }; } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc index 369ceec934e..79164a32098 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc @@ -15,7 +15,6 @@ #include #include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" -#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h" namespace paddle { @@ -82,8 +81,8 @@ void TestMain(const std::string& op_name, bool with_xshapes) { int total_nodes_before = graph->Nodes().size(); VLOG(3) << DebugString(graph); - auto pass = PassRegistry::Instance().Get("reshape_transpose_" + op_name + - "_mkldnn_fuse_pass"); + auto pass = + PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass"); graph.reset(pass->Apply(graph.release())); int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2"); @@ -137,4 +136,3 @@ TEST(ReshapeTransposeMatmulV2MkldnnFusePass, } // namespace paddle USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass); -USE_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.cc deleted file mode 100644 index ed57be12c78..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.cc +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h" - -#include -#include -#include - -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/string/pretty_log.h" - -namespace paddle { -namespace framework { -namespace ir { - -ReshapeTransposeMatmulV2MkldnnFusePass:: - ReshapeTransposeMatmulV2MkldnnFusePass() { - op_name_ = "matmul_v2"; - - AddOpCompat(OpCompat("reshape2")) - .AddInput("X") - .IsTensor() - .End() - // The reshape2 op for this pass should not have "Shape" and "ShapeTensor" - .AddOutput("Out") - .IsTensor() - .End() - .AddOutput("XShape") - .IsOptional() - .IsTensor() - .End() - .AddAttr("shape") - .IsType>() - .End(); - - AddOpCompat(OpCompat("transpose2")) - .AddInput("X") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddOutput("XShape") - .IsOptional() - .IsTensor() - .End() - .AddAttr("axis") - .IsType>() - .End(); - - AddOpCompat(OpCompat(op_name_)) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("trans_x") - .IsType() - .End() - .AddAttr("trans_y") - .IsType() - .End(); -} -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass, - paddle::framework::ir::ReshapeTransposeMatmulV2MkldnnFusePass); - -REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_v2_mkldnn_fuse_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul_v2", 0) - .EQ("transpose2", 0) - .EQ("reshape2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h deleted file mode 100644 index 7eeda7f1a61..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" - -namespace paddle { -namespace framework { -namespace ir { -/* - * Fuse Reshape->Transpose->MatMulV2 when MatMulV2 uses mkldnn. - */ - -class ReshapeTransposeMatmulV2MkldnnFusePass - : public ReshapeTransposeMatmulMkldnnFusePass { - public: - ReshapeTransposeMatmulV2MkldnnFusePass(); - virtual ~ReshapeTransposeMatmulV2MkldnnFusePass() {} - - protected: - const std::string name_scope_{"reshape_transpose_matmul_v2_fuse"}; -}; -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 235fd99535f..6119714c38c 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -299,14 +299,12 @@ void CpuPassStrategy::EnableMKLDNN() { // "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass", - "conv_activation_mkldnn_fuse_pass", // - "scale_matmul_fuse_pass", // - "reshape_transpose_matmul_mkldnn_fuse_pass", // - "reshape_transpose_matmul_v2_mkldnn_fuse_pass", // - "matmul_transpose_reshape_fuse_pass", // - "matmul_v2_transpose_reshape_fuse_pass", // - "matmul_elementwise_add_mkldnn_fuse_pass", // - "matmul_activation_mkldnn_fuse_pass", // + "conv_activation_mkldnn_fuse_pass", // + "scale_matmul_fuse_pass", // + "reshape_transpose_matmul_mkldnn_fuse_pass", // + "matmul_transpose_reshape_mkldnn_fuse_pass", // + "matmul_elementwise_add_mkldnn_fuse_pass", // + "matmul_activation_mkldnn_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", // "fc_act_mkldnn_fuse_pass", @@ -399,14 +397,12 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("repeated_fc_relu_fuse_pass"); passes_.push_back("fc_mkldnn_pass"); passes_.push_back("fc_act_mkldnn_fuse_pass"); - passes_.push_back("matmul_transpose_reshape_fuse_pass"); - passes_.push_back("matmul_v2_transpose_reshape_fuse_pass"); + passes_.push_back("matmul_transpose_reshape_mkldnn_fuse_pass"); passes_.push_back("batch_norm_act_fuse_pass"); passes_.push_back("softplus_activation_mkldnn_fuse_pass"); passes_.push_back("compute_propagate_scales_mkldnn_pass"); passes_.push_back("scale_matmul_fuse_pass"); passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); - passes_.push_back("reshape_transpose_matmul_v2_mkldnn_fuse_pass"); passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_squash_pass"); diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 2f155ca0edf..9fb14e4e720 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -448,8 +448,8 @@ class Quant2Int8MkldnnPass(object): # Disabled due to topology-dependent speed-up graph = self._apply_pass(graph, 'fc_mkldnn_pass') graph = self._apply_pass(graph, 'fc_act_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass') - graph = self._apply_pass(graph, 'matmul_v2_transpose_reshape_fuse_pass') + graph = self._apply_pass(graph, + 'matmul_transpose_reshape_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass') graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass') # the following pass should be the last one since it will work on all fused ops. @@ -650,8 +650,6 @@ class Quant2Int8MkldnnPass(object): graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') graph = self._apply_pass(graph, 'reshape_transpose_matmul_mkldnn_fuse_pass') - graph = self._apply_pass( - graph, 'reshape_transpose_matmul_v2_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'cpu_quantize_placement_pass', ['quantize_enabled_op_types'], [self._ops_to_quantize]) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_transpose_reshape_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_transpose_reshape_fuse_pass.py index a5471eca2c2..c8fb49c10c1 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_transpose_reshape_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_transpose_reshape_fuse_pass.py @@ -121,8 +121,8 @@ class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest): yield config, ["matmul"], (1e-5, 1e-5) def test(self): - self.run_and_statis(quant=False, - passes=["matmul_transpose_reshape_fuse_pass"]) + self.run_and_statis( + quant=False, passes=["matmul_transpose_reshape_mkldnn_fuse_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py index 28fe916a6ef..0e24c4a394f 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py @@ -142,8 +142,8 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest): yield config, [fused_op], (1e-5, 1e-5) def test(self): - self.run_and_statis(quant=False, - passes=["matmul_v2_transpose_reshape_fuse_pass"]) + self.run_and_statis( + quant=False, passes=["matmul_transpose_reshape_mkldnn_fuse_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_reshape_transpose_matmul_v2_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_reshape_transpose_matmul_v2_fuse_pass.py index fb8dc034bd5..cc699a5e27a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_reshape_transpose_matmul_v2_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_reshape_transpose_matmul_v2_fuse_pass.py @@ -29,7 +29,7 @@ class TestReshapeTransposeMatmulV2OneDNNFusePass(InferencePassTest): def setUp(self): self.set_params() self.tranpose_perm = [0, 2, 1, 3] - self.pass_name = 'reshape_transpose_matmul_v2_mkldnn_fuse_pass' + self.pass_name = 'reshape_transpose_matmul_mkldnn_fuse_pass' with fluid.program_guard(self.main_program, self.startup_program): data = fluid.data(name="data", -- GitLab