From a922168afc4a8aaec4e9e8539c6f08661504e7b9 Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Tue, 14 Dec 2021 02:32:25 +0100 Subject: [PATCH] add reshape+transpose+matmul_v2 only (#37847) * reshape+transpose+matmul_v2 * in_name->input_name * fix pr-ci-static-check --- paddle/fluid/framework/ir/CMakeLists.txt | 3 +- .../framework/ir/graph_pattern_detector.cc | 9 +- .../framework/ir/graph_pattern_detector.h | 3 +- ...shape_transpose_matmul_mkldnn_fuse_pass.cc | 18 +-- ...eshape_transpose_matmul_mkldnn_fuse_pass.h | 2 + ...ranspose_matmul_mkldnn_fuse_pass_tester.cc | 32 +++-- ...pe_transpose_matmul_v2_mkldnn_fuse_pass.cc | 91 ++++++++++++ ...ape_transpose_matmul_v2_mkldnn_fuse_pass.h | 39 +++++ .../fluid/framework/ir/pass_tester_helper.h | 13 ++ .../inference/api/paddle_pass_builder.cc | 21 +-- paddle/fluid/operators/compat/matmul_v2.pbtxt | 16 +++ paddle/fluid/operators/matmul_v2_op.cc | 95 +++++++++++- .../fluid/operators/mkldnn/matmul_mkldnn_op.h | 2 + .../operators/mkldnn/matmul_v2_mkldnn_op.cc | 135 +++++++++++++++--- .../quantization/quant2_int8_mkldnn_pass.py | 2 + ...n_reshape_transpose_matmul_v2_fuse_pass.py | 78 ++++++++++ .../unittests/mkldnn/test_matmul_mkldnn_op.py | 42 +++--- .../mkldnn/test_matmul_v2_mkldnn_op.py | 61 +++++++- 18 files changed, 591 insertions(+), 71 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_reshape_transpose_matmul_v2_fuse_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 08055cd9a54..029055d9a40 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -123,6 +123,7 @@ if(WITH_MKLDNN) pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) + pass_library(reshape_transpose_matmul_v2_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn) pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) @@ -190,7 +191,7 @@ endif() cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) - cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass) + cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass reshape_transpose_matmul_v2_mkldnn_fuse_pass) cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass) cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass) cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 5334b082489..732e31d55b2 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2711,12 +2711,13 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() { } PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( - bool with_reshape_xshape, bool with_transpose_xshape) { + const std::string &op_name, bool with_reshape_xshape, + bool with_transpose_xshape) { auto reshape_op = pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2"); auto transpose_op = pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); - auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op(op_name); auto reshape_in = pattern->NewNode(reshape_in_repr()) ->AsInput() @@ -2737,7 +2738,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( auto transpose_out = pattern->NewNode(transpose_out_repr()) ->AsIntermediate() - ->assert_is_op_input("matmul") + ->assert_is_op_input(op_name) ->assert_is_op_output("transpose2", "Out"); if (!with_transpose_xshape) transpose_out->assert_is_only_output_of_op("transpose2"); @@ -2751,7 +2752,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( auto matmul_out = pattern->NewNode(matmul_out_repr()) ->AsOutput() - ->assert_is_op_output("matmul", "Out"); + ->assert_is_op_output(op_name, "Out"); reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out}); if (with_reshape_xshape) reshape_op->LinksTo({reshape_xshape}); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index fa8504d074a..b15a75312dd 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1570,7 +1570,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase { const std::string& name_scope) : PatternBase(pattern, name_scope, "reshape_transpose_matmul") {} - PDNode* operator()(bool with_reshape_xshape, bool with_transpose_xshape); + PDNode* operator()(const std::string& op_name, bool with_reshape_xshape, + bool with_transpose_xshape); PATTERN_DECL_NODE(reshape_in); PATTERN_DECL_NODE(reshape_op); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc index e408440f26f..d0962757185 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc @@ -24,6 +24,8 @@ namespace framework { namespace ir { ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { + op_name_ = "matmul"; + AddOpCompat(OpCompat("reshape2")) .AddInput("X") .IsTensor() @@ -55,7 +57,7 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { .IsType>() .End(); - AddOpCompat(OpCompat("matmul")) + AddOpCompat(OpCompat(op_name_)) .AddInput("X") .IsTensor() .End() @@ -82,17 +84,17 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(), name_scope_); - rtm_pattern(with_reshape_xshape, with_transpose_xshape); + rtm_pattern(op_name_, with_reshape_xshape, with_transpose_xshape); int found_reshape_transpose_matmul_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Op compatible check in " - "reshape_transpose_matmul_mkldnn_fuse_pass failed."; + LOG(WARNING) << "Op compatible check in reshape_transpose_" << op_name_ + << "_mkldnn_fuse_pass failed."; return; } - VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse"; + VLOG(4) << "handle reshape_transpose_" << op_name_ << " fuse"; GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern); @@ -131,8 +133,8 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( } else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) { UpdateMatmul("Y"); } else { - throw platform::errors::InvalidArgument( - "Unexpected input to MatMul encountered."); + throw platform::errors::InvalidArgument("Unexpected input to " + + op_name_ + " encountered."); } std::unordered_set nodes_to_remove{ @@ -151,7 +153,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( if (!Has("disable_logs") || !Get("disable_logs")) { std::stringstream msg_ss; msg_ss << "--- Fused " << found_reshape_transpose_matmul_count - << " ReshapeTransposeMatmulMkldnn patterns"; + << " ReshapeTransposeMatmul patterns for " << op_name_ << " Op"; if (with_reshape_xshape) msg_ss << " with reshape's xshape"; if (with_transpose_xshape) msg_ss << " with transpose's xshape"; string::PrettyLogDetail(msg_ss.str().c_str()); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h index 4637d0659af..66f70942c0c 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h @@ -28,6 +28,7 @@ namespace ir { class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { public: ReshapeTransposeMatmulMkldnnFusePass(); + virtual ~ReshapeTransposeMatmulMkldnnFusePass() {} protected: void ApplyImpl(ir::Graph* graph) const override; @@ -35,6 +36,7 @@ class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { void Fuse(Graph* graph, bool with_reshape_xshape, bool with_transpose_xshape) const; + std::string op_name_; }; } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc index e6c366efa01..e6886356460 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h" #include #include "paddle/fluid/framework/ir/pass_tester_helper.h" @@ -37,7 +38,7 @@ Scope* CreateParamScope() { return param_scope; } -void TestMain(bool with_xshapes) { +void TestMain(const std::string& op_name, bool with_xshapes) { // inputs operator output // ----------------------------------------------- // a1,w1,bias1 fc -> b1 @@ -46,7 +47,7 @@ void TestMain(bool with_xshapes) { // a2,w2,bias2 fc -> b2 // b2 reshape -> c2 // c2 transpose -> d2 - // (d1, d2) matmul -> (...) + // (d1, d2) matmul(_v2) -> (...) Layers layers; auto* a1 = layers.data("a1", {-1, 128, 768}); auto* w1 = layers.data("w1", {768, 768}, true); @@ -66,7 +67,11 @@ void TestMain(bool with_xshapes) { c2->SetShape({-1, 128, 12, 64}); auto* d2 = layers.transpose2(c2, {0, 2, 1, 3}); d2->SetShape({-1, 12, 128, 64}); - layers.matmul(d1, d2); + if (op_name == "matmul_v2") { + layers.matmul_v2(d1, d2); + } else { + layers.matmul(d1, d2); + } std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); @@ -76,8 +81,8 @@ void TestMain(bool with_xshapes) { int total_nodes_before = graph->Nodes().size(); VLOG(3) << DebugString(graph); - auto pass = - PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass"); + auto pass = PassRegistry::Instance().Get("reshape_transpose_" + op_name + + "_mkldnn_fuse_pass"); graph.reset(pass->Apply(graph.release())); int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2"); @@ -92,7 +97,7 @@ void TestMain(bool with_xshapes) { int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape EXPECT_EQ(total_nodes_before - removed, total_nodes_after); - auto* matmul_op_desc = GetOpNodes(graph, "matmul").at(0)->Op(); + auto* matmul_op_desc = GetOpNodes(graph, op_name).at(0)->Op(); auto check = [&matmul_op_desc](std::string a) { std::string shape_str = "fused_reshape_" + a; @@ -108,12 +113,22 @@ void TestMain(bool with_xshapes) { TEST(ReshapeTransposeMatmulMkldnnFusePass, both_matmul_inputs_reshape_transpose) { - TestMain(false); + TestMain("matmul", false); } TEST(ReshapeTransposeMatmulMkldnnFusePass, both_matmul_inputs_reshape_transpose_one_with_xshapes) { - TestMain(true); + TestMain("matmul", true); +} + +TEST(ReshapeTransposeMatmulV2MkldnnFusePass, + both_matmulv2_inputs_reshape_transpose) { + TestMain("matmul_v2", false); +} + +TEST(ReshapeTransposeMatmulV2MkldnnFusePass, + both_matmulv2_inputs_reshape_transpose_one_with_xshapes) { + TestMain("matmul_v2", true); } } // namespace ir @@ -121,3 +136,4 @@ TEST(ReshapeTransposeMatmulMkldnnFusePass, } // namespace paddle USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass); +USE_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.cc new file mode 100644 index 00000000000..203966dc682 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +ReshapeTransposeMatmulV2MkldnnFusePass:: + ReshapeTransposeMatmulV2MkldnnFusePass() { + op_name_ = "matmul_v2"; + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + // The reshape2 op for this pass should not have "Shape" and "ShapeTensor" + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); + + AddOpCompat(OpCompat(op_name_)) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); +} +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass, + paddle::framework::ir::ReshapeTransposeMatmulV2MkldnnFusePass); + +REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_v2_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul_v2", 0) + .EQ("transpose2", 0) + .EQ("reshape2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h new file mode 100644 index 00000000000..7eeda7f1a61 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" + +namespace paddle { +namespace framework { +namespace ir { +/* + * Fuse Reshape->Transpose->MatMulV2 when MatMulV2 uses mkldnn. + */ + +class ReshapeTransposeMatmulV2MkldnnFusePass + : public ReshapeTransposeMatmulMkldnnFusePass { + public: + ReshapeTransposeMatmulV2MkldnnFusePass(); + virtual ~ReshapeTransposeMatmulV2MkldnnFusePass() {} + + protected: + const std::string name_scope_{"reshape_transpose_matmul_v2_fuse"}; +}; +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 284e54b3cb9..acefde9df68 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -307,6 +307,19 @@ struct Layers { return out; } + VarDesc* matmul_v2(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr, + bool trans_x = false, bool trans_y = false) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("matmul_v2"); + op->SetInput("X", {x->Name()}); + op->SetInput("Y", {y->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr("trans_x", trans_x); + op->SetAttr("trans_y", trans_y); + return out; + } + VarDesc* transpose2(VarDesc* x, std::vector axis, bool with_xshape = false) { VarDesc* out = lod_tensor(unique_name()); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 334a70d3e06..d571973a83f 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -244,16 +244,17 @@ void CpuPassStrategy::EnableMKLDNN() { "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass", - "conv_relu_mkldnn_fuse_pass", // - "conv_leaky_relu_mkldnn_fuse_pass", // - "conv_relu6_mkldnn_fuse_pass", // - "conv_swish_mkldnn_fuse_pass", // - "conv_hard_swish_mkldnn_fuse_pass", // - "conv_hard_sigmoid_mkldnn_fuse_pass", // - "scale_matmul_fuse_pass", // - "reshape_transpose_matmul_mkldnn_fuse_pass", // - "matmul_transpose_reshape_fuse_pass", // - "matmul_v2_transpose_reshape_fuse_pass", // + "conv_relu_mkldnn_fuse_pass", // + "conv_leaky_relu_mkldnn_fuse_pass", // + "conv_relu6_mkldnn_fuse_pass", // + "conv_swish_mkldnn_fuse_pass", // + "conv_hard_swish_mkldnn_fuse_pass", // + "conv_hard_sigmoid_mkldnn_fuse_pass", // + "scale_matmul_fuse_pass", // + "reshape_transpose_matmul_mkldnn_fuse_pass", // + "reshape_transpose_matmul_v2_mkldnn_fuse_pass", // + "matmul_transpose_reshape_fuse_pass", // + "matmul_v2_transpose_reshape_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", // "fc_act_mkldnn_fuse_pass", diff --git a/paddle/fluid/operators/compat/matmul_v2.pbtxt b/paddle/fluid/operators/compat/matmul_v2.pbtxt index fa2559939bb..cefb964a59f 100644 --- a/paddle/fluid/operators/compat/matmul_v2.pbtxt +++ b/paddle/fluid/operators/compat/matmul_v2.pbtxt @@ -39,6 +39,22 @@ extra { name: "op_device" type: STRING } + attrs { + name: "fused_reshape_X" + type: INTS + } + attrs { + name: "fused_reshape_Y" + type: INTS + } + attrs { + name: "fused_transpose_X" + type: INTS + } + attrs { + name: "fused_transpose_Y" + type: INTS + } attrs { name: "fused_reshape_Out" type: INTS diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index bd32af1c8f6..24201b1ba84 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -19,6 +19,81 @@ namespace paddle { namespace operators { +static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx, + const std::string input_name) { + auto shape = ctx.Attrs().Get>("fused_reshape_" + input_name); + auto axis = + ctx.Attrs().Get>("fused_transpose_" + input_name); + auto dim = ctx.GetInputDim(input_name); + + PADDLE_ENFORCE_GT(dim.size(), 0, + platform::errors::InvalidArgument( + "The Input(%s) has not been initialized properly. The " + "shape of Input(%s) = [%s].", + dim)); + + // if mkldnn reshape+transpose+matmul fuse activated + if (!shape.empty() && !axis.empty()) { + PADDLE_ENFORCE_GE( + shape.size(), 2, + platform::errors::InvalidArgument( + "shape_%s attribute of MatMulOp was implemented for 2, 3 " + "or 4 dimensions.", + input_name)); + PADDLE_ENFORCE_LE( + shape.size(), 4, + platform::errors::InvalidArgument( + "shape_%s attribute of MatMulOp was implemented for 2, 3 " + "or 4 dimensions.", + input_name)); + PADDLE_ENFORCE_EQ( + shape.size(), axis.size(), + platform::errors::InvalidArgument( + "Ranks of shape_%s and axis_%s attributes of MatMulOp " + "must be equal.", + input_name, input_name)); + + int num_negative = std::count(shape.begin(), shape.end(), -1); + PADDLE_ENFORCE_LE(num_negative, 1, + platform::errors::InvalidArgument( + "The max number of -1 in fused_reshape_%s is 1 " + "but received %d.", + input_name, num_negative)); + + auto it_zero = std::find(shape.begin(), shape.end(), 0); + if (it_zero != shape.end()) { + for (uint64_t i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + PADDLE_ENFORCE_LT(i, dim.size(), + platform::errors::InvalidArgument( + "The index of 0 in fused_reshape_%s ", + "should be less than output dim size, ", + "but the index is %d and output dim size is %d", + input_name, i, dim.size())); + shape[i] = dim.at(i); + } + } + } + + // if "-1" is present then one of reshape dims must be infered + auto it_negative = std::find(shape.begin(), shape.end(), -1); + if (it_negative != shape.end()) { + int64_t dim_product = 1; + for (int i = 0; i < dim.size(); i++) { + dim_product *= dim.at(i); + } + + int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1, + std::multiplies()); + int index = std::distance(shape.begin(), it_negative); + shape[index] = dim_product / shape_product; + } + + dim = dim.reshape(shape).transpose(axis); + } + return dim; +} + class MatMulV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -30,9 +105,9 @@ class MatMulV2Op : public framework::OperatorWithKernel { bool trans_y = ctx->Attrs().Get("trans_y"); std::vector dims_x = - paddle::framework::vectorize(ctx->GetInputDim("X")); + framework::vectorize(GetDimForInput(*ctx, "X")); std::vector dims_y = - paddle::framework::vectorize(ctx->GetInputDim("Y")); + framework::vectorize(GetDimForInput(*ctx, "Y")); auto ndims_x = dims_x.size(); auto ndims_y = dims_y.size(); PADDLE_ENFORCE_GT(ndims_x, 0, @@ -215,6 +290,22 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault("float32") .InEnum({"float32", "bfloat16"}) .AsExtra(); + AddAttr>("fused_reshape_X", + R"DOC(Shape of fused reshape of `X` input.)DOC") + .SetDefault({}) + .AsExtra(); + AddAttr>("fused_reshape_Y", + R"DOC(Shape of fused reshape of `Y` input.)DOC") + .SetDefault({}) + .AsExtra(); + AddAttr>("fused_transpose_X", + R"DOC(Axis of fused transpose of `X` input.)DOC") + .SetDefault({}) + .AsExtra(); + AddAttr>("fused_transpose_Y", + R"DOC(Axis of fused transpose of `Y` input.)DOC") + .SetDefault({}) + .AsExtra(); AddComment( R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K), B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)). diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h index c82119d06a0..af4c154cd37 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 0266edac75d..5cb6ae34dce 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -25,10 +25,88 @@ using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNGetDataType; using paddle::platform::to_void_cast; using Tensor = paddle::framework::Tensor; +using paddle::framework::DDim; using paddle::framework::GradVarName; using paddle::framework::make_ddim; using paddle::framework::vectorize; +// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the +// original x_dim is returned. +static DDim RowMatrixDimsFromVector(const DDim& x_dim) { + return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]}); +} + +// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the +// original y_dim is returned. +static DDim ColumnMatrixDimsFromVector(const DDim& y_dim) { + return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1}); +} + +static std::vector Transpose(const std::vector& x, + const std::vector& axis) { + size_t in_rank = x.size(); + size_t axis_size = axis.size(); + + auto axis_set = std::set(axis.begin(), axis.end()); + PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, + paddle::platform::errors::InvalidArgument( + "In an axis array, elements must be unique.")); + + PADDLE_ENFORCE_EQ(in_rank, axis_size, + paddle::platform::errors::InvalidArgument( + "The input dimension's size " + "should be equal to the axis's size. " + "But received dimension is %d, " + "axis's size is %d", + in_rank, axis_size)); + + PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, + paddle::platform::errors::InvalidArgument( + "Axis values must be ranging from 0 to (dims - 1).")); + + std::vector new_x(x.size()); + for (size_t i = 0; i < x.size(); i++) { + new_x[i] = x[axis[i]]; + } + return new_x; +} + +std::vector GetInputStrides(const ExecutionContext& ctx, + const std::string input_name) { + auto shape = ctx.Attr>("fused_reshape_" + input_name); + auto axis = ctx.Attr>("fused_transpose_" + input_name); + auto input_dims = ctx.Input(input_name)->dims(); + auto new_dims = input_dims; + if (!shape.empty() && !axis.empty()) { + new_dims = input_dims.reshape(shape).transpose(axis); + } + + auto& MatrixDimsFromVector = + input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; + paddle::operators::math::MatDescriptor mat_dim = + paddle::operators::math::CreateMatrixDescriptor( + MatrixDimsFromVector(new_dims), 0, + ctx.Attr(std::string("trans_") + + static_cast(std::tolower(input_name[0])))); + + std::vector strides; + if (!shape.empty()) { + auto shape2 = input_dims.reshape(shape); + strides.push_back(1); + for (auto i = shape2.size() - 1; i > 0; --i) { + strides.insert(strides.begin(), + strides.front() * static_cast(shape2[i])); + } + strides = Transpose(strides, axis); + if (shape.size() == 2) + strides.insert(strides.begin(), + static_cast(shape[0] * shape[1])); + mat_dim.stride_ = strides[0]; + if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin())); + } + return strides; +} + template class MatMulV2MKLDNNHandler : public paddle::platform::MKLDNNHandlerNoCachingT { @@ -37,7 +115,9 @@ class MatMulV2MKLDNNHandler paddle::platform::Place cpu_place, const std::vector& x_org_dims, bool trans_x, const std::vector& y_org_dims, bool trans_y, - bool is_output_fused) + bool is_output_fused, + const std::vector& x_strides_override, + const std::vector& y_strides_override) : paddle::platform::MKLDNNHandlerNoCachingT(engine, cpu_place) { // M X K * K X N @@ -64,16 +144,24 @@ class MatMulV2MKLDNNHandler y_strides.reserve(x_dims.size()); out_strides.reserve(x_dims.size()); - if (!trans_x) { - x_strides.insert(x_strides.end(), {M * K, K, 1}); + if (!x_strides_override.empty()) { + x_strides = x_strides_override; } else { - x_strides.insert(x_strides.end(), {M * K, 1, M}); + if (!trans_x) { + x_strides.insert(x_strides.end(), {M * K, K, 1}); + } else { + x_strides.insert(x_strides.end(), {M * K, 1, M}); + } } - if (!trans_y) { - y_strides.insert(y_strides.end(), {N * K, N, 1}); + if (!y_strides_override.empty()) { + y_strides = y_strides_override; } else { - y_strides.insert(y_strides.end(), {N * K, 1, K}); + if (!trans_y) { + y_strides.insert(y_strides.end(), {N * K, N, 1}); + } else { + y_strides.insert(y_strides.end(), {N * K, 1, K}); + } } out_strides.insert(out_strides.end(), {M * N, N, 1}); @@ -82,8 +170,12 @@ class MatMulV2MKLDNNHandler for (int i = x_dims.size() - 4; i >= 0; --i) { out_ddims[i] = std::max(x_dims[i], y_dims[i]); - x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; - y_strides[i] = y_dims[i + 1] * y_strides[i + 1]; + if (x_strides_override.empty()) { + x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; + } + if (y_strides_override.empty()) { + y_strides[i] = y_dims[i + 1] * y_strides[i + 1]; + } out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; } @@ -146,9 +238,11 @@ void ExecuteMatMulV2(const ExecutionContext& ctx, const Tensor* y, std::vector& y_dims, bool trans_y, Tensor* out, std::vector& out_dims, int execution_number = 0) { + std::vector x_strides_override = GetInputStrides(ctx, "X"); + std::vector y_strides_override = GetInputStrides(ctx, "Y"); MatMulV2MKLDNNHandler handler(onednn_engine, ctx.GetPlace(), x_dims, - trans_x, y_dims, trans_y, - IsOutputFused(ctx)); + trans_x, y_dims, trans_y, IsOutputFused(ctx), + x_strides_override, y_strides_override); const auto src_memory_p = handler.AcquireSrcMemory(x); const auto weights_memory_p = handler.AcquireWeightsMemory(y); @@ -171,6 +265,17 @@ void ExecuteMatMulV2(const ExecutionContext& ctx, out->set_format(format); } +DDim GetDimForInput(const paddle::framework::ExecutionContext& ctx, + const std::string& input_name) { + auto shape = ctx.Attr>("fused_reshape_" + input_name); + auto axis = ctx.Attr>("fused_transpose_" + input_name); + auto dim = ctx.Input(input_name)->dims(); + if (!shape.empty() && !axis.empty()) { + dim = dim.reshape(shape).transpose(axis); + } + return dim; +} + template class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { public: @@ -230,11 +335,11 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); - auto x_dims = vectorize(x->dims()); - auto y_dims = vectorize(y->dims()); + auto x_dims = vectorize(GetDimForInput(ctx, "X")); + auto y_dims = vectorize(GetDimForInput(ctx, "Y")); auto out_dims = vectorize(out->dims()); - int ndims = std::max(x->dims().size(), y->dims().size()); + int ndims = std::max(x_dims.size(), y_dims.size()); ndims = std::max(ndims, 3); std::vector x_bd_dims(ndims, 1); @@ -398,8 +503,6 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { }; } // anonymous namespace -namespace ops = paddle::operators; - REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace, MatMulV2MKLDNNKernel, MatMulV2MKLDNNKernel); diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 0627bf2123a..97c41443e04 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -638,6 +638,8 @@ class Quant2Int8MkldnnPass(object): graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') graph = self._apply_pass(graph, 'reshape_transpose_matmul_mkldnn_fuse_pass') + graph = self._apply_pass(graph, + 'reshape_transpose_matmul_v2_mkldnn_fuse_pass') graph = self._apply_pass( graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'], [self._var_quant_scales, self._get_data_layout(graph)]) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_reshape_transpose_matmul_v2_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_reshape_transpose_matmul_v2_fuse_pass.py new file mode 100644 index 00000000000..caf33156fc1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_reshape_transpose_matmul_v2_fuse_pass.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker + + +class TestReshapeTransposeMatmulV2OneDNNFusePass(InferencePassTest): + def setUp(self): + self.set_params() + self.tranpose_perm = [0, 2, 1, 3] + self.pass_name = 'reshape_transpose_matmul_v2_mkldnn_fuse_pass' + + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=self.data_shape, dtype="float32") + weight = fluid.layers.create_parameter( + shape=self.weight_shape, dtype="float32") + reshape = fluid.layers.reshape(data, shape=self.reshape_shape) + transpose = fluid.layers.transpose(reshape, self.tranpose_perm) + matmul = paddle.matmul( + transpose, + weight, + transpose_x=self.transpose_x, + transpose_y=self.transpose_y) + + self.fetch_list = [matmul] + self.enable_mkldnn = True + + def set_params(self): + self.data_shape = [-1, 128, 768] + self.weight_shape = [1, 12, 64, 128] + self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")} + self.transpose_x = False + self.transpose_y = False + self.reshape_shape = [0, 0, 12, 64] + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + def test_pass_compatible(self): + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class TestReshapeTransposeMatmulV2OneDNNFusePassBroadcast( + TestReshapeTransposeMatmulV2OneDNNFusePass): + def set_params(self): + self.data_shape = [2, 64, 16] + self.weight_shape = [1, 2, 8, 64] + self.feeds = {"data": np.random.random((2, 64, 16)).astype("float32")} + self.transpose_x = True + self.transpose_y = True + self.reshape_shape = [0, 0, 2, 8] + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py index 4ab15ac4480..d13012ee338 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py @@ -252,7 +252,7 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp): @skip_check_grad_ci(reason="DNNL's MatMul doesn't implement grad kernel.") -class TestMatMulOpReshapeTranspose(OpTest): +class TestReshapeTransposeMatMulOp(OpTest): def init_data_type(self): self.data_type_ = 'float32' @@ -267,10 +267,12 @@ class TestMatMulOpReshapeTranspose(OpTest): self.fused_reshape_Y = [] self.fused_transpose_Y = [] - def setUp(self): - # Set max isa, otherwise fails on SKX and earlier - os.environ["DNNL_MAX_CPU_ISA"] = "AVX" + def set_op_type_and_transpose_y_name(self): self.op_type = "matmul" + self.transpose_y_name = "transpose_Y" + + def setUp(self): + self.set_op_type_and_transpose_y_name() self._cpu_only = True self.use_mkldnn = True self.transpose_y = True @@ -280,7 +282,7 @@ class TestMatMulOpReshapeTranspose(OpTest): self.inputs = {'X': self.x, 'Y': self.y} self.attrs = { 'use_mkldnn': self.use_mkldnn, - 'transpose_Y': self.transpose_y + self.transpose_y_name: self.transpose_y } if len(self.fused_transpose_X) > 0: self.attrs['fused_transpose_X'] = self.fused_transpose_X @@ -297,7 +299,7 @@ class TestMatMulOpReshapeTranspose(OpTest): self.check_output() -class TestMatMulOpReshapeTranspose4DXFloat(TestMatMulOpReshapeTranspose): +class TestReshapeTransposeMatMulOp4DXFloat(TestReshapeTransposeMatMulOp): def generate_data(self): self.x = np.random.random([2, 128, 768]).astype("float32") self.y = np.random.random([2, 128, 768]).astype("float32").reshape( @@ -311,12 +313,12 @@ class TestMatMulOpReshapeTranspose4DXFloat(TestMatMulOpReshapeTranspose): self.y.transpose([0, 1, 3, 2])) -class TestMatMulOpReshapeTranspose4DXInt8(TestMatMulOpReshapeTranspose4DXFloat): +class TestReshapeTransposeMatMulOp4DXInt8(TestReshapeTransposeMatMulOp4DXFloat): def init_data_type(self): self.data_type_ = 'int8' -class TestMatMulOpReshapeTranspose4DYFloat(TestMatMulOpReshapeTranspose): +class TestReshapeTransposeMatMulOp4DYFloat(TestReshapeTransposeMatMulOp): def generate_data(self): self.x = np.random.random([2, 128, 768]).astype("float32").reshape( [2, 128, 12, 64]).transpose([0, 2, 1, 3]) @@ -329,12 +331,12 @@ class TestMatMulOpReshapeTranspose4DYFloat(TestMatMulOpReshapeTranspose): self.x, self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1])) -class TestMatMulOpReshapeTranspose4DYInt8(TestMatMulOpReshapeTranspose4DYFloat): +class TestReshapeTransposeMatMulOp4DYInt8(TestReshapeTransposeMatMulOp4DYFloat): def init_data_type(self): self.data_type_ = 'int8' -class TestMatMulOpReshapeTranspose4DXYFloat(TestMatMulOpReshapeTranspose): +class TestReshapeTransposeMatMulOp4DXYFloat(TestReshapeTransposeMatMulOp): def generate_data(self): self.x = np.random.random([2, 128, 768]).astype("float32") self.y = np.random.random([2, 128, 768]).astype("float32") @@ -347,13 +349,13 @@ class TestMatMulOpReshapeTranspose4DXYFloat(TestMatMulOpReshapeTranspose): self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1])) -class TestMatMulOpReshapeTranspose4DXYInt8( - TestMatMulOpReshapeTranspose4DXYFloat): +class TestReshapeTransposeMatMulOp4DXYInt8( + TestReshapeTransposeMatMulOp4DXYFloat): def init_data_type(self): self.data_type_ = 'int8' -class TestMatMulOpReshapeTranspose2DXFloat(TestMatMulOpReshapeTranspose): +class TestReshapeTransposeMatMulOp2DXFloat(TestReshapeTransposeMatMulOp): def generate_data(self): self.x = np.random.random([2, 5, 10]).astype("float32") self.y = np.random.random([2, 5, 10]).astype("float32").reshape( @@ -367,12 +369,12 @@ class TestMatMulOpReshapeTranspose2DXFloat(TestMatMulOpReshapeTranspose): self.y.transpose([1, 0])) -class TestMatMulOpReshapeTranspose2DXInt8(TestMatMulOpReshapeTranspose2DXFloat): +class TestReshapeTransposeMatMulOp2DXInt8(TestReshapeTransposeMatMulOp2DXFloat): def init_data_type(self): self.data_type_ = 'int8' -class TestMatMulOpReshapeTranspose2DYFloat(TestMatMulOpReshapeTranspose): +class TestReshapeTransposeMatMulOp2DYFloat(TestReshapeTransposeMatMulOp): def generate_data(self): self.x = np.random.random([2, 5, 10]).astype("float32").reshape( [10, 10]).transpose([1, 0]) @@ -384,12 +386,12 @@ class TestMatMulOpReshapeTranspose2DYFloat(TestMatMulOpReshapeTranspose): self.out = np.matmul(self.x, self.y.reshape([10, 10])) -class TestMatMulOpReshapeTranspose2DYInt8(TestMatMulOpReshapeTranspose2DYFloat): +class TestReshapeTransposeMatMulOp2DYInt8(TestReshapeTransposeMatMulOp2DYFloat): def init_data_type(self): self.data_type_ = 'int8' -class TestMatMulOpReshapeTranspose3DXFloat(TestMatMulOpReshapeTranspose): +class TestReshapeTransposeMatMulOp3DXFloat(TestReshapeTransposeMatMulOp): def generate_data(self): self.x = np.random.random([2, 2, 5, 5]).astype("float32") self.y = np.random.random([2, 2, 5, 5]).astype("float32").reshape( @@ -403,12 +405,12 @@ class TestMatMulOpReshapeTranspose3DXFloat(TestMatMulOpReshapeTranspose): self.y.transpose(0, 2, 1)) -class TestMatMulOpReshapeTranspose3DXInt8(TestMatMulOpReshapeTranspose3DXFloat): +class TestReshapeTransposeMatMulOp3DXInt8(TestReshapeTransposeMatMulOp3DXFloat): def init_data_type(self): self.data_type_ = 'int8' -class TestMatMulOpReshapeTranspose3DYFloat(TestMatMulOpReshapeTranspose): +class TestReshapeTransposeMatMulOp3DYFloat(TestReshapeTransposeMatMulOp): def generate_data(self): self.x = np.random.random([2, 2, 5, 5]).astype(self.data_type_).reshape( [2, 10, 5]).transpose([0, 2, 1]) @@ -420,7 +422,7 @@ class TestMatMulOpReshapeTranspose3DYFloat(TestMatMulOpReshapeTranspose): self.out = np.matmul(self.x, self.y.reshape([2, 10, 5])) -class TestMatMulOpReshapeTranspose3DYInt8(TestMatMulOpReshapeTranspose3DYFloat): +class TestReshapeTransposeMatMulOp3DYInt8(TestReshapeTransposeMatMulOp3DYFloat): def init_data_type(self): self.data_type_ = 'int8' diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index 9afe45efee3..5dd1795818c 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -29,7 +29,11 @@ from paddle.fluid.tests.unittests.mkldnn.test_matmul_mkldnn_op import ( TestMatMulOpTransposeReshapeOtherDimFloat, TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException, TestMatMulOpTransposeReshapeTransposeRankNotSupportedException, - TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException) + TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException, + TestReshapeTransposeMatMulOp, TestReshapeTransposeMatMulOp4DXFloat, + TestReshapeTransposeMatMulOp4DYFloat, TestReshapeTransposeMatMulOp4DXYFloat, + TestReshapeTransposeMatMulOp2DXFloat, TestReshapeTransposeMatMulOp2DYFloat, + TestReshapeTransposeMatMulOp3DXFloat, TestReshapeTransposeMatMulOp3DYFloat) def reference_matmul(X, Y, transpose_x=False, transpose_y=False): @@ -434,6 +438,61 @@ class TestMatMulV2OpTransposeReshapeTransposeRankNotSupportedException( self.op_type = "matmul_v2" +class TestMatMulV2OpReshapeTranspose(TestReshapeTransposeMatMulOp): + def set_op_type_and_transpose_y_name(self): + self.op_type = "matmul_v2" + self.transpose_y_name = "trans_y" + + +class TestMatMulV2OpReshapeTranspose4DXFloat( + TestReshapeTransposeMatMulOp4DXFloat): + def set_op_type_and_transpose_y_name(self): + self.op_type = "matmul_v2" + self.transpose_y_name = "trans_y" + + +class TestMatMulV2OpReshapeTranspose4DYFloat( + TestReshapeTransposeMatMulOp4DYFloat): + def set_op_type_and_transpose_y_name(self): + self.op_type = "matmul_v2" + self.transpose_y_name = "trans_y" + + +class TestMatMulV2OpReshapeTranspose4DXYFloat( + TestReshapeTransposeMatMulOp4DXYFloat): + def set_op_type_and_transpose_y_name(self): + self.op_type = "matmul_v2" + self.transpose_y_name = "trans_y" + + +class TestMatMulV2OpReshapeTranspose2DXFloat( + TestReshapeTransposeMatMulOp2DXFloat): + def set_op_type_and_transpose_y_name(self): + self.op_type = "matmul_v2" + self.transpose_y_name = "trans_y" + + +class TestMatMulV2OpReshapeTranspose2DYFloat( + TestReshapeTransposeMatMulOp2DYFloat): + def set_op_type_and_transpose_y_name(self): + self.op_type = "matmul_v2" + self.transpose_y_name = "trans_y" + + +class TestMatMulV2OpReshapeTranspose3DXFloat( + TestReshapeTransposeMatMulOp3DXFloat): + def set_op_type_and_transpose_y_name(self): + self.op_type = "matmul_v2" + self.transpose_y_name = "trans_y" + + +class TestMatMulV2OpReshapeTranspose3DYFloat( + TestReshapeTransposeMatMulOp3DYFloat): + def set_op_type_and_transpose_y_name(self): + self.op_type = "matmul_v2" + self.transpose_y_name = "trans_y" + + if __name__ == "__main__": paddle.enable_static() unittest.main() -- GitLab