From ea5f44b8e4c0fa7d154483a41aef90112a772d92 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 9 Nov 2022 11:37:25 +0800 Subject: [PATCH] [cherry-pick] Squeeze2 and transpose2 fuse using oneDNN(#47712) * suqeeze2 + transpose2 fuse onednn cherrypick 2.4 * format * fix merge --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/graph_pattern_detector.cc | 19 ++++ .../framework/ir/graph_pattern_detector.h | 20 +++++ .../squeeze2_transpose2_onednn_fuse_pass.cc | 86 +++++++++++++++++++ .../squeeze2_transpose2_onednn_fuse_pass.h | 35 ++++++++ .../inference/api/paddle_pass_builder.cc | 2 + .../operators/mkldnn/transpose_mkldnn_op.cc | 3 + paddle/fluid/operators/transpose_op.cc | 6 +- paddle/fluid/platform/mkldnn_reuse.h | 52 +++++++++-- 9 files changed, 217 insertions(+), 7 deletions(-) mode change 100755 => 100644 paddle/fluid/framework/ir/graph_pattern_detector.h create mode 100644 paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a2f1e28ec75..3771df64a62 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -219,6 +219,7 @@ if(WITH_MKLDNN) pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) + pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index d9593960966..92756e46ed4 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1042,6 +1042,25 @@ PDNode *patterns::SeqConvEltAddRelu::operator()( return relu_out_var; } +PDNode *patterns::Squeeze2Transpose2::operator()() { + auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr()) + ->AsInput() + ->assert_is_op_input("squeeze2", "X"); + auto *squeeze2_op = pattern->NewNode(squeeze2_op_repr()) + ->assert_is_op("squeeze2") + ->assert_has_n_outputs(2); + auto *squeeze2_op_out = pattern->NewNode(squeeze2_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("squeeze2", "Out") + ->assert_is_op_input("transpose2", "X"); + auto *transpose2_op = + pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); + + squeeze2_op->LinksFrom({squeeze2_op_in}).LinksTo({squeeze2_op_out}); + transpose2_op->LinksFrom({squeeze2_op_out}); + return transpose2_op; +} + PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, bool with_bias, bool with_relu) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h old mode 100755 new mode 100644 index 99e0e3732cd..af09ce0b86a --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -634,6 +634,20 @@ struct FCMKLDNN : public PatternBase { PATTERN_DECL_NODE(output); }; +// Squeeze2 + Transpose2 +// Forward pass +struct Squeeze2Transpose2 : public PatternBase { + Squeeze2Transpose2(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "squeeze2_transpose2") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(squeeze2_op_in); + PATTERN_DECL_NODE(squeeze2_op); + PATTERN_DECL_NODE(squeeze2_op_out); + PATTERN_DECL_NODE(transpose2_op); +}; + // Embedding struct Embedding : public PatternBase { Embedding(PDPattern* pattern, const std::string& name_scope) @@ -2002,6 +2016,12 @@ struct AddSupportInt8 : public PatternBase { out_var->inputs.clear(); \ out_var->inputs.push_back(op); +// Set the in_var as the input of the op +#define IR_VAR_OP_LINK(in_var, op) \ + in_var->outputs.clear(); \ + in_var->outputs.push_back(op); \ + op->inputs.push_back(in_var); + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.cc new file mode 100644 index 00000000000..00c077cc84d --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void FuseSqueeze2Transpose2OneDNNPass::ApplyImpl(Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument should not be NULL.")); + + FusePassBase::Init("squeeze2_transpose2_onednn_fuse_pass", graph); + + GraphPatternDetector gpd; + patterns::Squeeze2Transpose2 squeeze2_transpose2_pattern( + gpd.mutable_pattern(), "squeeze2_transpose2_onednn_fuse_pass"); + squeeze2_transpose2_pattern(); + + int found_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + GET_IR_NODE_FROM_SUBGRAPH( + squeeze2_op_in, squeeze2_op_in, squeeze2_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + squeeze2_op, squeeze2_op, squeeze2_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + squeeze2_op_out, squeeze2_op_out, squeeze2_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_op, transpose2_op, squeeze2_transpose2_pattern); + + if (!transpose2_op->Op()->HasAttr("use_mkldnn") || + (transpose2_op->Op()->HasAttr("use_mkldnn") && + !(PADDLE_GET_CONST(bool, + transpose2_op->Op()->GetAttr("use_mkldnn"))))) { + VLOG(4) << "Only oneDNN version of transpose2 can be fused after with " + "squeeze2."; + return; + } + + std::vector squeeze2_axes = + PADDLE_GET_CONST(std::vector, squeeze2_op->Op()->GetAttr("axes")); + transpose2_op->Op()->SetAttr("fused_squeeze2_axes", squeeze2_axes); + transpose2_op->Op()->SetInput("X", {squeeze2_op_in->Name()}); + + IR_VAR_OP_LINK(squeeze2_op_in, transpose2_op); + GraphSafeRemoveNodes(g, {squeeze2_op, squeeze2_op_out}); + found_count++; + }; + + gpd(graph, handler); + AddStatis(found_count); + if ((!Has("disable_logs") || !Get("disable_logs"))) { + PrettyLogDetail("--- fused %d squeeze2 with transpose2", found_count); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(squeeze2_transpose2_onednn_fuse_pass, + paddle::framework::ir::FuseSqueeze2Transpose2OneDNNPass); +REGISTER_PASS_CAPABILITY(squeeze2_transpose2_onednn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("squeeze2", 0) + .GE("transpose2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h new file mode 100644 index 00000000000..be3871bdfe2 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FuseSqueeze2Transpose2OneDNNPass : public FusePassBase { + public: + virtual ~FuseSqueeze2Transpose2OneDNNPass() {} + + protected: + void ApplyImpl(Graph *graph) const override; +}; + +} // namespace ir +} // namespace framework + +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 28fd91e3ba3..5696923198f 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -307,6 +307,7 @@ void CpuPassStrategy::EnableMKLDNN() { passes_.insert(passes_.begin(), "mkldnn_placement_pass"); for (auto &pass : std::vector({ + "squeeze2_transpose2_onednn_fuse_pass", "depthwise_conv_mkldnn_pass", // "conv_bn_fuse_pass", // Execute BN passes again to "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order @@ -386,6 +387,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("mkldnn_placement_pass"); passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("constant_folding_pass"); + passes_.push_back("squeeze2_transpose2_onednn_fuse_pass"); passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass"); passes_.push_back("seqconv_eltadd_relu_fuse_pass"); diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index c88377fccd3..246a9a772b0 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -42,6 +42,9 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + platform::SetInMemDescWithLogicalLayoutFusesSupport( + ctx, const_cast(x), x->mem_desc()); + if (ndims == 1) { framework::TensorCopy(*x, x->place(), out); out->set_mem_desc(x->mem_desc()); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index b342f01e46f..535300c826d 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -39,11 +39,13 @@ class TransposeOp : public framework::OperatorWithKernel { size_t x_rank = x_dims.size(); size_t axis_size = axis.size(); - PADDLE_ENFORCE_EQ(x_rank, + // Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank == + // axis_size + PADDLE_ENFORCE_GE(x_rank, axis_size, platform::errors::InvalidArgument( "The input tensor's dimension " - "should be equal to the axis's size. " + "should be equal to or greater than the axis's size. " "But received input tensor's dimension is %d, " "axis's size is %d", x_rank, diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index dcd8124fc6b..028c2d1426e 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -120,12 +120,10 @@ static void SetOutMemDescWithUnsqueeze2FuseSupport( const std::vector& op_tz = out_md.dims(); std::vector unsqueezed_op_tz( op_tz.size() + fused_unsqueeze2_axes.size(), 0); - for (const auto& axis : fused_unsqueeze2_axes) { int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis; unsqueezed_op_tz[positive_axis] = 1; } - int j = 0; for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) { if (unsqueezed_op_tz[i] == 0) { @@ -143,20 +141,17 @@ static void SetOutMemDescWithReshape2FuseSupport( std::vector fused_reshape2_shape( ctx.Attr>("fused_reshape2_shape").begin(), ctx.Attr>("fused_reshape2_shape").end()); - const int out_shape_numel = out->numel(); const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(), fused_reshape2_shape.end(), 1, std::multiplies()); - for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) { if (fused_reshape2_shape[i] == -1) { fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel; break; } } - out->set_mem_desc(out_md.reshape(fused_reshape2_shape)); out->Resize(phi::make_ddim(fused_reshape2_shape)); } @@ -169,11 +164,58 @@ static void SetOutMemDescWithLogicalLayoutFusesSupport( SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md); } else if (ctx.HasAttr("fused_reshape2_shape")) { SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md); + } else if (ctx.HasAttr("fused_squeeze2_axes")) { + out->set_mem_desc(out_md); + out->Resize(phi::make_ddim(out_md.dims())); } else { out->set_mem_desc(out_md); } } +static void SetInMemDescWithSqueeze2FuseSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* in, + const dnnl::memory::desc& in_md) { + const std::vector fused_squeeze2_axes = + ctx.Attr>("fused_squeeze2_axes"); + const std::set squeeze2_axes_set(fused_squeeze2_axes.begin(), + fused_squeeze2_axes.end()); + const std::vector& x_vec_dims = in_md.dims(); + std::vector squeezed_op_tz( + x_vec_dims.size() - fused_squeeze2_axes.size(), 0); + + int j = 0; + for (size_t i = 0; i < x_vec_dims.size(); ++i) { + if (squeeze2_axes_set.count(i) || + squeeze2_axes_set.count(i - x_vec_dims.size())) { + PADDLE_ENFORCE_EQ( + x_vec_dims[i], + 1, + platform::errors::InvalidArgument( + "Squeeze2 input '%d' dim should be equal to one, but get '%d'.", + i, + x_vec_dims[i])); + continue; + } + squeezed_op_tz[j++] = x_vec_dims[i]; + } + + in->set_mem_desc(in_md.reshape(squeezed_op_tz)); + in->Resize(phi::make_ddim(squeezed_op_tz)); +} + +static void SetInMemDescWithLogicalLayoutFusesSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* in, + const dnnl::memory::desc& in_md) { + if (ctx.HasAttr("fused_squeeze2_axes")) { + SetInMemDescWithSqueeze2FuseSupport(ctx, in, in_md); + } else { + in->set_mem_desc(in_md); + in->Resize(phi::make_ddim(in_md.dims())); + } +} + template constexpr bool IsInt8() { return std::is_same::value || std::is_same::value; -- GitLab