diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 13c5f2d9838022b6fde355ab3b6cf58d12e08d4b..29e64f0f35612dba8435236dbc2f0c639f5559ed 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -105,6 +105,7 @@ if(WITH_MKLDNN) pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn) pass_library(cpu_bfloat16_pass inference DIR mkldnn) pass_library(fc_mkldnn_pass inference DIR mkldnn) + pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) @@ -155,6 +156,7 @@ if (WITH_MKLDNN) cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass) cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) + cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass) cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass) set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context) if (WITH_GPU) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 2a72642b17d23b590ec5a35d7b9680f740f1ec21..a1e70d2be72f25b897c1599895404f118128c754 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1017,6 +1017,23 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x, return fc_out_var; } +PDNode *patterns::FCActOneDNN::operator()(const std::string &act_type) { + auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc"); + auto *fc_out = pattern->NewNode(fc_out_repr()) + ->assert_is_op_output("fc", "Out") + ->assert_is_op_input(act_type); + auto *act = + pattern->NewNode(act_repr())->assert_is_op(act_type)->AsIntermediate(); + auto *act_out = pattern->NewNode(act_out_repr()) + ->assert_is_op_output(act_type, "Out") + ->AsOutput(); + + fc->LinksTo({fc_out}); + act->LinksFrom({fc_out}).LinksTo({act_out}); + + return act_out; +} + PDNode *patterns::Embedding::operator()(PDNode *x) { x->assert_is_op_input("lookup_table", "Ids"); auto *lookup_table_op = diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index a1e7435523c6ce7ddaba2a3e178eeedf7f46264e..f27a41808b502ddcec37775af042d8874890b711 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -552,6 +552,27 @@ struct FCMKLDNN : public PatternBase { PATTERN_DECL_NODE(output); }; +// +// \brief Pattern looking for fc and a directly following activation +// operator. +// +// \note Currently only gelu and tanh are supported as an activation +// function. +// Formula: act(fc(x)) +// Op: fc + act +struct FCActOneDNN : public PatternBase { + FCActOneDNN(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "fc_act_onednn") {} + + PDNode* operator()(const std::string& act_type); + + // declare operator node's name + PATTERN_DECL_NODE(fc); + PATTERN_DECL_NODE(act); + PATTERN_DECL_NODE(fc_out); + PATTERN_DECL_NODE(act_out); +}; + // Embedding struct Embedding : public PatternBase { Embedding(PDPattern* pattern, const std::string& name_scope) diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..5fc6f92475e976ae858ec1c2a40ba5226a2c024c --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { + std::vector act_types = {"gelu", "tanh", "sigmoid"}; + + for (std::string act_type : act_types) FuseFCAct(graph, act_type); +} + +void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, + const std::string &act_type) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init("fc_act", graph); + + GraphPatternDetector gpd; + patterns::FCActOneDNN fc_act_pattern(gpd.mutable_pattern(), "fc_act"); + fc_act_pattern(act_type); + + int found_fc_act_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "Fuse fc with activation op."; + // FC output + GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_act_pattern); + // ACT output + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fc_act_pattern); + // ops + GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act, act, fc_act_pattern); + + auto *fc_op = fc->Op(); + auto *act_op = act->Op(); + + if (fc_op->HasAttr("use_mkldnn")) { + PADDLE_ENFORCE( + BOOST_GET_CONST(bool, fc_op->GetAttr("use_mkldnn")), + platform::errors::PreconditionNotMet( + "The FC+Act fusion may happen only when oneDNN library " + "is used.")); + } + + if (act_type == "gelu" && act_op->HasAttr("approximate")) { + bool approximate = BOOST_GET_CONST(bool, act_op->GetAttr("approximate")); + std::string type = approximate ? "_tanh" : "_erf"; + fc_op->SetAttr("activation_type", act_type + type); + } else + fc_op->SetAttr("activation_type", act_type); + + fc_op->SetAttr("use_mkldnn", true); + + fc_op->SetOutput("Out", {act_out->Name()}); + + IR_OP_VAR_LINK(fc, act_out); + GraphSafeRemoveNodes(g, {act, fc_out}); + found_fc_act_count++; + }; + + gpd(graph, handler); + AddStatis(found_fc_act_count); + PrettyLogDetail("--- fused %d fc with %s activation", found_fc_act_count, + act_type); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fc_act_mkldnn_fuse_pass, + paddle::framework::ir::FuseFCActOneDNNPass); +REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("fc", 0) + .LE("gelu", 0) + .LE("sigmoid", 0) + .LE("tanh", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..aa2b1c425e73abae71569a7238cc94a0c0a1faf0 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h @@ -0,0 +1,45 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * \brief Fuse the FC and activation operators into single OneDNN's + * FC with post-op. + * + * \note Currently only GeLU, sigmoid and tanh are supported as an activation + * function. + */ +class FuseFCActOneDNNPass : public FusePassBase { + public: + virtual ~FuseFCActOneDNNPass() {} + + protected: + void ApplyImpl(ir::Graph *graph) const override; + + void FuseFCAct(ir::Graph *graph, const std::string &act_types) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddlea diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..634f44a25891c58c779095ff75b6d819f3d03bae --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc @@ -0,0 +1,398 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle { +namespace framework { +namespace ir { + +// -------------------------- helper functions -------------------------------- +namespace { + +using InOutVarNamePair = std::pair; +using OpTypeCountPair = std::pair; + +/// +/// @brief Creates the specified operator and sets up its inputs/outputs. +/// +/// @param prog The program descriptor to which we add new op. +/// @param[in] op_type_name The operator type name. +/// @param[in] inputs The vector of input pairs: {input_name, variable +/// name} +/// @param[in] outputs The vector of output pairs {output_name, variable} +/// @param[in] use_mkldnn The flag deciding whether or not to set +/// 'use_mkldnn' attribute. +/// +/// @return Returns pointer to the created operator descriptor. +/// +OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name, + const std::vector& inputs, + const std::vector& outputs, + bool use_mkldnn = true) { + auto op = prog->MutableBlock(0)->AppendOp(); + op->SetType(op_type_name); + op->SetAttr("use_mkldnn", use_mkldnn); + + for (const auto& input : inputs) { + op->SetInput(input.first, {input.second}); + } + for (const auto& output : outputs) { + op->SetOutput(output.first, {output.second}); + } + + return op; +} + +/// +/// @brief Check whether node 'to' is reachable from node 'from' in graph. +/// +/// @param[in] graph The graph we're checking for reachability. +/// @param[in] from The 'from' node name. +/// @param[in] to The 'to' node name. +/// +/// @return True if there is connection between nodes 'from' and 'to'. +/// +bool TestIsReachable(const Graph& graph, std::string from, std::string to) { + auto hash = [](const Node* node) -> std::string { + return node->Name() + std::to_string(node->id()); + }; + + auto find_node = [&](const Graph& graph, const std::string& name) -> Node* { + for (auto& node : GraphTraits::DFS(graph)) { + if (name == hash(&node)) { + return &node; + } + } + + return nullptr; + }; + + if (from == to) return true; + + std::map visited; + // update the from and to strings to hashed equivs in loop from graph traits + for (auto& node : GraphTraits::DFS(graph)) { + auto hashed = hash(&node); + if (node.Name() == from) { + from = hashed; + } + if (node.Name() == to) { + to = hashed; + } + visited[hashed] = false; + } + + visited[from] = true; + + std::list queue; + queue.push_back(from); + + while (!queue.empty()) { + auto cur = find_node(graph, queue.front()); + queue.pop_front(); + if (cur == nullptr) { + return false; + } + + for (auto n : cur->outputs) { + auto hashed_name = hash(n); + if (hashed_name == to) { + return true; + } + + if (!visited[hashed_name]) { + visited[hashed_name] = true; + queue.push_back(hashed_name); + } + } + } + return false; +} + +/// +/// @brief Search through graph and counts provided operator occurences. +/// +/// @param[in] graph The graph we search through. +/// @param[in] op_type_count The vector of pairs {op_type_name, op count} +/// +/// @note After going through all graph nodes this function asserts +/// whether counted number for each requested op is as expected. +/// +void AssertOpsCount(const Graph& graph, + std::vector op_type_count) { + for (auto* node : graph.Nodes()) { + if (!node->IsOp()) { + continue; + } + + const std::string op_type_name = node->Op()->Type(); + auto op_it = + std::find_if(std::begin(op_type_count), std::end(op_type_count), + [op_type_name](const OpTypeCountPair& p) { + return op_type_name == p.first; + }); + if (op_it != std::end(op_type_count)) { + op_it->second--; + } + } + + for (const OpTypeCountPair& p : op_type_count) { + EXPECT_EQ(p.second, 0); + } +} + +/// +/// @brief Builds a program descriptor. +/// +/// @param[in] transient_vars The vector of transient variables names. +/// @param[in] persistent_vars The vector of persistent variables names. Those +/// will have persistable attribute set to true. +/// +/// @return The program descriptor object. +/// +ProgramDesc BuildProgramDesc(const std::vector& transient_vars, + const std::vector& persistent_vars) { + ProgramDesc prog; + + auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* { + auto var = prog.MutableBlock(0)->Var(var_name); + var->SetType(proto::VarType::LOD_TENSOR); + return var; + }; + + for (const auto& v : transient_vars) { + add_var_to_prog(v); + } + + for (const auto& v : persistent_vars) { + auto* var = add_var_to_prog(v); + var->SetPersistable(true); + } + + return prog; +} + +/// +/// @brief Execute pass on provided graph and perform checks. +/// +/// @param graph The graph we run pass on. +/// @param[in] from The name of a 'starting' node sequence in a +/// graph. This would be used to test for +/// correct node connections. +/// @param[in] to The name of a 'ending' node sequence in a +/// graph. This would be used to test for +/// correct node connections. +/// @param[in] removed_nodes_count The number of nodes we expect will be +/// removed/fused after pass execution. +/// @param[in] added_nodes_count The number of nodes we expect will be +/// added after pass execution. +/// +void RunPassAndAssert(Graph* graph, const std::string& from, + const std::string& to, int removed_nodes_count, + int added_nodes_count = 0) { + EXPECT_TRUE(TestIsReachable(*graph, from, to)); + int original_nodes_num = graph->Nodes().size(); + auto pass = PassRegistry::Instance().Get("fc_act_mkldnn_fuse_pass"); + pass->Apply(graph); + int current_nodes_num = graph->Nodes().size(); + + EXPECT_TRUE(TestIsReachable(*graph, from, to)); + EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count, + current_nodes_num); +} + +} // namespace + +// ------------------------------ Test cases ----------------------------------- + +TEST(FuseFCActOneDNNPass, ThrowUseMkldnn) { + auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}, false); + CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + + Graph graph(prog); + // No fusion in this attribute configuration + constexpr int removed_nodes_count = 0; + + EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); +} + +TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) { + auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + auto* act_op = + CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + act_op->SetAttr("approximate", true); + + Graph graph(prog); + constexpr int removed_nodes_count = 2; + + RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); + AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}); + + for (const auto* node : graph.Nodes()) { + if (node->IsOp() && node->Op()->Type() == "fc") { + const auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn"))); + ASSERT_TRUE(op->HasAttr("activation_type")); + auto act_type = + BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); + EXPECT_TRUE(act_type.compare("gelu_tanh") == 0); + } + } +} + +TEST(FuseFCActOneDNNPass, FuseWithGeluErf) { + auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + auto* act_op = + CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + act_op->SetAttr("approximate", false); + + Graph graph(prog); + constexpr int removed_nodes_count = 2; + + RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); + AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}); + + for (const auto* node : graph.Nodes()) { + if (node->IsOp() && node->Op()->Type() == "fc") { + const auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn"))); + ASSERT_TRUE(op->HasAttr("activation_type")); + auto act_type = + BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); + EXPECT_TRUE(act_type.compare("gelu_erf") == 0); + } + } +} + +TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) { + auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + + Graph graph(prog); + constexpr int removed_nodes_count = 2; + + RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); + AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}); + + for (const auto* node : graph.Nodes()) { + if (node->IsOp() && node->Op()->Type() == "fc") { + const auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn"))); + ASSERT_TRUE(op->HasAttr("activation_type")); + auto act_type = + BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); + EXPECT_TRUE(act_type.compare("gelu") == 0); + } + } +} + +TEST(FuseFCActOneDNNPass, FuseWithTanh) { + auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + CreateOp(&prog, "tanh", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + + Graph graph(prog); + constexpr int removed_nodes_count = 2; + + RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); + AssertOpsCount(graph, {{"fc", 1}, {"tanh", 0}}); + + for (const auto* node : graph.Nodes()) { + if (node->IsOp() && node->Op()->Type() == "fc") { + const auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn"))); + ASSERT_TRUE(op->HasAttr("activation_type")); + auto act_type = + BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); + EXPECT_TRUE(act_type.compare("tanh") == 0); + } + } +} + +TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { + auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + CreateOp(&prog, "sigmoid", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + + Graph graph(prog); + constexpr int removed_nodes_count = 2; + + RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); + AssertOpsCount(graph, {{"fc", 1}, {"sigmoid", 0}}); + + for (const auto* node : graph.Nodes()) { + if (node->IsOp() && node->Op()->Type() == "fc") { + const auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn"))); + ASSERT_TRUE(op->HasAttr("activation_type")); + auto act_type = + BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); + EXPECT_TRUE(act_type.compare("sigmoid") == 0); + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fc_act_mkldnn_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 1448d565661653c2195a0ecda2cb113a52206dcb..deed620aa4d88859a1e2915adbb70343818be349 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -206,7 +206,8 @@ void CpuPassStrategy::EnableMKLDNN() { "reshape_transpose_matmul_mkldnn_fuse_pass", // "matmul_transpose_reshape_fuse_pass", // // Disabled due to topology-dependent speed-up - // "fc_mkldnn_pass", + //"fc_mkldnn_pass", + //"fc_act_mkldnn_fuse_pass", "batch_norm_act_fuse_pass", "mkldnn_inplace_pass", // This pass should be activated after // fuses diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index d61c28c30d203acf4dd48e1461a881d61f8ec263..820bbf0701778f29b3431a93d81dfd0b5d2f408d 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -206,6 +206,7 @@ void profile(bool use_mkldnn = false) { "relu", "fc"}; cfg.SetMKLDNNOp(op_list); cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); + cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } std::vector> outputs; @@ -262,6 +263,7 @@ void compare(bool use_mkldnn = false) { "relu"}; cfg.SetMKLDNNOp(op_list); cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); + cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc index c6a898dc2f315a67e3693abd73f481b08cac414a..af0a51e4ddbb4b49fbc9e7e72adc91f03279b283 100644 --- a/paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc @@ -50,8 +50,10 @@ void profile(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); - if (!FLAGS_disable_mkldnn_fc) + if (!FLAGS_disable_mkldnn_fc) { cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); + cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); + } } std::vector> outputs; @@ -83,8 +85,10 @@ void compare(bool use_mkldnn = false) { SetConfig(&cfg); if (use_mkldnn) { cfg.EnableMKLDNN(); - if (!FLAGS_disable_mkldnn_fc) + if (!FLAGS_disable_mkldnn_fc) { cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); + cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); + } } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h index 0dac11bc3452d3e3e88d86a76d439dd5b489c9c0..5d7f7c290f6a2f25af4ea95ebae7177c5ee2a27c 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h @@ -163,6 +163,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { if (use_mkldnn) { cfg->EnableMKLDNN(); cfg->pass_builder()->AppendPass("fc_mkldnn_pass"); + cfg->pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } // Enable seqpool_concat_fuse_pass, disabled by default since it takes much // time diff --git a/paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc b/paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc index f26ec57103b76500eab99ef11eadc694e2c9b192..65306fd42edab9bc5db235a1ff456dcd9a239289 100644 --- a/paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc @@ -25,6 +25,7 @@ void compare(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); + cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc b/paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc index caeba3277163b2a15183972fb07d315bd951ccde..fc9492a0dfcf43585289e677e76daf2302de0e07 100644 --- a/paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc @@ -26,6 +26,7 @@ void profile(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); + cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index a2ced21a9ac9ad10c2b067a60597eee9fdff9eeb..faa15fc4f0a178be7b3d217ce8a19832676dcbf6 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -86,6 +86,7 @@ void profile(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); + cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } // cfg.pass_builder()->TurnOnDebug(); std::vector> outputs; @@ -136,6 +137,7 @@ void compare(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); + cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 613d193477b60e5960b9116d04fd885314078e8d..89a24cab5f6745e66b77ba02dd24daeae90aa122 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -459,6 +459,36 @@ class FCPrimitiveFactory { constexpr float placeholder = 1.0f; // beta post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, negative_slope, placeholder); + } else if (ctx.Attr("activation_type") == "gelu") { + constexpr float scale = 1.0f; + constexpr float alpha = 0.0f; + constexpr float beta = 0.0f; + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_gelu, + alpha, beta); + } else if (ctx.Attr("activation_type") == "gelu_tanh") { + constexpr float scale = 1.0f; + constexpr float alpha = 0.0f; + constexpr float beta = 0.0f; + post_operations.append_eltwise( + scale, mkldnn::algorithm::eltwise_gelu_tanh, alpha, beta); + } else if (ctx.Attr("activation_type") == "gelu_erf") { + constexpr float scale = 1.0f; + constexpr float alpha = 0.0f; + constexpr float beta = 0.0f; + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_gelu_erf, + alpha, beta); + } else if (ctx.Attr("activation_type") == "tanh") { + constexpr float scale = 1.0f; + constexpr float alpha = 0.0f; + constexpr float beta = 0.0f; + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_tanh, + alpha, beta); + } else if (ctx.Attr("activation_type") == "sigmoid") { + constexpr float scale = 1.0f; + constexpr float alpha = 0.0f; + constexpr float beta = 0.0f; + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_logistic, + alpha, beta); } attributes.set_post_ops(post_operations); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..28d1a239212e45f2641e21b9f26d937cbec15d11 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py @@ -0,0 +1,116 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for fusion of fc and activation.""" +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid as fluid +from inference_pass_test import InferencePassTest +from paddle import enable_static +from paddle.fluid.core import PassVersionChecker + +enable_static() + + +class FCGeluTanhOneDnnFusePassTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 128, 768], dtype="float32") + fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2) + gelu_out = fluid.layers.gelu(fc_out, approximate=False) + + self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")} + + self.fetch_list = [gelu_out] + self.enable_mkldnn = True + + def set_params(self): + self.pass_name = "fc_act_mkldnn_fuse_pass" + + def test_check_output(self): + self.check_output() + + +class FCGeluErfOneDnnFusePassTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 128, 768], dtype="float32") + fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2) + gelu_out = fluid.layers.gelu(fc_out, approximate=True) + + self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")} + + self.fetch_list = [gelu_out] + self.enable_mkldnn = True + + def set_params(self): + self.pass_name = "fc_act_mkldnn_fuse_pass" + + def test_check_output(self): + self.check_output() + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class FCTanhOneDnnFusePassTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 128, 768], dtype="float32") + fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2) + tanh_out = fluid.layers.tanh(fc_out) + + self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")} + + self.fetch_list = [tanh_out] + self.enable_mkldnn = True + + def set_params(self): + self.pass_name = "fc_act_mkldnn_fuse_pass" + + def test_check_output(self): + self.check_output() + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class FCSigmoidOneDnnFusePassTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 128, 768], dtype="float32") + fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2) + sigmoid_out = fluid.layers.sigmoid(fc_out) + + self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")} + + self.fetch_list = [sigmoid_out] + self.enable_mkldnn = True + + def set_params(self): + self.pass_name = "fc_act_mkldnn_fuse_pass" + + def test_check_output(self): + self.check_output() + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +if __name__ == "__main__": + unittest.main()