diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index e4954aadf3ec179b828d4da9e6c4ed27068d5b7d..8f78248f5981eb2f6adb5dd6194f87d3e958d743 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -177,6 +177,7 @@ if(WITH_MKLDNN) pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) + pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 352469b0afa4db83de206c1f8670a99dbc561f03..d03c5647e6a220ea49032c6806c5375db37a9f17 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -958,6 +958,25 @@ PDNode *patterns::OperatorActivation::operator()( return activation_out; } +PDNode *patterns::Squeeze2Transpose2::operator()() { + auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr()) + ->AsInput() + ->assert_is_op_input("squeeze2", "X"); + auto *squeeze2_op = pattern->NewNode(squeeze2_op_repr()) + ->assert_is_op("squeeze2") + ->assert_has_n_outputs(2); + auto *squeeze2_op_out = pattern->NewNode(squeeze2_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("squeeze2", "Out") + ->assert_is_op_input("transpose2", "X"); + auto *transpose2_op = + pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); + + squeeze2_op->LinksFrom({squeeze2_op_in}).LinksTo({squeeze2_op_out}); + transpose2_op->LinksFrom({squeeze2_op_out}); + return transpose2_op; +} + PDNode *patterns::OperatorUnsqueeze2::operator()( const std::string &operator_type, const int num_of_operator_outs) { auto *preceding_op = pattern->NewNode(preceding_op_repr()) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 50372886b72591f15193e873ac0ef324353c4ca6..a4ee0a09831891ea54be0aaec8c55099a0fb51b8 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -539,6 +539,18 @@ struct OperatorActivation : public PatternBase { PATTERN_DECL_NODE(activation_out); }; +struct Squeeze2Transpose2 : public PatternBase { + Squeeze2Transpose2(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "squeeze2_transpose2") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(squeeze2_op_in); + PATTERN_DECL_NODE(squeeze2_op); + PATTERN_DECL_NODE(squeeze2_op_out); + PATTERN_DECL_NODE(transpose2_op); +}; + struct OperatorUnsqueeze2 : public PatternBase { OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "operator_unsqueeze2") {} @@ -2030,6 +2042,12 @@ struct AddSupportInt8 : public PatternBase { out_var->inputs.clear(); \ out_var->inputs.push_back(op); +// Set the in_var as the input of the op +#define IR_VAR_OP_LINK(in_var, op) \ + in_var->outputs.clear(); \ + in_var->outputs.push_back(op); \ + op->inputs.push_back(in_var); + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a3b420073ea825c9def560d090609cb54c457aa --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void FuseSqueeze2Transpose2OneDNNPass::ApplyImpl(Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, + platform::errors::InvalidArgument( + "Input graph pointer argument should not be nullptr.")); + + FusePassBase::Init("squeeze2_transpose2_onednn_fuse_pass", graph); + + GraphPatternDetector gpd; + patterns::Squeeze2Transpose2 squeeze2_transpose2_pattern( + gpd.mutable_pattern(), "squeeze2_transpose2_onednn_fuse_pass"); + squeeze2_transpose2_pattern(); + + int found_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + GET_IR_NODE_FROM_SUBGRAPH( + squeeze2_op_in, squeeze2_op_in, squeeze2_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + squeeze2_op, squeeze2_op, squeeze2_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + squeeze2_op_out, squeeze2_op_out, squeeze2_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_op, transpose2_op, squeeze2_transpose2_pattern); + + if (!transpose2_op->Op()->HasAttr("use_mkldnn") || + (transpose2_op->Op()->HasAttr("use_mkldnn") && + !(PADDLE_GET_CONST(bool, + transpose2_op->Op()->GetAttr("use_mkldnn"))))) { + VLOG(4) << "Only oneDNN version of transpose2 can be fused after with " + "squeeze2."; + return; + } + + std::vector squeeze2_axes = + PADDLE_GET_CONST(std::vector, squeeze2_op->Op()->GetAttr("axes")); + transpose2_op->Op()->SetAttr("fused_squeeze2_axes", squeeze2_axes); + transpose2_op->Op()->SetInput("X", {squeeze2_op_in->Name()}); + + IR_VAR_OP_LINK(squeeze2_op_in, transpose2_op); + GraphSafeRemoveNodes(g, {squeeze2_op, squeeze2_op_out}); + found_count++; + }; + + gpd(graph, handler); + AddStatis(found_count); + if ((!Has("disable_logs") || !Get("disable_logs"))) { + PrettyLogDetail("--- fused %d squeeze2 with transpose2", found_count); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(squeeze2_transpose2_onednn_fuse_pass, + paddle::framework::ir::FuseSqueeze2Transpose2OneDNNPass); +REGISTER_PASS_CAPABILITY(squeeze2_transpose2_onednn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("squeeze2", 0) + .GE("transpose2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..be3871bdfe2fbccc4eaa83fed37512d4515c02ec --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FuseSqueeze2Transpose2OneDNNPass : public FusePassBase { + public: + virtual ~FuseSqueeze2Transpose2OneDNNPass() {} + + protected: + void ApplyImpl(Graph *graph) const override; +}; + +} // namespace ir +} // namespace framework + +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index a602e3edc282e5cdc502cdd512573051e97938a8..ae7b81e9c305b54cbe6f12643509b994e6b12477 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -308,6 +308,7 @@ void CpuPassStrategy::EnableMKLDNN() { passes_.insert(passes_.begin(), "mkldnn_placement_pass"); for (auto &pass : std::vector({ + "squeeze2_transpose2_onednn_fuse_pass", "depthwise_conv_mkldnn_pass", // "conv_bn_fuse_pass", // Execute BN passes again to "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order @@ -387,6 +388,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("mkldnn_placement_pass"); passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("constant_folding_pass"); + passes_.push_back("squeeze2_transpose2_onednn_fuse_pass"); passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass"); passes_.push_back("seqconv_eltadd_relu_fuse_pass"); diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index 5f02d087011f0bcd1b8455294873fcf9797f2993..d84cfe6de41d35d6b0fc71b9bea25b8b8f74d125 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -37,11 +37,14 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { const auto& dnnl_engine = dev_ctx.GetEngine(); std::vector transpose_axis = ctx.Attr>("axis"); int ndims = transpose_axis.size(); - auto* x = ctx.Input("X"); + const phi::DenseTensor* x = ctx.Input("X"); auto* out = ctx.Output("Out"); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + platform::SetInMemDescWithLogicalLayoutFusesSupport( + ctx, const_cast(x), x->mem_desc()); + if (ndims == 1) { framework::TensorCopy(*x, x->place(), out); out->set_mem_desc(x->mem_desc()); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 50417f3cfe743d6d5ee33fdab0382ab8ea7b8303..179339fae6b6ca92dc526934529e61196df0bf05 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -34,14 +34,17 @@ class TransposeOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Transpose"); auto x_dims = ctx->GetInputDim("X"); std::vector axis = ctx->Attrs().Get>("axis"); + size_t x_rank = x_dims.size(); size_t axis_size = axis.size(); - PADDLE_ENFORCE_EQ(x_rank, + // Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank == + // axis_size + PADDLE_ENFORCE_GE(x_rank, axis_size, platform::errors::InvalidArgument( "The input tensor's dimension " - "should be equal to the axis's size. " + "should be equal to or greater than the axis's size. " "But received input tensor's dimension is %d, " "axis's size is %d", x_rank, diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index baa3c2aea588e7dcb84e586675b3b3885fd3a08f..b189196429beaeb5246cac6d1da9a32059793b13 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -167,11 +167,58 @@ static void SetOutMemDescWithLogicalLayoutFusesSupport( SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md); } else if (ctx.HasAttr("fused_reshape2_shape")) { SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md); + } else if (ctx.HasAttr("fused_squeeze2_axes")) { + out->set_mem_desc(out_md); + out->Resize(phi::make_ddim(out_md.dims())); } else { out->set_mem_desc(out_md); } } +static void SetInMemDescWithSqueeze2FuseSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* in, + const dnnl::memory::desc& in_md) { + const std::vector fused_squeeze2_axes = + ctx.Attr>("fused_squeeze2_axes"); + const std::set squeeze2_axes_set(fused_squeeze2_axes.begin(), + fused_squeeze2_axes.end()); + const std::vector& x_vec_dims = in_md.dims(); + std::vector squeezed_op_tz( + x_vec_dims.size() - fused_squeeze2_axes.size(), 0); + + int j = 0; + for (size_t i = 0; i < x_vec_dims.size(); ++i) { + if (squeeze2_axes_set.count(i) || + squeeze2_axes_set.count(i - x_vec_dims.size())) { + PADDLE_ENFORCE_EQ( + x_vec_dims[i], + 1, + platform::errors::InvalidArgument( + "Squeeze2 input dim %d should be equal to one, but get %d.", + i, + x_vec_dims[i])); + continue; + } + squeezed_op_tz[j++] = x_vec_dims[i]; + } + + in->set_mem_desc(in_md.reshape(squeezed_op_tz)); + in->Resize(phi::make_ddim(squeezed_op_tz)); +} + +static void SetInMemDescWithLogicalLayoutFusesSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* in, + const dnnl::memory::desc& in_md) { + if (ctx.HasAttr("fused_squeeze2_axes")) { + SetInMemDescWithSqueeze2FuseSupport(ctx, in, in_md); + } else { + in->set_mem_desc(in_md); + in->Resize(phi::make_ddim(in_md.dims())); + } +} + template constexpr bool IsInt8() { return std::is_same::value || std::is_same::value;