From 04527ee348da5aca998688c433c02752cbb2cab4 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Mon, 27 Dec 2021 10:23:22 +0800 Subject: [PATCH] add attr check for infer in batch_norm_act mkldnn fuse pass (#38443) --- .../ir/mkldnn/batch_norm_act_fuse_pass.cc | 22 ++++++----------- .../mkldnn/batch_norm_act_fuse_pass_tester.cc | 24 +++++++++---------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc index fe758a6b3a..9f6cd8992d 100644 --- a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc @@ -67,12 +67,6 @@ FuseBatchNormActOneDNNPass::FuseBatchNormActOneDNNPass() { .AddAttr("epsilon") .IsNumGE(0.0f) .IsNumLE(0.001f) - .End() - .AddAttr("trainable_statistics") - .IsBoolEQ(false) - .End() - .AddAttr("is_test") - .IsBoolEQ(true) .End(); AddOpCompat(OpCompat("relu")) @@ -114,21 +108,19 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct( GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern); auto *bn_op = batch_norm->Op(); - if (bn_op->HasAttr("use_mkldnn")) { + if (bn_op->HasAttr("trainable_statistics")) { PADDLE_ENFORCE( - BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")), + !BOOST_GET_CONST(bool, bn_op->GetAttr("trainable_statistics")), platform::errors::PreconditionNotMet( - "The BatchNorm+Act fusion may happen only when oneDNN library " - "is used.")); + "The BatchNorm+Act fusion may happen only when mean and variance " + "are not calculated by current batch statistics.")); } - auto *act_op = act->Op(); - if (act_op->HasAttr("use_mkldnn")) { + if (bn_op->HasAttr("is_test")) { PADDLE_ENFORCE( - BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")), + BOOST_GET_CONST(bool, bn_op->GetAttr("is_test")), platform::errors::PreconditionNotMet( - "The BatchNorm+Act fusion may happen only when oneDNN library " - "is used.")); + "The BatchNorm+Act fusion may happen only during inference.")); } bn_op->SetAttr("use_mkldnn", true); diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc index 26828fdd94..e13d44ac23 100644 --- a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc @@ -65,9 +65,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) { // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", - "act_y", removed_nodes_count)); - EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, FuseIsTest) { @@ -123,9 +123,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) { // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", - "act_y", removed_nodes_count)); - EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) { @@ -149,9 +149,9 @@ TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) { // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", - "act_y", removed_nodes_count)); - EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) { @@ -176,9 +176,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) { // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", - "act_y", removed_nodes_count)); - EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), + paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, pass_op_version_check) { -- GitLab