diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 2b6d9f98abba02a9131ff4bb519838da355d6e5a..a2f1e28ec753eb9e256cd5f53be48e8693f74d60 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -219,6 +219,8 @@ if(WITH_MKLDNN) pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) + pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn) + pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 47d5db7ede300f2b10863309c6f6d0a55e982947..d959396096603cff31df60f53ff5d2b69351cbef 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -958,6 +958,44 @@ PDNode *patterns::OperatorActivation::operator()( return activation_out; } +PDNode *patterns::OperatorUnsqueeze2::operator()( + const std::string &operator_type, const int num_of_operator_outs) { + auto *preceding_op = pattern->NewNode(preceding_op_repr()) + ->assert_is_op(operator_type) + ->assert_has_n_outputs(num_of_operator_outs); + auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output(operator_type, "Out") + ->assert_is_op_input("unsqueeze2"); + auto *unsqueeze2_op = + pattern->NewNode(unsqueeze2_op_repr())->assert_is_op("unsqueeze2"); + auto *unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr()) + ->AsOutput() + ->assert_is_op_output("unsqueeze2"); + preceding_op->LinksTo({preceding_op_out}); + unsqueeze2_op->LinksFrom({preceding_op_out}).LinksTo({unsqueeze2_out}); + return unsqueeze2_out; +} + +PDNode *patterns::OperatorReshape2::operator()(const std::string &operator_type, + const int num_of_operator_outs) { + auto *preceding_op = pattern->NewNode(preceding_op_repr()) + ->assert_is_op(operator_type) + ->assert_has_n_outputs(num_of_operator_outs); + auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output(operator_type, "Out") + ->assert_is_op_input("reshape2"); + auto *reshape2_op = + pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2"); + auto *reshape2_out = pattern->NewNode(reshape2_out_repr()) + ->AsOutput() + ->assert_is_op_output("reshape2"); + preceding_op->LinksTo({preceding_op_out}); + reshape2_op->LinksFrom({preceding_op_out}).LinksTo({reshape2_out}); + return reshape2_out; +} + PDNode *patterns::SeqConvEltAddRelu::operator()( paddle::framework::ir::PDNode *seqconv_input) { // Create Operators diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 496a71ed4e57cbb85bad1972ba4516d0284a6d9e..99e0e3732cdf9769e6556c4a585f926759d23fdc 100755 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -539,6 +539,32 @@ struct OperatorActivation : public PatternBase { PATTERN_DECL_NODE(activation_out); }; +struct OperatorUnsqueeze2 : public PatternBase { + OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "operator_unsqueeze2") {} + + PDNode* operator()(const std::string& operator_type, + const int num_of_outputs); + + PATTERN_DECL_NODE(preceding_op); + PATTERN_DECL_NODE(preceding_op_out); + PATTERN_DECL_NODE(unsqueeze2_op); + PATTERN_DECL_NODE(unsqueeze2_out); +}; + +struct OperatorReshape2 : public PatternBase { + OperatorReshape2(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "operator_reshape2") {} + + PDNode* operator()(const std::string& operator_type, + const int num_of_outputs); + + PATTERN_DECL_NODE(preceding_op); + PATTERN_DECL_NODE(preceding_op_out); + PATTERN_DECL_NODE(reshape2_op); + PATTERN_DECL_NODE(reshape2_out); +}; + // SEQCONV with Elementwise_Add ReLU // op: seqconv + elementwise_add + relu // named nodes: diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc index cdb0f70a56667d3f50801db138a1d40563981f7f..f4ac65a9ab1993d990f52191515d0cd5d4b6cd44 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,9 +14,8 @@ #include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { @@ -26,20 +25,20 @@ namespace ir { using string::PrettyLogDetail; void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { - std::vector act_types = { - "gelu", "tanh", "sigmoid", "mish", "hard_swish"}; + auto act_types = paddle::platform::GetSupportedActivations(); - for (std::string act_type : act_types) FuseFCAct(graph, act_type); + for (auto act_type : act_types) FuseFCAct(graph, act_type); } void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, const std::string &act_type) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - FusePassBase::Init("fc_act", graph); + FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; - patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act"); + patterns::OperatorActivation fc_act_pattern( + gpd.mutable_pattern(), "fc_" + act_type + "_mkldnn_fuse_pass"); fc_act_pattern("fc", act_type); int found_fc_act_count = 0; @@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, "is used.")); } + auto attr_map = paddle::platform::GetAttributeMap(act_type); + for (const auto &attr : attr_map) { + if (act_op->HasAttr(attr.first)) { + fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first)); + } + } + if (act_type == "gelu" && act_op->HasAttr("approximate")) { - bool approximate = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")); - std::string type = approximate ? "_tanh" : "_erf"; - fc_op->SetAttr("activation_type", act_type + type); + std::string gelu_act_type = + PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh" + : "gelu_erf"; + fc_op->SetAttr("fuse_activation", gelu_act_type); } else { - fc_op->SetAttr("activation_type", act_type); + fc_op->SetAttr("fuse_activation", act_type); } - fc_op->SetAttr("use_mkldnn", true); + fc_op->SetAttr("use_mkldnn", true); fc_op->SetOutput("Out", {act_out->Name()}); IR_OP_VAR_LINK(fc, act_out); @@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, gpd(graph, handler); AddStatis(found_fc_act_count); - if (!Has("disable_logs") || !Get("disable_logs")) + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_fc_act_count > 0) PrettyLogDetail( "--- fused %d fc with %s activation", found_fc_act_count, act_type); } @@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("fc", 0) - .LE("gelu", 0) - .LE("sigmoid", 0) - .LE("mish", 1) + .EQ("abs", 0) + .LE("clip", 1) + .EQ("gelu", 0) + .EQ("hard_sigmoid", 0) .LE("hard_swish", 0) - .LE("tanh", 0)); + .LE("leaky_relu", 1) + .LE("mish", 1) + .EQ("relu", 0) + .EQ("relu6", 0) + .EQ("sigmoid", 0) + .EQ("sqrt", 0) + .EQ("swish", 0) + .EQ("tanh", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h index 23f4296b98bcabab17c896a7ea0c80f72e358e06..7e4032d4a135292baa600146718627a486e5149e 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h @@ -23,21 +23,14 @@ namespace paddle { namespace framework { namespace ir { -/* - * \brief Fuse the FC and activation operators into single OneDNN's - * FC with post-op. - * - * \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported - * as an activation function. - */ class FuseFCActOneDNNPass : public FusePassBase { public: virtual ~FuseFCActOneDNNPass() {} protected: - void ApplyImpl(ir::Graph *graph) const override; + void ApplyImpl(Graph *graph) const override; - void FuseFCAct(ir::Graph *graph, const std::string &act_types) const; + void FuseFCAct(Graph *graph, const std::string &act_types) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc index 38f253703ceeec75f45ac89639e6483f65f472b5..2951e2522d0f5c47a291ec09a8b2b12e93acf0ac 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc @@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("gelu_tanh"), 0); } } @@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("gelu_erf"), 0); } } @@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("gelu"), 0); } } @@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("tanh"), 0); } } @@ -213,9 +213,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("sigmoid"), 0); } } @@ -246,9 +246,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("mish"), 0); } } @@ -280,9 +280,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("hard_swish"), 0); } } diff --git a/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f8d0452aa17ba84073e845a3f50bdc54b69f6fa --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.cc @@ -0,0 +1,144 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h" + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void FuseOperatorReshape2OneDNNPass::ApplyImpl(Graph *graph) const { + // THIS FUSE WILL WORK ONLY WITH OPERATORS THAT OUTPUTS PLAIN MEMORY, F.E. + // ABCD FOR 4D! BE AWARE OF THAT! + std::vector> ops_and_outputs = { + {"fc", 1}, {"transpose2", 2}}; + + for (const auto &op_and_outputs : ops_and_outputs) + FuseReshape2(graph, op_and_outputs.first, op_and_outputs.second); +} + +void FuseOperatorReshape2OneDNNPass::FuseReshape2(Graph *graph, + const std::string &op_type, + int num_of_outputs) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(op_type + "_reshape2_onednn_fuse_pass", graph); + + GraphPatternDetector gpd; + patterns::OperatorReshape2 op_reshape2_pattern( + gpd.mutable_pattern(), op_type + "_reshape2_onednn_fuse_pass"); + op_reshape2_pattern(op_type, num_of_outputs); + + int found_operator_reshape2_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_reshape2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + operator_out, preceding_op_out, op_reshape2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, op_reshape2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_out, reshape2_out, op_reshape2_pattern); + + if (!operator_op->Op()->HasAttr("use_mkldnn") || + (operator_op->Op()->HasAttr("use_mkldnn") && + !(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) { + VLOG(4) << "Only oneDNN version of " << op_type + << "can be fused with reshape2."; + return; + } + + if (operator_op->Op()->HasAttr("fused_unsqueeze2_axes")) { + VLOG(4) << "Cannot do " << op_type << " + reshape2 fuse, because " + << op_type << " is already fused with unsqueeze2!"; + return; + } + + std::vector reshape2_shape = + PADDLE_GET_CONST(std::vector, reshape2_op->Op()->GetAttr("shape")); + + int num_of_minus_ones = 0; + + for (size_t i = 0; i < reshape2_shape.size(); ++i) { + if (reshape2_shape[i] == 0) { + VLOG(4) << "OneDNN op+reshape2 fuse pass does not support zero dims, " + "skipping"; + return; + } else if (reshape2_shape[i] == -1) { + ++num_of_minus_ones; + } + } + + if (num_of_minus_ones > 1) { + VLOG(4) << "Number of -1 values inside of reshape2 shouldn't be greater " + "than one in op+reshape2 oneDNN fuse pass, skipping"; + return; + } + + auto const &names = reshape2_op->Op()->InputNames(); + + bool has_shape_tensor = + std::find(names.begin(), names.end(), "ShapeTensor") != names.end(); + bool has_shape_tensor_list = + std::find(names.begin(), names.end(), "ShapeTensorList") != names.end(); + + if (has_shape_tensor && + reshape2_op->Op()->Input("ShapeTensor").size() > 0) { + VLOG(4) << "Cannot fuse " << op_type + << " and reshape2 because reshape2 dims are specified by " + "ShapeTensor!"; + return; + } + + if (has_shape_tensor_list && + reshape2_op->Op()->Input("ShapeTensorList").size() > 0) { + VLOG(4) << "Cannot fuse " << op_type + << " and reshape2 because reshape2 dims are specified by " + "ShapeTensorList!"; + return; + } + + operator_op->Op()->SetAttr("fused_reshape2_shape", reshape2_shape); + operator_op->Op()->SetOutput("Out", {reshape2_out->Name()}); + + IR_OP_VAR_LINK(operator_op, reshape2_out); + GraphSafeRemoveNodes(g, {reshape2_op, operator_out}); + found_operator_reshape2_count++; + }; + + gpd(graph, handler); + AddStatis(found_operator_reshape2_count); + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_operator_reshape2_count > 0) + PrettyLogDetail("--- fused %d %s with reshape2", + found_operator_reshape2_count, + op_type); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(operator_reshape2_onednn_fuse_pass, + paddle::framework::ir::FuseOperatorReshape2OneDNNPass); +REGISTER_PASS_CAPABILITY(operator_reshape2_onednn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("reshape2", 0) + .GE("fc", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..a3369b453deefae54d998728df96bd7f79dd97a0 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FuseOperatorReshape2OneDNNPass : public FusePassBase { + public: + virtual ~FuseOperatorReshape2OneDNNPass() {} + + protected: + void ApplyImpl(Graph *graph) const override; + void FuseReshape2(Graph *graph, + const std::string &op_type, + int num_of_outputs) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..80f49613c63aca70504338f2024f9a1f1e783d1b --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h" + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void FuseOperatorUnsqueeze2OneDNNPass::ApplyImpl(Graph *graph) const { + std::vector> ops_and_outputs = { + {"transpose2", 2}, {"elementwise_mul", 1}}; + + for (const auto &op_and_outputs : ops_and_outputs) + FuseUnsqueeze2(graph, op_and_outputs.first, op_and_outputs.second); +} + +void FuseOperatorUnsqueeze2OneDNNPass::FuseUnsqueeze2( + Graph *graph, const std::string &op_type, int num_of_outputs) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(op_type + "_unsqueeze2_onednn_fuse_pass", graph); + + GraphPatternDetector gpd; + patterns::OperatorUnsqueeze2 op_unsqueeze2_pattern( + gpd.mutable_pattern(), op_type + "_unsqueeze2_onednn_fuse_pass"); + op_unsqueeze2_pattern(op_type, num_of_outputs); + + int found_operator_unsqueeze2_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_unsqueeze2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + operator_out, preceding_op_out, op_unsqueeze2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + unsqueeze2_op, unsqueeze2_op, op_unsqueeze2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + unsqueeze2_out, unsqueeze2_out, op_unsqueeze2_pattern); + + if (!operator_op->Op()->HasAttr("use_mkldnn") || + (operator_op->Op()->HasAttr("use_mkldnn") && + !(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) { + VLOG(4) << "Only oneDNN version of " << op_type + << "can be fused with unsqueeze2."; + return; + } + + std::vector unsqueeze2_axes = PADDLE_GET_CONST( + std::vector, unsqueeze2_op->Op()->GetAttr("axes")); + + auto const &names = unsqueeze2_op->Op()->InputNames(); + + bool has_axes_tensor = + std::find(names.begin(), names.end(), "AxesTensor") != names.end(); + bool has_axes_tensor_list = + std::find(names.begin(), names.end(), "AxesTensorList") != names.end(); + + if (has_axes_tensor && + unsqueeze2_op->Op()->Input("AxesTensor").size() > 0) { + VLOG(4) << "Cannot fuse " << op_type + << " and unsqueeze2 because unsqueeze2 dims are specified by " + "AxesTensor!"; + return; + } + + if (has_axes_tensor_list && + unsqueeze2_op->Op()->Input("AxesTensorList").size() > 0) { + VLOG(4) << "Cannot fuse " << op_type + << " and unsqueeze2 because unsqueeze2 dims are specified by " + "AxesTensorList!"; + return; + } + + operator_op->Op()->SetAttr("fused_unsqueeze2_axes", unsqueeze2_axes); + operator_op->Op()->SetOutput("Out", {unsqueeze2_out->Name()}); + + IR_OP_VAR_LINK(operator_op, unsqueeze2_out); + GraphSafeRemoveNodes(g, {unsqueeze2_op, operator_out}); + found_operator_unsqueeze2_count++; + }; + + gpd(graph, handler); + AddStatis(found_operator_unsqueeze2_count); + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_operator_unsqueeze2_count > 0) + PrettyLogDetail("--- fused %d %s with unsqueeze2", + found_operator_unsqueeze2_count, + op_type); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(operator_unsqueeze2_onednn_fuse_pass, + paddle::framework::ir::FuseOperatorUnsqueeze2OneDNNPass); +REGISTER_PASS_CAPABILITY(operator_unsqueeze2_onednn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("unsqueeze2", 0) + .GE("transpose2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..eddd62e61062870a1947024a2d67d1fdd68a8cfc --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FuseOperatorUnsqueeze2OneDNNPass : public FusePassBase { + public: + virtual ~FuseOperatorUnsqueeze2OneDNNPass() {} + + protected: + void ApplyImpl(Graph *graph) const override; + void FuseUnsqueeze2(Graph *graph, + const std::string &op_type, + int num_of_outputs) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 2a6f47be48723d41f1ef3aea9de3ba7c2e67a7dd..770cf4577b1cd8e16a43daaaaa96fa2c2a9ec4e2 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -327,6 +327,8 @@ void CpuPassStrategy::EnableMKLDNN() { "shuffle_channel_mkldnn_detect_pass", // "elt_act_mkldnn_fuse_pass", // "operator_scale_onednn_fuse_pass", // + "operator_unsqueeze2_onednn_fuse_pass", // + "operator_reshape2_onednn_fuse_pass", // // TODO(intel): Please fix the bug on windows. // https://github.com/PaddlePaddle/Paddle/issues/29710 // "mkldnn_inplace_pass", // This pass should be activated after @@ -421,6 +423,8 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass"); passes_.push_back("operator_scale_onednn_fuse_pass"); + passes_.push_back("operator_unsqueeze2_onednn_fuse_pass"); + passes_.push_back("operator_reshape2_onednn_fuse_pass"); passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_squash_pass"); diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index 0b061b09ec5660cef088073699dcc79746ce35dc..31372dc323f45c6e653597c0f75a86a46d0072e3 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -129,12 +129,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { astream.wait(); if (handler.use_broadcasting_hack == false) { - z->set_mem_desc(dst_memory->get_desc()); + platform::SetOutMemDescWithLogicalLayoutFusesSupport( + ctx, z, dst_memory->get_desc()); } else { auto dims = dst_memory->get_desc().dims(); dims.insert(dims.begin(), x->dims()[0]); dims[1] /= dims[0]; - z->set_mem_desc(dst_memory->get_desc().reshape(dims)); + platform::SetOutMemDescWithLogicalLayoutFusesSupport( + ctx, z, dst_memory->get_desc().reshape(dims)); } } }; diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 93b27b2caac02e4ef37b909ae5b9e821a6d72a94..1cd3e883b42bca1dce7e1be8eff3b497396d0365 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -16,10 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" - -namespace phi { -class DenseTensor; -} // namespace phi +#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { @@ -29,393 +26,131 @@ using dnnl::memory; using dnnl::primitive; using dnnl::prop_kind; using dnnl::stream; -using framework::DataLayout; using framework::DDim; using framework::ExecutionContext; -using framework::LoDTensor; -using framework::Tensor; +using LoDTensor = phi::DenseTensor; using platform::GetMKLDNNFormat; using platform::MKLDNNDeviceContext; +using platform::MKLDNNGetDataType; using platform::to_void_cast; +template +constexpr bool IsInt8() { + return std::is_same::value || std::is_same::value; +} + +struct InnerProductCache { + dnnl::inner_product_forward inner_product_p; + dnnl::memory src_mem; + dnnl::memory weights_mem; + dnnl::memory bias_mem; + dnnl::memory dst_mem; +}; template -class FCPrimitiveFactory { +class FCMKLDNNHandler + : public platform::MKLDNNHandlerNoCachingT { public: - explicit FCPrimitiveFactory(const dnnl::engine& engine) : engine_(engine) {} - - void ExecuteFcPrimitive(const LoDTensor* input, - const Tensor* weights, - const Tensor* bias, - LoDTensor* output, - const MKLDNNDeviceContext& dev_ctx, - const ExecutionContext& ctx) { - RecomputeOutputDims(ctx, input, weights, output); - // If primitive has already been created and cached, don't create new one, - // but update input and output data pointers and return it. - if (fc_) { - UpdateDataPointers(ctx, output, input); - this->Execute(); - return; - } // Otherwise, create a new one. - - auto in_col_dims = ctx.Attr("in_num_col_dims"); - PADDLE_ENFORCE_LE( - in_col_dims, - 2, - platform::errors::Unimplemented( - "DNNL FC doesn't support in_num_col_dims parameter to " - "be higher than " - "2.")); - if (in_col_dims == 2) { - PADDLE_ENFORCE_EQ( - input->dims().size(), - 3, - platform::errors::Unimplemented( - "DNNL FC only supports in_num_col_dims equal to 2 when " - "3 dim input is provided.")); - PADDLE_ENFORCE_EQ( - input->format(), - MKLDNNMemoryFormat::ncw, - platform::errors::Unimplemented( - "DNNL FC only supports in_num_col_dims equal to 2 when " - "input format is equal to ncw.")); + FCMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, + const platform::MKLDNNDeviceContext& dev_ctx, + const phi::DenseTensor* x, + const phi::DenseTensor* weights, + const phi::DenseTensor* bias, + phi::DenseTensor* out, + const int in_num_col_dims, + dnnl::engine mkldnn_engine, + platform::Place cpu_place) + : platform::MKLDNNHandlerNoCachingT( + mkldnn_engine, cpu_place), + dev_ctx_(dev_ctx) { + this->memory_key_ = ctx.InputName("W"); + + auto x_vec_dims = phi::vectorize(x->dims()); + auto weights_vec_dims = phi::vectorize(weights->dims()); + + int MB = 1; + for (int i = 0; i < in_num_col_dims; ++i) { + MB *= x_vec_dims[i]; } - weights_ = CreateWeightsMemory(weights); - - // Since MKL-DNN has a lot of limitations on what the input/weights/output - // dimensions should be, to simplify the code, the creation of primitive - // descriptor has been divided into separate cases, based on the number - // of input dimensions. - size_t input_dim_num = input->dims().size(); - paddle::optional fc_prim_desc; - memory::desc usr_weights_desc = {}; - switch (input_dim_num) { - case 2: - fc_prim_desc = - Create2DFcPrimDescriptor(input, weights, bias, output, ctx); - usr_weights_desc = Create2DUserWeightsDesc(); - break; - case 3: - fc_prim_desc = - Create3DFcPrimDescriptor(input, weights, bias, output, ctx); - usr_weights_desc = Create3DUserWeightsDesc(weights); - break; - case 4: - fc_prim_desc = - Create4DFcPrimDescriptor(input, weights, bias, output, ctx); - usr_weights_desc = Create4DUserWeightsDesc(input, weights); - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "DNNL FC doesn't support input dims different than 2, 3, 4.")); - break; + int IC = 1; + for (size_t i = in_num_col_dims; i < x_vec_dims.size(); ++i) { + IC *= x_vec_dims[i]; } - input_ = CreateMemory(fc_prim_desc->src_desc(), input); - // Update weights format inside of its memory - weights_ = Reorder( - usr_weights_desc, usr_weights_desc, weights_->get_data_handle()); - - // Quantize weights and reorder to format chosen by FC primitive descriptor. - QuantizeWeights(ctx, fc_prim_desc->weights_desc()); - - bias_ = CreateMemoryToBeCached(fc_prim_desc->bias_desc(), bias); - // If int8 is desired, quantize bias into 32-bit signed int - QuantizeBias(*fc_prim_desc, ctx); - - // Store weights and bias in the mkldnn cache - CacheWeightsAndBias(dev_ctx, ctx); - - // Based on format determined by inner_product, create output in desired - // memory format - output_ = CreateDstMemory(*fc_prim_desc, ctx, output); - - // Return MKL-DNN primitive ready to be fed into pipeline and executed - fc_ = inner_product_forward(*fc_prim_desc); - this->Execute(); - } - void Execute() { - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - if (bias_) { - fc_->execute(astream, - {{DNNL_ARG_SRC, *input_}, - {DNNL_ARG_WEIGHTS, *weights_}, - {DNNL_ARG_BIAS, *bias_}, - {DNNL_ARG_DST, *output_}}); - } else { - fc_->execute(astream, - {{DNNL_ARG_SRC, *input_}, - {DNNL_ARG_WEIGHTS, *weights_}, - {DNNL_ARG_DST, *output_}}); - } - astream.wait(); - } + int OC = weights_vec_dims[1]; - private: - // DNNL always returns 2-dimensional data block as a result of computing - // inner product. Hence the format 'nc' is always set for its output - // primitive. Therefore, function SetOutputFormat is needed to choose - // an appropriate format based on the number of input dimensions and - // format of an input tensor. - void SetOutputFormat(MKLDNNMemoryFormat in_format, Tensor* out) { - int dim_num = out->dims().size(); - // In case of 2 dims, we set the only possible format, nc - if (dim_num == 2) { - out->set_format(MKLDNNMemoryFormat::nc); - out->set_mem_desc({phi::vectorize(out->dims()), - platform::MKLDNNGetDataType(), - out->format()}); - // In case of 3 dims, we generate a format that is based on number - // of output dims and the layout of input format (nchw or nhwc). - } else if (dim_num == 3) { - if (in_format == MKLDNNMemoryFormat::nwc || - in_format == MKLDNNMemoryFormat::nhwc) { - out->set_format( - platform::MKLDNNFormatForSize(dim_num, MKLDNNMemoryFormat::nhwc)); - } else { - out->set_format( - platform::MKLDNNFormatForSize(dim_num, MKLDNNMemoryFormat::nchw)); - } - // In any other case we overwrite the output format with the input one. - } else { - out->set_format(in_format); - } - } + dnnl::memory::desc bias_md; - void UpdateDataPointers(const ExecutionContext& ctx, - Tensor* out, - const Tensor* in) { - input_->set_data_handle(to_void_cast(in->data())); - output_->set_data_handle(out->mutable_data(ctx.GetPlace())); - // If the primitive exists, but the output tensor has changed its - // variable, update its format to what has been determined in first - // call to CreateFcPrimitive method. - if (out->format() == MKLDNNMemoryFormat::undef) { - SetOutputFormat(in->format(), out); + auto src_md = dnnl::memory::desc( + {MB, IC}, MKLDNNGetDataType(), dnnl::memory::format_tag::any); + auto weights_md = dnnl::memory::desc( + {OC, IC}, MKLDNNGetDataType(), dnnl::memory::format_tag::any); + auto dst_md = dnnl::memory::desc( + {MB, OC}, MKLDNNGetDataType(), dnnl::memory::format_tag::any); + if (bias) { + bias_md = dnnl::memory::desc({bias->numel()}, + MKLDNNGetDataType(), + dnnl::memory::format_tag::a); } - } - dnnl::inner_product_forward::primitive_desc Create2DFcPrimDescriptor( - const LoDTensor* input, - const Tensor* weights, - const Tensor* bias, - LoDTensor* output, - const ExecutionContext& ctx) { - auto src_desc = CreateMemDescriptor(input, MKLDNNMemoryFormat::any); - auto weight_dims = Get2DWeightDimsForDNNL(weights); - auto weights_desc = - CreateMemDescriptor(weight_dims, MKLDNNMemoryFormat::any); - auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); - auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any); - const auto attrs = CreateFCAttrs(ctx); - return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); - } - - std::vector Get2DWeightDimsForDNNL(const Tensor* weights) { - auto dims = phi::vectorize(weights->dims()); - std::swap(dims[0], dims[1]); // swap input dim with output dim - return dims; - } - - memory::desc Create2DUserWeightsDesc() { return weights_->get_desc(); } - - dnnl::inner_product_forward::primitive_desc Create3DFcPrimDescriptor( - const LoDTensor* input, - const Tensor* weights, - const Tensor* bias, - LoDTensor* output, - const ExecutionContext& ctx) { - auto input_dims = phi::vectorize(input->dims()); - std::vector new_input_dims = { - input_dims[0] * input_dims[1], input_dims[2], 1}; - auto src_desc = - CreateMemDescriptor(new_input_dims, MKLDNNMemoryFormat::any); - - auto weight_dims = Get3DWeightDimsForDNNL(weights); - auto weights_desc = - CreateMemDescriptor(weight_dims, MKLDNNMemoryFormat::any); - - auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); - - auto dst_dims = {input_dims[0] * input_dims[1], weight_dims[0]}; - auto dst_desc = - CreateMemDescriptor(dst_dims, MKLDNNMemoryFormat::any); - const auto attrs = CreateFCAttrs(ctx); - return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); - } - - std::vector Get3DWeightDimsForDNNL(const Tensor* weights) { - auto paddle_w_dims = phi::vectorize(weights->dims()); - return {paddle_w_dims[1], paddle_w_dims[0], 1}; - } - - memory::desc Create3DUserWeightsDesc(const Tensor* weights) { - auto dims = Get3DWeightDimsForDNNL(weights); - return CreateMemDescriptor(dims, MKLDNNMemoryFormat::oiw); - } - - dnnl::inner_product_forward::primitive_desc Create4DFcPrimDescriptor( - const LoDTensor* input, - const Tensor* weights, - const Tensor* bias, - LoDTensor* output, - const ExecutionContext& ctx) { - auto src_desc = CreateMemDescriptor(input, MKLDNNMemoryFormat::any); - // Since MKL-DNN doesn't support 4D column-major data formats in - // inner_product primitive, transpose the weights to be in - // row-major format - auto dims = Get4DWeightDimsForDNNL(input, weights); - auto weights_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::any); - auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); - auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any); const auto attrs = CreateFCAttrs(ctx); - return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); - } - - std::vector Get4DWeightDimsForDNNL(const LoDTensor* input, - const Tensor* weights) { - auto old_w_dims = phi::vectorize(weights->dims()); - auto old_in_dims = phi::vectorize(input->dims()); - auto dims = {old_w_dims[1], old_in_dims[1], old_in_dims[2], old_in_dims[3]}; - return dims; - } - - memory::desc Create4DUserWeightsDesc(const LoDTensor* input, - const Tensor* weights) { - auto dims = Get4DWeightDimsForDNNL(input, weights); - return CreateMemDescriptor(dims, MKLDNNMemoryFormat::oihw); - } - // Convert data from one data format to another - std::shared_ptr Reorder(const memory::desc& src_desc, - const memory::desc& dst_desc, - void* src_data) { - auto src_mem = memory(src_desc, engine_, src_data); - auto dst_mem = std::make_shared(dst_desc, engine_); - - auto reorder = dnnl::reorder(src_mem, *dst_mem); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - - { - platform::RecordEvent record_reorder( - "int_reorder", - platform::TracerEventType::UserDefined, - 2, - platform::EventRole::kUniqueOp); - reorder.execute(astream, src_mem, *dst_mem); - astream.wait(); - } - - return dst_mem; + this->AcquireForwardPrimitiveDescriptor(attrs, + prop_kind::forward_inference, + src_md, + weights_md, + bias_md, + dst_md); } - // Convert data from one data format to another and rescale it. - // If the desired data type is (un)signed int8, quantization occurs here. - std::shared_ptr ReorderWithScale( - const std::shared_ptr src_mem, - const memory::desc& dst_md, - const std::vector& scale_data) { - auto dst_mem = std::make_shared(dst_md, engine_); + private: + dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) { dnnl::primitive_attr attributes; - // According to MKL-DNN's documentation mask determines along which - // dimensions should the scale be applied. - // 0 - Single scale applied to whole tensor - // 1 - Apply Scale along a slice of each dimension which index is 1. - // In case of weights quantization, that dimension is output, - // becuase we perform per-output-channel quantization - int mask = CreateMask(0, scale_data.size() > 1); - attributes.set_output_scales(mask, scale_data); - auto reorder = dnnl::reorder(*src_mem, *dst_mem, attributes); + dnnl::post_ops post_operations; - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - { - platform::RecordEvent record_reorder( - "int_reorder", - platform::TracerEventType::UserDefined, - 2, - platform::EventRole::kUniqueOp); - reorder.execute(astream, - {{DNNL_ARG_FROM, *src_mem}, {DNNL_ARG_TO, *dst_mem}}); - astream.wait(); + std::vector output_shift_scale; + float scale = 1.0f; + if (IsInt8()) { + std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx); + int mask = CreateMask(1, output_shift_scale.size() > 1); + attributes.set_output_scales(mask, output_shift_scale); } - return dst_mem; - } - - template - static dnnl::memory::desc CreateMemDescriptor( - const std::vector& dims, MKLDNNMemoryFormat format) { - return platform::MKLDNNMemDesc( - dims, platform::MKLDNNGetDataType(), format); - } - - template - static dnnl::memory::desc CreateMemDescriptor(const Tensor* tensor, - MKLDNNMemoryFormat format) { - auto dims = phi::vectorize(tensor->dims()); - return CreateMemDescriptor(dims, format); - } - - template - dnnl::memory CreateMemory(const dnnl::memory::desc& desc, - const Tensor* tensor) { - return CreateMemory(desc, platform::to_void_cast(tensor->data())); - } - - dnnl::memory CreateMemory(const dnnl::memory::desc& desc, void* data) { - return memory(desc, engine_, data); - } - - template - std::shared_ptr CreateMemoryToBeCached( - const dnnl::memory::desc& desc, const Tensor* tensor) { - return CreateMemoryToBeCached(desc, - platform::to_void_cast(tensor->data())); - } - - std::shared_ptr CreateMemoryToBeCached( - const dnnl::memory::desc& desc, void* data) { - return std::make_shared(desc, engine_, data); - } + float sum_scale = 1.0f; + if (ctx.HasAttr("fuse_residual_connection") && + ctx.Attr("fuse_residual_connection")) { + post_operations.append_sum(sum_scale); + } - // Create weights memory and transform to default MKL-DNN format - std::shared_ptr CreateWeightsMemory(const Tensor* weights) { - auto dims = phi::vectorize(weights->dims()); - std::swap(dims[0], dims[1]); // Correct output dimensions - auto src_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::io); - auto dst_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::oi); - // Transpose weights through MKL-DNN's reorder from io to oi format. - return Reorder(src_desc, - dst_desc, - platform::to_void_cast(weights->data())); - } + // ReLU from "fc_fuse_pass" + if (ctx.Attr("activation_type") == "relu") { + post_operations.append_eltwise( + scale, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f); + } + platform::AppendActivation(ctx, post_operations, scale); - void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx, - const ExecutionContext& ctx) { - std::string key = platform::CreateKey(dev_ctx); - key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + if (ctx.HasAttr("fused_output_scale")) { + float scale_alpha = ctx.Attr("fused_output_scale"); + post_operations.append_eltwise( + 1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f); + } - const std::string weights_key = key + ctx.InputName("W"); - const std::string bias_key = key + ctx.InputName("Bias"); - dev_ctx.SetBlob(weights_key, weights_); - dev_ctx.SetBlob(bias_key, bias_); + attributes.set_post_ops(post_operations); + return attributes; } // Compute the bias scales so that its values correspond to the // scale of data being an output of weights and input multiplication - std::vector ComputeBiasScales(const ExecutionContext& ctx) { - auto scale_in_data = ctx.Attr("Scale_in"); - auto scale_weights_data = ctx.Attr>("Scale_weights"); - const size_t weight_scales_num = scale_weights_data.size(); - std::vector bias_scales(weight_scales_num); + std::vector ComputeBiasScales( + const float scale_in, const std::vector& scale_weights) { + std::vector bias_scales(scale_weights.size()); -#pragma omp parallel for - for (size_t i = 0; i < weight_scales_num; i++) { - if (scale_weights_data[i] == 0.0) + for (size_t i = 0; i < bias_scales.size(); ++i) { + if (scale_weights[i] == 0.0) bias_scales[i] = 1.0f; else - bias_scales[i] = scale_in_data * scale_weights_data[i]; + bias_scales[i] = scale_in * scale_weights[i]; } return bias_scales; @@ -442,18 +177,16 @@ class FCPrimitiveFactory { ? 1.0f : ctx.Attr("Scale_out"); const size_t weight_scales_num = scale_weights_data.size(); - std::vector output_shift_scale(weight_scales_num); -#pragma omp parallel for - for (size_t i = 0; i < weight_scales_num; i++) { + for (size_t i = 0; i < weight_scales_num; ++i) { if (scale_weights_data[i] == 0.0) - output_shift_scale[i] = inner_scale; + scale_weights_data[i] = inner_scale; else - output_shift_scale[i] = + scale_weights_data[i] = inner_scale / (scale_in_data * scale_weights_data[i]); } - return make_tuple(output_shift_scale, scale); + return make_tuple(scale_weights_data, scale); } // Computing MKL-DNN's scaling mask which determines along which dimension @@ -464,137 +197,300 @@ class FCPrimitiveFactory { return is_multi_channel_quantizied ? 1 << slice_dimension : 0; } - void QuantizeWeights(const ExecutionContext& ctx, memory::desc dst) { - weights_ = ReorderWithScale( - weights_, dst, ctx.Attr>("Scale_weights")); - } + std::shared_ptr AcquireMemoryWithReorderAndAttrs( + const dnnl::memory::desc& user_md, + const dnnl::memory::desc& target_md, + void* ptr, + const dnnl::primitive_attr& attrs) { + std::shared_ptr target_memory_p; - void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc, - const ExecutionContext& ctx) { - auto bias_scales = ComputeBiasScales(ctx); - bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales); - } + auto user_memory_p = + std::make_shared(user_md, this->engine_, ptr); + target_memory_p = std::make_shared(target_md, this->engine_); + auto reorder_p = std::make_shared( + *user_memory_p, *target_memory_p, attrs); - dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) { - dnnl::primitive_attr attributes; - dnnl::post_ops post_operations; + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + reorder_p->execute( + astream, + {{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}}); + astream.wait(); - std::vector output_shift_scale; - float scale; - std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx); - int mask = CreateMask(1, output_shift_scale.size() > 1); - attributes.set_output_scales(mask, output_shift_scale); + return target_memory_p; + } - float sum_scale = 1.0f; - if (ctx.HasAttr("fuse_residual_connection") && - ctx.Attr("fuse_residual_connection")) { - post_operations.append_sum(sum_scale); - } + std::string memory_key_; + const platform::MKLDNNDeviceContext& dev_ctx_; - if (ctx.Attr("activation_type") == "relu") { - constexpr float negative_slope = 0.0f; - constexpr float placeholder = 1.0f; // beta - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_relu, negative_slope, placeholder); - } else if (ctx.Attr("activation_type") == "gelu") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_gelu, alpha, beta); - } else if (ctx.Attr("activation_type") == "gelu_tanh") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_gelu_tanh, alpha, beta); - } else if (ctx.Attr("activation_type") == "gelu_erf") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_gelu_erf, alpha, beta); - } else if (ctx.Attr("activation_type") == "tanh") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_tanh, alpha, beta); - } else if (ctx.Attr("activation_type") == "sigmoid") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_logistic, alpha, beta); - } else if (ctx.Attr("activation_type") == "mish") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_mish, alpha, beta); - } else if (ctx.Attr("activation_type") == "hard_swish") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_hardswish, alpha, beta); + public: + std::shared_ptr AcquireSrcMemoryWithReorder( + const phi::DenseTensor* x) { + const T_in* x_data = x->data(); + + auto user_md = x->mem_desc(); + if (x->dims().size() != 2) { + // reshape restrictions are always satisfied because in case of 3 or 4 dim + // input, plain layout is enforced + user_md = user_md.reshape(this->fwd_pd_->src_desc().dims()); } - if (ctx.HasAttr("fused_output_scale")) { - float scale_alpha = ctx.Attr("fused_output_scale"); - post_operations.append_eltwise( - 1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f); - } + return this->AcquireMemoryWithReorder( + user_md, this->fwd_pd_->src_desc(), to_void_cast(x_data)); + } - attributes.set_post_ops(post_operations); - return attributes; + std::shared_ptr AcquireBiasMemoryWithReorder( + const phi::DenseTensor* bias, + const float scale_in, + const std::vector& scale_weights) { + const float* bias_data = bias->data(); + + if (IsInt8() == false) { + // for BF16/FP32 bias is 1D and has no scales, so reorder is not needed + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->bias_desc(), + to_void_cast(bias_data)); + } else { + const std::string bias_key = this->memory_key_ + "@bias"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(bias_key)); + + if (!memory_p) { + const auto& scale_data = ComputeBiasScales(scale_in, scale_weights); + dnnl::primitive_attr attrs; + + int mask = CreateMask(0, scale_data.size() > 1); + attrs.set_output_scales(mask, scale_data); + + auto user_md = dnnl::memory::desc({bias->dims()[0]}, + MKLDNNGetDataType(), + dnnl::memory::format_tag::a); + + memory_p = this->AcquireMemoryWithReorderAndAttrs( + user_md, + this->fwd_pd_->bias_desc(), + to_void_cast(bias_data), + attrs); + this->dev_ctx_.SetBlob(bias_key, memory_p); + } + return memory_p; + } } - dnnl::inner_product_forward::primitive_desc CreateFcPrimDesc( - const dnnl::memory::desc& input_desc, - const dnnl::memory::desc& weights_desc, - const dnnl::memory::desc& bias_desc, - const dnnl::memory::desc& dst_desc, - const dnnl::primitive_attr& attrs) { - auto fc_desc = inner_product_forward::desc(prop_kind::forward_scoring, - input_desc, - weights_desc, - bias_desc, - dst_desc); + std::shared_ptr AcquireWeightsMemoryWithReorder( + const phi::DenseTensor* weights, const std::vector& scale_data) { + const std::string weights_key = this->memory_key_ + "@weights"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(weights_key)); + + if (!memory_p) { + const float* weights_data = weights->data(); + auto weights_dims = this->fwd_pd_->weights_desc().dims(); - return inner_product_forward::primitive_desc(fc_desc, attrs, engine_); + auto user_md = dnnl::memory::desc(weights_dims, + MKLDNNGetDataType(), + dnnl::memory::format_tag::io); + + if (IsInt8()) { + dnnl::primitive_attr attrs; + int mask = CreateMask(0, scale_data.size() > 1); + attrs.set_output_scales(mask, scale_data); + + memory_p = this->AcquireMemoryWithReorderAndAttrs( + user_md, + this->fwd_pd_->weights_desc(), + to_void_cast(weights_data), + attrs); + } else { + memory_p = + this->AcquireMemoryWithReorder(user_md, + this->fwd_pd_->weights_desc(), + to_void_cast(weights_data)); + } + + this->dev_ctx_.SetBlob(weights_key, memory_p); + } + return memory_p; } - // Create output memory based on output tensor and inner_product - // primitive descriptor format chosen for output - dnnl::memory CreateDstMemory( - const dnnl::inner_product_forward::primitive_desc& fc_prim_desc, - const ExecutionContext& ctx, - Tensor* output) { + std::shared_ptr AcquireCustomDstMemory( + const ExecutionContext& ctx, phi::DenseTensor* out) { if (ctx.HasAttr("fuse_residual_connection") && ctx.Attr("fuse_residual_connection")) { - auto* residual_param = ctx.Output("ResidualData"); + auto* residual_param = ctx.Output("ResidualData"); PADDLE_ENFORCE_EQ( - output->dims(), + out->dims(), residual_param->dims(), platform::errors::InvalidArgument( "Output and elementwise parameter need to have the " "same dimension sizes, but got output's dimension = %d" " and residual param's dimension =%d .", - output->dims().size(), + out->dims().size(), residual_param->dims().size())); - output->ShareDataWith(*residual_param); + out->ShareDataWith(*residual_param); } + return this->template AcquireDstMemory(out); + } // namespace operators +}; // namespace paddle - auto dst_desc = fc_prim_desc.dst_desc(); - auto buffer_size = dst_desc.get_size(); - T_out* output_data = - output->mutable_data(ctx.GetPlace(), buffer_size); - memory dst_mem(dst_desc, engine_, to_void_cast(output_data)); - SetOutputFormat(ctx.Input("Input")->format(), output); +template +class FCMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + bool force_fp32_output = ctx.Attr("force_fp32_output"); + bool fuse_relu = ctx.Attr("activation_type") == "relu"; - return dst_mem; + if (force_fp32_output) { + this->RunKernel(ctx); + } else if (IsInt8()) { + if (fuse_relu) { + this->RunKernel(ctx); + } else { + this->RunKernel(ctx); + } + } else { + this->RunKernel(ctx); + } + } + + void PrepareSrcMem(const std::shared_ptr& fc_p, + const std::shared_ptr& src_mem, + const LoDTensor* x, + const dnnl::engine& engine) const { + auto x_md = x->mem_desc().reshape(src_mem->get_desc().dims()); + if (x_md != src_mem->get_desc()) { + dnnl::memory x_mem(x_md, engine, to_void_cast(x->data())); + auto reorder_p = dnnl::reorder(x_mem, *src_mem); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + reorder_p.execute(astream, x_mem, *src_mem); + astream.wait(); + } else { + src_mem->set_data_handle(to_void_cast(x->data())); + } + } + + template + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + const auto* x = ctx.Input("Input"); + const auto* weights = ctx.Input("W"); + const auto* bias = ctx.Input("Bias"); + auto out = ctx.Output("Out"); + + const float scale_in = ctx.Attr("Scale_in"); + const auto& scale_weights = ctx.Attr>("Scale_weights"); + + std::shared_ptr fc_p; + std::shared_ptr src_memory_p; + std::shared_ptr weights_memory_p; + std::shared_ptr bias_memory_p; + std::shared_ptr dst_memory_p; + + std::string cache_key; + cache_key.reserve(64); + cache_key = platform::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, + platform::CreateKey(dev_ctx, + ctx.InputName("Input"), + ctx.InputName("W"), + phi::vectorize(x->dims()))); + + auto inner_product_cache = + std::static_pointer_cast(dev_ctx.GetBlob(cache_key)); + + RecomputeOutputDims(ctx, x, weights, out); + + if (inner_product_cache) { + fc_p = std::make_shared( + inner_product_cache->inner_product_p); + src_memory_p = + std::make_shared(inner_product_cache->src_mem); + PrepareSrcMem(fc_p, src_memory_p, x, mkldnn_engine); + + weights_memory_p = + std::make_shared(inner_product_cache->weights_mem); + + dst_memory_p = + std::make_shared(inner_product_cache->dst_mem); + if (ctx.HasAttr("fuse_residual_connection") && + ctx.Attr("fuse_residual_connection")) { + auto* residual_param = ctx.Output("ResidualData"); + out->ShareDataWith(*residual_param); + } + auto out_ptr = out->mutable_data( + ctx.GetPlace(), dst_memory_p->get_desc().get_size()); + dst_memory_p->set_data_handle(out_ptr); + + if (bias) { + bias_memory_p = + std::make_shared(inner_product_cache->bias_mem); + } + } else { + auto in_col_dims = ctx.Attr("in_num_col_dims"); + + FCMKLDNNHandler handler(ctx, + dev_ctx, + x, + weights, + bias, + out, + in_col_dims, + mkldnn_engine, + ctx.GetPlace()); + + src_memory_p = handler.AcquireSrcMemoryWithReorder(x); + weights_memory_p = + handler.AcquireWeightsMemoryWithReorder(weights, scale_weights); + dst_memory_p = handler.AcquireCustomDstMemory(ctx, out); + + if (bias) { + bias_memory_p = + handler.AcquireBiasMemoryWithReorder(bias, scale_in, scale_weights); + } + + fc_p = handler.AcquireForwardPrimitive(); + } + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + + std::unordered_map fc_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + if (bias) { + fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p}); + } + + fc_p->execute(astream, fc_args); + astream.wait(); + + if (!inner_product_cache) { + auto ip_cache = std::make_shared(); + ip_cache->inner_product_p = *fc_p; + ip_cache->src_mem = *src_memory_p; + ip_cache->weights_mem = *weights_memory_p; + ip_cache->dst_mem = *dst_memory_p; + if (bias) { + ip_cache->bias_mem = *bias_memory_p; + } + dev_ctx.SetBlob(cache_key, ip_cache); + } + + platform::SetOutMemDescWithLogicalLayoutFusesSupport( + ctx, + out, + dst_memory_p->get_desc().reshape(phi::vectorize(out->dims()))); } void RecomputeOutputDims(const ExecutionContext& ctx, - const LoDTensor* input, - const Tensor* w, - LoDTensor* output) { + const LoDTensor* x, + const phi::DenseTensor* weights, + LoDTensor* out) const { int in_num_col_dims = ctx.Attr("in_num_col_dims"); bool padding_weights = ctx.Attr("padding_weights"); PADDLE_ENFORCE_EQ(padding_weights, @@ -602,102 +498,16 @@ class FCPrimitiveFactory { platform::errors::PermissionDenied( "Weight padding in fc can not be used in MKLDNN.")); std::vector output_dims; - FCOutputSize(input->dims(), - w->dims(), + FCOutputSize(x->dims(), + weights->dims(), output_dims, in_num_col_dims, padding_weights); - output->Resize(phi::make_ddim(output_dims)); - output->set_lod(input->lod()); + out->Resize(phi::make_ddim(output_dims)); + out->set_lod(x->lod()); } - - private: - const dnnl::engine& engine_; - paddle::optional input_; - paddle::optional output_; - std::shared_ptr bias_; - std::shared_ptr weights_; - paddle::optional fc_; }; -// Attempt to fetch cached primitive factory based on provided parameters -// of input format, weight dimensions and output name. -// If not cached, create a new one. -template -static std::shared_ptr> -GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx, - const std::string& key) { - auto prim_creator = - std::static_pointer_cast>( - dev_ctx.GetBlob(key)); - if (prim_creator == nullptr) { - prim_creator = std::make_shared>( - dev_ctx.GetEngine()); - dev_ctx.SetBlob(key, prim_creator); - } - - return prim_creator; -} - -// Choose appropriate primitive factory implementation based on inferred -// output type (uint8, int8 or float). -template -static void ExecuteFc(const ExecutionContext& ctx, - const LoDTensor* input, - const Tensor* w, - const Tensor* bias, - LoDTensor* output, - bool fuse_relu, - bool force_fp32_output) { - auto& dev_ctx = ctx.template device_context(); - std::string prim_key = platform::CreateKey(dev_ctx, - input->format(), - input->dims()[0], - phi::vectorize(w->dims()), - ctx.OutputName("Out")); - prim_key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, prim_key); - - constexpr bool is_int8 = - std::is_same::value || std::is_same::value; - bool is_bfloat16 = std::is_same::value; - if ((!is_int8 && !is_bfloat16) || force_fp32_output) { - GetPrimitiveFactory(dev_ctx, prim_key) - ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); - } else if (is_bfloat16) { - GetPrimitiveFactory(dev_ctx, prim_key) - ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); - } else if (fuse_relu) { - GetPrimitiveFactory(dev_ctx, prim_key) - ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); - } else { - GetPrimitiveFactory(dev_ctx, prim_key) - ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); - } -} - -template -class FCMKLDNNOpKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), - true, - platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace.")); - platform::MKLDNNDeviceContext::tls().log_lib_version(); - auto input = ctx.Input("Input"); - auto w = ctx.Input("W"); - auto bias = ctx.Input("Bias"); - auto output = ctx.Output("Out"); - - bool fuse_relu = ctx.Attr("activation_type") == "relu"; - bool force_fp32_output = ctx.Attr("force_fp32_output"); - - ExecuteFc( - ctx, input, w, bias, output, fuse_relu, force_fp32_output); - - output->set_layout(DataLayout::kMKLDNN); - } -}; } // namespace operators } // namespace paddle @@ -710,7 +520,7 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, ::paddle::platform::CPUPlace, FP32, ops::kFCMKLDNNFP32, - ops::FCMKLDNNOpKernel); + ops::FCMKLDNNKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( fc, @@ -718,19 +528,19 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( ::paddle::platform::CPUPlace, BF16, ops::kFCMKLDNNFP32, - ops::FCMKLDNNOpKernel); + ops::FCMKLDNNKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace, U8, ops::kFCMKLDNNINT8, - ops::FCMKLDNNOpKernel); + ops::FCMKLDNNKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace, S8, ops::kFCMKLDNNINT8, - ops::FCMKLDNNOpKernel); + ops::FCMKLDNNKernel); diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index 611c54d74c9cbb6535fdda785d10b846e5baba2a..c88377fccd37536ed9b84ce8fca67b4aace59211 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -21,72 +21,8 @@ namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using framework::DataLayout; - -template -class TransposeMKLDNNHandler { - public: - TransposeMKLDNNHandler(std::vector& dims, // NOLINT - std::vector& axis, // NOLINT - dnnl::engine engine) - : dims_(dims), - axis_(axis), - logical_axis_(dims.size(), 0), - engine_(engine) {} - - std::shared_ptr AcquireSrcMemory(const MKLDNNMemoryFormat& fmt, - void* ptr) { - // Make memory descriptor using input format, unless it - // cannot be trusted (nchw) then make up memory fmt manually - for (size_t i = 0; i < this->logical_axis_.size(); ++i) { - this->logical_axis_[i] = i; - } - - auto src_md = fmt != MKLDNNMemoryFormat::nchw - ? platform::MKLDNNMemDesc( - dims_, platform::MKLDNNGetDataType(), fmt) - : Axis2MemoryDesc(dims_, logical_axis_); - return std::make_shared(src_md, engine_, ptr); - } - - std::shared_ptr AcquireDstMemory(framework::Tensor* output, - platform::Place place) { - auto dst_md = Axis2MemoryDesc(dims_, axis_); - auto dst_data = output->mutable_data(place, dst_md.get_size()); - return std::make_shared(dst_md, engine_, dst_data); - } - - std::shared_ptr AcquireTranspose( - std::shared_ptr dst_memory_p, - std::shared_ptr src_memory_p) { - return std::make_shared(*(src_memory_p), *(dst_memory_p)); - } - - protected: - dnnl::memory::desc Axis2MemoryDesc(std::vector& nchw_tz, // NOLINT - std::vector& axis // NOLINT - ) { - size_t ndims = axis.size(); - - std::vector strides(ndims); - unsigned int total_stride = 1; - for (int i = ndims - 1; i >= 0; --i) { - strides[axis[i]] = total_stride; - total_stride *= nchw_tz[axis[i]]; - } - dnnl::memory::desc mem_d( - nchw_tz, platform::MKLDNNGetDataType(), strides); - - return mem_d; - } - - private: - std::vector dims_; - std::vector axis_; - std::vector logical_axis_; - dnnl::engine engine_; -}; +using Tensor = phi::DenseTensor; +using phi::DataLayout; template class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { @@ -98,37 +34,87 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { "Operator DNNL Transpose must use CPUPlace")); auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - std::vector axis = ctx.Attr>("axis"); - int ndims = axis.size(); - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - const T* input_data = input->data(); + const auto& dnnl_engine = dev_ctx.GetEngine(); + std::vector transpose_axis = ctx.Attr>("axis"); + int ndims = transpose_axis.size(); + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); if (ndims == 1) { - framework::TensorCopy(*input, input->place(), output); - output->set_format(input->format()); + framework::TensorCopy(*x, x->place(), out); + out->set_mem_desc(x->mem_desc()); return; } - auto nchw_tz = phi::vectorize(input->dims()); + auto x_vec_dims = phi::vectorize(x->dims()); - TransposeMKLDNNHandler handler(nchw_tz, axis, mkldnn_engine); + framework::proto::VarType::Type x_paddle_type = + framework::TransToProtoVarType(x->dtype()); + dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(x_paddle_type); + platform::ReorderMKLDNNHandler reorder_handler( + x_vec_dims, x_paddle_type, x_type, dnnl_engine); - auto transpose_src_memory_p = handler.AcquireSrcMemory( - input->format(), platform::to_void_cast(input_data)); - auto transpose_dst_memory_p = - handler.AcquireDstMemory(output, ctx.GetPlace()); - auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p, - transpose_src_memory_p); + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x->mem_desc(), platform::to_void_cast(x->data())); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - transpose_p->execute( - astream, *transpose_src_memory_p, *transpose_dst_memory_p); + auto dst_md = + dnnl::memory::desc(x_vec_dims, + x->mem_desc().data_type(), + platform::GetPlainMKLDNNFormat(x_vec_dims.size())); + // a trick is used here to fake transpose of out_md, so later it will be + // "untransposed", leaving output data in plain format tag + auto dst_strides = FakeTranposeStrides(dst_md, transpose_axis); + + dst_md = + dnnl::memory::desc(x_vec_dims, x->mem_desc().data_type(), dst_strides); + auto dst_data = + out->mutable_data(ctx.GetPlace(), x->type(), dst_md.get_size()); + + auto reorder_dst_memory_p = + std::make_shared(dst_md, dnnl_engine, dst_data); + + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + reorder_src_memory_p); + + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); - output->set_layout(DataLayout::kNCHW); - output->set_format(MKLDNNMemoryFormat::undef); + platform::SetOutMemDescWithLogicalLayoutFusesSupport( + ctx, + out, + reorder_dst_memory_p->get_desc().permute_axes( + TransposeToPermuteAxis(transpose_axis))); + } + + private: + // it is needed because oneDNN's permute axis understand axes order in + // different way PaddlePaddle's transpose + std::vector TransposeToPermuteAxis( + const std::vector& transpose_axis) const { + std::vector permute_axis(transpose_axis.size()); + + for (size_t i = 0; i < transpose_axis.size(); ++i) { + permute_axis[transpose_axis[i]] = i; + } + return permute_axis; + } + + std::vector FakeTranposeStrides( + const dnnl::memory::desc& dst_md, + const std::vector& transpose_axis) const { + std::vector fake_strides(transpose_axis.size()); + auto dims = dst_md.dims(); + int total_stride = 1; + int ndims = static_cast(dims.size()); + + for (int i = ndims - 1; i >= 0; --i) { + fake_strides[transpose_axis[i]] = total_stride; + total_stride *= dims[transpose_axis[i]]; + } + + return fake_strides; } }; @@ -140,44 +126,47 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { true, paddle::platform::errors::PreconditionNotMet( "Operator DNNL TransposeGrad must use CPUPlace")); - auto* out_grad = - ctx.Input(framework::GradVarName("Out")); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - if (!x_grad) return; + + const auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + if (!dx) return; auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - std::vector axis = ctx.Attr>("axis"); - std::vector reversed_axis(axis); - int ndims = axis.size(); + const auto& dnnl_engine = dev_ctx.GetEngine(); + std::vector transpose_axis = ctx.Attr>("axis"); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + int ndims = transpose_axis.size(); if (ndims == 1) { - framework::TensorCopy(*out_grad, out_grad->place(), x_grad); - x_grad->set_format(out_grad->format()); + framework::TensorCopy(*dout, dout->place(), dx); + dx->set_mem_desc(dout->mem_desc()); return; } - for (size_t i = 0; i < axis.size(); i++) { - reversed_axis[axis[i]] = i; - } + auto dout_vec_dims = phi::vectorize(dout->dims()); - const T* out_grad_data = out_grad->data(); - x_grad->mutable_data(ctx.GetPlace()); + framework::proto::VarType::Type dout_paddle_type = + framework::TransToProtoVarType(dout->dtype()); + dnnl::memory::data_type dout_type = + framework::ToMKLDNNDataType(dout_paddle_type); - auto nchw_tz = phi::vectorize(out_grad->dims()); + platform::ReorderMKLDNNHandler reorder_handler( + dout_vec_dims, dout_paddle_type, dout_type, dnnl_engine); - TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, mkldnn_engine); + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + dout->mem_desc(), platform::to_void_cast(dout->data())); - auto transpose_src_memory_p = handler.AcquireSrcMemory( - out_grad->format(), platform::to_void_cast(out_grad_data)); - auto transpose_dst_memory_p = - handler.AcquireDstMemory(x_grad, ctx.GetPlace()); - auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p, - transpose_src_memory_p); + auto reorder_dst_memory_p = + reorder_handler.AcquireDstMemory(dx, dout->mem_desc(), ctx.GetPlace()); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - transpose_p->execute( - astream, *transpose_src_memory_p, *transpose_dst_memory_p); + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + reorder_src_memory_p); + + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); + dx->set_mem_desc( + reorder_dst_memory_p->get_desc().permute_axes(transpose_axis)); } }; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 9b870af90a1782596616b3a4158a6d57247b0b80..dcd8124fc6b8e8b641d04f97dbdafdc9ccd4d612 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -111,6 +111,69 @@ static void AppendActivation(const framework::ExecutionContext& ctx, } } +static void SetOutMemDescWithUnsqueeze2FuseSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* out, + const dnnl::memory::desc& out_md) { + const std::vector& fused_unsqueeze2_axes = + ctx.Attr>("fused_unsqueeze2_axes"); + const std::vector& op_tz = out_md.dims(); + std::vector unsqueezed_op_tz( + op_tz.size() + fused_unsqueeze2_axes.size(), 0); + + for (const auto& axis : fused_unsqueeze2_axes) { + int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis; + unsqueezed_op_tz[positive_axis] = 1; + } + + int j = 0; + for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) { + if (unsqueezed_op_tz[i] == 0) { + unsqueezed_op_tz[i] = op_tz[j++]; + } + } + out->set_mem_desc(out_md.reshape(unsqueezed_op_tz)); + out->Resize(phi::make_ddim(unsqueezed_op_tz)); +} + +static void SetOutMemDescWithReshape2FuseSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* out, + const dnnl::memory::desc& out_md) { + std::vector fused_reshape2_shape( + ctx.Attr>("fused_reshape2_shape").begin(), + ctx.Attr>("fused_reshape2_shape").end()); + + const int out_shape_numel = out->numel(); + const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(), + fused_reshape2_shape.end(), + 1, + std::multiplies()); + + for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) { + if (fused_reshape2_shape[i] == -1) { + fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel; + break; + } + } + + out->set_mem_desc(out_md.reshape(fused_reshape2_shape)); + out->Resize(phi::make_ddim(fused_reshape2_shape)); +} + +static void SetOutMemDescWithLogicalLayoutFusesSupport( + const framework::ExecutionContext& ctx, + phi::DenseTensor* out, + const dnnl::memory::desc& out_md) { + if (ctx.HasAttr("fused_unsqueeze2_axes")) { + SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md); + } else if (ctx.HasAttr("fused_reshape2_shape")) { + SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md); + } else { + out->set_mem_desc(out_md); + } +} + template constexpr bool IsInt8() { return std::is_same::value || std::is_same::value;