From 6a2bc9a275f578fb728df17225afd012a5da5eb7 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Mon, 25 Feb 2019 15:44:41 +0100 Subject: [PATCH] Add Conv Residual Connection UT for Projection test=develop --- ...elementwise_add_mkldnn_fuse_pass_tester.cc | 50 +++++++++++++++---- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index 9ef5c298b8..433d89d8d3 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -44,10 +44,14 @@ struct TestIsReachable { using func = std::function; auto operator()(const std::unique_ptr& graph) -> func { - auto find_node = [](const std::unique_ptr& graph, - const std::string& name) -> Node* { + auto hash = [](const Node* node) -> std::string { + return node->Name() + std::to_string(node->id()); + }; + + auto find_node = [&](const std::unique_ptr& graph, + const std::string& name) -> Node* { for (auto& node : GraphTraits::DFS(*graph)) { - if (name == node.Name()) { + if (name == hash(&node)) { return &node; } } @@ -55,13 +59,17 @@ struct TestIsReachable { return nullptr; }; - return [&](std::string from, const std::string to) -> bool { + // update the from and to strings to hashed equivs in loop from graph traits + return [&](std::string from, std::string to) -> bool { if (from == to) return true; std::map visited; for (auto& node : GraphTraits::DFS(*graph)) { - visited[node.Name()] = false; + auto hashed = hash(&node); + if (node.Name() == from) from = hashed; + if (node.Name() == to) to = hashed; + visited[hashed] = false; } visited[from] = true; @@ -72,15 +80,15 @@ struct TestIsReachable { while (!queue.empty()) { auto cur = find_node(graph, queue.front()); queue.pop_front(); - if (cur == nullptr) return false; for (auto n : cur->outputs) { - if (n->Name() == to) return true; + auto hashed_name = hash(n); + if (hashed_name == to) return true; - if (!visited[n->Name()]) { - visited[n->Name()] = true; - queue.push_back(n->Name()); + if (!visited[hashed_name]) { + visited[hashed_name] = true; + queue.push_back(hashed_name); } } } @@ -166,6 +174,28 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) { RunPassAndAssert(&prog, "a", "relu", 1); } +TEST(ConvElementwiseAddMKLDNNFusePass, + ConvolutionProjectionAsYWithElementwiseAddRelu) { + auto prog = BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, + {"bias", "weights", "bias2", "weights2"}); + + SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); + // right branch + SetOp(&prog, "conv2d", + {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {"Output", "c"}); + + // left branch + SetOp(&prog, "conv2d", + {{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}}, + {"Output", "f"}); + + SetOp(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}}, {"Out", "d"}); + SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); + + RunPassAndAssert(&prog, "a", "relu", 2); +} + TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddReluNoBias) { auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); -- GitLab