diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 279ab07ff31b0ddb7138d845424c3ecd6ad37e0b..e4954aadf3ec179b828d4da9e6c4ed27068d5b7d 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -177,6 +177,8 @@ if(WITH_MKLDNN) pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) + pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn) + pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 55c7787012be16db6fd1f0f00c8fe459e38862e9..352469b0afa4db83de206c1f8670a99dbc561f03 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -958,6 +958,44 @@ PDNode *patterns::OperatorActivation::operator()( return activation_out; } +PDNode *patterns::OperatorUnsqueeze2::operator()( + const std::string &operator_type, const int num_of_operator_outs) { + auto *preceding_op = pattern->NewNode(preceding_op_repr()) + ->assert_is_op(operator_type) + ->assert_has_n_outputs(num_of_operator_outs); + auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output(operator_type, "Out") + ->assert_is_op_input("unsqueeze2"); + auto *unsqueeze2_op = + pattern->NewNode(unsqueeze2_op_repr())->assert_is_op("unsqueeze2"); + auto *unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr()) + ->AsOutput() + ->assert_is_op_output("unsqueeze2"); + preceding_op->LinksTo({preceding_op_out}); + unsqueeze2_op->LinksFrom({preceding_op_out}).LinksTo({unsqueeze2_out}); + return unsqueeze2_out; +} + +PDNode *patterns::OperatorReshape2::operator()(const std::string &operator_type, + const int num_of_operator_outs) { + auto *preceding_op = pattern->NewNode(preceding_op_repr()) + ->assert_is_op(operator_type) + ->assert_has_n_outputs(num_of_operator_outs); + auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output(operator_type, "Out") + ->assert_is_op_input("reshape2"); + auto *reshape2_op = + pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2"); + auto *reshape2_out = pattern->NewNode(reshape2_out_repr()) + ->AsOutput() + ->assert_is_op_output("reshape2"); + preceding_op->LinksTo({preceding_op_out}); + reshape2_op->LinksFrom({preceding_op_out}).LinksTo({reshape2_out}); + return reshape2_out; +} + PDNode *patterns::SeqConvEltAddRelu::operator()( paddle::framework::ir::PDNode *seqconv_input) { // Create Operators diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 110e73b228e540c3c7fcc58e053e0ca90ed04795..50372886b72591f15193e873ac0ef324353c4ca6 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -539,6 +539,32 @@ struct OperatorActivation : public PatternBase { PATTERN_DECL_NODE(activation_out); }; +struct OperatorUnsqueeze2 : public PatternBase { + OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "operator_unsqueeze2") {} + + PDNode* operator()(const std::string& operator_type, + const int num_of_outputs); + + PATTERN_DECL_NODE(preceding_op); + PATTERN_DECL_NODE(preceding_op_out); + PATTERN_DECL_NODE(unsqueeze2_op); + PATTERN_DECL_NODE(unsqueeze2_out); +}; + +struct OperatorReshape2 : public PatternBase { + OperatorReshape2(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "operator_reshape2") {} + + PDNode* operator()(const std::string& operator_type, + const int num_of_outputs); + + PATTERN_DECL_NODE(preceding_op); + PATTERN_DECL_NODE(preceding_op_out); + PATTERN_DECL_NODE(reshape2_op); + PATTERN_DECL_NODE(reshape2_out); +}; + // SEQCONV with Elementwise_Add ReLU // op: seqconv + elementwise_add + relu // named nodes: diff --git a/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f8d0452aa17ba84073e845a3f50bdc54b69f6fa --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.cc @@ -0,0 +1,144 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h" + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void FuseOperatorReshape2OneDNNPass::ApplyImpl(Graph *graph) const { + // THIS FUSE WILL WORK ONLY WITH OPERATORS THAT OUTPUTS PLAIN MEMORY, F.E. + // ABCD FOR 4D! BE AWARE OF THAT! + std::vector> ops_and_outputs = { + {"fc", 1}, {"transpose2", 2}}; + + for (const auto &op_and_outputs : ops_and_outputs) + FuseReshape2(graph, op_and_outputs.first, op_and_outputs.second); +} + +void FuseOperatorReshape2OneDNNPass::FuseReshape2(Graph *graph, + const std::string &op_type, + int num_of_outputs) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(op_type + "_reshape2_onednn_fuse_pass", graph); + + GraphPatternDetector gpd; + patterns::OperatorReshape2 op_reshape2_pattern( + gpd.mutable_pattern(), op_type + "_reshape2_onednn_fuse_pass"); + op_reshape2_pattern(op_type, num_of_outputs); + + int found_operator_reshape2_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_reshape2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + operator_out, preceding_op_out, op_reshape2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, op_reshape2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_out, reshape2_out, op_reshape2_pattern); + + if (!operator_op->Op()->HasAttr("use_mkldnn") || + (operator_op->Op()->HasAttr("use_mkldnn") && + !(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) { + VLOG(4) << "Only oneDNN version of " << op_type + << "can be fused with reshape2."; + return; + } + + if (operator_op->Op()->HasAttr("fused_unsqueeze2_axes")) { + VLOG(4) << "Cannot do " << op_type << " + reshape2 fuse, because " + << op_type << " is already fused with unsqueeze2!"; + return; + } + + std::vector reshape2_shape = + PADDLE_GET_CONST(std::vector, reshape2_op->Op()->GetAttr("shape")); + + int num_of_minus_ones = 0; + + for (size_t i = 0; i < reshape2_shape.size(); ++i) { + if (reshape2_shape[i] == 0) { + VLOG(4) << "OneDNN op+reshape2 fuse pass does not support zero dims, " + "skipping"; + return; + } else if (reshape2_shape[i] == -1) { + ++num_of_minus_ones; + } + } + + if (num_of_minus_ones > 1) { + VLOG(4) << "Number of -1 values inside of reshape2 shouldn't be greater " + "than one in op+reshape2 oneDNN fuse pass, skipping"; + return; + } + + auto const &names = reshape2_op->Op()->InputNames(); + + bool has_shape_tensor = + std::find(names.begin(), names.end(), "ShapeTensor") != names.end(); + bool has_shape_tensor_list = + std::find(names.begin(), names.end(), "ShapeTensorList") != names.end(); + + if (has_shape_tensor && + reshape2_op->Op()->Input("ShapeTensor").size() > 0) { + VLOG(4) << "Cannot fuse " << op_type + << " and reshape2 because reshape2 dims are specified by " + "ShapeTensor!"; + return; + } + + if (has_shape_tensor_list && + reshape2_op->Op()->Input("ShapeTensorList").size() > 0) { + VLOG(4) << "Cannot fuse " << op_type + << " and reshape2 because reshape2 dims are specified by " + "ShapeTensorList!"; + return; + } + + operator_op->Op()->SetAttr("fused_reshape2_shape", reshape2_shape); + operator_op->Op()->SetOutput("Out", {reshape2_out->Name()}); + + IR_OP_VAR_LINK(operator_op, reshape2_out); + GraphSafeRemoveNodes(g, {reshape2_op, operator_out}); + found_operator_reshape2_count++; + }; + + gpd(graph, handler); + AddStatis(found_operator_reshape2_count); + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_operator_reshape2_count > 0) + PrettyLogDetail("--- fused %d %s with reshape2", + found_operator_reshape2_count, + op_type); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(operator_reshape2_onednn_fuse_pass, + paddle::framework::ir::FuseOperatorReshape2OneDNNPass); +REGISTER_PASS_CAPABILITY(operator_reshape2_onednn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("reshape2", 0) + .GE("fc", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..a3369b453deefae54d998728df96bd7f79dd97a0 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FuseOperatorReshape2OneDNNPass : public FusePassBase { + public: + virtual ~FuseOperatorReshape2OneDNNPass() {} + + protected: + void ApplyImpl(Graph *graph) const override; + void FuseReshape2(Graph *graph, + const std::string &op_type, + int num_of_outputs) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..80f49613c63aca70504338f2024f9a1f1e783d1b --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h" + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void FuseOperatorUnsqueeze2OneDNNPass::ApplyImpl(Graph *graph) const { + std::vector> ops_and_outputs = { + {"transpose2", 2}, {"elementwise_mul", 1}}; + + for (const auto &op_and_outputs : ops_and_outputs) + FuseUnsqueeze2(graph, op_and_outputs.first, op_and_outputs.second); +} + +void FuseOperatorUnsqueeze2OneDNNPass::FuseUnsqueeze2( + Graph *graph, const std::string &op_type, int num_of_outputs) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(op_type + "_unsqueeze2_onednn_fuse_pass", graph); + + GraphPatternDetector gpd; + patterns::OperatorUnsqueeze2 op_unsqueeze2_pattern( + gpd.mutable_pattern(), op_type + "_unsqueeze2_onednn_fuse_pass"); + op_unsqueeze2_pattern(op_type, num_of_outputs); + + int found_operator_unsqueeze2_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_unsqueeze2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + operator_out, preceding_op_out, op_unsqueeze2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + unsqueeze2_op, unsqueeze2_op, op_unsqueeze2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + unsqueeze2_out, unsqueeze2_out, op_unsqueeze2_pattern); + + if (!operator_op->Op()->HasAttr("use_mkldnn") || + (operator_op->Op()->HasAttr("use_mkldnn") && + !(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) { + VLOG(4) << "Only oneDNN version of " << op_type + << "can be fused with unsqueeze2."; + return; + } + + std::vector unsqueeze2_axes = PADDLE_GET_CONST( + std::vector, unsqueeze2_op->Op()->GetAttr("axes")); + + auto const &names = unsqueeze2_op->Op()->InputNames(); + + bool has_axes_tensor = + std::find(names.begin(), names.end(), "AxesTensor") != names.end(); + bool has_axes_tensor_list = + std::find(names.begin(), names.end(), "AxesTensorList") != names.end(); + + if (has_axes_tensor && + unsqueeze2_op->Op()->Input("AxesTensor").size() > 0) { + VLOG(4) << "Cannot fuse " << op_type + << " and unsqueeze2 because unsqueeze2 dims are specified by " + "AxesTensor!"; + return; + } + + if (has_axes_tensor_list && + unsqueeze2_op->Op()->Input("AxesTensorList").size() > 0) { + VLOG(4) << "Cannot fuse " << op_type + << " and unsqueeze2 because unsqueeze2 dims are specified by " + "AxesTensorList!"; + return; + } + + operator_op->Op()->SetAttr("fused_unsqueeze2_axes", unsqueeze2_axes); + operator_op->Op()->SetOutput("Out", {unsqueeze2_out->Name()}); + + IR_OP_VAR_LINK(operator_op, unsqueeze2_out); + GraphSafeRemoveNodes(g, {unsqueeze2_op, operator_out}); + found_operator_unsqueeze2_count++; + }; + + gpd(graph, handler); + AddStatis(found_operator_unsqueeze2_count); + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_operator_unsqueeze2_count > 0) + PrettyLogDetail("--- fused %d %s with unsqueeze2", + found_operator_unsqueeze2_count, + op_type); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(operator_unsqueeze2_onednn_fuse_pass, + paddle::framework::ir::FuseOperatorUnsqueeze2OneDNNPass); +REGISTER_PASS_CAPABILITY(operator_unsqueeze2_onednn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("unsqueeze2", 0) + .GE("transpose2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..eddd62e61062870a1947024a2d67d1fdd68a8cfc --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FuseOperatorUnsqueeze2OneDNNPass : public FusePassBase { + public: + virtual ~FuseOperatorUnsqueeze2OneDNNPass() {} + + protected: + void ApplyImpl(Graph *graph) const override; + void FuseUnsqueeze2(Graph *graph, + const std::string &op_type, + int num_of_outputs) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 7286b603c6b0b477fafa43b6af651a3a13129116..a602e3edc282e5cdc502cdd512573051e97938a8 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -334,6 +334,8 @@ void CpuPassStrategy::EnableMKLDNN() { "shuffle_channel_mkldnn_detect_pass", // "elt_act_mkldnn_fuse_pass", // "operator_scale_onednn_fuse_pass", // + "operator_unsqueeze2_onednn_fuse_pass", // + "operator_reshape2_onednn_fuse_pass", // // TODO(intel): Please fix the bug on windows. // https://github.com/PaddlePaddle/Paddle/issues/29710 // "mkldnn_inplace_pass", // This pass should be activated after @@ -428,6 +430,8 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass"); passes_.push_back("operator_scale_onednn_fuse_pass"); + passes_.push_back("operator_unsqueeze2_onednn_fuse_pass"); + passes_.push_back("operator_reshape2_onednn_fuse_pass"); passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_squash_pass"); diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index cc70155a217f0f34f568455c4ddf276a054d32ae..e2037d258f16b6bbe85d9c8dd159143120ff848e 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -196,12 +196,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { astream.wait(); if (handler.use_broadcasting_hack == false) { - z->set_mem_desc(dst_memory->get_desc()); + platform::SetOutMemDescWithLogicalLayoutFusesSupport( + ctx, z, dst_memory->get_desc()); } else { auto dims = dst_memory->get_desc().dims(); dims.insert(dims.begin(), x->dims()[0]); dims[1] /= dims[0]; - z->set_mem_desc(dst_memory->get_desc().reshape(dims)); + platform::SetOutMemDescWithLogicalLayoutFusesSupport( + ctx, z, dst_memory->get_desc().reshape(dims)); } } }; diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index eab1251828116ec6ae9bc3202ba38e296c728dfa..f47838283db95ebd32ade43dc74f0df5e9b2b38a 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -39,6 +39,13 @@ constexpr bool IsInt8() { return std::is_same::value || std::is_same::value; } +struct InnerProductCache { + dnnl::inner_product_forward inner_product_p; + dnnl::memory src_mem; + dnnl::memory weights_mem; + dnnl::memory bias_mem; + dnnl::memory dst_mem; +}; template class FCMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT { })); } + void PrepareSrcMem(const std::shared_ptr& fc_p, + const std::shared_ptr& src_mem, + const LoDTensor* x, + const dnnl::engine& engine) const { + auto x_md = x->mem_desc().reshape(src_mem->get_desc().dims()); + if (x_md != src_mem->get_desc()) { + dnnl::memory x_mem(x_md, engine, to_void_cast(x->data())); + auto reorder_p = dnnl::reorder(x_mem, *src_mem); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + reorder_p.execute(astream, x_mem, *src_mem); + astream.wait(); + } else { + src_mem->set_data_handle(to_void_cast(x->data())); + } + } + template void RunKernel(const framework::ExecutionContext& ctx) const { const auto& dev_ctx = @@ -368,29 +392,80 @@ class FCMKLDNNKernel : public framework::OpKernel { const auto* bias = ctx.Input("Bias"); auto out = ctx.Output("Out"); - auto in_col_dims = ctx.Attr("in_num_col_dims"); - const float scale_in = ctx.Attr("Scale_in"); const auto& scale_weights = ctx.Attr>("Scale_weights"); + std::shared_ptr fc_p; + std::shared_ptr src_memory_p; + std::shared_ptr weights_memory_p; + std::shared_ptr bias_memory_p; + std::shared_ptr dst_memory_p; + + std::string cache_key; + cache_key.reserve(64); + cache_key = platform::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, + platform::CreateKey(dev_ctx, + ctx.InputName("Input"), + ctx.InputName("W"), + phi::vectorize(x->dims()))); + + auto inner_product_cache = + std::static_pointer_cast(dev_ctx.GetBlob(cache_key)); + RecomputeOutputDims(ctx, x, weights, out); - FCMKLDNNHandler handler(ctx, - dev_ctx, - x, - weights, - bias, - out, - in_col_dims, - mkldnn_engine, - ctx.GetPlace()); - - auto src_memory_p = handler.AcquireSrcMemoryWithReorder(x); - auto weights_memory_p = - handler.AcquireWeightsMemoryWithReorder(weights, scale_weights); - auto dst_memory_p = handler.AcquireCustomDstMemory(ctx, out); - - auto fc_p = handler.AcquireForwardPrimitive(); + if (inner_product_cache) { + fc_p = std::make_shared( + inner_product_cache->inner_product_p); + src_memory_p = + std::make_shared(inner_product_cache->src_mem); + PrepareSrcMem(fc_p, src_memory_p, x, mkldnn_engine); + + weights_memory_p = + std::make_shared(inner_product_cache->weights_mem); + + dst_memory_p = + std::make_shared(inner_product_cache->dst_mem); + if (ctx.HasAttr("fuse_residual_connection") && + ctx.Attr("fuse_residual_connection")) { + auto* residual_param = ctx.Output("ResidualData"); + out->ShareDataWith(*residual_param); + } + auto out_ptr = out->mutable_data( + ctx.GetPlace(), dst_memory_p->get_desc().get_size()); + dst_memory_p->set_data_handle(out_ptr); + + if (bias) { + bias_memory_p = + std::make_shared(inner_product_cache->bias_mem); + } + } else { + auto in_col_dims = ctx.Attr("in_num_col_dims"); + + FCMKLDNNHandler handler(ctx, + dev_ctx, + x, + weights, + bias, + out, + in_col_dims, + mkldnn_engine, + ctx.GetPlace()); + + src_memory_p = handler.AcquireSrcMemoryWithReorder(x); + weights_memory_p = + handler.AcquireWeightsMemoryWithReorder(weights, scale_weights); + dst_memory_p = handler.AcquireCustomDstMemory(ctx, out); + + if (bias) { + bias_memory_p = + handler.AcquireBiasMemoryWithReorder(bias, scale_in, scale_weights); + } + + fc_p = handler.AcquireForwardPrimitive(); + } + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); std::unordered_map fc_args = { @@ -399,15 +474,27 @@ class FCMKLDNNKernel : public framework::OpKernel { {DNNL_ARG_DST, *dst_memory_p}}; if (bias) { - auto bias_memory_p = - handler.AcquireBiasMemoryWithReorder(bias, scale_in, scale_weights); fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p}); } fc_p->execute(astream, fc_args); astream.wait(); - out->set_mem_desc( + if (!inner_product_cache) { + auto ip_cache = std::make_shared(); + ip_cache->inner_product_p = *fc_p; + ip_cache->src_mem = *src_memory_p; + ip_cache->weights_mem = *weights_memory_p; + ip_cache->dst_mem = *dst_memory_p; + if (bias) { + ip_cache->bias_mem = *bias_memory_p; + } + dev_ctx.SetBlob(cache_key, ip_cache); + } + + platform::SetOutMemDescWithLogicalLayoutFusesSupport( + ctx, + out, dst_memory_p->get_desc().reshape(phi::vectorize(out->dims()))); } diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index d964fe8ec0c2f40fbc714de67594a59754c00003..5f02d087011f0bcd1b8455294873fcf9797f2993 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -81,8 +81,11 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); - out->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes( - TransposeToPermuteAxis(transpose_axis))); + platform::SetOutMemDescWithLogicalLayoutFusesSupport( + ctx, + out, + reorder_dst_memory_p->get_desc().permute_axes( + TransposeToPermuteAxis(transpose_axis))); } private: diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 24884709833cb441e9696f19ce27ade3d158140a..baa3c2aea588e7dcb84e586675b3b3885fd3a08f 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -109,6 +109,69 @@ static void AppendActivation(const framework::ExecutionContext& ctx, } } +static void SetOutMemDescWithUnsqueeze2FuseSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* out, + const dnnl::memory::desc& out_md) { + const std::vector& fused_unsqueeze2_axes = + ctx.Attr>("fused_unsqueeze2_axes"); + const std::vector& op_tz = out_md.dims(); + std::vector unsqueezed_op_tz( + op_tz.size() + fused_unsqueeze2_axes.size(), 0); + + for (const auto& axis : fused_unsqueeze2_axes) { + int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis; + unsqueezed_op_tz[positive_axis] = 1; + } + + int j = 0; + for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) { + if (unsqueezed_op_tz[i] == 0) { + unsqueezed_op_tz[i] = op_tz[j++]; + } + } + out->set_mem_desc(out_md.reshape(unsqueezed_op_tz)); + out->Resize(phi::make_ddim(unsqueezed_op_tz)); +} + +static void SetOutMemDescWithReshape2FuseSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* out, + const dnnl::memory::desc& out_md) { + std::vector fused_reshape2_shape( + ctx.Attr>("fused_reshape2_shape").begin(), + ctx.Attr>("fused_reshape2_shape").end()); + + const int out_shape_numel = out->numel(); + const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(), + fused_reshape2_shape.end(), + 1, + std::multiplies()); + + for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) { + if (fused_reshape2_shape[i] == -1) { + fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel; + break; + } + } + + out->set_mem_desc(out_md.reshape(fused_reshape2_shape)); + out->Resize(phi::make_ddim(fused_reshape2_shape)); +} + +static void SetOutMemDescWithLogicalLayoutFusesSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* out, + const dnnl::memory::desc& out_md) { + if (ctx.HasAttr("fused_unsqueeze2_axes")) { + SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md); + } else if (ctx.HasAttr("fused_reshape2_shape")) { + SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md); + } else { + out->set_mem_desc(out_md); + } +} + template constexpr bool IsInt8() { return std::is_same::value || std::is_same::value;