From 7db747d9e88a989fb48be970b687b5479c22f52f Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 26 Oct 2020 13:12:48 +0100 Subject: [PATCH] oneDNN BatchNorm + Act fusion pass. (#27912) --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/graph_pattern_detector.cc | 20 + .../framework/ir/graph_pattern_detector.h | 21 + .../ir/mkldnn/batch_norm_act_fuse_pass.cc | 108 +++++ .../ir/mkldnn/batch_norm_act_fuse_pass.h | 44 ++ .../mkldnn/batch_norm_act_fuse_pass_tester.cc | 382 ++++++++++++++++++ .../inference/api/paddle_pass_builder.cc | 1 + .../test_mkldnn_batch_norm_act_fuse_pass.py | 79 ++++ 8 files changed, 657 insertions(+) create mode 100644 paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_batch_norm_act_fuse_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 5bb833f613..9415fe6e61 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -110,6 +110,7 @@ if(WITH_MKLDNN) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) + pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) endif() cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector ) @@ -151,6 +152,7 @@ if (WITH_MKLDNN) cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass) cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) + cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass) set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context) if (WITH_GPU) set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index ed2863e8bf..3127a3fd8a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1188,6 +1188,26 @@ PDNode *patterns::BatchNormActGrad::operator()( return bn_grad; } +PDNode *patterns::BatchNormActOneDNN::operator()(const std::string &act_type) { + auto *bn_x = pattern->NewNode(bn_in_repr()) + ->AsInput() + ->assert_is_op_input("batch_norm", "X"); + auto *bn = pattern->NewNode(batch_norm_repr())->assert_is_op("batch_norm"); + auto *bn_out = pattern->NewNode(bn_out_repr()) + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input(act_type); + auto *act = + pattern->NewNode(act_repr())->assert_is_op(act_type)->AsIntermediate(); + auto *act_out = pattern->NewNode(act_out_repr()) + ->assert_is_op_output(act_type, "Out") + ->AsOutput(); + + bn->LinksFrom({bn_x}).LinksTo({bn_out}); + act->LinksFrom({bn_out}).LinksTo({act_out}); + + return act_out; +} + PDNode *patterns::ElewiseAddAct::operator()( paddle::framework::ir::PDNode *ele_x_var, std::unordered_set act_types) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 15f6ea1541..c44c7b4059 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -664,6 +664,27 @@ struct BatchNormActGrad : public PatternBase { PATTERN_DECL_NODE(d_bn_bias); }; +// +// \brief Pattern looking for batch_norm and a directly following activation +// operator. +// +// \note Currently only ReLU is supported as an activation function. +// Formula: act(bn(x)) +// Op: batch_norm + act +struct BatchNormActOneDNN : public PatternBase { + BatchNormActOneDNN(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "bn_act_onednn") {} + + PDNode* operator()(const std::string& act_type); + + // declare operator node's name + PATTERN_DECL_NODE(bn_in); + PATTERN_DECL_NODE(batch_norm); + PATTERN_DECL_NODE(act); + PATTERN_DECL_NODE(bn_out); + PATTERN_DECL_NODE(act_out); +}; + // The following patterns are used to fuse elewise_add and act // formula: act(ele_add(x, y)) // op: elementwise_add + act diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc new file mode 100644 index 0000000000..7e28ccd24a --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void FuseBatchNormActOneDNNPass::ApplyImpl(Graph *graph) const { + std::string act_type("relu"); + FuseBatchNormAct(graph, act_type); +} + +void FuseBatchNormActOneDNNPass::FuseBatchNormAct( + Graph *graph, const std::string &act_type) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument( + "The input graph of " + "FuseBatchNormActOneDNNPass should not be nullptr.")); + FusePassBase::Init("bn_act", graph); + + GraphPatternDetector gpd; + patterns::BatchNormActOneDNN bn_act_pattern(gpd.mutable_pattern(), "bn_act"); + bn_act_pattern(act_type); + + int found_bn_act_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "Fuse BatchNorm with ReLU activation op."; + // BN output + GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_act_pattern); + // ACT output + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_act_pattern); + // ops + GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, bn_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern); + + auto *bn_op = batch_norm->Op(); + + if (bn_op->HasAttr("use_mkldnn")) { + PADDLE_ENFORCE( + BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")), + platform::errors::PreconditionNotMet( + "The BatchNorm+Act fusion may happen only when oneDNN library " + "is used.")); + } + + if (bn_op->HasAttr("trainable_statistics")) { + PADDLE_ENFORCE( + !BOOST_GET_CONST(bool, bn_op->GetAttr("trainable_statistics")), + platform::errors::PreconditionNotMet( + "The BatchNorm+Act fusion may happen only when mean and variance " + "are not calculated by current batch statistics.")); + } + + if (bn_op->HasAttr("is_test")) { + PADDLE_ENFORCE( + BOOST_GET_CONST(bool, bn_op->GetAttr("is_test")), + platform::errors::PreconditionNotMet( + "The BatchNorm+Act fusion may happen only during inference.")); + } + + bn_op->SetAttr("use_mkldnn", true); + bn_op->SetAttr("is_test", true); + bn_op->SetAttr("fuse_with_relu", true); + bn_op->SetAttr("trainable_statistics", false); + bn_op->SetOutput("Y", {act_out->Name()}); + + IR_OP_VAR_LINK(batch_norm, act_out); + GraphSafeRemoveNodes(g, {act, bn_out}); + found_bn_act_count++; + }; + + gpd(graph, handler); + AddStatis(found_bn_act_count); + PrettyLogDetail("--- fused %d batch norm with relu activation", + found_bn_act_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(batch_norm_act_fuse_pass, + paddle::framework::ir::FuseBatchNormActOneDNNPass); +REGISTER_PASS_CAPABILITY(batch_norm_act_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("batch_norm", 0) + .EQ("relu", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h new file mode 100644 index 0000000000..843e7e420b --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h @@ -0,0 +1,44 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * \brief Fuse the BatchNorm and activation operators into single OneDNN's + * BatchNorm with post-op. + * + * \note Currently only ReLU is supported as an activation function. + */ +class FuseBatchNormActOneDNNPass : public FusePassBase { + public: + virtual ~FuseBatchNormActOneDNNPass() {} + + protected: + void ApplyImpl(ir::Graph *graph) const override; + + void FuseBatchNormAct(ir::Graph *graph, const std::string &act_types) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc new file mode 100644 index 0000000000..5543d19b91 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc @@ -0,0 +1,382 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle { +namespace framework { +namespace ir { + +// -------------------------- helper functions -------------------------------- +namespace { + +using InOutVarNamePair = std::pair; +using OpTypeCountPair = std::pair; + +/// +/// @brief Creates the specified operator and sets up its inputs/outputs. +/// +/// @param prog The program descriptor to which we add new op. +/// @param[in] op_type_name The operator type name. +/// @param[in] inputs The vector of input pairs: {input_name, variable +/// name} +/// @param[in] outputs The vector of output pairs {output_name, variable} +/// @param[in] use_mkldnn The flag deciding whether or not to set +/// 'use_mkldnn' attribute. +/// +/// @return Returns pointer to the created operator descriptor. +/// +OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name, + const std::vector& inputs, + const std::vector& outputs, + bool use_mkldnn = true) { + auto op = prog->MutableBlock(0)->AppendOp(); + op->SetType(op_type_name); + op->SetAttr("use_mkldnn", use_mkldnn); + + for (const auto& input : inputs) { + op->SetInput(input.first, {input.second}); + } + for (const auto& output : outputs) { + op->SetOutput(output.first, {output.second}); + } + + return op; +} + +/// +/// @brief Check whether node 'to' is reachable from node 'from' in graph. +/// +/// @param[in] graph The graph we're checking for reachability. +/// @param[in] from The 'from' node name. +/// @param[in] to The 'to' node name. +/// +/// @return True if there is connection between nodes 'from' and 'to'. +/// +bool TestIsReachable(const Graph& graph, std::string from, std::string to) { + auto hash = [](const Node* node) -> std::string { + return node->Name() + std::to_string(node->id()); + }; + + auto find_node = [&](const Graph& graph, const std::string& name) -> Node* { + for (auto& node : GraphTraits::DFS(graph)) { + if (name == hash(&node)) { + return &node; + } + } + + return nullptr; + }; + + if (from == to) return true; + + std::map visited; + // update the from and to strings to hashed equivs in loop from graph traits + for (auto& node : GraphTraits::DFS(graph)) { + auto hashed = hash(&node); + if (node.Name() == from) { + from = hashed; + } + if (node.Name() == to) { + to = hashed; + } + visited[hashed] = false; + } + + visited[from] = true; + + std::list queue; + queue.push_back(from); + + while (!queue.empty()) { + auto cur = find_node(graph, queue.front()); + queue.pop_front(); + if (cur == nullptr) { + return false; + } + + for (auto n : cur->outputs) { + auto hashed_name = hash(n); + if (hashed_name == to) { + return true; + } + + if (!visited[hashed_name]) { + visited[hashed_name] = true; + queue.push_back(hashed_name); + } + } + } + return false; +} + +/// +/// @brief Search through graph and counts provided operator occurences. +/// +/// @param[in] graph The graph we search through. +/// @param[in] op_type_count The vector of pairs {op_type_name, op count} +/// +/// @note After going through all graph nodes this function asserts +/// whether counted number for each requested op is as expected. +/// +void AssertOpsCount(const Graph& graph, + std::vector op_type_count) { + for (auto* node : graph.Nodes()) { + if (!node->IsOp()) { + continue; + } + + const std::string op_type_name = node->Op()->Type(); + auto op_it = + std::find_if(std::begin(op_type_count), std::end(op_type_count), + [op_type_name](const OpTypeCountPair& p) { + return op_type_name == p.first; + }); + if (op_it != std::end(op_type_count)) { + op_it->second--; + } + } + + for (const OpTypeCountPair& p : op_type_count) { + EXPECT_EQ(p.second, 0); + } +} + +/// +/// @brief Builds a program descriptor. +/// +/// @param[in] transient_vars The vector of transient variables names. +/// @param[in] persistent_vars The vector of persistent variables names. Those +/// will have persistable attribute set to true. +/// +/// @return The program descriptor object. +/// +ProgramDesc BuildProgramDesc(const std::vector& transient_vars, + const std::vector& persistent_vars) { + ProgramDesc prog; + + auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* { + auto var = prog.MutableBlock(0)->Var(var_name); + var->SetType(proto::VarType::LOD_TENSOR); + return var; + }; + + for (const auto& v : transient_vars) { + add_var_to_prog(v); + } + + for (const auto& v : persistent_vars) { + auto* var = add_var_to_prog(v); + var->SetPersistable(true); + } + + return prog; +} + +/// +/// @brief Execute pass on provided graph and perform checks. +/// +/// @param graph The graph we run pass on. +/// @param[in] from The name of a 'starting' node sequence in a +/// graph. This would be used to test for +/// correct node connections. +/// @param[in] to The name of a 'ending' node sequence in a +/// graph. This would be used to test for +/// correct node connections. +/// @param[in] removed_nodes_count The number of nodes we expect will be +/// removed/fused after pass execution. +/// @param[in] added_nodes_count The number of nodes we expect will be +/// added after pass execution. +/// +void RunPassAndAssert(Graph* graph, const std::string& from, + const std::string& to, int removed_nodes_count, + int added_nodes_count = 0) { + EXPECT_TRUE(TestIsReachable(*graph, from, to)); + int original_nodes_num = graph->Nodes().size(); + auto pass = PassRegistry::Instance().Get("batch_norm_act_fuse_pass"); + pass->Apply(graph); + int current_nodes_num = graph->Nodes().size(); + + EXPECT_TRUE(TestIsReachable(*graph, from, to)); + EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count, + current_nodes_num); +} + +void SetBatchNormAttrs(OpDesc* bn_op, bool is_test = true, + bool trainable_stats = true) { + bn_op->SetAttr("is_test", is_test); + bn_op->SetAttr("trainable_statistics", trainable_stats); + bn_op->SetAttr("fuse_with_relu", false); +} + +} // namespace + +// ------------------------------ Test cases ----------------------------------- + +// The below test cases are distinguished by whether following attributes have +// true or false value: +// - is_test +// - trainable_statistics +// The test case name would have only attributes with true value in its name. + +TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) { + auto prog = BuildProgramDesc( + {"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"}, + {"scale", "bias"}); + auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"}, + {"Scale", "scale"}, + {"Bias", "bias"}, + {"Mean", "m"}, + {"Variance", "v"}}, + {{"Y", "bn_y"}, + {"MeanOut", "m_out"}, + {"VarianceOut", "var_out"}, + {"SavedMean", "sm"}, + {"SavedVariance", "sv"}}); + SetBatchNormAttrs(bn_op, true, true); + CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); + + Graph graph(prog); + // No fusion in this attribute configuration + constexpr int removed_nodes_count = 0; + + EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); +} + +TEST(FuseBatchNormActOneDNNPass, FuseIsTest) { + auto prog = + BuildProgramDesc({"x", "m", "v", "bn_y", "act_y"}, {"scale", "bias"}); + auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"}, + {"Scale", "scale"}, + {"Bias", "bias"}, + {"Mean", "m"}, + {"Variance", "v"}}, + {{"Y", "bn_y"}}); + SetBatchNormAttrs(bn_op, true, false); + CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); + + Graph graph(prog); + constexpr int removed_nodes_count = 2; + + RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); + AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 0}}); + + for (const auto* node : graph.Nodes()) { + if (node->IsOp() && node->Op()->Type() == "batch_norm") { + const auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn"))); + ASSERT_TRUE(op->HasAttr("fuse_with_relu")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("fuse_with_relu"))); + ASSERT_TRUE(op->HasAttr("trainable_statistics")); + EXPECT_FALSE(BOOST_GET_CONST(bool, op->GetAttr("trainable_statistics"))); + } + } +} + +TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) { + auto prog = BuildProgramDesc( + {"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"}, + {"scale", "bias"}); + auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"}, + {"Scale", "scale"}, + {"Bias", "bias"}, + {"Mean", "m"}, + {"Variance", "v"}}, + {{"Y", "bn_y"}, + {"MeanOut", "m_out"}, + {"VarianceOut", "var_out"}, + {"SavedMean", "sm"}, + {"SavedVariance", "sv"}}); + SetBatchNormAttrs(bn_op, false, true); + CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); + + Graph graph(prog); + // No fusion in this attribute configuration + constexpr int removed_nodes_count = 0; + + EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); +} + +TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) { + auto prog = BuildProgramDesc( + {"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"}, + {"scale", "bias"}); + auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"}, + {"Scale", "scale"}, + {"Bias", "bias"}, + {"Mean", "m"}, + {"Variance", "v"}}, + {{"Y", "bn_y"}, + {"MeanOut", "m_out"}, + {"VarianceOut", "var_out"}, + {"SavedMean", "sm"}, + {"SavedVariance", "sv"}}); + SetBatchNormAttrs(bn_op, false, false); + CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); + + Graph graph(prog); + // No fusion in this attribute configuration + constexpr int removed_nodes_count = 0; + + EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); +} + +TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) { + auto prog = BuildProgramDesc( + {"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"}, + {"scale", "bias"}); + auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"}, + {"Scale", "scale"}, + {"Bias", "bias"}, + {"Mean", "m"}, + {"Variance", "v"}}, + {{"Y", "bn_y"}, + {"MeanOut", "m_out"}, + {"VarianceOut", "var_out"}, + {"SavedMean", "sm"}, + {"SavedVariance", "sv"}}, + false); + SetBatchNormAttrs(bn_op, false, false); + CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); + + Graph graph(prog); + // No fusion in this attribute configuration + constexpr int removed_nodes_count = 0; + + EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(batch_norm_act_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 19f52422b4..1448d56566 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -207,6 +207,7 @@ void CpuPassStrategy::EnableMKLDNN() { "matmul_transpose_reshape_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", + "batch_norm_act_fuse_pass", "mkldnn_inplace_pass", // This pass should be activated after // fuses })) { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_batch_norm_act_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_batch_norm_act_fuse_pass.py new file mode 100644 index 0000000000..c119cbec88 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_batch_norm_act_fuse_pass.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for fusion of batch norm and activation.""" +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid as fluid +from inference_pass_test import InferencePassTest +from paddle import enable_static +from paddle.fluid.core import PassVersionChecker + +enable_static() + + +class BnReluOneDnnFusePassTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 100, 100], dtype="float32") + bn_out = fluid.layers.batch_norm( + input=data, is_test=True, use_global_stats=self.global_stats) + relu_out = fluid.layers.relu(bn_out) + + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [relu_out] + self.enable_mkldnn = True + + def set_params(self): + self.global_stats = False + self.pass_name = "batch_norm_act_fuse_pass" + + def test_check_output(self): + self.check_output() + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class BnReluGlobalStatsOneDnnFusePassTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 100, 100], dtype="float32") + bn_out = fluid.layers.batch_norm( + input=data, is_test=True, use_global_stats=self.global_stats) + relu_out = fluid.layers.relu(bn_out) + + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [relu_out] + self.enable_mkldnn = True + + def set_params(self): + self.global_stats = True + self.pass_name = "batch_norm_act_fuse_pass" + + def test_check_output(self): + self.check_output() + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +if __name__ == "__main__": + unittest.main() -- GitLab