未验证 提交 04527ee3 编写于 作者: B baoachun 提交者: GitHub

add attr check for infer in batch_norm_act mkldnn fuse pass (#38443)

上级 37022482
......@@ -67,12 +67,6 @@ FuseBatchNormActOneDNNPass::FuseBatchNormActOneDNNPass() {
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("trainable_statistics")
.IsBoolEQ(false)
.End()
.AddAttr("is_test")
.IsBoolEQ(true)
.End();
AddOpCompat(OpCompat("relu"))
......@@ -114,21 +108,19 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern);
auto *bn_op = batch_norm->Op();
if (bn_op->HasAttr("use_mkldnn")) {
if (bn_op->HasAttr("trainable_statistics")) {
PADDLE_ENFORCE(
BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")),
!BOOST_GET_CONST(bool, bn_op->GetAttr("trainable_statistics")),
platform::errors::PreconditionNotMet(
"The BatchNorm+Act fusion may happen only when oneDNN library "
"is used."));
"The BatchNorm+Act fusion may happen only when mean and variance "
"are not calculated by current batch statistics."));
}
auto *act_op = act->Op();
if (act_op->HasAttr("use_mkldnn")) {
if (bn_op->HasAttr("is_test")) {
PADDLE_ENFORCE(
BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")),
BOOST_GET_CONST(bool, bn_op->GetAttr("is_test")),
platform::errors::PreconditionNotMet(
"The BatchNorm+Act fusion may happen only when oneDNN library "
"is used."));
"The BatchNorm+Act fusion may happen only during inference."));
}
bn_op->SetAttr("use_mkldnn", true);
......
......@@ -65,9 +65,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) {
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}}));
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, FuseIsTest) {
......@@ -123,9 +123,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) {
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}}));
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
......@@ -149,9 +149,9 @@ TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}}));
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
......@@ -176,9 +176,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 1}}));
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, pass_op_version_check) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册