未验证 提交 d0cf9d9d 编写于 作者: S Sławomir Siwek 提交者: GitHub

Merge matmul_v1 and matmul_v2 fuse passes (#44870)

* remove v2_transpose_reshape

* matmul_transpose_reshape

* reshape_transpose_matmul

* restore ut

* adjust old ut

* restore parallel UT ruels

* feedback from review
上级 1f7e9546
...@@ -218,10 +218,7 @@ if(WITH_MKLDNN) ...@@ -218,10 +218,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_v2_mkldnn_fuse_pass inference DIR pass_library(matmul_transpose_reshape_mkldnn_fuse_pass inference DIR mkldnn)
mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_fuse_pass inference DIR mkldnn) pass_library(multi_gru_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn) pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
...@@ -497,13 +494,11 @@ if(WITH_MKLDNN) ...@@ -497,13 +494,11 @@ if(WITH_MKLDNN)
cc_test( cc_test(
test_reshape_transpose_matmul_mkldnn_fuse_pass test_reshape_transpose_matmul_mkldnn_fuse_pass
SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc
DEPS reshape_transpose_matmul_mkldnn_fuse_pass DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
reshape_transpose_matmul_v2_mkldnn_fuse_pass)
cc_test( cc_test(
test_matmul_transpose_reshape_fuse_pass test_matmul_transpose_reshape_fuse_pass
SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc SRCS mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass_tester.cc
DEPS matmul_transpose_reshape_fuse_pass DEPS matmul_transpose_reshape_mkldnn_fuse_pass)
matmul_v2_transpose_reshape_fuse_pass)
cc_test( cc_test(
test_shuffle_channel_mkldnn_detect_pass test_shuffle_channel_mkldnn_detect_pass
SRCS mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc SRCS mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,12 +12,8 @@ ...@@ -12,12 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h"
#include <paddle/fluid/string/pretty_log.h> #include <paddle/fluid/string/pretty_log.h>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -25,76 +21,28 @@ namespace paddle { ...@@ -25,76 +21,28 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { using string::PrettyLogDetail;
op_name_ = "matmul";
AddOpCompat(OpCompat(op_name_)) void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(Graph *graph) const {
.AddInput("X") auto matmul_types = {"matmul", "matmul_v2"};
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha") // unconstrained. can be any float value.
.IsType<float>()
.End()
.AddAttr("transpose_X") // unconstrained. can be any bool value.
.IsType<bool>()
.End()
.AddAttr("transpose_Y") // unconstrained. can be any bool value.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("transpose2")) for (const auto &matmul_type : matmul_types) {
.AddInput("X") Fuse(graph, matmul_type);
.IsTensor() }
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis") // ints
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape") // ints
.IsType<std::vector<int>>()
.End();
} }
void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
void MatmulTransposeReshapeMKLDNNPass::Fuse(
Graph *graph, const std::string &matmul_type) const {
PADDLE_ENFORCE_NOT_NULL(graph, PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL.")); "Pointer to graph argument should not be NULL."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(matmul_type + "_transpose_reshape_mkldnn_fuse_pass",
graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(), patterns::MatmulTransposeReshapePattern mtrp(
name_scope_); gpd.mutable_pattern(),
matmul_type + "_transpose_reshape_mkldnn_fuse_pass");
mtrp(op_name_); mtrp(matmul_type);
int found_matmul_transpose_reshape_count = 0; int found_matmul_transpose_reshape_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
...@@ -103,7 +51,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { ...@@ -103,7 +51,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "Pass in op compat failed.";
return; return;
} }
VLOG(4) << "handle " + op_name_ + "_transpose_reshape fuse";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp);
...@@ -112,6 +60,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { ...@@ -112,6 +60,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, mtrp); GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out_xshape, reshape_out_xshape, mtrp); GET_IR_NODE_FROM_SUBGRAPH(reshape_out_xshape, reshape_out_xshape, mtrp);
auto reshape_shape = auto reshape_shape =
PADDLE_GET_CONST(std::vector<int>, reshape_op->Op()->GetAttr("shape")); PADDLE_GET_CONST(std::vector<int>, reshape_op->Op()->GetAttr("shape"));
auto transpose_axis = auto transpose_axis =
...@@ -123,17 +72,17 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { ...@@ -123,17 +72,17 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
const bool supported_transpose_axis = std::equal( const bool supported_transpose_axis = std::equal(
transpose_axis.begin(), transpose_axis.end(), supported_axis.begin()); transpose_axis.begin(), transpose_axis.end(), supported_axis.begin());
if (transpose_out_size != 4) { if (transpose_out_size != 4) {
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: " VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: "
<< "supported rank is 4, received " << transpose_out_size; << "supported rank is 4, received " << transpose_out_size;
return; return;
} }
if (!supported_transpose_axis) { if (!supported_transpose_axis) {
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: " VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: "
<< "supported transpose axis for the fuse are {0, 2, 1, 3}"; << "supported transpose axis for the fuse are {0, 2, 1, 3}";
return; return;
} }
if (reshape_out_size != 3) { if (reshape_out_size != 3) {
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: " VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: "
<< "reshape_out supported rank is 3, received " << "reshape_out supported rank is 3, received "
<< reshape_out_size; << reshape_out_size;
return; return;
...@@ -158,23 +107,93 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { ...@@ -158,23 +107,93 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_matmul_transpose_reshape_count); AddStatis(found_matmul_transpose_reshape_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) { if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
std::stringstream msg_ss; found_matmul_transpose_reshape_count > 0) {
msg_ss << "--- Fused " << found_matmul_transpose_reshape_count PrettyLogDetail("--- fused %d %s + transpose + reshape patterns",
<< " MatmulTransposeReshape patterns for " + op_name_ + " Op"; found_matmul_transpose_reshape_count,
paddle::string::PrettyLogDetail(msg_ss.str().c_str()); matmul_type);
} }
} }
MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(matmul_transpose_reshape_fuse_pass, REGISTER_PASS(matmul_transpose_reshape_mkldnn_fuse_pass,
paddle::framework::ir::MatmulTransposeReshapeMKLDNNPass); paddle::framework::ir::MatmulTransposeReshapeMKLDNNPass);
REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_fuse_pass) REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1) .LE("matmul", 1)
.EQ("transpose", 0) .EQ("matmul_v2", 0)
.EQ("reshape", 0)); .EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,14 +14,11 @@ ...@@ -14,14 +14,11 @@
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class Graph;
class MatmulTransposeReshapeMKLDNNPass : public FusePassBase { class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
public: public:
...@@ -29,10 +26,10 @@ class MatmulTransposeReshapeMKLDNNPass : public FusePassBase { ...@@ -29,10 +26,10 @@ class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
virtual ~MatmulTransposeReshapeMKLDNNPass() {} virtual ~MatmulTransposeReshapeMKLDNNPass() {}
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph *graph) const override;
const std::string name_scope_{"matmul_transpose_reshape_fuse"}; void Fuse(Graph *graph, const std::string &matmul_type) const;
std::string op_name_;
}; };
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -74,7 +74,7 @@ void MainTest(const ProgramDesc &prog, const std::string &op_name) { ...@@ -74,7 +74,7 @@ void MainTest(const ProgramDesc &prog, const std::string &op_name) {
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
auto pass = auto pass =
PassRegistry::Instance().Get(op_name + "_transpose_reshape_fuse_pass"); PassRegistry::Instance().Get("matmul_transpose_reshape_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release())); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
...@@ -106,5 +106,4 @@ TEST(MatmulTransposeReshapeFusePass, matmul_v2_fuse_pass) { ...@@ -106,5 +106,4 @@ TEST(MatmulTransposeReshapeFusePass, matmul_v2_fuse_pass) {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(matmul_transpose_reshape_fuse_pass); USE_PASS(matmul_transpose_reshape_mkldnn_fuse_pass);
USE_PASS(matmul_v2_transpose_reshape_fuse_pass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
MatmulV2TransposeReshapeMKLDNNPass::MatmulV2TransposeReshapeMKLDNNPass() {
op_name_ = "matmul_v2";
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(matmul_v2_transpose_reshape_fuse_pass,
paddle::framework::ir::MatmulV2TransposeReshapeMKLDNNPass);
REGISTER_PASS_CAPABILITY(matmul_v2_transpose_reshape_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
class MatmulV2TransposeReshapeMKLDNNPass
: public MatmulTransposeReshapeMKLDNNPass {
public:
MatmulV2TransposeReshapeMKLDNNPass();
virtual ~MatmulV2TransposeReshapeMKLDNNPass() {}
protected:
const std::string name_scope_{"matmul_v2_transpose_reshape_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -13,11 +13,6 @@ ...@@ -13,11 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
...@@ -26,78 +21,46 @@ namespace paddle { ...@@ -26,78 +21,46 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(Graph *graph) const {
op_name_ = "matmul"; auto matmul_types = {"matmul", "matmul_v2"};
bool with_reshape_xshape = true;
AddOpCompat(OpCompat("reshape2")) bool with_transpose_xshape = true;
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat(op_name_)) for (const auto &matmul_type : matmul_types) {
.AddInput("X") Fuse(graph, matmul_type, with_reshape_xshape, with_transpose_xshape);
.IsTensor() Fuse(graph, matmul_type, with_reshape_xshape, !with_transpose_xshape);
.End() Fuse(graph, matmul_type, !with_reshape_xshape, with_transpose_xshape);
.AddInput("Y") Fuse(graph, matmul_type, !with_reshape_xshape, !with_transpose_xshape);
.IsTensor() }
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
} }
void ReshapeTransposeMatmulMkldnnFusePass::Fuse( void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
Graph *graph, bool with_reshape_xshape, bool with_transpose_xshape) const { Graph *graph,
const std::string &matmul_type,
bool with_reshape_xshape,
bool with_transpose_xshape) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
FusePassBase::Init("reshape_transpose_" + matmul_type + "_mkldnn_fuse_pass",
graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(), patterns::ReshapeTransposeMatmulPattern rtm_pattern(
name_scope_); gpd.mutable_pattern(),
"reshape_transpose_" + matmul_type + "_mkldnn_fuse_pass");
rtm_pattern(op_name_, with_reshape_xshape, with_transpose_xshape); rtm_pattern(matmul_type, with_reshape_xshape, with_transpose_xshape);
int found_reshape_transpose_matmul_count = 0; int found_reshape_transpose_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Op compatible check in reshape_transpose_" << op_name_ LOG(WARNING) << "Op compatible check in reshape_transpose_" << matmul_type
<< "_mkldnn_fuse_pass failed."; << "_mkldnn_fuse_pass failed.";
return; return;
} }
VLOG(4) << "handle reshape_transpose_" << op_name_ << " fuse";
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern);
...@@ -137,7 +100,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -137,7 +100,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
UpdateMatmul("Y"); UpdateMatmul("Y");
} else { } else {
throw platform::errors::InvalidArgument("Unexpected input to " + throw platform::errors::InvalidArgument("Unexpected input to " +
op_name_ + " encountered."); matmul_type + " encountered.");
} }
std::unordered_set<const ir::Node *> nodes_to_remove{ std::unordered_set<const ir::Node *> nodes_to_remove{
...@@ -153,26 +116,85 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -153,26 +116,85 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_reshape_transpose_matmul_count); AddStatis(found_reshape_transpose_matmul_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) { if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_reshape_transpose_matmul_count > 0) {
std::stringstream msg_ss; std::stringstream msg_ss;
msg_ss << "--- Fused " << found_reshape_transpose_matmul_count msg_ss << "--- fused " << found_reshape_transpose_matmul_count
<< " ReshapeTransposeMatmul patterns for " << op_name_ << " Op"; << " reshape + transpose + " << matmul_type;
if (with_reshape_xshape) msg_ss << " with reshape's xshape"; if (with_reshape_xshape) msg_ss << " with reshape's xshape";
if (with_transpose_xshape) msg_ss << " with transpose's xshape"; if (with_transpose_xshape) msg_ss << " with transpose's xshape";
string::PrettyLogDetail(msg_ss.str().c_str()); string::PrettyLogDetail(msg_ss.str().c_str());
} }
} }
void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(ir::Graph *graph) const { ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
PADDLE_ENFORCE_NOT_NULL(graph, AddOpCompat(OpCompat("reshape2"))
platform::errors::InvalidArgument( .AddInput("X")
"Pointer to graph argument should not be NULL.")); .IsTensor()
FusePassBase::Init(name_scope_, graph); .End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
Fuse(graph, false, false); AddOpCompat(OpCompat("matmul_v2"))
Fuse(graph, false, true); .AddInput("X")
Fuse(graph, true, false); .IsTensor()
Fuse(graph, true, true); .End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
} }
} // namespace ir } // namespace ir
...@@ -184,5 +206,8 @@ REGISTER_PASS(reshape_transpose_matmul_mkldnn_fuse_pass, ...@@ -184,5 +206,8 @@ REGISTER_PASS(reshape_transpose_matmul_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ( paddle::framework::compatible::OpVersionComparatorCombination()
"matmul", 1)); .EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("matmul", 1)
.EQ("matmul_v2", 0));
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -13,17 +13,11 @@ ...@@ -13,17 +13,11 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
/*
* Fuse Reshape->Transpose->MatMul when MatMul uses mkldnn.
*/
class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
public: public:
...@@ -31,13 +25,11 @@ class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { ...@@ -31,13 +25,11 @@ class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
virtual ~ReshapeTransposeMatmulMkldnnFusePass() {} virtual ~ReshapeTransposeMatmulMkldnnFusePass() {}
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
const std::string name_scope_{"reshape_transpose_matmul_fuse"};
void Fuse(Graph* graph, void Fuse(Graph* graph,
const std::string& matmul_type,
bool with_reshape_xshape, bool with_reshape_xshape,
bool with_transpose_xshape) const; bool with_transpose_xshape) const;
std::string op_name_;
}; };
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle { namespace paddle {
...@@ -82,8 +81,8 @@ void TestMain(const std::string& op_name, bool with_xshapes) { ...@@ -82,8 +81,8 @@ void TestMain(const std::string& op_name, bool with_xshapes) {
int total_nodes_before = graph->Nodes().size(); int total_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
auto pass = PassRegistry::Instance().Get("reshape_transpose_" + op_name + auto pass =
"_mkldnn_fuse_pass"); PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release())); graph.reset(pass->Apply(graph.release()));
int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2"); int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2");
...@@ -137,4 +136,3 @@ TEST(ReshapeTransposeMatmulV2MkldnnFusePass, ...@@ -137,4 +136,3 @@ TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
} // namespace paddle } // namespace paddle
USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass); USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass);
USE_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
ReshapeTransposeMatmulV2MkldnnFusePass::
ReshapeTransposeMatmulV2MkldnnFusePass() {
op_name_ = "matmul_v2";
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass,
paddle::framework::ir::ReshapeTransposeMatmulV2MkldnnFusePass);
REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_v2_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse Reshape->Transpose->MatMulV2 when MatMulV2 uses mkldnn.
*/
class ReshapeTransposeMatmulV2MkldnnFusePass
: public ReshapeTransposeMatmulMkldnnFusePass {
public:
ReshapeTransposeMatmulV2MkldnnFusePass();
virtual ~ReshapeTransposeMatmulV2MkldnnFusePass() {}
protected:
const std::string name_scope_{"reshape_transpose_matmul_v2_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -302,9 +302,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -302,9 +302,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_activation_mkldnn_fuse_pass", // "conv_activation_mkldnn_fuse_pass", //
"scale_matmul_fuse_pass", // "scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", //
"reshape_transpose_matmul_v2_mkldnn_fuse_pass", // "matmul_transpose_reshape_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", //
"matmul_elementwise_add_mkldnn_fuse_pass", // "matmul_elementwise_add_mkldnn_fuse_pass", //
"matmul_activation_mkldnn_fuse_pass", // "matmul_activation_mkldnn_fuse_pass", //
// Disabled due to topology-dependent speed-up // Disabled due to topology-dependent speed-up
...@@ -399,14 +397,12 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -399,14 +397,12 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("repeated_fc_relu_fuse_pass"); passes_.push_back("repeated_fc_relu_fuse_pass");
passes_.push_back("fc_mkldnn_pass"); passes_.push_back("fc_mkldnn_pass");
passes_.push_back("fc_act_mkldnn_fuse_pass"); passes_.push_back("fc_act_mkldnn_fuse_pass");
passes_.push_back("matmul_transpose_reshape_fuse_pass"); passes_.push_back("matmul_transpose_reshape_mkldnn_fuse_pass");
passes_.push_back("matmul_v2_transpose_reshape_fuse_pass");
passes_.push_back("batch_norm_act_fuse_pass"); passes_.push_back("batch_norm_act_fuse_pass");
passes_.push_back("softplus_activation_mkldnn_fuse_pass"); passes_.push_back("softplus_activation_mkldnn_fuse_pass");
passes_.push_back("compute_propagate_scales_mkldnn_pass"); passes_.push_back("compute_propagate_scales_mkldnn_pass");
passes_.push_back("scale_matmul_fuse_pass"); passes_.push_back("scale_matmul_fuse_pass");
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("reshape_transpose_matmul_v2_mkldnn_fuse_pass");
passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass"); passes_.push_back("cpu_quantize_squash_pass");
......
...@@ -448,8 +448,8 @@ class Quant2Int8MkldnnPass(object): ...@@ -448,8 +448,8 @@ class Quant2Int8MkldnnPass(object):
# Disabled due to topology-dependent speed-up # Disabled due to topology-dependent speed-up
graph = self._apply_pass(graph, 'fc_mkldnn_pass') graph = self._apply_pass(graph, 'fc_mkldnn_pass')
graph = self._apply_pass(graph, 'fc_act_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'fc_act_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass') graph = self._apply_pass(graph,
graph = self._apply_pass(graph, 'matmul_v2_transpose_reshape_fuse_pass') 'matmul_transpose_reshape_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass') graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass')
graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass')
# the following pass should be the last one since it will work on all fused ops. # the following pass should be the last one since it will work on all fused ops.
...@@ -650,8 +650,6 @@ class Quant2Int8MkldnnPass(object): ...@@ -650,8 +650,6 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
graph = self._apply_pass(graph, graph = self._apply_pass(graph,
'reshape_transpose_matmul_mkldnn_fuse_pass') 'reshape_transpose_matmul_mkldnn_fuse_pass')
graph = self._apply_pass(
graph, 'reshape_transpose_matmul_v2_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'cpu_quantize_placement_pass', graph = self._apply_pass(graph, 'cpu_quantize_placement_pass',
['quantize_enabled_op_types'], ['quantize_enabled_op_types'],
[self._ops_to_quantize]) [self._ops_to_quantize])
......
...@@ -121,8 +121,8 @@ class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -121,8 +121,8 @@ class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest):
yield config, ["matmul"], (1e-5, 1e-5) yield config, ["matmul"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis(quant=False, self.run_and_statis(
passes=["matmul_transpose_reshape_fuse_pass"]) quant=False, passes=["matmul_transpose_reshape_mkldnn_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -142,8 +142,8 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -142,8 +142,8 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest):
yield config, [fused_op], (1e-5, 1e-5) yield config, [fused_op], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis(quant=False, self.run_and_statis(
passes=["matmul_v2_transpose_reshape_fuse_pass"]) quant=False, passes=["matmul_transpose_reshape_mkldnn_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -29,7 +29,7 @@ class TestReshapeTransposeMatmulV2OneDNNFusePass(InferencePassTest): ...@@ -29,7 +29,7 @@ class TestReshapeTransposeMatmulV2OneDNNFusePass(InferencePassTest):
def setUp(self): def setUp(self):
self.set_params() self.set_params()
self.tranpose_perm = [0, 2, 1, 3] self.tranpose_perm = [0, 2, 1, 3]
self.pass_name = 'reshape_transpose_matmul_v2_mkldnn_fuse_pass' self.pass_name = 'reshape_transpose_matmul_mkldnn_fuse_pass'
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", data = fluid.data(name="data",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册