diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index 1dae5e12a8c88f47ee946dc13ebbcb93a00e9741..11e11e7f822dbd5dfb7a50abaede581920c07bb3 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -131,5 +131,67 @@ DDim stride_numel(const DDim& ddim) { return strides; } +DDim DDim::reshape(const std::vector& shape) const { + const int64_t copy_dim_val = 0; + const DDim& in_dims = *this; + DDim out_dims; + out_dims.rank_ = shape.size(); + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == copy_dim_val) { + PADDLE_ENFORCE_LT(static_cast(i), in_dims.size(), + platform::errors::InvalidArgument( + "Index %d of shape under which the value of 0 " + "is stored, must be lower than the number of " + "old dimensions. But received shape[%d] = 0, " + "dimensions = %d, shape = [%s].", + i, in_dims.size(), in_dims)); + out_dims[i] = in_dims[i]; + } else { + out_dims[i] = shape[i]; + } + } + return out_dims; +} + +DDim DDim::transpose(const std::vector& axis) const { + const DDim& in_dims = *this; + size_t in_rank = in_dims.size(); + size_t axis_size = axis.size(); + + PADDLE_ENFORCE_EQ( + in_rank, axis_size, + platform::errors::InvalidArgument("The input dimension's size " + "should be equal to the axis's size. " + "But received dimension is %d, " + "axis's size is %d", + in_rank, axis_size)); + + std::vector count(axis_size, 0); + for (size_t i = 0; i < axis_size; i++) { + PADDLE_ENFORCE_LT(axis[i], static_cast(axis_size), + platform::errors::InvalidArgument( + "ValueError: Each element of axis must appear " + "exactly once in the range from 0 to (dims - 1), " + "where the dims is the axis's size, " + "but received axis[%d] is %d, axis_size is %d", + i, axis[i], axis_size)); + PADDLE_ENFORCE_EQ( + ++count[axis[i]], 1, + platform::errors::InvalidArgument( + "ValueError: Each element of axis should " + "be a unique value range from 0 to (dims - 1), " + "where the dims is the axis's size, " + "unique value means this axis value can appear only once. " + "But received count[axis[%d]] is %d", + i, count[axis[i]])); + } + + DDim out_dims(in_dims); + for (size_t i = 0; i < axis_size; i++) { + out_dims[i] = in_dims[axis[i]]; + } + return out_dims; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index 2f04c428e4402a8e5a4b84f424c17faac9917ac1..cbc8b0fb7cc7813a2bf1b309bc24a15d3af0f13e 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -126,6 +126,10 @@ class DDim { std::string to_str() const; + DDim reshape(const std::vector& shape) const; + + DDim transpose(const std::vector& axis) const; + private: template inline Dim& UnsafeCast() { diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 8dd9fd271f3da1cc200f7737b96f700909074e32..7290f07ca2e9ba6f295cab4e9e83dbe6a94183e4 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -97,6 +97,7 @@ if(WITH_MKLDNN) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) + pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) endif() cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector ) @@ -144,4 +145,5 @@ if (WITH_MKLDNN) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) + cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) endif () diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 4b20fd7a82f91cb5ab2f8a5e6bec1d24d0f6f5c7..dbaf631085b9af2078f46f805fe9adf58201dc37 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2147,6 +2147,43 @@ void patterns::DeleteQuantDequantOpPattern::operator()() { any_op2->LinksFrom({quant_dequant_out}); } +PDNode *patterns::MatmulTransposeReshapePattern::operator()() { + auto reshape_op = + pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2"); + auto transpose_op = + pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); + auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + + auto matmul_out = pattern->NewNode(matmul_out_repr()) + ->AsInput() + ->assert_is_op_output("matmul", "Out") + ->assert_is_op_input("transpose2", "X"); + + auto transpose_out = pattern->NewNode(transpose_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("reshape2", "X"); + + auto transpose_out_xshape = pattern->NewNode(transpose_out_xshape_repr()) + ->AsIntermediate() + ->assert_is_op_output("transpose2", "XShape"); + + auto reshape_out = pattern->NewNode(reshape_out_repr()) + ->AsOutput() + ->assert_is_op_output("reshape2"); + + auto reshape_out_xshape = pattern->NewNode(reshape_out_xshape_repr()) + ->AsIntermediate() + ->assert_is_op_output("reshape2", "XShape"); + + matmul_op->LinksTo({matmul_out}); + transpose_op->LinksTo({transpose_out_xshape}); + reshape_op->LinksTo({reshape_out_xshape}); + transpose_op->LinksFrom({matmul_out}).LinksTo({transpose_out}); + reshape_op->LinksFrom({transpose_out}).LinksTo({reshape_out}); + return reshape_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 7e077a6bdcac8c49ca92e90312cedaae0c253c15..0f4ca1aa96b8264e8772de8967fe53062004fdf9 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1210,6 +1210,24 @@ struct DeleteQuantDequantOpPattern : public PatternBase { PATTERN_DECL_NODE(any_op2); }; +// Matmul + Transpose + Reshape +struct MatmulTransposeReshapePattern : public PatternBase { + MatmulTransposeReshapePattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "matmul_transpose_reshape") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(matmul_op); + PATTERN_DECL_NODE(matmul_out); + PATTERN_DECL_NODE(transpose_op); + PATTERN_DECL_NODE(transpose_out); + PATTERN_DECL_NODE(transpose_out_xshape); + PATTERN_DECL_NODE(reshape_op); + PATTERN_DECL_NODE(reshape_out); + PATTERN_DECL_NODE(reshape_out_xshape); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..d08869685310a0b9ad718f990bad488522058692 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument should not be NULL.")); + FusePassBase::Init(name_scope_, graph); + + GraphPatternDetector gpd; + patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(), + name_scope_); + + mtrp(); + + int found_matmul_transpose_reshape_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle matmul_transpose_reshape fuse"; + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp); + GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp); + GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, mtrp); + GET_IR_NODE_FROM_SUBGRAPH(transpose_out_xshape, transpose_out_xshape, mtrp); + GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, mtrp); + GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, mtrp); + GET_IR_NODE_FROM_SUBGRAPH(reshape_out_xshape, reshape_out_xshape, mtrp); + auto reshape_shape = + boost::get>(reshape_op->Op()->GetAttr("shape")); + auto transpose_axis = + boost::get>(transpose_op->Op()->GetAttr("axis")); + + auto reshape_out_size = reshape_shape.size(); + auto transpose_out_size = transpose_axis.size(); + const std::vector supported_axis{0, 2, 1, 3}; + const bool supported_transpose_axis = std::equal( + transpose_axis.begin(), transpose_axis.end(), supported_axis.begin()); + if (transpose_out_size != 4) { + VLOG(3) << "do not perform matmul_transpose_reshape fuse: " + << "supported rank is 4, received " << transpose_out_size; + return; + } + if (!supported_transpose_axis) { + VLOG(3) << "do not perform matmul_transpose_reshape fuse: " + << "supported transpose axis for the fuse are {0, 2, 1, 3}"; + return; + } + if (reshape_out_size != 3) { + VLOG(3) << "do not perform matmul_transpose_reshape fuse: " + << "reshape_out supported rank is 3, received " + << reshape_out_size; + return; + } + OpDesc *matmul_desc = matmul_op->Op(); + matmul_desc->SetOutput("Out", {reshape_out->Name()}); + matmul_desc->SetAttr("fused_reshape_Out", reshape_shape); + matmul_desc->SetAttr("fused_transpose_Out", transpose_axis); + + GraphSafeRemoveNodes(graph, + {matmul_out, transpose_op, transpose_out, reshape_op, + transpose_out_xshape, reshape_out_xshape}); + + IR_OP_VAR_LINK(matmul_op, reshape_out); + + found_matmul_transpose_reshape_count++; + }; + + gpd(graph, handler); + AddStatis(found_matmul_transpose_reshape_count); + std::stringstream msg_ss; + msg_ss << "--- Fused " << found_matmul_transpose_reshape_count + << " MatmulTransposeReshape patterns"; + paddle::string::PrettyLogDetail(msg_ss.str().c_str()); +} +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(matmul_transpose_reshape_fuse_pass, + paddle::framework::ir::MatmulTransposeReshapeMKLDNNPass); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..77e30b353467c7baca7baaac80b56e47ffef81ef --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +class MatmulTransposeReshapeMKLDNNPass : public FusePassBase { + public: + virtual ~MatmulTransposeReshapeMKLDNNPass() {} + + protected: + void ApplyImpl(Graph* graph) const override; + const std::string name_scope_{"matmul_transpose_reshape_fuse"}; +}; +} +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..122a7f802a52972612e2879eaea29d14e5d7c561 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h" +#include + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc *prog, const std::string &type, + const std::vector &inputs, + const std::vector &outputs) { + auto *op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + op->SetInput("X", {inputs[0]}); + op->SetOutput("Out", {outputs[0]}); + if (type == "transpose2") { + op->SetAttr("axis", std::vector({0, 2, 1, 3})); + op->SetOutput("XShape", {outputs[1]}); + } + if (type == "reshape2") { + op->SetAttr("shape", std::vector({4, 5, 6})); + op->SetOutput("XShape", {outputs[1]}); + } + + if (type == "matmul") { + op->SetInput("Y", {inputs[1]}); + op->SetAttr("use_mkldnn", true); + } +} + +ProgramDesc BuildProgramDesc() { + ProgramDesc prog; + for (auto &v : std::initializer_list( + {"a1", "a2", "b", "c", "cx", "d", "dx", "e"})) { + auto *var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::SELECTED_ROWS); + } + + SetOp(&prog, "matmul", {"a1", "a2"}, {"b"}); + SetOp(&prog, "transpose2", {"b"}, {"c", "cx"}); + SetOp(&prog, "reshape2", {"c"}, {"d", "dx"}); + SetOp(&prog, "fc", {"d"}, {"e"}); + + return prog; +} + +void MainTest(const ProgramDesc &prog) { + std::unique_ptr graph(new ir::Graph(prog)); + + int original_nodes_num = graph->Nodes().size(); + + auto pass = + PassRegistry::Instance().Get("matmul_transpose_reshape_fuse_pass"); + graph.reset(pass->Apply(graph.release())); + + int current_nodes_num = graph->Nodes().size(); + EXPECT_EQ(original_nodes_num - 6, current_nodes_num); + + for (auto *node : graph->Nodes()) { + if (node->IsOp()) { + auto *op = node->Op(); + if (op->Type() == "matmul") { + EXPECT_EQ(op->GetAttrIfExists>("fused_reshape_Out"), + std::vector({4, 5, 6})); + EXPECT_EQ(op->GetAttrIfExists>("fused_transpose_Out"), + std::vector({0, 2, 1, 3})); + } + } + } +} + +TEST(MatmulTransposeReshapeFusePass, matmul_inputs) { + auto prog = BuildProgramDesc(); + MainTest(prog); +} +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(matmul_transpose_reshape_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 411de95d3c18181cecf8c4c44be8f5ca183e69f3..f04b5692fffdd57438fd47e6537422fd6ea65369 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -191,11 +191,12 @@ void CpuPassStrategy::EnableMKLDNN() { "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass", - "conv_relu_mkldnn_fuse_pass", // - "conv_leaky_relu_mkldnn_fuse_pass", // - "conv_relu6_mkldnn_fuse_pass", // - "conv_swish_mkldnn_fuse_pass", // - "scale_matmul_fuse_pass", // + "conv_relu_mkldnn_fuse_pass", // + "conv_leaky_relu_mkldnn_fuse_pass", // + "conv_relu6_mkldnn_fuse_pass", // + "conv_swish_mkldnn_fuse_pass", // + "scale_matmul_fuse_pass", // + "matmul_transpose_reshape_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", "mkldnn_inplace_pass", // This pass should be activated after diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 14f7f00a5fd026ccb5d7110eefb6b78937fcf863..a91a4d55a1d384213d8a66c7844f64650c658897 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -407,7 +407,45 @@ class MatMulOp : public framework::OperatorWithKernel { if (dim_out.empty()) { dim_out = {1}; } - context->SetOutputDim("Out", framework::make_ddim(dim_out)); + + framework::DDim ddim_out = framework::make_ddim(dim_out); + +#ifdef PADDLE_WITH_MKLDNN + // if mkldnn matmul+transpose+reshape fuse activated + auto reshape_out = + context->Attrs().Get>("fused_reshape_Out"); + auto transpose_out = + context->Attrs().Get>("fused_transpose_Out"); + + if (!reshape_out.empty() && !transpose_out.empty()) { + auto reshape_out_size = reshape_out.size(); + auto transpose_out_size = transpose_out.size(); + PADDLE_ENFORCE_EQ(transpose_out_size, 4, + platform::errors::InvalidArgument( + "transpose_out supported rank is 4, " + "received %d", + transpose_out_size)); + const std::vector supported_axis{0, 2, 1, 3}; + const bool supported_transpose_axis = std::equal( + transpose_out.begin(), transpose_out.end(), supported_axis.begin()); + PADDLE_ENFORCE_EQ( + supported_transpose_axis, true, + platform::errors::InvalidArgument( + "supported transpose axis for the fuse are {0, 2, 1, 3}")); + PADDLE_ENFORCE_EQ( + reshape_out_size, 3, + platform::errors::InvalidArgument("reshape_out supported rank is 3, " + "received %d", + reshape_out_size)); + framework::DDim shape_out = + ddim_out.transpose(transpose_out).reshape(reshape_out); + context->SetOutputDim("Out", shape_out); + } else { + context->SetOutputDim("Out", ddim_out); + } +#else + context->SetOutputDim("Out", ddim_out); +#endif context->ShareLoD("X", /*->*/ "Out"); } @@ -446,6 +484,16 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { "use_mkldnn", "(bool, default false) Indicates if MKL-DNN kernel will be used") .SetDefault(false); + AddAttr>( + "fused_reshape_Out", + R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, " + "it's a shape atribute of fused reshape for `Out` output.)DOC") + .SetDefault({}); + AddAttr>( + "fused_transpose_Out", + R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, " + "it's a axis atribute of fused transpose for `Out` output.)DOC") + .SetDefault({}); /* int8 parameters */ AddAttr("use_quantizer", "(bool, default false) " @@ -466,6 +514,7 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) Force INT8 kernel output FP32, only " "used in MKL-DNN INT8") .SetDefault(false); + #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) AddAttr("head_number", "The number of heads of the matrix") .SetDefault(1); diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 338a5206356e336a7fda002d33077400e8e306de..a9dc515a0f6dd853bf042bce687d66681f57e7e6 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -31,6 +31,11 @@ using platform::MKLDNNDeviceContext; using framework::ExecutionContext; using Tensor = framework::Tensor; +template +constexpr bool IsInt8() { + return std::is_same::value || std::is_same::value; +} + // Get row matrix shape from a vector shape. If the rank of x_dim > 1, the // original x_dim is returned. static framework::DDim RowMatrixDimsFromVector(const framework::DDim& x_dim) { @@ -64,7 +69,8 @@ class MatMulFactory { private: struct MatMulDims { - const memory::dim BS, M, N, K; + const memory::dims x_dims, y_dims, out_dims, x_strides, y_strides, + out_strides; }; void SetDNNLEngine(const ExecutionContext& ctx) { @@ -80,6 +86,19 @@ class MatMulFactory { return dnnl::memory(md, engine_, to_void_cast(data)); } + bool IsOutputFused(const ExecutionContext& ctx) const { + auto& fused_reshape_Out = ctx.Attr>("fused_reshape_Out"); + auto& fused_transpose_Out = + ctx.Attr>("fused_transpose_Out"); + return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); + } + + void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx, + const memory::dim N, memory::dim b, + memory::dims* out_strides) const { + if (!IsInt8() && IsOutputFused(ctx)) *out_strides = {N, b * N, 1}; + } + MatMulDims GetMatmulDims(const ExecutionContext& ctx) { auto mat_dim_x = math::CreateMatrixDescriptor( RowMatrixDimsFromVector(ctx.Input("X")->dims()), 0, @@ -100,34 +119,45 @@ class MatMulFactory { const memory::dim M = mat_dim_x.height_; const memory::dim N = mat_dim_y.width_; const memory::dim K = mat_dim_x.width_; - return {BS, M, N, K}; + + batch_size_ = 1; + auto b = BS; + if (BS > 1 && IsOutputFused(ctx)) { + batch_size_ = ctx.Input("X")->dims()[0]; + b = BS / batch_size_; + } + memory::dims x_dims = {b, M, K}; + memory::dims y_dims = {b, K, N}; + memory::dims out_dims = {b, M, N}; + + size_t x_size = b * M * K * sizeof(XT); + size_t y_size = b * K * N * sizeof(YT); + size_t out_size = b * M * N * sizeof(OT); + offsets_ = {x_size, y_size, out_size}; + + // Translate transA and transB + memory::dims strides_x = !ctx.Attr("transpose_X") + ? memory::dims{M * K, K, 1} + : memory::dims{M * K, 1, M}; + memory::dims strides_y = !ctx.Attr("transpose_Y") + ? memory::dims{N * K, N, 1} + : memory::dims{N * K, 1, K}; + memory::dims out_strides = memory::dims{M * N, N, 1}; + + CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides); + + return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides}; } void CreateMemories(const ExecutionContext& ctx) { auto matmul_dims = GetMatmulDims(ctx); - auto BS = matmul_dims.BS; - auto M = matmul_dims.M; - auto N = matmul_dims.N; - auto K = matmul_dims.K; - bool x_trans = ctx.Attr("transpose_X"); - bool y_trans = ctx.Attr("transpose_Y"); - - typedef memory::dims dims; - dims x_dims = {BS, M, K}; - dims y_dims = {BS, K, N}; - dims out_dims = {BS, M, N}; - // Translate transA and transB - dims x_strides = !x_trans ? dims{M * K, K, 1} : dims{M * K, 1, M}; - dims y_strides = !y_trans ? dims{N * K, N, 1} : dims{N * K, 1, K}; - dims out_strides = {M * N, N, 1}; - - x_mem_ = - CreateMemory(x_dims, x_strides, ctx.Input("X")->data()); - y_mem_ = - CreateMemory(y_dims, y_strides, ctx.Input("Y")->data()); + x_mem_ = CreateMemory(matmul_dims.x_dims, matmul_dims.x_strides, + ctx.Input("X")->data()); + y_mem_ = CreateMemory(matmul_dims.y_dims, matmul_dims.y_strides, + ctx.Input("Y")->data()); out_mem_ = CreateMemory( - out_dims, out_strides, + matmul_dims.out_dims, matmul_dims.out_strides, ctx.Output("Out")->mutable_data(ctx.GetPlace())); } @@ -156,11 +186,25 @@ class MatMulFactory { void Execute() { dnnl::stream stream(engine_); - matmul_prim_.execute(stream, { - {MKLDNN_ARG_SRC, x_mem_}, - {MKLDNN_ARG_WEIGHTS, y_mem_}, - {MKLDNN_ARG_DST, out_mem_}, - }); + + auto offsets = offsets_; + unsigned bs = batch_size_; + void* x_ptr = x_mem_.get_data_handle(); + void* y_ptr = y_mem_.get_data_handle(); + void* out_ptr = out_mem_.get_data_handle(); + for (unsigned i = 0; i < bs; i++) { + x_mem_.set_data_handle(x_ptr); + y_mem_.set_data_handle(y_ptr); + out_mem_.set_data_handle(out_ptr); + matmul_prim_.execute(stream, { + {MKLDNN_ARG_SRC, x_mem_}, + {MKLDNN_ARG_WEIGHTS, y_mem_}, + {MKLDNN_ARG_DST, out_mem_}, + }); + x_ptr = static_cast(x_ptr) + offsets.x_offset; + y_ptr = static_cast(y_ptr) + offsets.y_offset; + out_ptr = static_cast(out_ptr) + offsets.out_offset; + } stream.wait(); } @@ -188,11 +232,19 @@ class MatMulFactory { void SetInitialized() { initialized_ = true; } private: + struct memory_offsets { + size_t x_offset; + size_t y_offset; + size_t out_offset; + }; + dnnl::engine engine_; dnnl::memory x_mem_; dnnl::memory y_mem_; dnnl::memory out_mem_; dnnl::matmul matmul_prim_; + memory_offsets offsets_; + unsigned batch_size_; bool initialized_ = false; }; @@ -217,10 +269,6 @@ static std::shared_ptr> GetPrimitiveFactory( return factory; } -template -constexpr bool IsInt8() { - return std::is_same::value || std::is_same::value; -} // Choose appropriate primitive factory implementation based on inferred // output type (uint8, int8 or float). template diff --git a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py index 8f2170838654f37732efe894d91cf1400c298a51..a0997ac2cfc516e06903e32dfce011ee41037077 100644 --- a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py @@ -371,6 +371,7 @@ class Qat2Int8MkldnnPass(object): ['use_gpu', 'use_fc_padding'], [False, False]) graph = self._apply_pass(graph, 'fc_mkldnn_pass') + graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass') return graph def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..b1ad5804ebc2c6b696dc455d7ca64a2de9e3cd9b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py @@ -0,0 +1,110 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid as fluid +from inference_pass_test import InferencePassTest + + +class TestMKLDNNMatmulFuseOp(InferencePassTest): + def init_data(self): + self.bs = 8 + self.d_type = np.float32 + self.shape_x = [12, 128, 128] + self.shape_y = [12, 128, 64] + self.enable_mkldnn = True + + def make_network(self): + with fluid.program_guard(self.main_program, self.startup_program): + x = fluid.data( + name='x', shape=[-1] + self.shape_x, dtype=self.d_type) + y = fluid.data( + name='y', shape=[-1] + self.shape_y, dtype=self.d_type) + out = fluid.layers.matmul(x, y) + out = fluid.layers.transpose(out, perm=[0, 2, 1, 3]) + out = fluid.layers.reshape( + out, [0, 0, self.shape_y[0] * self.shape_y[2]]) + out = fluid.layers.fc(out, size=1) + return out + + def setUp(self): + self.init_data() + out = self.make_network() + self.set_feeds(out) + + def set_feeds(self, out): + self.feeds = { + "x": np.random.random([self.bs] + self.shape_x).astype(self.d_type), + "y": np.random.random([self.bs] + self.shape_y).astype(self.d_type) + } + self.fetch_list = [out] + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + +class TestMKLDNNMatmulOtherDimsFuseOp(TestMKLDNNMatmulFuseOp): + def init_data(self): + self.bs = 8 + self.d_type = np.float32 + self.shape_x = [12, 1, 1] + self.shape_y = [12, 1, 64] + self.enable_mkldnn = True + + +class TestMKLDNNMatmulOpNotFusedWrongTransposeAxis(TestMKLDNNMatmulFuseOp): + def make_network(self): + with fluid.program_guard(self.main_program, self.startup_program): + x = fluid.data( + name='x', shape=[-1] + self.shape_x, dtype=self.d_type) + y = fluid.data( + name='y', shape=[-1] + self.shape_y, dtype=self.d_type) + out = fluid.layers.matmul(x, y) + out = fluid.layers.transpose(out, perm=[0, 1, 2, 3]) + out = fluid.layers.reshape(out, [0, 0, 0, 0]) + out = fluid.layers.fc(out, size=1) + return out + + +class TestMKLDNNMatmulOpNotFusedBreakPattern(TestMKLDNNMatmulFuseOp): + def init_data(self): + self.bs = 7 + self.d_type = np.float32 + self.shape_x = [12, 128, 128] + self.shape_y = [12, 128, 64] + self.enable_mkldnn = True + + def make_network(self): + with fluid.program_guard(self.main_program, self.startup_program): + x = fluid.data( + name='x', shape=[-1] + self.shape_x, dtype=self.d_type) + y = fluid.data( + name='y', shape=[-1] + self.shape_y, dtype=self.d_type) + out = fluid.layers.matmul(x, y) + out = fluid.layers.transpose(out, perm=[0, 2, 1, 3]) + out = fluid.layers.transpose( + out, perm=[0, 1, 2, 3]) # breaks pattern + out = fluid.layers.reshape( + out, [0, 0, self.shape_y[0] * self.shape_y[2]]) + out = fluid.layers.fc(out, size=1) + return out + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py index de547a6a19052ff9c0d50e7115acebc6833c2ca5..b6b5f0f134b9c9295c60e692c9b701adbd343470 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py @@ -161,5 +161,134 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp): self.attrs = {'force_fp32_output': True} +@skip_check_grad_ci(reason="Tests inference only optimization.") +class TestMatMulOpTransposeReshapeEmptyFloat(OpTest): + def init_data_type(self): + self.data_type_ = np.float32 + + def generate_data(self): + self.bs = 1 + self.x = np.random.random([self.bs, 128, 128]).astype(self.data_type_) + self.y = np.random.random([self.bs, 128, 64]).astype(self.data_type_) + + def init_params_and_out(self): + self.transpose_out = [] + self.reshape_out = [] + self.out = np.matmul(self.x, self.y) + + def setUp(self): + os.environ["DNNL_MAX_CPU_ISA"] = "AVX" + self.op_type = "matmul" + self._cpu_only = True + self.use_mkldnn = True + self.init_data_type() + self.generate_data() + self.init_params_and_out() + + self.inputs = {'X': self.x, 'Y': self.y} + self.attrs = {'use_mkldnn': self.use_mkldnn} + + if len(self.reshape_out) > 0: + self.attrs['fused_reshape_Out'] = self.reshape_out + if len(self.transpose_out) > 0: + self.attrs['fused_transpose_Out'] = self.transpose_out + + self.inputs = {'X': self.x, 'Y': self.y} + self.outputs = {'Out': self.out} + + def test_check_output(self): + self.check_output() + + def check_raise_error(self, msg): + try: + self.check_output() + except Exception as e: + if msg in str(e): + raise AttributeError + else: + print(e) + + +class TestMatMulOpTransposeReshapeIntEmptyInt( + TestMatMulOpTransposeReshapeEmptyFloat): + def init_data_type(self): + self.data_type_ = np.int8 + + +class TestMatMulOpTransposeReshapeBasicFloat( + TestMatMulOpTransposeReshapeEmptyFloat): + def generate_data(self): + self.bs = 8 + self.x = np.random.random( + [self.bs, 12, 128, 128]).astype(self.data_type_) + self.y = np.random.random( + [self.bs, 12, 128, 64]).astype(self.data_type_) + + def init_params_and_out(self): + self.transpose_out = [0, 2, 1, 3] + self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]] + self.out = np.matmul(self.x, self.y).transpose([0, 2, 1, 3]).reshape( + [self.bs, -1, self.x.shape[1] * self.y.shape[-1]]) + + +class TestMatMulOpTransposeReshapeBasicInt( + TestMatMulOpTransposeReshapeBasicFloat): + def init_data_type(self): + self.data_type_ = np.int8 + + +class TestMatMulOpTransposeReshapeOtherDimFloat( + TestMatMulOpTransposeReshapeBasicFloat): + def generate_data(self): + self.bs = 11 + self.x = np.random.random([self.bs, 12, 14, 18]).astype(self.data_type_) + self.y = np.random.random([self.bs, 12, 18, 13]).astype(self.data_type_) + + +class TestMatMulOpTransposeReshapeOtherDimInt( + TestMatMulOpTransposeReshapeOtherDimFloat): + def init_data_type(self): + self.data_type_ = np.int8 + + +class TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException( + TestMatMulOpTransposeReshapeBasicFloat): + def init_params_and_out(self): + self.transpose_out = [0, 1, 2, 3] + self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]] + self.out = np.matmul(self.x, self.y) + + def test_check_output(self): + self.assertRaises(AttributeError, self.check_raise_error, + 'InvalidArgumentError: supported transpose axis ' + 'for the fuse are {0, 2, 1, 3}') + + +class TestMatMulOpTransposeReshapeTransposeRankNotSupportedException( + TestMatMulOpTransposeReshapeBasicFloat): + def init_params_and_out(self): + self.transpose_out = [0, 2, 1] + self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]] + self.out = np.matmul(self.x, self.y) + + def test_check_output(self): + self.assertRaises( + AttributeError, self.check_raise_error, + 'InvalidArgumentError: transpose_out supported rank is 4') + + +class TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException( + TestMatMulOpTransposeReshapeBasicFloat): + def init_params_and_out(self): + self.transpose_out = [0, 2, 1, 3] + self.reshape_out = [0, 0] + self.out = np.matmul(self.x, self.y) + + def test_check_output(self): + self.assertRaises( + AttributeError, self.check_raise_error, + 'InvalidArgumentError: reshape_out supported rank is 3') + + if __name__ == "__main__": unittest.main()