diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a2f1e28ec753eb9e256cd5f53be48e8693f74d60..3771df64a6261c9ef371cd854a10942d7bcfd991 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -219,6 +219,7 @@ if(WITH_MKLDNN) pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) + pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index d959396096603cff31df60f53ff5d2b69351cbef..92756e46ed48426af0f78e19cc1aa75661c5dc55 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1042,6 +1042,25 @@ PDNode *patterns::SeqConvEltAddRelu::operator()( return relu_out_var; } +PDNode *patterns::Squeeze2Transpose2::operator()() { + auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr()) + ->AsInput() + ->assert_is_op_input("squeeze2", "X"); + auto *squeeze2_op = pattern->NewNode(squeeze2_op_repr()) + ->assert_is_op("squeeze2") + ->assert_has_n_outputs(2); + auto *squeeze2_op_out = pattern->NewNode(squeeze2_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("squeeze2", "Out") + ->assert_is_op_input("transpose2", "X"); + auto *transpose2_op = + pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); + + squeeze2_op->LinksFrom({squeeze2_op_in}).LinksTo({squeeze2_op_out}); + transpose2_op->LinksFrom({squeeze2_op_out}); + return transpose2_op; +} + PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, bool with_bias, bool with_relu) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h old mode 100755 new mode 100644 index 99e0e3732cdf9769e6556c4a585f926759d23fdc..af09ce0b86a510c29b2598a683897cfd973dc629 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -634,6 +634,20 @@ struct FCMKLDNN : public PatternBase { PATTERN_DECL_NODE(output); }; +// Squeeze2 + Transpose2 +// Forward pass +struct Squeeze2Transpose2 : public PatternBase { + Squeeze2Transpose2(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "squeeze2_transpose2") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(squeeze2_op_in); + PATTERN_DECL_NODE(squeeze2_op); + PATTERN_DECL_NODE(squeeze2_op_out); + PATTERN_DECL_NODE(transpose2_op); +}; + // Embedding struct Embedding : public PatternBase { Embedding(PDPattern* pattern, const std::string& name_scope) @@ -2002,6 +2016,12 @@ struct AddSupportInt8 : public PatternBase { out_var->inputs.clear(); \ out_var->inputs.push_back(op); +// Set the in_var as the input of the op +#define IR_VAR_OP_LINK(in_var, op) \ + in_var->outputs.clear(); \ + in_var->outputs.push_back(op); \ + op->inputs.push_back(in_var); + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..00c077cc84d504a94870899baa939e2d35d91727 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void FuseSqueeze2Transpose2OneDNNPass::ApplyImpl(Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument should not be NULL.")); + + FusePassBase::Init("squeeze2_transpose2_onednn_fuse_pass", graph); + + GraphPatternDetector gpd; + patterns::Squeeze2Transpose2 squeeze2_transpose2_pattern( + gpd.mutable_pattern(), "squeeze2_transpose2_onednn_fuse_pass"); + squeeze2_transpose2_pattern(); + + int found_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + GET_IR_NODE_FROM_SUBGRAPH( + squeeze2_op_in, squeeze2_op_in, squeeze2_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + squeeze2_op, squeeze2_op, squeeze2_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + squeeze2_op_out, squeeze2_op_out, squeeze2_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_op, transpose2_op, squeeze2_transpose2_pattern); + + if (!transpose2_op->Op()->HasAttr("use_mkldnn") || + (transpose2_op->Op()->HasAttr("use_mkldnn") && + !(PADDLE_GET_CONST(bool, + transpose2_op->Op()->GetAttr("use_mkldnn"))))) { + VLOG(4) << "Only oneDNN version of transpose2 can be fused after with " + "squeeze2."; + return; + } + + std::vector squeeze2_axes = + PADDLE_GET_CONST(std::vector, squeeze2_op->Op()->GetAttr("axes")); + transpose2_op->Op()->SetAttr("fused_squeeze2_axes", squeeze2_axes); + transpose2_op->Op()->SetInput("X", {squeeze2_op_in->Name()}); + + IR_VAR_OP_LINK(squeeze2_op_in, transpose2_op); + GraphSafeRemoveNodes(g, {squeeze2_op, squeeze2_op_out}); + found_count++; + }; + + gpd(graph, handler); + AddStatis(found_count); + if ((!Has("disable_logs") || !Get("disable_logs"))) { + PrettyLogDetail("--- fused %d squeeze2 with transpose2", found_count); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(squeeze2_transpose2_onednn_fuse_pass, + paddle::framework::ir::FuseSqueeze2Transpose2OneDNNPass); +REGISTER_PASS_CAPABILITY(squeeze2_transpose2_onednn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("squeeze2", 0) + .GE("transpose2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..be3871bdfe2fbccc4eaa83fed37512d4515c02ec --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FuseSqueeze2Transpose2OneDNNPass : public FusePassBase { + public: + virtual ~FuseSqueeze2Transpose2OneDNNPass() {} + + protected: + void ApplyImpl(Graph *graph) const override; +}; + +} // namespace ir +} // namespace framework + +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 28fd91e3ba3fdc509016b68c726aa5f9e7829020..5696923198f6b4b0b632fa67540da8012d885d7c 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -307,6 +307,7 @@ void CpuPassStrategy::EnableMKLDNN() { passes_.insert(passes_.begin(), "mkldnn_placement_pass"); for (auto &pass : std::vector({ + "squeeze2_transpose2_onednn_fuse_pass", "depthwise_conv_mkldnn_pass", // "conv_bn_fuse_pass", // Execute BN passes again to "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order @@ -386,6 +387,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("mkldnn_placement_pass"); passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("constant_folding_pass"); + passes_.push_back("squeeze2_transpose2_onednn_fuse_pass"); passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass"); passes_.push_back("seqconv_eltadd_relu_fuse_pass"); diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index c88377fccd37536ed9b84ce8fca67b4aace59211..246a9a772b08693a0a6e9accafe0113620ce758d 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -42,6 +42,9 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + platform::SetInMemDescWithLogicalLayoutFusesSupport( + ctx, const_cast(x), x->mem_desc()); + if (ndims == 1) { framework::TensorCopy(*x, x->place(), out); out->set_mem_desc(x->mem_desc()); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index b342f01e46ff7661d4d76483b59fa0cb05d1fa58..535300c826d24d0a36160d13016cccdef3ef5876 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -39,11 +39,13 @@ class TransposeOp : public framework::OperatorWithKernel { size_t x_rank = x_dims.size(); size_t axis_size = axis.size(); - PADDLE_ENFORCE_EQ(x_rank, + // Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank == + // axis_size + PADDLE_ENFORCE_GE(x_rank, axis_size, platform::errors::InvalidArgument( "The input tensor's dimension " - "should be equal to the axis's size. " + "should be equal to or greater than the axis's size. " "But received input tensor's dimension is %d, " "axis's size is %d", x_rank, diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index dcd8124fc6b8e8b641d04f97dbdafdc9ccd4d612..028c2d1426e0816d56b70b9b13b4849d7d69c1bc 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -120,12 +120,10 @@ static void SetOutMemDescWithUnsqueeze2FuseSupport( const std::vector& op_tz = out_md.dims(); std::vector unsqueezed_op_tz( op_tz.size() + fused_unsqueeze2_axes.size(), 0); - for (const auto& axis : fused_unsqueeze2_axes) { int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis; unsqueezed_op_tz[positive_axis] = 1; } - int j = 0; for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) { if (unsqueezed_op_tz[i] == 0) { @@ -143,20 +141,17 @@ static void SetOutMemDescWithReshape2FuseSupport( std::vector fused_reshape2_shape( ctx.Attr>("fused_reshape2_shape").begin(), ctx.Attr>("fused_reshape2_shape").end()); - const int out_shape_numel = out->numel(); const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(), fused_reshape2_shape.end(), 1, std::multiplies()); - for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) { if (fused_reshape2_shape[i] == -1) { fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel; break; } } - out->set_mem_desc(out_md.reshape(fused_reshape2_shape)); out->Resize(phi::make_ddim(fused_reshape2_shape)); } @@ -169,11 +164,58 @@ static void SetOutMemDescWithLogicalLayoutFusesSupport( SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md); } else if (ctx.HasAttr("fused_reshape2_shape")) { SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md); + } else if (ctx.HasAttr("fused_squeeze2_axes")) { + out->set_mem_desc(out_md); + out->Resize(phi::make_ddim(out_md.dims())); } else { out->set_mem_desc(out_md); } } +static void SetInMemDescWithSqueeze2FuseSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* in, + const dnnl::memory::desc& in_md) { + const std::vector fused_squeeze2_axes = + ctx.Attr>("fused_squeeze2_axes"); + const std::set squeeze2_axes_set(fused_squeeze2_axes.begin(), + fused_squeeze2_axes.end()); + const std::vector& x_vec_dims = in_md.dims(); + std::vector squeezed_op_tz( + x_vec_dims.size() - fused_squeeze2_axes.size(), 0); + + int j = 0; + for (size_t i = 0; i < x_vec_dims.size(); ++i) { + if (squeeze2_axes_set.count(i) || + squeeze2_axes_set.count(i - x_vec_dims.size())) { + PADDLE_ENFORCE_EQ( + x_vec_dims[i], + 1, + platform::errors::InvalidArgument( + "Squeeze2 input '%d' dim should be equal to one, but get '%d'.", + i, + x_vec_dims[i])); + continue; + } + squeezed_op_tz[j++] = x_vec_dims[i]; + } + + in->set_mem_desc(in_md.reshape(squeezed_op_tz)); + in->Resize(phi::make_ddim(squeezed_op_tz)); +} + +static void SetInMemDescWithLogicalLayoutFusesSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* in, + const dnnl::memory::desc& in_md) { + if (ctx.HasAttr("fused_squeeze2_axes")) { + SetInMemDescWithSqueeze2FuseSupport(ctx, in, in_md); + } else { + in->set_mem_desc(in_md); + in->Resize(phi::make_ddim(in_md.dims())); + } +} + template constexpr bool IsInt8() { return std::is_same::value || std::is_same::value;