未验证 提交 04527ee3 编写于 作者: B baoachun 提交者: GitHub

add attr check for infer in batch_norm_act mkldnn fuse pass (#38443)

上级 37022482
...@@ -67,12 +67,6 @@ FuseBatchNormActOneDNNPass::FuseBatchNormActOneDNNPass() { ...@@ -67,12 +67,6 @@ FuseBatchNormActOneDNNPass::FuseBatchNormActOneDNNPass() {
.AddAttr("epsilon") .AddAttr("epsilon")
.IsNumGE(0.0f) .IsNumGE(0.0f)
.IsNumLE(0.001f) .IsNumLE(0.001f)
.End()
.AddAttr("trainable_statistics")
.IsBoolEQ(false)
.End()
.AddAttr("is_test")
.IsBoolEQ(true)
.End(); .End();
AddOpCompat(OpCompat("relu")) AddOpCompat(OpCompat("relu"))
...@@ -114,21 +108,19 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct( ...@@ -114,21 +108,19 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern);
auto *bn_op = batch_norm->Op(); auto *bn_op = batch_norm->Op();
if (bn_op->HasAttr("use_mkldnn")) { if (bn_op->HasAttr("trainable_statistics")) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")), !BOOST_GET_CONST(bool, bn_op->GetAttr("trainable_statistics")),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The BatchNorm+Act fusion may happen only when oneDNN library " "The BatchNorm+Act fusion may happen only when mean and variance "
"is used.")); "are not calculated by current batch statistics."));
} }
auto *act_op = act->Op(); if (bn_op->HasAttr("is_test")) {
if (act_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")), BOOST_GET_CONST(bool, bn_op->GetAttr("is_test")),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The BatchNorm+Act fusion may happen only when oneDNN library " "The BatchNorm+Act fusion may happen only during inference."));
"is used."));
} }
bn_op->SetAttr("use_mkldnn", true); bn_op->SetAttr("use_mkldnn", true);
......
...@@ -65,9 +65,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) { ...@@ -65,9 +65,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) {
// No fusion in this attribute configuration // No fusion in this attribute configuration
constexpr int removed_nodes_count = 0; constexpr int removed_nodes_count = 0;
EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count)); "act_y", removed_nodes_count),
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); paddle::platform::EnforceNotMet);
} }
TEST(FuseBatchNormActOneDNNPass, FuseIsTest) { TEST(FuseBatchNormActOneDNNPass, FuseIsTest) {
...@@ -123,9 +123,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) { ...@@ -123,9 +123,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) {
// No fusion in this attribute configuration // No fusion in this attribute configuration
constexpr int removed_nodes_count = 0; constexpr int removed_nodes_count = 0;
EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count)); "act_y", removed_nodes_count),
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); paddle::platform::EnforceNotMet);
} }
TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) { TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
...@@ -149,9 +149,9 @@ TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) { ...@@ -149,9 +149,9 @@ TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
// No fusion in this attribute configuration // No fusion in this attribute configuration
constexpr int removed_nodes_count = 0; constexpr int removed_nodes_count = 0;
EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count)); "act_y", removed_nodes_count),
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); paddle::platform::EnforceNotMet);
} }
TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) { TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
...@@ -176,9 +176,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) { ...@@ -176,9 +176,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
// No fusion in this attribute configuration // No fusion in this attribute configuration
constexpr int removed_nodes_count = 0; constexpr int removed_nodes_count = 0;
EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count)); "act_y", removed_nodes_count),
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}})); paddle::platform::EnforceNotMet);
} }
TEST(FuseBatchNormActOneDNNPass, pass_op_version_check) { TEST(FuseBatchNormActOneDNNPass, pass_op_version_check) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册