From 667b6617864c65ef1f2c31d938c26b778fff5ae8 Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Mon, 24 Sep 2018 10:32:44 +0200 Subject: [PATCH] updated the test --- .../ir/conv_relu_mkldnn_fuse_pass_tester.cc | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc index 82b5fa18860..9dd780ec89a 100644 --- a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc @@ -85,16 +85,13 @@ TEST(ConvReLUFusePass, basic) { for (auto* node : graph->Nodes()) { if (node->IsOp() && node->Op()->Type() == "conv2d") { - if (node->Op()->HasAttr("use_mkldnn")) { - bool use_mkldnn = boost::get(node->Op()->GetAttr("use_mkldnn")); - if (use_mkldnn) { - if (node->Op()->HasAttr("fuse_relu")) { - bool fuse_relu = boost::get(node->Op()->GetAttr("fuse_relu")); - if (fuse_relu) { - ++conv_relu_count; - } - } - } + auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); + ASSERT_TRUE(op->HasAttr("fuse_relu")); + bool fuse_relu = boost::get(op->GetAttr("fuse_relu")); + if (fuse_relu) { + ++conv_relu_count; } } } -- GitLab