diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc index eae55e0e266f365f8ade279a28376771653b1201..ac15e1b3d5f0874a21e63a5daecf51b38f822f5a 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -1,4 +1,20 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h" +#include + #include "paddle/fluid/framework/ir/graph_traits.h" namespace paddle { @@ -8,15 +24,14 @@ namespace patterns { struct Pattern : public PatternBase { Pattern(PDPattern* pattern, const std::string& name_scope) - : PatternBase{pattern, name_scope, ""} - { } - - private: + : PatternBase{pattern, name_scope, ""} {} + + private: std::string name_scope() { return name_scope_; } - std::string repr() { return repr_; } + std::string repr() { return repr_; } size_t id() { return id_; } PDPattern* node_pattern() { return pattern; } - + public: std::string node_name(std::string op_name) { return PDNodeName(name_scope(), repr(), id(), op_name); @@ -37,22 +52,18 @@ struct Conv { std::string filter_name() { return "Filter"; } std::string output_name() { return "Output"; } - std::function operator()(std::shared_ptr pattern) { + std::function operator()(std::shared_ptr pattern) { return [&]() -> PDNode* { - auto conv_op = pattern->new_node(op_name()) - ->assert_is_op("conv2d"); + auto conv_op = pattern->new_node(op_name())->assert_is_op("conv2d"); auto input_var = pattern->new_node(input_name()) - ->assert_is_op_input(op_name(), - input_name()); - + ->assert_is_op_input(op_name(), input_name()); + auto filter_var = pattern->new_node(filter_name()) - ->assert_is_op_input(op_name(), - filter_name()); + ->assert_is_op_input(op_name(), filter_name()); auto output_var = pattern->new_node(output_name()) - ->assert_is_op_output(op_name(), - output_name()); + ->assert_is_op_output(op_name(), output_name()); conv_op->LinksFrom({input_var, filter_var}); conv_op->LinksTo({output_var}); @@ -68,22 +79,19 @@ struct ElementwiseAdd { std::string y_name() { return "Y"; } std::string out_name() { return "Out"; } - std::function operator()(std::shared_ptr pattern) { + std::function operator()(std::shared_ptr pattern) { return [&](PDNode* conv_output) -> PDNode* { - auto elementwise_add_op = pattern->new_node(op_name()) - ->assert_is_op("elementwise_add"); + auto elementwise_add_op = + pattern->new_node(op_name())->assert_is_op("elementwise_add"); + + auto x_var = + pattern->new_node(x_name())->assert_is_op_input(op_name(), x_name()); - auto x_var = pattern->new_node(x_name()) - ->assert_is_op_input(op_name(), - x_name()); - - conv_output->assert_is_op_input(op_name(), - y_name()); + conv_output->assert_is_op_input(op_name(), y_name()); auto out_var = pattern->new_node(out_name()) - ->AsOutput() - ->assert_is_op_output(op_name(), - out_name()); + ->AsOutput() + ->assert_is_op_output(op_name(), out_name()); elementwise_add_op->LinksFrom({x_var, conv_output}); elementwise_add_op->LinksTo({out_var}); @@ -94,13 +102,13 @@ struct ElementwiseAdd { }; Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph, - std::shared_ptr pattern, - const std::string& op_name) { + std::shared_ptr pattern, + const std::string& op_name) { PADDLE_ENFORCE(subgraph.count(pattern->retrieve_node(op_name)), "Node not found for PDNode %s", pattern->node_name(op_name)); Node* var = subgraph.at(pattern->retrieve_node(op_name)); PADDLE_ENFORCE(var, "node %s not exists in the sub-graph"); - + return var; } @@ -109,10 +117,9 @@ void LinkNodes(Node* from, Node* to) { to->inputs.push_back(from); } -template +template void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) { - if (s == e) - return; + if (s == e) return; auto it = std::find_if(s, e, f); @@ -126,8 +133,7 @@ void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) { void CorrectGraphEdges(Graph* graph, Node* from, Node* to) { for (auto& node : GraphTraits::DFS(*graph)) { - auto same = std::find_if(std::begin(node.inputs), - std::end(node.inputs), + auto same = std::find_if(std::begin(node.inputs), std::end(node.inputs), [from](Node* n) { return n == from; }); if (same != std::end(node.inputs)) { @@ -137,17 +143,19 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) { using input_type = VariableNameMap::value_type; - ReplaceAllOccurances(std::begin(inputs), std::end(inputs), - [from](const input_type& i) -> bool { - auto params = i.second; - auto pi = std::find_if(std::begin(params), std::end(params), - std::bind(std::equal_to(), - from->Name(), std::placeholders::_1)); - return pi != std::end(params); - }, - [to, &node](const input_type& i) { - node.Op()->SetInput(i.first, {to->Name()}); - }); + ReplaceAllOccurances( + std::begin(inputs), std::end(inputs), + [from](const input_type& i) -> bool { + auto params = i.second; + auto pi = + std::find_if(std::begin(params), std::end(params), + std::bind(std::equal_to(), + from->Name(), std::placeholders::_1)); + return pi != std::end(params); + }, + [to, &node](const input_type& i) { + node.Op()->SetInput(i.first, {to->Name()}); + }); } } } @@ -169,7 +177,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { conv_output->AsIntermediate(); - auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* conv_output, Node* elementwise_add_x) { + auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, + Node* conv_output, Node* elementwise_add_x) { OpDesc op_desc; op_desc.SetType("conv2d"); @@ -189,22 +198,23 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { patterns::LinkNodes(fused_conv_op, conv_output); }; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, - conv_pattern.op_name()); + conv_pattern.op_name()); auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, - conv_pattern.input_name()); - auto conv_filter = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, - conv_pattern.filter_name()); - auto conv_output = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, - conv_pattern.output_name()); - - auto elementwise_add_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, - elementwise_add_pattern.op_name()); - auto elementwise_add_x = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, - elementwise_add_pattern.x_name()); - auto elementwise_add_out = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, - elementwise_add_pattern.out_name()); + conv_pattern.input_name()); + auto conv_filter = patterns::GetNodeFromSubgraph( + subgraph, pattern_ptr, conv_pattern.filter_name()); + auto conv_output = patterns::GetNodeFromSubgraph( + subgraph, pattern_ptr, conv_pattern.output_name()); + + auto elementwise_add_op = patterns::GetNodeFromSubgraph( + subgraph, pattern_ptr, elementwise_add_pattern.op_name()); + auto elementwise_add_x = patterns::GetNodeFromSubgraph( + subgraph, pattern_ptr, elementwise_add_pattern.x_name()); + auto elementwise_add_out = patterns::GetNodeFromSubgraph( + subgraph, pattern_ptr, elementwise_add_pattern.out_name()); fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x); patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output); @@ -219,4 +229,5 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { } // namespace framework } // namespace paddle -REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass); +REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, + paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index 17de916c63a4f1398bc842a17f3b3f1721ba9ef7..58b1097a2598532d61013dd8fb785e7eb34c0819 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -1,8 +1,22 @@ -#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h" -#include "paddle/fluid/framework/ir/graph_traits.h" +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include #include +#include + +#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/graph_traits.h" namespace paddle { namespace framework { @@ -33,10 +47,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, } struct IsReachable { - using func = std::function; + using func = std::function; auto operator()(const std::unique_ptr& graph) -> func { - auto find_node = [](const std::unique_ptr& graph, const std::string& name) -> Node* { + auto find_node = [](const std::unique_ptr& graph, + const std::string& name) -> Node* { for (auto& node : GraphTraits::DFS(*graph)) { if (name == node.Name()) { return &node; @@ -47,8 +62,7 @@ struct IsReachable { }; return [&](std::string from, const std::string to) -> bool { - if (from == to) - return true; + if (from == to) return true; std::map visited; @@ -61,16 +75,14 @@ struct IsReachable { std::list queue; queue.push_back(from); - while(!queue.empty()) { + while (!queue.empty()) { auto cur = find_node(graph, queue.front()); queue.pop_front(); - if (cur == nullptr) - return false; + if (cur == nullptr) return false; for (auto n : cur->outputs) { - if (n->Name() == to) - return true; + if (n->Name() == to) return true; if (!visited[n->Name()]) { visited[n->Name()] = true; @@ -87,14 +99,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { auto build_program_desc = [&]() -> ProgramDesc { ProgramDesc prog; for (auto& v : - std::vector({"a", "b", "weights", "c", "d", "e"})) { + std::vector({"a", "b", "weights", "c", "d", "e"})) { auto* var = prog.MutableBlock(0)->Var(v); var->SetType(proto::VarType::LOD_TENSOR); if (v == "weights") { var->SetPersistable(true); } } - + SetOp(&prog, "conv2d", {"a", "weights"}, {"b"}); SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"}); SetOp(&prog, "relu", {"d"}, {"e"}); @@ -109,14 +121,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { EXPECT_TRUE(is_reachable(graph)("a", "relu")); - auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); + auto pass = + PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); int original_nodes_num = graph->Nodes().size(); graph = pass->Apply(std::move(graph)); int current_nodes_num = graph->Nodes().size(); EXPECT_TRUE(is_reachable(graph)("a", "relu")); - EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num); + EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, + current_nodes_num); // Assert conv_relu op in newly generated graph int conv_count = 0; int elementwise_add_count = 0; @@ -136,15 +150,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { auto build_program_desc = [&]() -> ProgramDesc { ProgramDesc prog; - for (auto& v : - std::vector({"a", "b", "weights"})) { + for (auto& v : std::vector({"a", "b", "weights"})) { auto* var = prog.MutableBlock(0)->Var(v); var->SetType(proto::VarType::LOD_TENSOR); if (v == "weights" || v == "bias") { var->SetPersistable(true); } } - + SetOp(&prog, "conv2d", {"a", "weights"}, {"b"}); SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"}); @@ -157,14 +170,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { IsReachable is_reachable; EXPECT_TRUE(is_reachable(graph)("a", "d")); - auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); + auto pass = + PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); int original_nodes_num = graph->Nodes().size(); graph = pass->Apply(std::move(graph)); int current_nodes_num = graph->Nodes().size(); EXPECT_FALSE(is_reachable(graph)("a", "d")); - - EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num); + + EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, + current_nodes_num); // Assert conv_relu op in newly generated graph int conv_count = 0; int elementwise_add_count = 0; @@ -185,14 +200,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { auto build_program_desc = [&]() -> ProgramDesc { ProgramDesc prog; for (auto& v : - std::vector({"a", "b", "weights", "c", "d", "e", "f"})) { + std::vector({"a", "b", "weights", "c", "d", "e", "f"})) { auto* var = prog.MutableBlock(0)->Var(v); var->SetType(proto::VarType::LOD_TENSOR); if (v.find("weights")) { var->SetPersistable(true); } } - + SetOp(&prog, "sigmoid", {"a"}, {"b"}); SetOp(&prog, "conv2d", {"b", "weights"}, {"c"}); SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"}); @@ -208,14 +223,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { EXPECT_TRUE(is_reachable(graph)("a", "f")); - auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); + auto pass = + PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); int original_nodes_num = graph->Nodes().size(); graph = pass->Apply(std::move(graph)); int current_nodes_num = graph->Nodes().size(); EXPECT_TRUE(is_reachable(graph)("a", "f")); - EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num); + EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, + current_nodes_num); // Assert conv_relu op in newly generated graph int conv_count = 0; int elementwise_add_count = 0;