diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 7b80d331ff7077c526320573ebce33c48036a2c6..80ae0f04daa4a0a7e0e689032a7564d43b280c92 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -123,6 +123,7 @@ if(WITH_MKLDNN) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) + pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn) pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) pass_library(multi_gru_fuse_pass inference DIR mkldnn) pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn) @@ -192,7 +193,7 @@ endif() cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass) - cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) + cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass) cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass) cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 2f18b678e2856b83a7b9fdf90b27601eae67c179..71b30d854ca24d49f116f932aea766885d85278b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2697,16 +2697,18 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( return matmul_out; } -PDNode *patterns::MatmulTransposeReshapePattern::operator()() { +// shared function for matmul and matmul_v2 +PDNode *patterns::MatmulTransposeReshapePattern::operator()( + const std::string &op_name) { auto reshape_op = pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2"); auto transpose_op = pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); - auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op(op_name); auto matmul_out = pattern->NewNode(matmul_out_repr()) ->AsInput() - ->assert_is_op_output("matmul", "Out") + ->assert_is_op_output(op_name, "Out") ->assert_is_op_input("transpose2", "X"); auto transpose_out = pattern->NewNode(transpose_out_repr()) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index ba0d982dcc481bf96ebe040fa2d5b60444fef88e..cc9d1c76ab11bfa6bba426864c6c50e91d0cf354 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1546,7 +1546,7 @@ struct MatmulTransposeReshapePattern : public PatternBase { const std::string& name_scope) : PatternBase(pattern, name_scope, "matmul_transpose_reshape") {} - PDNode* operator()(); + PDNode* operator()(const std::string& op_name); PATTERN_DECL_NODE(matmul_op); PATTERN_DECL_NODE(matmul_out); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc index a61099b4986747073bf4cde39ce497f365cea51f..34a35877a7f2565f8a8903fab21ce1486073b837 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc @@ -23,7 +23,9 @@ namespace framework { namespace ir { MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { - AddOpCompat(OpCompat("matmul")) + op_name_ = "matmul"; + + AddOpCompat(OpCompat(op_name_)) .AddInput("X") .IsTensor() .End() @@ -89,7 +91,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(), name_scope_); - mtrp(); + mtrp(op_name_); int found_matmul_transpose_reshape_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, @@ -98,7 +100,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { LOG(WARNING) << "Pass in op compat failed."; return; } - VLOG(4) << "handle matmul_transpose_reshape fuse"; + VLOG(4) << "handle " + op_name_ + "_transpose_reshape fuse"; GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp); GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp); @@ -118,17 +120,17 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { const bool supported_transpose_axis = std::equal( transpose_axis.begin(), transpose_axis.end(), supported_axis.begin()); if (transpose_out_size != 4) { - VLOG(3) << "do not perform matmul_transpose_reshape fuse: " + VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: " << "supported rank is 4, received " << transpose_out_size; return; } if (!supported_transpose_axis) { - VLOG(3) << "do not perform matmul_transpose_reshape fuse: " + VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: " << "supported transpose axis for the fuse are {0, 2, 1, 3}"; return; } if (reshape_out_size != 3) { - VLOG(3) << "do not perform matmul_transpose_reshape fuse: " + VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: " << "reshape_out supported rank is 3, received " << reshape_out_size; return; @@ -152,7 +154,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { if (!Has("disable_logs") || !Get("disable_logs")) { std::stringstream msg_ss; msg_ss << "--- Fused " << found_matmul_transpose_reshape_count - << " MatmulTransposeReshape patterns"; + << " MatmulTransposeReshape patterns for " + op_name_ + " Op"; paddle::string::PrettyLogDetail(msg_ss.str().c_str()); } } diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h index 09cbe9bdf7b2fb5c8fd0c8676730031482f3d6d9..e03746e6e80e85f693ffb634ffc9bb741802f8cc 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h @@ -31,6 +31,7 @@ class MatmulTransposeReshapeMKLDNNPass : public FusePassBase { protected: void ApplyImpl(Graph* graph) const override; const std::string name_scope_{"matmul_transpose_reshape_fuse"}; + std::string op_name_; }; } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc index d98d640e1002b1ff97e9d03a44a866987e3a2af8..ed99989cf382f1d03762e13460f7f6b2cf91f1b1 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h" #include +#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h" namespace paddle { namespace framework { @@ -42,9 +42,15 @@ void SetOp(ProgramDesc *prog, const std::string &type, op->SetAttr("transpose_X", true); op->SetAttr("transpose_Y", true); } + if (type == "matmul_v2") { + op->SetInput("Y", {inputs[1]}); + op->SetAttr("use_mkldnn", true); + op->SetAttr("trans_x", true); + op->SetAttr("trans_y", true); + } } -ProgramDesc BuildProgramDesc() { +ProgramDesc BuildProgramDesc(const std::string &op_name) { ProgramDesc prog; for (auto &v : std::initializer_list( {"a1", "a2", "b", "c", "cx", "d", "dx", "e"})) { @@ -52,7 +58,7 @@ ProgramDesc BuildProgramDesc() { var->SetType(proto::VarType::SELECTED_ROWS); } - SetOp(&prog, "matmul", {"a1", "a2"}, {"b"}); + SetOp(&prog, op_name, {"a1", "a2"}, {"b"}); SetOp(&prog, "transpose2", {"b"}, {"c", "cx"}); SetOp(&prog, "reshape2", {"c"}, {"d", "dx"}); SetOp(&prog, "fc", {"d"}, {"e"}); @@ -60,13 +66,13 @@ ProgramDesc BuildProgramDesc() { return prog; } -void MainTest(const ProgramDesc &prog) { +void MainTest(const ProgramDesc &prog, const std::string &op_name) { std::unique_ptr graph(new ir::Graph(prog)); int original_nodes_num = graph->Nodes().size(); auto pass = - PassRegistry::Instance().Get("matmul_transpose_reshape_fuse_pass"); + PassRegistry::Instance().Get(op_name + "_transpose_reshape_fuse_pass"); graph.reset(pass->Apply(graph.release())); int current_nodes_num = graph->Nodes().size(); @@ -75,7 +81,7 @@ void MainTest(const ProgramDesc &prog) { for (auto *node : graph->Nodes()) { if (node->IsOp()) { auto *op = node->Op(); - if (op->Type() == "matmul") { + if (op->Type() == op_name) { EXPECT_EQ(op->GetAttrIfExists>("fused_reshape_Out"), std::vector({4, 5, 6})); EXPECT_EQ(op->GetAttrIfExists>("fused_transpose_Out"), @@ -85,12 +91,18 @@ void MainTest(const ProgramDesc &prog) { } } -TEST(MatmulTransposeReshapeFusePass, matmul_inputs) { - auto prog = BuildProgramDesc(); - MainTest(prog); +TEST(MatmulTransposeReshapeFusePass, matmul_fuse_pass) { + auto prog = BuildProgramDesc("matmul"); + MainTest(prog, "matmul"); +} + +TEST(MatmulTransposeReshapeFusePass, matmul_v2_fuse_pass) { + auto prog = BuildProgramDesc("matmul_v2"); + MainTest(prog, "matmul_v2"); } } // namespace ir } // namespace framework } // namespace paddle USE_PASS(matmul_transpose_reshape_fuse_pass); +USE_PASS(matmul_v2_transpose_reshape_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..dcf4664d963da77b7d480c7de14d692ad34238e4 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h" +#include +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +MatmulV2TransposeReshapeMKLDNNPass::MatmulV2TransposeReshapeMKLDNNPass() { + op_name_ = "matmul_v2"; + + AddOpCompat(OpCompat(op_name_)) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); +} +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(matmul_v2_transpose_reshape_fuse_pass, + paddle::framework::ir::MatmulV2TransposeReshapeMKLDNNPass); + +REGISTER_PASS_CAPABILITY(matmul_v2_transpose_reshape_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul_v2", 0) + .EQ("transpose2", 0) + .EQ("reshape2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..60b7e98145698270b69ef22676a155724cb9060d --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h" + +namespace paddle { +namespace framework { +namespace ir { +class MatmulV2TransposeReshapeMKLDNNPass + : public MatmulTransposeReshapeMKLDNNPass { + public: + MatmulV2TransposeReshapeMKLDNNPass(); + virtual ~MatmulV2TransposeReshapeMKLDNNPass() {} + + protected: + const std::string name_scope_{"matmul_v2_transpose_reshape_fuse"}; +}; +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 504f81bfa01ac631d4202662aef5f197bcd3af68..9eccf0a6142753c3ba2512e9ab76e20ebfcbd707 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -249,6 +249,7 @@ void CpuPassStrategy::EnableMKLDNN() { "scale_matmul_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", // "matmul_transpose_reshape_fuse_pass", // + "matmul_v2_transpose_reshape_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", // "fc_act_mkldnn_fuse_pass", diff --git a/paddle/fluid/operators/compat/matmul_v2.pbtxt b/paddle/fluid/operators/compat/matmul_v2.pbtxt index 5f43e1f8bf0e0c502566a2cc783b8927e5df56cc..fa2559939bbd2fbfc0503d2ec688ab1930b8b948 100644 --- a/paddle/fluid/operators/compat/matmul_v2.pbtxt +++ b/paddle/fluid/operators/compat/matmul_v2.pbtxt @@ -39,4 +39,12 @@ extra { name: "op_device" type: STRING } + attrs { + name: "fused_reshape_Out" + type: INTS + } + attrs { + name: "fused_transpose_Out" + type: INTS + } } diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 953c3a555fa4b7517bb909323082d1f64a1ae9e3..1b609b15d6e56934a460b6d2ec249f7dc6a916d6 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -90,8 +90,62 @@ class MatMulV2Op : public framework::OperatorWithKernel { new_dims.push_back(1); } - auto out_dims = framework::make_ddim(new_dims); - ctx->SetOutputDim("Out", out_dims); + auto ddim_out = framework::make_ddim(new_dims); + +#ifdef PADDLE_WITH_MKLDNN + // if mkldnn matmul_v2+transpose+reshape fuse activated + auto reshape_out = ctx->Attrs().Get>("fused_reshape_Out"); + auto transpose_out = + ctx->Attrs().Get>("fused_transpose_Out"); + + if (!reshape_out.empty() && !transpose_out.empty()) { + auto reshape_out_size = reshape_out.size(); + auto transpose_out_size = transpose_out.size(); + PADDLE_ENFORCE_EQ(transpose_out_size, 4, + platform::errors::InvalidArgument( + "transpose_out supported rank is 4, " + "received %d", + transpose_out_size)); + const std::vector supported_axis{0, 2, 1, 3}; + const bool supported_transpose_axis = std::equal( + transpose_out.begin(), transpose_out.end(), supported_axis.begin()); + PADDLE_ENFORCE_EQ( + supported_transpose_axis, true, + platform::errors::InvalidArgument( + "supported transpose axis for the fuse are {0, 2, 1, 3}")); + PADDLE_ENFORCE_EQ( + reshape_out_size, 3, + platform::errors::InvalidArgument("reshape_out supported rank is 3, " + "received %d", + reshape_out_size)); + + auto it = std::find(reshape_out.begin(), reshape_out.end(), -1); + + // if "-1" is present then one of reshape dims must be infered + if (it != reshape_out.end()) { + int index = std::distance(reshape_out.begin(), it); + + auto ddim_out_vec = framework::vectorize(ddim_out); + + int ddim_out_product = + std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1, + std::multiplies()); + int reshape_out_product = std::accumulate( + reshape_out.begin(), reshape_out.end(), -1, std::multiplies()); + + reshape_out[index] = ddim_out_product / reshape_out_product; + } + + framework::DDim shape_out = + ddim_out.transpose(transpose_out).reshape(reshape_out); + ctx->SetOutputDim("Out", shape_out); + } else { + ctx->SetOutputDim("Out", ddim_out); + } +#else + ctx->SetOutputDim("Out", ddim_out); +#endif + ctx->ShareLoD("X", /* --> */ "Out"); } @@ -139,6 +193,18 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { "Set true to transpose the last two dimensions of Y before " "doing multiplication") .SetDefault(false); + AddAttr>( + "fused_reshape_Out", + R"DOC(When MKLDNN matmul_v2_transpose_reshape fuse activated, " + "it's a shape atribute of fused reshape for `Out` output.)DOC") + .SetDefault({}) + .AsExtra(); + AddAttr>( + "fused_transpose_Out", + R"DOC(When MKLDNN matmul_v2_transpose_reshape fuse activated, " + "it's a axis atribute of fused transpose for `Out` output.)DOC") + .SetDefault({}) + .AsExtra(); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false) diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index c332b9194164ea3be52cf793febf90f7aea679c6..aa0a16944bcfabafa3a8184e7bc44c2c5bb9af20 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -36,7 +36,8 @@ class MatMulV2MKLDNNHandler MatMulV2MKLDNNHandler(const mkldnn::engine engine, paddle::platform::Place cpu_place, const std::vector& x_org_dims, bool trans_x, - const std::vector& y_org_dims, bool trans_y) + const std::vector& y_org_dims, bool trans_y, + bool is_output_fused) : paddle::platform::MKLDNNHandlerNoCachingT(engine, cpu_place) { // M X K * K X N @@ -86,6 +87,10 @@ class MatMulV2MKLDNNHandler out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; } + if (is_output_fused) { + out_strides = FakeTransposeStrides(out_ddims); + } + auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); auto out_md = memory::desc(out_ddims, MKLDNNGetDataType(), out_strides); @@ -93,6 +98,24 @@ class MatMulV2MKLDNNHandler this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md); } + std::vector FakeTransposeStrides( + const std::vector& matmul_out_dims) const { + // fuse matmul_v2 + transpose + reshape guarantees that output is 4D and + // transpose axis are: {0, 2, 1, 3} + std::vector transpose_axis = {0, 2, 1, 3}; + std::vector fake_strides(transpose_axis.size()); + int ndims = static_cast(transpose_axis.size()); + + int total_stride = 1; + + for (int i = ndims - 1; i >= 0; --i) { + fake_strides[transpose_axis[i]] = total_stride; + total_stride *= matmul_out_dims[transpose_axis[i]]; + } + + return fake_strides; + } + std::shared_ptr AcquireWeightsMemory(const Tensor* input) { const T* input_data = input->data(); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), @@ -116,7 +139,8 @@ class MatMulV2MKLDNNKernel bool trans_y, Tensor* out, std::vector& out_dims, int execution_number = 0) const { MatMulV2MKLDNNHandler handler(onednn_engine, ctx.GetPlace(), x_dims, - trans_x, y_dims, trans_y); + trans_x, y_dims, trans_y, + IsOutputFused(ctx)); const auto src_memory_p = handler.AcquireSrcMemory(x); const auto weights_memory_p = handler.AcquireWeightsMemory(y); @@ -133,9 +157,10 @@ class MatMulV2MKLDNNKernel matmul_p->execute(astream, matmul_args); astream.wait(); + auto format = paddle::platform::MKLDNNFormatForSize( + out->dims().size(), dnnl::memory::format_tag::nchw); out->set_layout(paddle::framework::DataLayout::kMKLDNN); - out->set_format( - GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims))); + out->set_format(format); } private: @@ -166,7 +191,8 @@ class MatMulV2MKLDNNKernel } } - if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2) { + if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2 && + !IsOutputFused(ctx)) { for (size_t i = 0; i < x_dims.size() - 2; ++i) { PADDLE_ENFORCE_EQ( x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true, @@ -181,6 +207,13 @@ class MatMulV2MKLDNNKernel } } + bool IsOutputFused(const ExecutionContext& ctx) const { + auto& fused_reshape_Out = ctx.Attr>("fused_reshape_Out"); + auto& fused_transpose_Out = + ctx.Attr>("fused_transpose_Out"); + return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); + } + void RunKernel(const ExecutionContext& ctx) const { const auto& dev_ctx = ctx.template device_context(); const auto& onednn_engine = dev_ctx.GetEngine(); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..698e399c71ccd415218ecfea2adba3b68616f12a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker + + +class TestMatmulV2OneDNNTransposeReshapeFusePass(InferencePassTest): + def setUp(self): + self.set_params() + self.tranpose_perm = [0, 2, 1, 3] + self.pass_name = 'matmul_v2_transpose_reshape_fuse_pass' + + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=self.data_shape, dtype="float32") + weight = fluid.layers.create_parameter( + shape=self.weight_shape, dtype="float32") + matmul = paddle.matmul( + data, + weight, + transpose_x=self.transpose_x, + transpose_y=self.transpose_y) + transpose = fluid.layers.transpose(matmul, self.tranpose_perm) + reshape = fluid.layers.reshape(transpose, shape=self.reshape_shape) + + self.fetch_list = [reshape] + self.enable_mkldnn = True + + def set_params(self): + self.data_shape = [-1, 3, 100, 110] + self.weight_shape = [1, 3, 110, 100] + self.feeds = { + "data": np.random.random((1, 3, 100, 110)).astype("float32") + } + self.transpose_x = False + self.transpose_y = False + self.reshape_shape = [3, 100, 100] + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + def test_pass_compatible(self): + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class TestMatmulV2OneDNNTransposeReshapeFusePassDifferentDims( + TestMatmulV2OneDNNTransposeReshapeFusePass): + def set_params(self): + self.data_shape = [-1, 4, 100, 80] + self.weight_shape = [1, 4, 80, 100] + self.feeds = { + "data": np.random.random((1, 4, 100, 80)).astype("float32") + } + self.transpose_x = True + self.transpose_y = True + self.reshape_shape = [8, 40, 80] + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py index 724b9d9818dc4510bff4db7bbb9ea9889df4ec93..4ab15ac448047ccab0feae6617eede265c7b154d 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py @@ -440,9 +440,11 @@ class TestMatMulOpTransposeReshapeEmptyFloat(OpTest): self.reshape_out = [] self.out = np.matmul(self.x, self.y) - def setUp(self): - os.environ["DNNL_MAX_CPU_ISA"] = "AVX" + def set_op_type(self): self.op_type = "matmul" + + def setUp(self): + self.set_op_type() self._cpu_only = True self.use_mkldnn = True self.init_data_type() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index 994d78126bda5852a07cd04cbde82585ea739631..9afe45efee362ab5e9dd144d52911872bc90ddbf 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -23,6 +23,13 @@ import paddle.fluid.core as core import paddle import paddle.fluid as fluid import paddle.fluid.framework as framework +from paddle.fluid.tests.unittests.mkldnn.test_matmul_mkldnn_op import ( + TestMatMulOpTransposeReshapeEmptyFloat, + TestMatMulOpTransposeReshapeBasicFloat, + TestMatMulOpTransposeReshapeOtherDimFloat, + TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException, + TestMatMulOpTransposeReshapeTransposeRankNotSupportedException, + TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException) def reference_matmul(X, Y, transpose_x=False, transpose_y=False): @@ -390,6 +397,43 @@ create_bf16_test_class(TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp) + +class TestMatMulV2OpTransposeReshapeEmptyFloat( + TestMatMulOpTransposeReshapeEmptyFloat): + def set_op_type(self): + self.op_type = "matmul_v2" + + +class TestMatMulV2OpTransposeReshapeBasicFloat( + TestMatMulOpTransposeReshapeBasicFloat): + def set_op_type(self): + self.op_type = "matmul_v2" + + +class TestMatMulV2OpTransposeReshapeOtherDimFloat( + TestMatMulOpTransposeReshapeOtherDimFloat): + def set_op_type(self): + self.op_type = "matmul_v2" + + +class TestMatMulV2OpTransposeReshapeTransposeAxisNotSupportedException( + TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException): + def set_op_type(self): + self.op_type = "matmul_v2" + + +class TestMatMulV2OpTransposeReshapeRankOfReshapeNotSupportedException( + TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException): + def set_op_type(self): + self.op_type = "matmul_v2" + + +class TestMatMulV2OpTransposeReshapeTransposeRankNotSupportedException( + TestMatMulOpTransposeReshapeTransposeRankNotSupportedException): + def set_op_type(self): + self.op_type = "matmul_v2" + + if __name__ == "__main__": paddle.enable_static() unittest.main()