From c5ffad126c7c0e80cecd274d281267d2bad370d2 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Sat, 16 Jan 2021 12:11:57 +0100 Subject: [PATCH] [oneDNN] Refactor fuse pass helper functions to one place. (#30460) * Move pass tester helper functions to single common place. * Use helper functions in two more fuse pass tests. --- paddle/fluid/framework/ir/CMakeLists.txt | 7 +- .../mkldnn/batch_norm_act_fuse_pass_tester.cc | 345 ++++------------- ...elementwise_add_mkldnn_fuse_pass_tester.cc | 308 +++++----------- .../mkldnn/fc_act_mkldnn_fuse_pass_tester.cc | 346 ++++-------------- .../framework/ir/mkldnn/pass_test_util.cc | 174 +++++++++ .../framework/ir/mkldnn/pass_test_util.h | 119 ++++++ 6 files changed, 556 insertions(+), 743 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/pass_test_util.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/pass_test_util.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 201c1db9c5..ee25f16fde 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -158,13 +158,14 @@ if(NOT WIN32) cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) endif() if (WITH_MKLDNN) + cc_library(pass_test_util SRCS mkldnn/pass_test_util.cc DEPS graph pass) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass) cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) - cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) - cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass) - cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass) + cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util) + cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util) + cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util) set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context) if (WITH_GPU) set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv) diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc index 5543d19b91..c8a4d94fe2 100644 --- a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc @@ -13,17 +13,11 @@ // limitations under the License. #include -#include -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/errors.h" @@ -31,209 +25,15 @@ namespace paddle { namespace framework { namespace ir { -// -------------------------- helper functions -------------------------------- namespace { -using InOutVarNamePair = std::pair; -using OpTypeCountPair = std::pair; - -/// -/// @brief Creates the specified operator and sets up its inputs/outputs. -/// -/// @param prog The program descriptor to which we add new op. -/// @param[in] op_type_name The operator type name. -/// @param[in] inputs The vector of input pairs: {input_name, variable -/// name} -/// @param[in] outputs The vector of output pairs {output_name, variable} -/// @param[in] use_mkldnn The flag deciding whether or not to set -/// 'use_mkldnn' attribute. -/// -/// @return Returns pointer to the created operator descriptor. -/// -OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name, - const std::vector& inputs, - const std::vector& outputs, - bool use_mkldnn = true) { - auto op = prog->MutableBlock(0)->AppendOp(); - op->SetType(op_type_name); - op->SetAttr("use_mkldnn", use_mkldnn); - - for (const auto& input : inputs) { - op->SetInput(input.first, {input.second}); - } - for (const auto& output : outputs) { - op->SetOutput(output.first, {output.second}); - } - - return op; -} - -/// -/// @brief Check whether node 'to' is reachable from node 'from' in graph. -/// -/// @param[in] graph The graph we're checking for reachability. -/// @param[in] from The 'from' node name. -/// @param[in] to The 'to' node name. -/// -/// @return True if there is connection between nodes 'from' and 'to'. -/// -bool TestIsReachable(const Graph& graph, std::string from, std::string to) { - auto hash = [](const Node* node) -> std::string { - return node->Name() + std::to_string(node->id()); - }; - - auto find_node = [&](const Graph& graph, const std::string& name) -> Node* { - for (auto& node : GraphTraits::DFS(graph)) { - if (name == hash(&node)) { - return &node; - } - } - - return nullptr; - }; - - if (from == to) return true; - - std::map visited; - // update the from and to strings to hashed equivs in loop from graph traits - for (auto& node : GraphTraits::DFS(graph)) { - auto hashed = hash(&node); - if (node.Name() == from) { - from = hashed; - } - if (node.Name() == to) { - to = hashed; - } - visited[hashed] = false; - } - - visited[from] = true; - - std::list queue; - queue.push_back(from); - - while (!queue.empty()) { - auto cur = find_node(graph, queue.front()); - queue.pop_front(); - if (cur == nullptr) { - return false; - } - - for (auto n : cur->outputs) { - auto hashed_name = hash(n); - if (hashed_name == to) { - return true; - } - - if (!visited[hashed_name]) { - visited[hashed_name] = true; - queue.push_back(hashed_name); - } - } - } - return false; -} - -/// -/// @brief Search through graph and counts provided operator occurences. -/// -/// @param[in] graph The graph we search through. -/// @param[in] op_type_count The vector of pairs {op_type_name, op count} -/// -/// @note After going through all graph nodes this function asserts -/// whether counted number for each requested op is as expected. -/// -void AssertOpsCount(const Graph& graph, - std::vector op_type_count) { - for (auto* node : graph.Nodes()) { - if (!node->IsOp()) { - continue; - } - - const std::string op_type_name = node->Op()->Type(); - auto op_it = - std::find_if(std::begin(op_type_count), std::end(op_type_count), - [op_type_name](const OpTypeCountPair& p) { - return op_type_name == p.first; - }); - if (op_it != std::end(op_type_count)) { - op_it->second--; - } - } - - for (const OpTypeCountPair& p : op_type_count) { - EXPECT_EQ(p.second, 0); - } -} - -/// -/// @brief Builds a program descriptor. -/// -/// @param[in] transient_vars The vector of transient variables names. -/// @param[in] persistent_vars The vector of persistent variables names. Those -/// will have persistable attribute set to true. -/// -/// @return The program descriptor object. -/// -ProgramDesc BuildProgramDesc(const std::vector& transient_vars, - const std::vector& persistent_vars) { - ProgramDesc prog; - - auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* { - auto var = prog.MutableBlock(0)->Var(var_name); - var->SetType(proto::VarType::LOD_TENSOR); - return var; - }; - - for (const auto& v : transient_vars) { - add_var_to_prog(v); - } - - for (const auto& v : persistent_vars) { - auto* var = add_var_to_prog(v); - var->SetPersistable(true); - } - - return prog; -} - -/// -/// @brief Execute pass on provided graph and perform checks. -/// -/// @param graph The graph we run pass on. -/// @param[in] from The name of a 'starting' node sequence in a -/// graph. This would be used to test for -/// correct node connections. -/// @param[in] to The name of a 'ending' node sequence in a -/// graph. This would be used to test for -/// correct node connections. -/// @param[in] removed_nodes_count The number of nodes we expect will be -/// removed/fused after pass execution. -/// @param[in] added_nodes_count The number of nodes we expect will be -/// added after pass execution. -/// -void RunPassAndAssert(Graph* graph, const std::string& from, - const std::string& to, int removed_nodes_count, - int added_nodes_count = 0) { - EXPECT_TRUE(TestIsReachable(*graph, from, to)); - int original_nodes_num = graph->Nodes().size(); - auto pass = PassRegistry::Instance().Get("batch_norm_act_fuse_pass"); - pass->Apply(graph); - int current_nodes_num = graph->Nodes().size(); - - EXPECT_TRUE(TestIsReachable(*graph, from, to)); - EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count, - current_nodes_num); -} - void SetBatchNormAttrs(OpDesc* bn_op, bool is_test = true, bool trainable_stats = true) { bn_op->SetAttr("is_test", is_test); bn_op->SetAttr("trainable_statistics", trainable_stats); bn_op->SetAttr("fuse_with_relu", false); } - -} // namespace +} // ------------------------------ Test cases ----------------------------------- @@ -244,47 +44,49 @@ void SetBatchNormAttrs(OpDesc* bn_op, bool is_test = true, // The test case name would have only attributes with true value in its name. TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) { - auto prog = BuildProgramDesc( + auto prog = test::BuildProgramDesc( {"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"}, {"scale", "bias"}); - auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"}, - {"Scale", "scale"}, - {"Bias", "bias"}, - {"Mean", "m"}, - {"Variance", "v"}}, - {{"Y", "bn_y"}, - {"MeanOut", "m_out"}, - {"VarianceOut", "var_out"}, - {"SavedMean", "sm"}, - {"SavedVariance", "sv"}}); + auto* bn_op = test::CreateOp(&prog, "batch_norm", {{"X", "x"}, + {"Scale", "scale"}, + {"Bias", "bias"}, + {"Mean", "m"}, + {"Variance", "v"}}, + {{"Y", "bn_y"}, + {"MeanOut", "m_out"}, + {"VarianceOut", "var_out"}, + {"SavedMean", "sm"}, + {"SavedVariance", "sv"}}); SetBatchNormAttrs(bn_op, true, true); - CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); + test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count), + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, FuseIsTest) { - auto prog = - BuildProgramDesc({"x", "m", "v", "bn_y", "act_y"}, {"scale", "bias"}); - auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"}, - {"Scale", "scale"}, - {"Bias", "bias"}, - {"Mean", "m"}, - {"Variance", "v"}}, - {{"Y", "bn_y"}}); + auto prog = test::BuildProgramDesc({"x", "m", "v", "bn_y", "act_y"}, + {"scale", "bias"}); + auto* bn_op = test::CreateOp(&prog, "batch_norm", {{"X", "x"}, + {"Scale", "scale"}, + {"Bias", "bias"}, + {"Mean", "m"}, + {"Variance", "v"}}, + {{"Y", "bn_y"}}); SetBatchNormAttrs(bn_op, true, false); - CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); + test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); constexpr int removed_nodes_count = 2; - RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); - AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 0}}); + EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count)); + EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 0}})); for (const auto* node : graph.Nodes()) { if (node->IsOp() && node->Op()->Type() == "batch_norm") { @@ -300,81 +102,90 @@ TEST(FuseBatchNormActOneDNNPass, FuseIsTest) { } TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) { - auto prog = BuildProgramDesc( + auto prog = test::BuildProgramDesc( {"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"}, {"scale", "bias"}); - auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"}, - {"Scale", "scale"}, - {"Bias", "bias"}, - {"Mean", "m"}, - {"Variance", "v"}}, - {{"Y", "bn_y"}, - {"MeanOut", "m_out"}, - {"VarianceOut", "var_out"}, - {"SavedMean", "sm"}, - {"SavedVariance", "sv"}}); + auto* bn_op = test::CreateOp(&prog, "batch_norm", {{"X", "x"}, + {"Scale", "scale"}, + {"Bias", "bias"}, + {"Mean", "m"}, + {"Variance", "v"}}, + {{"Y", "bn_y"}, + {"MeanOut", "m_out"}, + {"VarianceOut", "var_out"}, + {"SavedMean", "sm"}, + {"SavedVariance", "sv"}}); SetBatchNormAttrs(bn_op, false, true); - CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); + test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count), + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) { - auto prog = BuildProgramDesc( + auto prog = test::BuildProgramDesc( {"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"}, {"scale", "bias"}); - auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"}, - {"Scale", "scale"}, - {"Bias", "bias"}, - {"Mean", "m"}, - {"Variance", "v"}}, - {{"Y", "bn_y"}, - {"MeanOut", "m_out"}, - {"VarianceOut", "var_out"}, - {"SavedMean", "sm"}, - {"SavedVariance", "sv"}}); + auto* bn_op = test::CreateOp(&prog, "batch_norm", {{"X", "x"}, + {"Scale", "scale"}, + {"Bias", "bias"}, + {"Mean", "m"}, + {"Variance", "v"}}, + {{"Y", "bn_y"}, + {"MeanOut", "m_out"}, + {"VarianceOut", "var_out"}, + {"SavedMean", "sm"}, + {"SavedVariance", "sv"}}); SetBatchNormAttrs(bn_op, false, false); - CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); + test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count), + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), paddle::platform::EnforceNotMet); } TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) { - auto prog = BuildProgramDesc( + auto prog = test::BuildProgramDesc( {"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"}, {"scale", "bias"}); - auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"}, - {"Scale", "scale"}, - {"Bias", "bias"}, - {"Mean", "m"}, - {"Variance", "v"}}, - {{"Y", "bn_y"}, - {"MeanOut", "m_out"}, - {"VarianceOut", "var_out"}, - {"SavedMean", "sm"}, - {"SavedVariance", "sv"}}, - false); + auto* bn_op = test::CreateOp(&prog, "batch_norm", {{"X", "x"}, + {"Scale", "scale"}, + {"Bias", "bias"}, + {"Mean", "m"}, + {"Variance", "v"}}, + {{"Y", "bn_y"}, + {"MeanOut", "m_out"}, + {"VarianceOut", "var_out"}, + {"SavedMean", "sm"}, + {"SavedVariance", "sv"}}, + false); SetBatchNormAttrs(bn_op, false, false); - CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); + test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count), + EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x", + "act_y", removed_nodes_count), paddle::platform::EnforceNotMet); } +TEST(FuseBatchNormActOneDNNPass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("batch_norm_act_fuse_pass")); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index fd4910fc8e..35b40ec471 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -13,259 +13,151 @@ // limitations under the License. #include -#include -#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { namespace ir { -namespace { constexpr int nodes_removed = 3; constexpr int nodes_added = 1; -void SetOp(ProgramDesc* prog, const std::string& type, - const std::vector>& inputs, - const std::pair& output) { - auto op = prog->MutableBlock(0)->AppendOp(); - op->SetType(type); - op->SetAttr("use_mkldnn", true); - - for (const auto& input : inputs) { - op->SetInput(input.first, {input.second}); - } - - op->SetOutput(output.first, {output.second}); -} - -struct TestIsReachable { - using func = std::function; - - auto operator()(const std::unique_ptr& graph) -> func { - auto hash = [](const Node* node) -> std::string { - return node->Name() + std::to_string(node->id()); - }; - - auto find_node = [&](const std::unique_ptr& graph, - const std::string& name) -> Node* { - for (auto& node : GraphTraits::DFS(*graph)) { - if (name == hash(&node)) { - return &node; - } - } - - return nullptr; - }; - - // update the from and to strings to hashed equivs in loop from graph traits - return [&](std::string from, std::string to) -> bool { - if (from == to) return true; - - std::map visited; - - for (auto& node : GraphTraits::DFS(*graph)) { - auto hashed = hash(&node); - if (node.Name() == from) from = hashed; - if (node.Name() == to) to = hashed; - visited[hashed] = false; - } - - visited[from] = true; - - std::list queue; - queue.push_back(from); - - while (!queue.empty()) { - auto cur = find_node(graph, queue.front()); - queue.pop_front(); - if (cur == nullptr) return false; - - for (auto n : cur->outputs) { - auto hashed_name = hash(n); - if (hashed_name == to) return true; - - if (!visited[hashed_name]) { - visited[hashed_name] = true; - queue.push_back(hashed_name); - } - } - } - return false; - }; - } -}; - -void AssertOpsCount(const std::unique_ptr& graph, - int expected_conv_count, - int expected_elementwise_add_count = 0) { - int conv_count = 0; - int elementwise_add_count = 0; - - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "conv2d") { - ++conv_count; - } - if (node->IsOp() && node->Op()->Type() == "elementwise_add") { - ++elementwise_add_count; - } - } - EXPECT_EQ(conv_count, expected_conv_count); - EXPECT_EQ(elementwise_add_count, expected_elementwise_add_count); -} - -ProgramDesc BuildProgramDesc(const std::vector& transient_vars, - const std::vector& persistent_vars) { - ProgramDesc prog; - - auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* { - auto var = prog.MutableBlock(0)->Var(var_name); - var->SetType(proto::VarType::LOD_TENSOR); - - return var; - }; - - for (const auto& v : transient_vars) { - add_var_to_prog(v); - } - - for (const auto& v : persistent_vars) { - auto var = add_var_to_prog(v); - var->SetPersistable(true); - } - - return prog; -} - -void RunPassAndAssert(ProgramDesc* prog, const std::string& from, - const std::string& to, int expected_conv_num) { - std::unique_ptr graph(new ir::Graph(*prog)); - - TestIsReachable is_reachable; - EXPECT_TRUE(is_reachable(graph)(from, to)); - - auto pass = - PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); - int original_nodes_num = graph->Nodes().size(); - graph.reset(pass->Apply(graph.release())); - int current_nodes_num = graph->Nodes().size(); - - EXPECT_TRUE(is_reachable(graph)(from, to)); - - EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, - current_nodes_num); - - AssertOpsCount(graph, expected_conv_num); -} -} // namespace - TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) { - auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); - - SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); - SetOp(&prog, "conv2d", - {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, - {"Output", "c"}); - - SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"}); - SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); - - RunPassAndAssert(&prog, "a", "relu", 1); + auto prog = + test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); + + test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); + test::CreateOp(&prog, "conv2d", + {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {{"Output", "c"}}); + test::CreateOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, + {{"Out", "d"}}); + test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); + + Graph graph(prog); + + EXPECT_TRUE(test::RunPassAndAssert(&graph, + "conv_elementwise_add_mkldnn_fuse_pass", + "a", "relu", nodes_removed, nodes_added)); + EXPECT_TRUE( + test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}})); } TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionProjectionAsYWithElementwiseAddRelu) { - auto prog = BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, - {"bias", "weights", "bias2", "weights2"}); + auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, + {"bias", "weights", "bias2", "weights2"}); - SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); + test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); // right branch - SetOp(&prog, "conv2d", - {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, - {"Output", "c"}); + test::CreateOp(&prog, "conv2d", + {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {{"Output", "c"}}); // left branch - SetOp(&prog, "conv2d", - {{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}}, - {"Output", "f"}); + test::CreateOp(&prog, "conv2d", + {{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}}, + {{"Output", "f"}}); - SetOp(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}}, {"Out", "d"}); - SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); + test::CreateOp(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}}, + {{"Out", "d"}}); + test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); - RunPassAndAssert(&prog, "a", "relu", 2); + Graph graph(prog); + + EXPECT_TRUE(test::RunPassAndAssert(&graph, + "conv_elementwise_add_mkldnn_fuse_pass", + "a", "relu", nodes_removed, nodes_added)); + EXPECT_TRUE( + test::AssertOpsCount(graph, {{"conv2d", 2}, {"elementwise_add", 0}})); } TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddReluNoBias) { - auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); - - SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); - SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, - {"Output", "c"}); - SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"}); - SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); - - RunPassAndAssert(&prog, "a", "relu", 1); + auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); + + test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); + test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, + {{"Output", "c"}}); + test::CreateOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, + {{"Out", "d"}}); + test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); + + Graph graph(prog); + + EXPECT_TRUE(test::RunPassAndAssert(&graph, + "conv_elementwise_add_mkldnn_fuse_pass", + "a", "relu", nodes_removed, nodes_added)); + EXPECT_TRUE( + test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}})); } TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) { - auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); + auto prog = + test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); - SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); - SetOp(&prog, "conv2d", - {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, - {"Output", "c"}); + test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); + test::CreateOp(&prog, "conv2d", + {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {{"Output", "c"}}); - SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"}); - SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); + test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, + {{"Out", "d"}}); + test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); - RunPassAndAssert(&prog, "a", "relu", 1); + Graph graph(prog); + + EXPECT_TRUE(test::RunPassAndAssert(&graph, + "conv_elementwise_add_mkldnn_fuse_pass", + "a", "relu", nodes_removed, nodes_added)); + EXPECT_TRUE( + test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}})); } TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddReluNoBias) { - auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); - - SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); - SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, - {"Output", "c"}); - SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"}); - SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); - - RunPassAndAssert(&prog, "a", "relu", 1); + auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); + + test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); + test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, + {{"Output", "c"}}); + test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, + {{"Out", "d"}}); + test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); + + Graph graph(prog); + + EXPECT_TRUE(test::RunPassAndAssert(&graph, + "conv_elementwise_add_mkldnn_fuse_pass", + "a", "relu", nodes_removed, nodes_added)); + EXPECT_TRUE( + test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}})); } TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) { auto prog = - BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"}); - - SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); - SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, - {"Output", "c"}); - - SetOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}}, - {"Output", "e"}); - - SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}}, {"Out", "f"}); - SetOp(&prog, "relu", {{"X", "f"}}, {"Out", "g"}); + test::BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"}); - std::unique_ptr graph(new ir::Graph(prog)); + test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); + test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, + {{"Output", "c"}}); - TestIsReachable is_reachable; - EXPECT_TRUE(is_reachable(graph)("a", "g")); + test::CreateOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}}, + {{"Output", "e"}}); - auto pass = - PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); - int original_nodes_num = graph->Nodes().size(); - graph.reset(pass->Apply(graph.release())); - int current_nodes_num = graph->Nodes().size(); + test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}}, + {{"Out", "f"}}); + test::CreateOp(&prog, "relu", {{"X", "f"}}, {{"Out", "g"}}); - EXPECT_TRUE(is_reachable(graph)("a", "g")); - EXPECT_EQ(original_nodes_num, current_nodes_num); + Graph graph(prog); - AssertOpsCount(graph, 2, 1); + EXPECT_TRUE(test::RunPassAndAssert( + &graph, "conv_elementwise_add_mkldnn_fuse_pass", "a", "g", 0, 0)); + EXPECT_TRUE( + test::AssertOpsCount(graph, {{"conv2d", 2}, {"elementwise_add", 1}})); } TEST(ConvElementwiseAddMKLDNNFusePass, pass_op_version_check) { diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc index 634f44a258..e7d332864c 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc @@ -13,17 +13,11 @@ // limitations under the License. #include -#include -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/errors.h" @@ -31,238 +25,45 @@ namespace paddle { namespace framework { namespace ir { -// -------------------------- helper functions -------------------------------- -namespace { - -using InOutVarNamePair = std::pair; -using OpTypeCountPair = std::pair; - -/// -/// @brief Creates the specified operator and sets up its inputs/outputs. -/// -/// @param prog The program descriptor to which we add new op. -/// @param[in] op_type_name The operator type name. -/// @param[in] inputs The vector of input pairs: {input_name, variable -/// name} -/// @param[in] outputs The vector of output pairs {output_name, variable} -/// @param[in] use_mkldnn The flag deciding whether or not to set -/// 'use_mkldnn' attribute. -/// -/// @return Returns pointer to the created operator descriptor. -/// -OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name, - const std::vector& inputs, - const std::vector& outputs, - bool use_mkldnn = true) { - auto op = prog->MutableBlock(0)->AppendOp(); - op->SetType(op_type_name); - op->SetAttr("use_mkldnn", use_mkldnn); - - for (const auto& input : inputs) { - op->SetInput(input.first, {input.second}); - } - for (const auto& output : outputs) { - op->SetOutput(output.first, {output.second}); - } - - return op; -} - -/// -/// @brief Check whether node 'to' is reachable from node 'from' in graph. -/// -/// @param[in] graph The graph we're checking for reachability. -/// @param[in] from The 'from' node name. -/// @param[in] to The 'to' node name. -/// -/// @return True if there is connection between nodes 'from' and 'to'. -/// -bool TestIsReachable(const Graph& graph, std::string from, std::string to) { - auto hash = [](const Node* node) -> std::string { - return node->Name() + std::to_string(node->id()); - }; - - auto find_node = [&](const Graph& graph, const std::string& name) -> Node* { - for (auto& node : GraphTraits::DFS(graph)) { - if (name == hash(&node)) { - return &node; - } - } - - return nullptr; - }; - - if (from == to) return true; - - std::map visited; - // update the from and to strings to hashed equivs in loop from graph traits - for (auto& node : GraphTraits::DFS(graph)) { - auto hashed = hash(&node); - if (node.Name() == from) { - from = hashed; - } - if (node.Name() == to) { - to = hashed; - } - visited[hashed] = false; - } - - visited[from] = true; - - std::list queue; - queue.push_back(from); - - while (!queue.empty()) { - auto cur = find_node(graph, queue.front()); - queue.pop_front(); - if (cur == nullptr) { - return false; - } - - for (auto n : cur->outputs) { - auto hashed_name = hash(n); - if (hashed_name == to) { - return true; - } - - if (!visited[hashed_name]) { - visited[hashed_name] = true; - queue.push_back(hashed_name); - } - } - } - return false; -} - -/// -/// @brief Search through graph and counts provided operator occurences. -/// -/// @param[in] graph The graph we search through. -/// @param[in] op_type_count The vector of pairs {op_type_name, op count} -/// -/// @note After going through all graph nodes this function asserts -/// whether counted number for each requested op is as expected. -/// -void AssertOpsCount(const Graph& graph, - std::vector op_type_count) { - for (auto* node : graph.Nodes()) { - if (!node->IsOp()) { - continue; - } - - const std::string op_type_name = node->Op()->Type(); - auto op_it = - std::find_if(std::begin(op_type_count), std::end(op_type_count), - [op_type_name](const OpTypeCountPair& p) { - return op_type_name == p.first; - }); - if (op_it != std::end(op_type_count)) { - op_it->second--; - } - } - - for (const OpTypeCountPair& p : op_type_count) { - EXPECT_EQ(p.second, 0); - } -} - -/// -/// @brief Builds a program descriptor. -/// -/// @param[in] transient_vars The vector of transient variables names. -/// @param[in] persistent_vars The vector of persistent variables names. Those -/// will have persistable attribute set to true. -/// -/// @return The program descriptor object. -/// -ProgramDesc BuildProgramDesc(const std::vector& transient_vars, - const std::vector& persistent_vars) { - ProgramDesc prog; - - auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* { - auto var = prog.MutableBlock(0)->Var(var_name); - var->SetType(proto::VarType::LOD_TENSOR); - return var; - }; - - for (const auto& v : transient_vars) { - add_var_to_prog(v); - } - - for (const auto& v : persistent_vars) { - auto* var = add_var_to_prog(v); - var->SetPersistable(true); - } - - return prog; -} - -/// -/// @brief Execute pass on provided graph and perform checks. -/// -/// @param graph The graph we run pass on. -/// @param[in] from The name of a 'starting' node sequence in a -/// graph. This would be used to test for -/// correct node connections. -/// @param[in] to The name of a 'ending' node sequence in a -/// graph. This would be used to test for -/// correct node connections. -/// @param[in] removed_nodes_count The number of nodes we expect will be -/// removed/fused after pass execution. -/// @param[in] added_nodes_count The number of nodes we expect will be -/// added after pass execution. -/// -void RunPassAndAssert(Graph* graph, const std::string& from, - const std::string& to, int removed_nodes_count, - int added_nodes_count = 0) { - EXPECT_TRUE(TestIsReachable(*graph, from, to)); - int original_nodes_num = graph->Nodes().size(); - auto pass = PassRegistry::Instance().Get("fc_act_mkldnn_fuse_pass"); - pass->Apply(graph); - int current_nodes_num = graph->Nodes().size(); - - EXPECT_TRUE(TestIsReachable(*graph, from, to)); - EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count, - current_nodes_num); -} - -} // namespace - // ------------------------------ Test cases ----------------------------------- TEST(FuseFCActOneDNNPass, ThrowUseMkldnn) { - auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); - CreateOp(&prog, "fc", - { - {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, - }, - {{"Out", "fc_y"}}, false); - CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + auto prog = + test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + test::CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}, false); + test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); // No fusion in this attribute configuration constexpr int removed_nodes_count = 0; - EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count), + EXPECT_THROW(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x", + "act_y", removed_nodes_count), paddle::platform::EnforceNotMet); } TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) { - auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); - CreateOp(&prog, "fc", - { - {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, - }, - {{"Out", "fc_y"}}); - auto* act_op = - CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + auto prog = + test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + test::CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + auto* act_op = test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, + {{"Out", "act_y"}}, false); act_op->SetAttr("approximate", true); Graph graph(prog); constexpr int removed_nodes_count = 2; - RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); - AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}); + EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x", + "act_y", removed_nodes_count)); + EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}})); for (const auto* node : graph.Nodes()) { if (node->IsOp() && node->Op()->Type() == "fc") { @@ -272,27 +73,29 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) { ASSERT_TRUE(op->HasAttr("activation_type")); auto act_type = BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); - EXPECT_TRUE(act_type.compare("gelu_tanh") == 0); + EXPECT_EQ(act_type.compare("gelu_tanh"), 0); } } } TEST(FuseFCActOneDNNPass, FuseWithGeluErf) { - auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); - CreateOp(&prog, "fc", - { - {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, - }, - {{"Out", "fc_y"}}); - auto* act_op = - CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + auto prog = + test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + test::CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + auto* act_op = test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, + {{"Out", "act_y"}}, false); act_op->SetAttr("approximate", false); Graph graph(prog); constexpr int removed_nodes_count = 2; - RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); - AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}); + EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x", + "act_y", removed_nodes_count)); + EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}})); for (const auto* node : graph.Nodes()) { if (node->IsOp() && node->Op()->Type() == "fc") { @@ -302,25 +105,27 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) { ASSERT_TRUE(op->HasAttr("activation_type")); auto act_type = BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); - EXPECT_TRUE(act_type.compare("gelu_erf") == 0); + EXPECT_EQ(act_type.compare("gelu_erf"), 0); } } } TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) { - auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); - CreateOp(&prog, "fc", - { - {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, - }, - {{"Out", "fc_y"}}); - CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + auto prog = + test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + test::CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); constexpr int removed_nodes_count = 2; - RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); - AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}); + EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x", + "act_y", removed_nodes_count)); + EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}})); for (const auto* node : graph.Nodes()) { if (node->IsOp() && node->Op()->Type() == "fc") { @@ -330,25 +135,27 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) { ASSERT_TRUE(op->HasAttr("activation_type")); auto act_type = BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); - EXPECT_TRUE(act_type.compare("gelu") == 0); + EXPECT_EQ(act_type.compare("gelu"), 0); } } } TEST(FuseFCActOneDNNPass, FuseWithTanh) { - auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); - CreateOp(&prog, "fc", - { - {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, - }, - {{"Out", "fc_y"}}); - CreateOp(&prog, "tanh", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + auto prog = + test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + test::CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + test::CreateOp(&prog, "tanh", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); constexpr int removed_nodes_count = 2; - RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); - AssertOpsCount(graph, {{"fc", 1}, {"tanh", 0}}); + EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x", + "act_y", removed_nodes_count)); + EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"tanh", 0}})); for (const auto* node : graph.Nodes()) { if (node->IsOp() && node->Op()->Type() == "fc") { @@ -358,25 +165,28 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) { ASSERT_TRUE(op->HasAttr("activation_type")); auto act_type = BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); - EXPECT_TRUE(act_type.compare("tanh") == 0); + EXPECT_EQ(act_type.compare("tanh"), 0); } } } TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { - auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); - CreateOp(&prog, "fc", - { - {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, - }, - {{"Out", "fc_y"}}); - CreateOp(&prog, "sigmoid", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + auto prog = + test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"}); + test::CreateOp(&prog, "fc", + { + {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"}, + }, + {{"Out", "fc_y"}}); + test::CreateOp(&prog, "sigmoid", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, + false); Graph graph(prog); constexpr int removed_nodes_count = 2; - RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count); - AssertOpsCount(graph, {{"fc", 1}, {"sigmoid", 0}}); + EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x", + "act_y", removed_nodes_count)); + EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"sigmoid", 0}})); for (const auto* node : graph.Nodes()) { if (node->IsOp() && node->Op()->Type() == "fc") { @@ -386,11 +196,17 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { ASSERT_TRUE(op->HasAttr("activation_type")); auto act_type = BOOST_GET_CONST(std::string, op->GetAttr("activation_type")); - EXPECT_TRUE(act_type.compare("sigmoid") == 0); + EXPECT_EQ(act_type.compare("sigmoid"), 0); } } } +TEST(FuseFCActOneDNNPass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("fc_act_mkldnn_fuse_pass")); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/pass_test_util.cc b/paddle/fluid/framework/ir/mkldnn/pass_test_util.cc new file mode 100644 index 0000000000..a6c8a6662c --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/pass_test_util.cc @@ -0,0 +1,174 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace test { + +OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name, + const std::vector& inputs, + const std::vector& outputs, + bool use_mkldnn) { + auto op = prog->MutableBlock(0)->AppendOp(); + op->SetType(op_type_name); + op->SetAttr("use_mkldnn", use_mkldnn); + + for (const auto& input : inputs) { + op->SetInput(input.first, {input.second}); + } + for (const auto& output : outputs) { + op->SetOutput(output.first, {output.second}); + } + + return op; +} + +bool TestIsReachable(const Graph& graph, std::string from, std::string to) { + auto hash = [](const Node* node) -> std::string { + return node->Name() + std::to_string(node->id()); + }; + + auto find_node = [&](const Graph& graph, const std::string& name) -> Node* { + for (auto& node : GraphTraits::DFS(graph)) { + if (name == hash(&node)) { + return &node; + } + } + + return nullptr; + }; + + if (from == to) return true; + + std::map visited; + // update the from and to strings to hashed equivs in loop from graph traits + for (auto& node : GraphTraits::DFS(graph)) { + auto hashed = hash(&node); + if (node.Name() == from) { + from = hashed; + } + if (node.Name() == to) { + to = hashed; + } + visited[hashed] = false; + } + + visited[from] = true; + + std::list queue; + queue.push_back(from); + + while (!queue.empty()) { + auto cur = find_node(graph, queue.front()); + queue.pop_front(); + if (cur == nullptr) { + return false; + } + + for (auto n : cur->outputs) { + auto hashed_name = hash(n); + if (hashed_name == to) { + return true; + } + + if (!visited[hashed_name]) { + visited[hashed_name] = true; + queue.push_back(hashed_name); + } + } + } + return false; +} + +bool AssertOpsCount(const Graph& graph, + std::vector op_type_count) { + for (auto* node : graph.Nodes()) { + if (!node->IsOp()) { + continue; + } + + const std::string op_type_name = node->Op()->Type(); + auto op_it = + std::find_if(std::begin(op_type_count), std::end(op_type_count), + [op_type_name](const OpTypeCountPair& p) { + return op_type_name == p.first; + }); + if (op_it != std::end(op_type_count)) { + op_it->second--; + } + } + + bool result{true}; + + for (const OpTypeCountPair& p : op_type_count) { + result = result && (p.second == 0); + } + return result; +} + +ProgramDesc BuildProgramDesc(const std::vector& transient_vars, + const std::vector& persistent_vars) { + ProgramDesc prog; + + auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* { + auto var = prog.MutableBlock(0)->Var(var_name); + var->SetType(proto::VarType::LOD_TENSOR); + return var; + }; + + for (const auto& v : transient_vars) { + add_var_to_prog(v); + } + + for (const auto& v : persistent_vars) { + auto* var = add_var_to_prog(v); + var->SetPersistable(true); + } + + return prog; +} + +bool RunPassAndAssert(Graph* graph, const std::string& pass_name, + const std::string& from, const std::string& to, + int removed_nodes_count, int added_nodes_count) { + if (!TestIsReachable(*graph, from, to)) return false; + + int original_nodes_num = graph->Nodes().size(); + auto pass = PassRegistry::Instance().Get(pass_name); + pass->Apply(graph); + int current_nodes_num = graph->Nodes().size(); + + if (!TestIsReachable(*graph, from, to)) return false; + + int expected_nodes_num = + original_nodes_num - removed_nodes_count + added_nodes_count; + return expected_nodes_num == current_nodes_num; +} + +} // namespace test +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/pass_test_util.h b/paddle/fluid/framework/ir/mkldnn/pass_test_util.h new file mode 100644 index 0000000000..08ee50e0f1 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/pass_test_util.h @@ -0,0 +1,119 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { +namespace ir { + +// -------------------------- helper functions -------------------------------- +namespace test { + +/// The pair describing correlation between {input/output name, variable name}. +using InOutVarNamePair = std::pair; +/// The pair describing number of occurrences of given op type. +using OpTypeCountPair = std::pair; + +/// +/// @brief Creates the specified operator and sets up its inputs/outputs. +/// +/// @param prog The program descriptor to which we add new op. +/// @param[in] op_type_name The operator type name. +/// @param[in] inputs The vector of input pairs: {input_name, variable +/// name} +/// @param[in] outputs The vector of output pairs {output_name, variable} +/// @param[in] use_mkldnn The flag deciding whether or not to set +/// 'use_mkldnn' attribute. +/// +/// @return Returns pointer to the created operator descriptor. +/// +OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name, + const std::vector& inputs, + const std::vector& outputs, + bool use_mkldnn = true); + +/// +/// @brief Check whether node 'to' is reachable from node 'from' in graph. +/// +/// @param[in] graph The graph we're checking for reachability. +/// @param[in] from The 'from' node name. +/// @param[in] to The 'to' node name. +/// +/// @return True if there is connection between nodes 'from' and 'to'. +/// +bool TestIsReachable(const Graph& graph, std::string from, std::string to); + +/// +/// @brief Search through graph and counts provided operator occurrences. +/// +/// @param[in] graph The graph we search through. +/// @param[in] op_type_count The vector of pairs {op_type_name, op count} +/// +/// @note After going through all graph nodes this function asserts +/// whether counted number for each requested op is as expected. +/// +/// @return Returns true if occurrences of all ops is as expected. +/// +bool AssertOpsCount(const Graph& graph, + std::vector op_type_count); + +/// +/// @brief Builds a program descriptor. +/// +/// @param[in] transient_vars The vector of transient variables names. +/// @param[in] persistent_vars The vector of persistent variables names. Those +/// will have persistable attribute set to true. +/// +/// @return The program descriptor object. +/// +ProgramDesc BuildProgramDesc(const std::vector& transient_vars, + const std::vector& persistent_vars); + +/// +/// @brief Execute pass on provided graph and perform checks. +/// +/// @note Check whether the balance of removed and added nodes after pass +/// is as expected. +/// +/// @param graph The graph we run pass on. +/// @param[in] from The name of a 'starting' node sequence in a +/// graph. This would be used to test for +/// correct node connections. +/// @param[in] to The name of a 'ending' node sequence in a +/// graph. This would be used to test for +/// correct node connections. +/// @param[in] removed_nodes_count The number of nodes we expect will be +/// removed/fused after pass execution. +/// @param[in] added_nodes_count The number of nodes we expect will be added +/// after pass execution. +/// +/// @return Return true if all checks passed, otherwise false. +/// +bool RunPassAndAssert(Graph* graph, const std::string& pass_name, + const std::string& from, const std::string& to, + int removed_nodes_count, int added_nodes_count = 0); + +} // namespace test +} // namespace ir +} // namespace framework +} // namespace paddle -- GitLab