diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc index fe758a6b3a7af43dafebc75a81a07226703c9ce2..9f6cd8992dcb9b5100ff88112117567cbd7c478c 100644 --- a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc @@ -67,12 +67,6 @@ FuseBatchNormActOneDNNPass::FuseBatchNormActOneDNNPass() { .AddAttr("epsilon") .IsNumGE(0.0f) .IsNumLE(0.001f) - .End() - .AddAttr("trainable_statistics") - .IsBoolEQ(false) - .End() - .AddAttr("is_test") - .IsBoolEQ(true) .End(); AddOpCompat(OpCompat("relu")) @@ -114,21 +108,19 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct( GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern); auto *bn_op = batch_norm->Op(); - if (bn_op->HasAttr("use_mkldnn")) { + if (bn_op->HasAttr("trainable_statistics")) { PADDLE_ENFORCE( - BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")), + !BOOST_GET_CONST(bool, bn_op->GetAttr("trainable_statistics")), platform::errors::PreconditionNotMet( - "The BatchNorm+Act fusion may happen only when oneDNN library " - "is used.")); + "The BatchNorm+Act fusion may happen only when mean and variance " + "are not calculated by current batch statistics.")); } - auto *act_op = act->Op(); - if (act_op->HasAttr("use_mkldnn")) { + if (bn_op->HasAttr("is_test")) { PADDLE_ENFORCE( - BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")), + BOOST_GET_CONST(bool, bn_op->GetAttr("is_test")), platform::errors::PreconditionNotMet( - "The BatchNorm+Act fusion may happen only when oneDNN library " - "is used.")); + "The BatchNorm+Act fusion may happen only during inference.")); } bn_op->SetAttr("use_mkldnn", true); diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc index 26828fdd94b730ca4f7edceb5ab70c5fc7d4083b..e13d44ac23222187a82753a027dd3585f423800b 100644 --- a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc @@ -65,9 +65,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) { // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", - "act_y", removed_nodes_count)); - EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, FuseIsTest) { @@ -123,9 +123,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) { // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", - "act_y", removed_nodes_count)); - EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) { @@ -149,9 +149,9 @@ TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) { // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", - "act_y", removed_nodes_count)); - EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) { @@ -176,9 +176,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) { // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", - "act_y", removed_nodes_count)); - EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, pass_op_version_check) {