From e1a7a880576f072bf27cbda568bcc4a5bcfb25fd Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Tue, 28 Apr 2020 11:31:35 +0200 Subject: [PATCH] added reshape transpose matmul fuse pass (#23754) --- paddle/fluid/framework/ddim.cc | 28 +-- paddle/fluid/framework/ir/CMakeLists.txt | 4 + .../framework/ir/graph_pattern_detector.cc | 52 +++++- .../framework/ir/graph_pattern_detector.h | 23 +++ ...shape_transpose_matmul_mkldnn_fuse_pass.cc | 119 ++++++++++++ ...eshape_transpose_matmul_mkldnn_fuse_pass.h | 41 +++++ ...ranspose_matmul_mkldnn_fuse_pass_tester.cc | 124 +++++++++++++ .../fluid/framework/ir/pass_tester_helper.h | 25 ++- .../inference/api/paddle_pass_builder.cc | 13 +- paddle/fluid/operators/matmul_op.cc | 47 ++++- .../operators/mkldnn/matmul_mkldnn_op.cc | 131 +++++++++---- paddle/fluid/platform/device_context.cc | 5 +- .../quantization/qat2_int8_mkldnn_pass.py | 2 + .../unittests/mkldnn/test_matmul_mkldnn_op.py | 174 ++++++++++++++++++ 14 files changed, 724 insertions(+), 64 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index 11e11e7f822..799deec1b69 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ddim.h" +#include #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -158,6 +159,11 @@ DDim DDim::transpose(const std::vector& axis) const { size_t in_rank = in_dims.size(); size_t axis_size = axis.size(); + auto axis_set = std::set(axis.begin(), axis.end()); + PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, + platform::errors::InvalidArgument( + "In an axis array, elements must be unique.")); + PADDLE_ENFORCE_EQ( in_rank, axis_size, platform::errors::InvalidArgument("The input dimension's size " @@ -166,25 +172,9 @@ DDim DDim::transpose(const std::vector& axis) const { "axis's size is %d", in_rank, axis_size)); - std::vector count(axis_size, 0); - for (size_t i = 0; i < axis_size; i++) { - PADDLE_ENFORCE_LT(axis[i], static_cast(axis_size), - platform::errors::InvalidArgument( - "ValueError: Each element of axis must appear " - "exactly once in the range from 0 to (dims - 1), " - "where the dims is the axis's size, " - "but received axis[%d] is %d, axis_size is %d", - i, axis[i], axis_size)); - PADDLE_ENFORCE_EQ( - ++count[axis[i]], 1, - platform::errors::InvalidArgument( - "ValueError: Each element of axis should " - "be a unique value range from 0 to (dims - 1), " - "where the dims is the axis's size, " - "unique value means this axis value can appear only once. " - "But received count[axis[%d]] is %d", - i, count[axis[i]])); - } + PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, + platform::errors::InvalidArgument( + "Axis values must be ranging from 0 to (dims - 1).")); DDim out_dims(in_dims); for (size_t i = 0; i < axis_size; i++) { diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 5dbac6239c1..7c49bc1dcd0 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -97,6 +97,7 @@ if(WITH_MKLDNN) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) + pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) endif() @@ -145,5 +146,8 @@ if (WITH_MKLDNN) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) + if(NOT WITH_COVERAGE) + cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass) + endif() cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) endif () diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 02861b6edcd..f91bbdf0a29 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -33,7 +33,6 @@ namespace paddle { namespace framework { namespace ir { -using string::PrettyLogEndl; using string::PrettyLog; using string::Style; @@ -2148,6 +2147,57 @@ void patterns::DeleteQuantDequantOpPattern::operator()() { any_op2->LinksFrom({quant_dequant_out}); } +PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( + bool with_reshape_xshape, bool with_transpose_xshape) { + auto reshape_op = + pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2"); + auto transpose_op = + pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); + auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + + auto reshape_in = pattern->NewNode(reshape_in_repr()) + ->AsInput() + ->assert_is_op_input("reshape2", "X"); + + auto reshape_out = pattern->NewNode(reshape_out_repr()) + ->AsIntermediate() + ->assert_is_op_input("transpose2", "X") + ->assert_is_op_output("reshape2", "Out"); + if (!with_reshape_xshape) + reshape_out->assert_is_only_output_of_op("reshape2"); + + auto reshape_xshape = with_reshape_xshape + ? pattern->NewNode(reshape_xshape_repr()) + ->AsIntermediate() + ->assert_is_op_output("reshape2", "XShape") + : nullptr; + + auto transpose_out = pattern->NewNode(transpose_out_repr()) + ->AsIntermediate() + ->assert_is_op_input("matmul") + ->assert_is_op_output("transpose2", "Out"); + if (!with_transpose_xshape) + transpose_out->assert_is_only_output_of_op("transpose2"); + + auto transpose_xshape = + with_transpose_xshape + ? pattern->NewNode(transpose_xshape_repr()) + ->AsIntermediate() + ->assert_is_op_output("transpose2", "XShape") + : nullptr; + + auto matmul_out = pattern->NewNode(matmul_out_repr()) + ->AsOutput() + ->assert_is_op_output("matmul", "Out"); + + reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out}); + if (with_reshape_xshape) reshape_op->LinksTo({reshape_xshape}); + transpose_op->LinksFrom({reshape_out}).LinksTo({transpose_out}); + if (with_transpose_xshape) transpose_op->LinksTo({transpose_xshape}); + matmul_op->LinksFrom({transpose_out}).LinksTo({matmul_out}); + return matmul_out; +} + PDNode *patterns::MatmulTransposeReshapePattern::operator()() { auto reshape_op = pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 0f4ca1aa96b..4d7a4e283d3 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1210,6 +1210,29 @@ struct DeleteQuantDequantOpPattern : public PatternBase { PATTERN_DECL_NODE(any_op2); }; +// Reshape + Transpose + Matmul +// named nodes: +// reshape_op, reshape_out, reshape_xshape, +// transpose_op, transpose_out, transpose_xshape, +// matmul_op, matmul_out +struct ReshapeTransposeMatmulPattern : public PatternBase { + ReshapeTransposeMatmulPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "reshape_transpose_matmul") {} + + PDNode* operator()(bool with_reshape_xshape, bool with_transpose_xshape); + + PATTERN_DECL_NODE(reshape_in); + PATTERN_DECL_NODE(reshape_op); + PATTERN_DECL_NODE(reshape_out); + PATTERN_DECL_NODE(reshape_xshape); + PATTERN_DECL_NODE(transpose_op); + PATTERN_DECL_NODE(transpose_out); + PATTERN_DECL_NODE(transpose_xshape); + PATTERN_DECL_NODE(matmul_op); + PATTERN_DECL_NODE(matmul_out); +}; + // Matmul + Transpose + Reshape struct MatmulTransposeReshapePattern : public PatternBase { MatmulTransposeReshapePattern(PDPattern* pattern, diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc new file mode 100644 index 00000000000..b4c53ec5f91 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +void ReshapeTransposeMatmulMkldnnFusePass::Fuse( + Graph *graph, bool with_reshape_xshape, bool with_transpose_xshape) const { + GraphPatternDetector gpd; + patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(), + name_scope_); + + rtm_pattern(with_reshape_xshape, with_transpose_xshape); + + int found_reshape_transpose_matmul_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse"; + GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern); + ir::Node *reshape_xshape{nullptr}; + if (with_reshape_xshape) { + GET_IR_NODE_FROM_SUBGRAPH(reshape_xshape1, reshape_xshape, rtm_pattern); + reshape_xshape = reshape_xshape1; + } + GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, rtm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, rtm_pattern); + ir::Node *transpose_xshape{nullptr}; + if (with_transpose_xshape) { + GET_IR_NODE_FROM_SUBGRAPH(transpose_xshape1, transpose_xshape, + rtm_pattern); + transpose_xshape = transpose_xshape1; + } + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, rtm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, rtm_pattern); + + auto reshape_shape = + boost::get>(reshape_op->Op()->GetAttr("shape")); + auto transpose_axis = + boost::get>(transpose_op->Op()->GetAttr("axis")); + + OpDesc *matmul_desc = matmul_op->Op(); + std::string input_var_name = transpose_out->Name(); + + auto UpdateMatmul = [&](std::string matmul_input_name) { + matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()}); + matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape); + matmul_desc->SetAttr("fused_transpose_" + matmul_input_name, + transpose_axis); + }; + if (matmul_desc->Inputs().at("X").at(0) == input_var_name) { + UpdateMatmul("X"); + } else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) { + UpdateMatmul("Y"); + } else { + throw platform::errors::InvalidArgument( + "Unexpected input to MatMul encountered."); + } + + std::unordered_set nodes_to_remove{ + reshape_op, reshape_out, transpose_op, transpose_out}; + if (with_reshape_xshape) nodes_to_remove.insert(reshape_xshape); + if (with_transpose_xshape) nodes_to_remove.insert(transpose_xshape); + GraphSafeRemoveNodes(graph, nodes_to_remove); + + IR_NODE_LINK_TO(reshape_in, matmul_op); + + ++found_reshape_transpose_matmul_count; + }; + + gpd(graph, handler); + AddStatis(found_reshape_transpose_matmul_count); + + std::stringstream msg_ss; + msg_ss << "--- Fused " << found_reshape_transpose_matmul_count + << " ReshapeTransposeMatmulMkldnn patterns"; + if (with_reshape_xshape) msg_ss << " with reshape's xshape"; + if (with_transpose_xshape) msg_ss << " with transpose's xshape"; + string::PrettyLogDetail(msg_ss.str().c_str()); +} + +void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument should not be NULL.")); + FusePassBase::Init(name_scope_, graph); + + Fuse(graph, false, false); + Fuse(graph, false, true); + Fuse(graph, true, false); + Fuse(graph, true, true); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(reshape_transpose_matmul_mkldnn_fuse_pass, + paddle::framework::ir::ReshapeTransposeMatmulMkldnnFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h new file mode 100644 index 00000000000..eab9f095623 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h @@ -0,0 +1,41 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +/* + * Fuse Reshape->Transpose->MatMul when MatMul uses mkldnn. + */ +class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { + public: + virtual ~ReshapeTransposeMatmulMkldnnFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + const std::string name_scope_{"reshape_transpose_matmul_fuse"}; + + void Fuse(Graph* graph, bool with_reshape_xshape, + bool with_transpose_xshape) const; +}; +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc new file mode 100644 index 00000000000..a9392b998a9 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" + +#include +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "w1", {768, 768}); + AddVarToScope(param_scope, "bias1", {768}); + AddVarToScope(param_scope, "w2", {768, 768}); + AddVarToScope(param_scope, "bias2", {768}); + return param_scope; +} + +void TestMain(bool with_xshapes) { + // inputs operator output + // ----------------------------------------------- + // a1,w1,bias1 fc -> b1 + // b1 reshape -> c1 + // c1 transpose -> d1 + // a2,w2,bias2 fc -> b2 + // b2 reshape -> c2 + // c2 transpose -> d2 + // (d1, d2) matmul -> (...) + Layers layers; + auto* a1 = layers.data("a1", {-1, 128, 768}); + auto* w1 = layers.data("w1", {768, 768}, true); + auto* bias1 = layers.data("bias1", {768}, true); + auto* b1 = layers.fc(a1, w1, bias1, 2); + b1->SetShape({-1, 128, 768}); + auto* c1 = layers.reshape2(b1, {0, 0, 12, 64}, with_xshapes); + c1->SetShape({-1, 128, 12, 64}); + auto* d1 = layers.transpose2(c1, {0, 2, 1, 3}, with_xshapes); + d1->SetShape({-1, 12, 128, 64}); + auto* a2 = layers.data("a2", {-1, 128, 768}); + auto* w2 = layers.data("w2", {768, 768}, true); + auto* bias2 = layers.data("bias2", {768}, true); + auto* b2 = layers.fc(a2, w2, bias2, 2); + b2->SetShape({-1, 128, 768}); + auto* c2 = layers.reshape2(b2, {0, 0, 12, 64}); + c2->SetShape({-1, 128, 12, 64}); + auto* d2 = layers.transpose2(c2, {0, 2, 1, 3}); + d2->SetShape({-1, 12, 128, 64}); + layers.matmul(d1, d2); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + int num_reshape_nodes_before = GetNumOpNodes(graph, "reshape2"); + int num_transpose_nodes_before = GetNumOpNodes(graph, "transpose2"); + int total_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + auto pass = + PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass"); + graph.reset(pass->Apply(graph.release())); + + int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2"); + int num_transpose_nodes_after = GetNumOpNodes(graph, "transpose2"); + int total_nodes_after = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + EXPECT_EQ(num_reshape_nodes_before, 2); + EXPECT_EQ(num_reshape_nodes_after, 0); + EXPECT_EQ(num_transpose_nodes_before, 2); + EXPECT_EQ(num_transpose_nodes_after, 0); + int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out + if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape + EXPECT_EQ(total_nodes_before - removed, total_nodes_after); + auto* matmul_op_desc = GetOpNodes(graph, "matmul").at(0)->Op(); + + auto check = [&matmul_op_desc](std::string a) { + std::string shape_str = "fused_reshape_" + a; + EXPECT_THAT(matmul_op_desc->GetAttrIfExists>(shape_str), + testing::ElementsAre(0, 0, 12, 64)); + std::string axis_str = "fused_transpose_" + a; + EXPECT_THAT(matmul_op_desc->GetAttrIfExists>(axis_str), + testing::ElementsAre(0, 2, 1, 3)); + }; + check("X"); + check("Y"); +} + +TEST(ReshapeTransposeMatmulMkldnnFusePass, + both_matmul_inputs_reshape_transpose) { + TestMain(false); +} + +TEST(ReshapeTransposeMatmulMkldnnFusePass, + both_matmul_inputs_reshape_transpose_one_with_xshapes) { + TestMain(true); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index ac438d368de..9001402233b 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -258,23 +258,33 @@ struct Layers { return out; } - VarDesc* transpose2(VarDesc* x, std::vector axis) { + VarDesc* transpose2(VarDesc* x, std::vector axis, + bool with_xshape = false) { VarDesc* out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); op->SetType("transpose2"); op->SetInput("X", {x->Name()}); op->SetAttr("axis", axis); op->SetOutput("Out", {out->Name()}); + if (with_xshape) { + VarDesc* xshape = lod_tensor(unique_name()); + op->SetOutput("XShape", {xshape->Name()}); + } return out; } - VarDesc* reshape2(VarDesc* x, std::vector shape) { + VarDesc* reshape2(VarDesc* x, std::vector shape, + bool with_xshape = false) { VarDesc* out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); op->SetType("reshape2"); op->SetInput("X", {x->Name()}); op->SetAttr("shape", shape); op->SetOutput("Out", {out->Name()}); + if (with_xshape) { + VarDesc* xshape = lod_tensor(unique_name()); + op->SetOutput("XShape", {xshape->Name()}); + } return out; } @@ -579,6 +589,17 @@ static std::string DebugString(const std::unique_ptr& graph) { return DebugString(graph.get()); } +static std::vector GetOpNodes(const std::unique_ptr& graph, + std::string op_type) { + std::vector rc; + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op() && node->Op()->Type() == op_type) { + rc.push_back(node); + } + } + return rc; +} + static int GetNumOpNodes(const std::unique_ptr& graph, std::string op_type) { int num_nodes = 0; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index f04b5692fff..c07ac11e278 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -191,12 +191,13 @@ void CpuPassStrategy::EnableMKLDNN() { "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass", - "conv_relu_mkldnn_fuse_pass", // - "conv_leaky_relu_mkldnn_fuse_pass", // - "conv_relu6_mkldnn_fuse_pass", // - "conv_swish_mkldnn_fuse_pass", // - "scale_matmul_fuse_pass", // - "matmul_transpose_reshape_fuse_pass", // + "conv_relu_mkldnn_fuse_pass", // + "conv_leaky_relu_mkldnn_fuse_pass", // + "conv_relu6_mkldnn_fuse_pass", // + "conv_swish_mkldnn_fuse_pass", // + "scale_matmul_fuse_pass", // + "reshape_transpose_matmul_mkldnn_fuse_pass", // + "matmul_transpose_reshape_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", "mkldnn_inplace_pass", // This pass should be activated after diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index a91a4d55a1d..98277df454e 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -318,6 +318,36 @@ class MatMulGradKernel : public framework::OpKernel { } }; +framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, + std::string input_name) { + auto shape = ctx.Attrs().Get>("fused_reshape_" + input_name); + auto axis = + ctx.Attrs().Get>("fused_transpose_" + input_name); + auto dim = ctx.GetInputDim(input_name); + if (!shape.empty() && !axis.empty()) { + PADDLE_ENFORCE_GE( + shape.size(), 2, + platform::errors::InvalidArgument( + "shape_%s attribute of MatMulOp was implemented for 2, 3 " + "or 4 dimensions.", + input_name)); + PADDLE_ENFORCE_LE( + shape.size(), 4, + platform::errors::InvalidArgument( + "shape_%s attribute of MatMulOp was implemented for 2, 3 " + "or 4 dimensions.", + input_name)); + PADDLE_ENFORCE_EQ( + shape.size(), axis.size(), + platform::errors::InvalidArgument( + "Ranks of shape_%s and axis_%s attributes of MatMulOp " + "must be equal.", + input_name, input_name)); + dim = dim.reshape(shape).transpose(axis); + } + return dim; +} + class MatMulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -328,9 +358,8 @@ class MatMulOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul"); OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "matmul"); - auto dim_x = context->GetInputDim("X"); - auto dim_y = context->GetInputDim("Y"); - + auto dim_x = GetDimForInput(*context, "X"); + auto dim_y = GetDimForInput(*context, "Y"); auto mat_dim_x = math::CreateMatrixDescriptor(RowMatrixFromVector(dim_x), 0, context->Attrs().Get("transpose_X")); @@ -484,6 +513,18 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { "use_mkldnn", "(bool, default false) Indicates if MKL-DNN kernel will be used") .SetDefault(false); + AddAttr>("fused_reshape_X", + R"DOC(Shape of fused reshape of `X` input.)DOC") + .SetDefault({}); + AddAttr>("fused_reshape_Y", + R"DOC(Shape of fused reshape of `Y` input.)DOC") + .SetDefault({}); + AddAttr>("fused_transpose_X", + R"DOC(Axis of fused transpose of `X` input.)DOC") + .SetDefault({}); + AddAttr>("fused_transpose_Y", + R"DOC(Axis of fused transpose of `Y` input.)DOC") + .SetDefault({}); AddAttr>( "fused_reshape_Out", R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, " diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index a9dc515a0f6..bc1a8522b0f 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -23,12 +23,12 @@ namespace operators { using dnnl::memory; using dnnl::primitive; -using platform::to_void_cast; using framework::DataLayout; +using framework::ExecutionContext; using platform::GetMKLDNNFormat; -using platform::MKLDNNGetDataType; using platform::MKLDNNDeviceContext; -using framework::ExecutionContext; +using platform::MKLDNNGetDataType; +using platform::to_void_cast; using Tensor = framework::Tensor; template @@ -86,6 +86,74 @@ class MatMulFactory { return dnnl::memory(md, engine_, to_void_cast(data)); } + std::vector Transpose(const std::vector& x, + const std::vector& axis) { + size_t in_rank = x.size(); + size_t axis_size = axis.size(); + + auto axis_set = std::set(axis.begin(), axis.end()); + PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, + platform::errors::InvalidArgument( + "In an axis array, elements must be unique.")); + + PADDLE_ENFORCE_EQ( + in_rank, axis_size, + platform::errors::InvalidArgument("The input dimension's size " + "should be equal to the axis's size. " + "But received dimension is %d, " + "axis's size is %d", + in_rank, axis_size)); + + PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, + platform::errors::InvalidArgument( + "Axis values must be ranging from 0 to (dims - 1).")); + + std::vector new_x(x.size()); + for (size_t i = 0; i < x.size(); i++) { + new_x[i] = x[axis[i]]; + } + return new_x; + } + + std::pair GetInputDimsAndStrides( + const ExecutionContext& ctx, std::string input_name) { + auto shape = ctx.Attr>("fused_reshape_" + input_name); + auto axis = ctx.Attr>("fused_transpose_" + input_name); + auto input_dims = ctx.Input(input_name)->dims(); + auto new_dims = input_dims; + if (!shape.empty() && !axis.empty()) { + new_dims = input_dims.reshape(shape).transpose(axis); + } + + auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector + : ColumnMatrixDimsFromVector; + math::MatDescriptor mat_dim = + math::CreateMatrixDescriptor(MatrixDimsFromVector(new_dims), 0, + ctx.Attr("transpose_" + input_name)); + + memory::dims strides; + if (!shape.empty()) { + auto shape2 = input_dims.reshape(shape); + strides.push_back(1); + for (auto i = shape2.size() - 1; i > 0; --i) { + strides.insert(strides.begin(), strides.front() * shape2[i]); + } + strides = Transpose(strides, axis); + if (shape.size() == 4) + strides.erase(strides.begin()); + else if (shape.size() == 2) + strides.insert(strides.begin(), shape[0] * shape[1]); + mat_dim.stride_ = strides[0]; + if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin())); + } + return std::make_pair(mat_dim, strides); + } + + bool IsInputFused(const ExecutionContext& ctx) const { + return !(ctx.Attr>("fused_reshape_X").empty() && + ctx.Attr>("fused_reshape_Y").empty()); + } + bool IsOutputFused(const ExecutionContext& ctx) const { auto& fused_reshape_Out = ctx.Attr>("fused_reshape_Out"); auto& fused_transpose_Out = @@ -100,12 +168,12 @@ class MatMulFactory { } MatMulDims GetMatmulDims(const ExecutionContext& ctx) { - auto mat_dim_x = math::CreateMatrixDescriptor( - RowMatrixDimsFromVector(ctx.Input("X")->dims()), 0, - ctx.Attr("transpose_X")); - auto mat_dim_y = math::CreateMatrixDescriptor( - ColumnMatrixDimsFromVector(ctx.Input("Y")->dims()), 0, - ctx.Attr("transpose_Y")); + math::MatDescriptor mat_dim_x; + memory::dims strides_x; + std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X"); + math::MatDescriptor mat_dim_y; + memory::dims strides_y; + std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y"); const auto x_bs = mat_dim_x.batch_size_; const auto y_bs = mat_dim_y.batch_size_; @@ -122,26 +190,27 @@ class MatMulFactory { batch_size_ = 1; auto b = BS; - if (BS > 1 && IsOutputFused(ctx)) { - batch_size_ = ctx.Input("X")->dims()[0]; + if (BS > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) { + auto& x_dims = ctx.Input("X")->dims(); + auto& y_dims = ctx.Input("Y")->dims(); + batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0]; b = BS / batch_size_; } memory::dims x_dims = {b, M, K}; memory::dims y_dims = {b, K, N}; memory::dims out_dims = {b, M, N}; - size_t x_size = b * M * K * sizeof(XT); - size_t y_size = b * K * N * sizeof(YT); - size_t out_size = b * M * N * sizeof(OT); - offsets_ = {x_size, y_size, out_size}; + x_offset_ = b * M * K * sizeof(XT); + y_offset_ = b * K * N * sizeof(YT); + out_offset_ = b * M * N * sizeof(OT); // Translate transA and transB - memory::dims strides_x = !ctx.Attr("transpose_X") - ? memory::dims{M * K, K, 1} - : memory::dims{M * K, 1, M}; - memory::dims strides_y = !ctx.Attr("transpose_Y") - ? memory::dims{N * K, N, 1} - : memory::dims{N * K, 1, K}; + if (strides_x.empty()) + strides_x = !ctx.Attr("transpose_X") ? memory::dims{M * K, K, 1} + : memory::dims{M * K, 1, M}; + if (strides_y.empty()) + strides_y = !ctx.Attr("transpose_Y") ? memory::dims{N * K, N, 1} + : memory::dims{N * K, 1, K}; memory::dims out_strides = memory::dims{M * N, N, 1}; CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides); @@ -187,12 +256,10 @@ class MatMulFactory { void Execute() { dnnl::stream stream(engine_); - auto offsets = offsets_; - unsigned bs = batch_size_; void* x_ptr = x_mem_.get_data_handle(); void* y_ptr = y_mem_.get_data_handle(); void* out_ptr = out_mem_.get_data_handle(); - for (unsigned i = 0; i < bs; i++) { + for (uint16_t i = 0; i < batch_size_; i++) { x_mem_.set_data_handle(x_ptr); y_mem_.set_data_handle(y_ptr); out_mem_.set_data_handle(out_ptr); @@ -201,9 +268,9 @@ class MatMulFactory { {MKLDNN_ARG_WEIGHTS, y_mem_}, {MKLDNN_ARG_DST, out_mem_}, }); - x_ptr = static_cast(x_ptr) + offsets.x_offset; - y_ptr = static_cast(y_ptr) + offsets.y_offset; - out_ptr = static_cast(out_ptr) + offsets.out_offset; + x_ptr = static_cast(x_ptr) + x_offset_; + y_ptr = static_cast(y_ptr) + y_offset_; + out_ptr = static_cast(out_ptr) + out_offset_; } stream.wait(); } @@ -243,21 +310,21 @@ class MatMulFactory { dnnl::memory y_mem_; dnnl::memory out_mem_; dnnl::matmul matmul_prim_; - memory_offsets offsets_; - unsigned batch_size_; + uint32_t x_offset_; + uint32_t y_offset_; + uint32_t out_offset_; + uint16_t batch_size_; bool initialized_ = false; }; template static std::shared_ptr> GetPrimitiveFactory( const ExecutionContext& ctx) { - const auto x_dims = framework::vectorize(ctx.Input("X")->dims()); - const auto y_dims = framework::vectorize(ctx.Input("Y")->dims()); const auto& out_name = ctx.OutputName("Out"); const auto& dev_ctx = ctx.template device_context(); const std::string key = - platform::CreateKey(platform::ThreadIDasStr(), x_dims, y_dims, out_name); + platform::CreateKey(platform::ThreadIDasStr(), out_name); auto factory = std::static_pointer_cast>(dev_ctx.GetBlob(key)); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 3a1405e95c4..634251a8560 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -408,7 +408,10 @@ framework::DataLayout get_cur_paddle_data_layout(void) { return cur_paddle_data_layout; } -void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } +void MKLDNNDeviceContext::ResetBlobMap() const { + VLOG(3) << "Clearing DNNL cache."; + p_blobmap_->clear(); +} size_t MKLDNNDeviceContext::GetShapeBlobSize() const { std::lock_guard lock(*p_mutex_); diff --git a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py index a0997ac2cfc..64a1c3cf4ed 100644 --- a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py @@ -500,6 +500,8 @@ class Qat2Int8MkldnnPass(object): graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()), graph.all_op_nodes()) graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') + graph = self._apply_pass(graph, + 'reshape_transpose_matmul_mkldnn_fuse_pass') graph = self._apply_pass( graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'], [self._var_quant_scales, self._get_data_layout()]) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py index b6b5f0f134b..bd8842da03e 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py @@ -161,6 +161,180 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp): self.attrs = {'force_fp32_output': True} +@skip_check_grad_ci(reason="DNNL's MatMul doesn't implement grad kernel.") +class TestMatMulOpReshapeTranspose(OpTest): + def init_data_type(self): + self.data_type_ = 'float32' + + def generate_data(self): + self.x = np.random.random([2, 128, 768]).astype("float32").reshape( + [2, 128, 12, 64]).transpose([0, 2, 1, 3]) + self.y = np.random.random([2, 128, 768]).astype("float32").reshape( + [2, 128, 12, 64]).transpose([0, 2, 1, 3]) + self.out = np.matmul(self.x, self.y.transpose([0, 1, 3, 2])) + self.fused_reshape_X = [] + self.fused_transpose_X = [] + self.fused_reshape_Y = [] + self.fused_transpose_Y = [] + + def setUp(self): + # Set max isa, otherwise fails on SKX and earlier + os.environ["DNNL_MAX_CPU_ISA"] = "AVX" + self.op_type = "matmul" + self._cpu_only = True + self.use_mkldnn = True + self.transpose_y = True + self.init_data_type() + self.generate_data() + + self.inputs = {'X': self.x, 'Y': self.y} + self.attrs = { + 'use_mkldnn': self.use_mkldnn, + 'transpose_Y': self.transpose_y + } + if len(self.fused_transpose_X) > 0: + self.attrs['fused_transpose_X'] = self.fused_transpose_X + if len(self.fused_transpose_Y) > 0: + self.attrs['fused_transpose_Y'] = self.fused_transpose_Y + if len(self.fused_reshape_X) > 0: + self.attrs['fused_reshape_X'] = self.fused_reshape_X + if len(self.fused_reshape_Y) > 0: + self.attrs['fused_reshape_Y'] = self.fused_reshape_Y + + self.outputs = {'Out': self.out} + + def test_check_output(self): + self.check_output() + + +class TestMatMulOpReshapeTranspose4DXFloat(TestMatMulOpReshapeTranspose): + def generate_data(self): + self.x = np.random.random([2, 128, 768]).astype("float32") + self.y = np.random.random([2, 128, 768]).astype("float32").reshape( + [2, 128, 12, 64]).transpose([0, 2, 1, 3]) + self.fused_transpose_X = [0, 2, 1, 3] + self.fused_reshape_X = [0, 0, 12, 64] + self.fused_transpose_Y = [] + self.fused_reshape_Y = [] + self.out = np.matmul( + self.x.reshape([2, 128, 12, 64]).transpose([0, 2, 1, 3]), + self.y.transpose([0, 1, 3, 2])) + + +class TestMatMulOpReshapeTranspose4DXInt8(TestMatMulOpReshapeTranspose4DXFloat): + def init_data_type(self): + self.data_type_ = 'int8' + + +class TestMatMulOpReshapeTranspose4DYFloat(TestMatMulOpReshapeTranspose): + def generate_data(self): + self.x = np.random.random([2, 128, 768]).astype("float32").reshape( + [2, 128, 12, 64]).transpose([0, 2, 1, 3]) + self.y = np.random.random([2, 128, 768]).astype("float32") + self.fused_transpose_X = [] + self.fused_reshape_X = [] + self.fused_transpose_Y = [0, 2, 1, 3] + self.fused_reshape_Y = [0, 0, 12, 64] + self.out = np.matmul( + self.x, self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1])) + + +class TestMatMulOpReshapeTranspose4DYInt8(TestMatMulOpReshapeTranspose4DYFloat): + def init_data_type(self): + self.data_type_ = 'int8' + + +class TestMatMulOpReshapeTranspose4DXYFloat(TestMatMulOpReshapeTranspose): + def generate_data(self): + self.x = np.random.random([2, 128, 768]).astype("float32") + self.y = np.random.random([2, 128, 768]).astype("float32") + self.fused_transpose_X = [0, 2, 1, 3] + self.fused_reshape_X = [0, 0, 12, 64] + self.fused_transpose_Y = [0, 2, 1, 3] + self.fused_reshape_Y = [0, 0, 12, 64] + self.out = np.matmul( + self.x.reshape([2, 128, 12, 64]).transpose([0, 2, 1, 3]), + self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1])) + + +class TestMatMulOpReshapeTranspose4DXYInt8( + TestMatMulOpReshapeTranspose4DXYFloat): + def init_data_type(self): + self.data_type_ = 'int8' + + +class TestMatMulOpReshapeTranspose2DXFloat(TestMatMulOpReshapeTranspose): + def generate_data(self): + self.x = np.random.random([2, 5, 10]).astype("float32") + self.y = np.random.random([2, 5, 10]).astype("float32").reshape( + [10, 10]).transpose([1, 0]) + self.fused_transpose_X = [1, 0] + self.fused_reshape_X = [10, 10] + self.fused_transpose_Y = [] + self.fused_reshape_Y = [] + self.out = np.matmul( + self.x.reshape([10, 10]).transpose([1, 0]), + self.y.transpose([1, 0])) + + +class TestMatMulOpReshapeTranspose2DXInt8(TestMatMulOpReshapeTranspose2DXFloat): + def init_data_type(self): + self.data_type_ = 'int8' + + +class TestMatMulOpReshapeTranspose2DYFloat(TestMatMulOpReshapeTranspose): + def generate_data(self): + self.x = np.random.random([2, 5, 10]).astype("float32").reshape( + [10, 10]).transpose([1, 0]) + self.y = np.random.random([2, 5, 10]).astype("float32") + self.fused_transpose_X = [] + self.fused_reshape_X = [] + self.fused_transpose_Y = [1, 0] + self.fused_reshape_Y = [10, 10] + self.out = np.matmul(self.x, self.y.reshape([10, 10])) + + +class TestMatMulOpReshapeTranspose2DYInt8(TestMatMulOpReshapeTranspose2DYFloat): + def init_data_type(self): + self.data_type_ = 'int8' + + +class TestMatMulOpReshapeTranspose3DXFloat(TestMatMulOpReshapeTranspose): + def generate_data(self): + self.x = np.random.random([2, 2, 5, 5]).astype("float32") + self.y = np.random.random([2, 2, 5, 5]).astype("float32").reshape( + [2, 10, 5]).transpose([0, 2, 1]) + self.fused_transpose_X = [0, 2, 1] + self.fused_reshape_X = [2, 10, 5] + self.fused_transpose_Y = [] + self.fused_reshape_Y = [] + self.out = np.matmul( + self.x.reshape([2, 10, 5]).transpose(0, 2, 1), + self.y.transpose(0, 2, 1)) + + +class TestMatMulOpReshapeTranspose3DXInt8(TestMatMulOpReshapeTranspose3DXFloat): + def init_data_type(self): + self.data_type_ = 'int8' + + +class TestMatMulOpReshapeTranspose3DYFloat(TestMatMulOpReshapeTranspose): + def generate_data(self): + self.x = np.random.random([2, 2, 5, 5]).astype(self.data_type_).reshape( + [2, 10, 5]).transpose([0, 2, 1]) + self.y = np.random.random([2, 2, 5, 5]).astype(self.data_type_) + self.fused_transpose_X = [] + self.fused_reshape_X = [] + self.fused_transpose_Y = [0, 2, 1] + self.fused_reshape_Y = [2, 10, 5] + self.out = np.matmul(self.x, self.y.reshape([2, 10, 5])) + + +class TestMatMulOpReshapeTranspose3DYInt8(TestMatMulOpReshapeTranspose3DYFloat): + def init_data_type(self): + self.data_type_ = 'int8' + + @skip_check_grad_ci(reason="Tests inference only optimization.") class TestMatMulOpTransposeReshapeEmptyFloat(OpTest): def init_data_type(self): -- GitLab