diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc index 8d0035ae98b093979eb8bbcc0a8d6ae5356d951f..5376fc163e259e5049955052baf02fd614aa511e 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -14,14 +14,15 @@ #include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h" #include -#include +#include +#include +#include #include "paddle/fluid/framework/ir/graph_traits.h" namespace paddle { namespace framework { namespace ir { -namespace { // The function keeps the graph consistent by replacing // a node 'from' in the set of inputs nodes @@ -51,99 +52,382 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) { } } } -} // namespace -using graph_ptr = std::unique_ptr; -graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); +bool IsReachable(ir::Graph* graph, Node* from, Node* to) { + auto find_node = [](ir::Graph* graph, const Node* node) -> Node* { + for (auto n : graph->Nodes()) { + if (n == node) { + return n; + } + } - GraphPatternDetector gpd; - auto pattern = gpd.mutable_pattern(); + return nullptr; + }; - patterns::Conv conv_pattern{pattern, name_scope_}; - auto conv_output = conv_pattern(); + if (from == to) { + return true; + } - patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; - elementwise_add_pattern(conv_output); + std::map visited; - conv_output->AsIntermediate(); + for (auto& node : GraphTraits::DFS(*graph)) { + visited[&node] = false; + } - auto conv_op_has_bias = [](const Node& conv_op) -> std::pair { - auto bias_input_names = conv_op.Op()->Inputs(); - auto bias_it = bias_input_names.find("Bias"); - - if (bias_it != std::end(bias_input_names)) { - bool has_bias = !bias_it->second.empty(); - - if (has_bias) { - auto conv_bias_names = bias_it->second; - auto conv_bias_names_it = - std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs), - [&conv_bias_names](Node* n) -> bool { - return n->Name() == conv_bias_names[0]; - }); - return std::make_pair(has_bias, *conv_bias_names_it); - } - } + visited[from] = true; - return std::make_pair(false, nullptr); - }; + std::list queue; + queue.push_back(from); - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); + while (!queue.empty()) { + auto cur = find_node(graph, queue.front()); + queue.pop_front(); - if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; + if (!cur) return false; - OpDesc op_desc; - op_desc.SetType("conv2d"); + for (auto n : cur->outputs) { + if (n == to) { + return true; + } - op_desc.SetInput("Input", {conv_input->Name()}); - op_desc.SetInput("Filter", {conv_filter->Name()}); - op_desc.SetInput("ResidualData", {elementwise_add_x->Name()}); - op_desc.SetOutput("Output", {conv_output->Name()}); + if (!visited[n]) { + visited[n] = true; + queue.push_back(n); + } + } + } + return false; +} - bool has_bias; - Node* conv_bias; +boost::optional HasBias(const Node& op, const std::string& bias_name) { + auto bias_input_names = op.Op()->Inputs(); + auto bias_it = bias_input_names.find(bias_name); - std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op); + if (bias_it != std::end(bias_input_names)) { + bool has_bias = !bias_it->second.empty(); if (has_bias) { - op_desc.SetInput("Bias", {conv_bias->Name()}); + auto bias_names = bias_it->second; + auto bias_names_it = + std::find_if(std::begin(op.inputs), std::end(op.inputs), + [&bias_names](Node* n) -> bool { + return n->Name() == bias_names[0]; + }); + return *bias_names_it; } + } - for (const auto& attr : conv_op->Op()->GetAttrMap()) { - op_desc.SetAttr(attr.first, attr.second); - } + return boost::none; +} - op_desc.SetAttr("fuse_residual_connection", true); +ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle( + const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func, + const ResidualConnectionMKLDNNFusePass::IdentityConvFunc& + get_node_from_conv_op, + const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc& + get_node_from_elementwise_add_op) + : fusion_stats{std::make_shared(0)}, + can_fuse_func{can_fuse_func}, + get_node_from_conv_op{get_node_from_conv_op}, + get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {} + +void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( + const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + Node* conv_op; + Node* conv_input; + Node* conv_filter; + Node* conv_output; + + Node* elementwise_add_op; + Node* elementwise_add_identity; + Node* elementwise_add_out; + + std::tie(conv_op, conv_input, conv_filter, conv_output) = + get_node_from_conv_op(subgraph); + std::tie(elementwise_add_op, elementwise_add_identity, elementwise_add_out) = + get_node_from_elementwise_add_op(subgraph); + + if (!can_fuse_func(conv_op, elementwise_add_op)) return; + + if (!IsReachable(graph, elementwise_add_identity, conv_output)) return; + + OpDesc op_desc; + op_desc.SetType("conv2d"); + + op_desc.SetInput("Input", {conv_input->Name()}); + op_desc.SetInput("Filter", {conv_filter->Name()}); + op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()}); + op_desc.SetOutput("Output", {conv_output->Name()}); + + auto conv_bias = HasBias(*conv_op, "Bias"); + + if (conv_bias) { + op_desc.SetInput("Bias", {(*conv_bias)->Name()}); + } - auto fused_conv_op = g->CreateOpNode(&op_desc); + for (const auto& attr : conv_op->Op()->GetAttrMap()) { + op_desc.SetAttr(attr.first, attr.second); + } - IR_NODE_LINK_TO(conv_input, fused_conv_op); - IR_NODE_LINK_TO(conv_filter, fused_conv_op); - IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op); - IR_NODE_LINK_TO(fused_conv_op, conv_output); + op_desc.SetAttr("fuse_residual_connection", true); - if (has_bias) { - IR_NODE_LINK_TO(conv_bias, fused_conv_op); - } + auto fused_conv_op = graph->CreateOpNode(&op_desc); - CorrectGraphEdges(g, elementwise_add_out, conv_output); - GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op}); - }; + IR_NODE_LINK_TO(conv_input, fused_conv_op); + IR_NODE_LINK_TO(conv_filter, fused_conv_op); + IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op); + IR_NODE_LINK_TO(fused_conv_op, conv_output); - gpd(graph.get(), handler); + if (conv_bias) { + IR_NODE_LINK_TO((*conv_bias), fused_conv_op); + } + CorrectGraphEdges(graph, elementwise_add_out, conv_output); + GraphSafeRemoveNodes(graph, + {elementwise_add_out, conv_op, elementwise_add_op}); + (*fusion_stats)++; +} + +ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle( + const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func, + const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc& + get_node_from_conv_x_op, + const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc& + get_node_from_conv_y_op, + const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc& + get_node_from_elementwise_add_op) + : fusion_stats{std::make_shared(0)}, + can_fuse_func{can_fuse_func}, + get_node_from_conv_x_op{get_node_from_conv_x_op}, + get_node_from_conv_y_op{get_node_from_conv_y_op}, + get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {} + +void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( + const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + Node* conv_x_op; + Node* conv_x_input; + Node* conv_x_filter; + Node* conv_x_output; + + Node* conv_y_op; + Node* conv_y_input; + Node* conv_y_filter; + Node* conv_y_output; + + Node* elementwise_add_op; + Node* elementwise_add_out; + + std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) = + get_node_from_conv_x_op(subgraph); + std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) = + get_node_from_conv_y_op(subgraph); + std::tie(elementwise_add_op, elementwise_add_out) = + get_node_from_elementwise_add_op(subgraph); + + if (!can_fuse_func(conv_x_op, elementwise_add_op)) return; + if (!can_fuse_func(conv_y_op, elementwise_add_op)) return; + + Node* projection_node; + Node* residual_conv_op; + Node* residual_conv_input; + Node* residual_conv_filter; + Node* residual_conv_output; + + if (IsReachable(graph, conv_x_input, conv_y_output)) { + projection_node = conv_x_output; + residual_conv_op = conv_y_op; + residual_conv_input = conv_y_input; + residual_conv_filter = conv_y_filter; + residual_conv_output = conv_y_output; + } else if (IsReachable(graph, conv_y_input, conv_x_output)) { + projection_node = conv_y_output; + residual_conv_op = conv_x_op; + residual_conv_input = conv_x_input; + residual_conv_filter = conv_x_filter; + residual_conv_output = conv_x_output; + } else { + return; + } + + OpDesc op_desc; + op_desc.SetType("conv2d"); + + op_desc.SetInput("Input", {residual_conv_input->Name()}); + op_desc.SetInput("Filter", {residual_conv_filter->Name()}); + op_desc.SetInput("ResidualData", {projection_node->Name()}); + op_desc.SetOutput("Output", {residual_conv_output->Name()}); + + auto residual_conv_bias = HasBias(*residual_conv_op, "Bias"); + + if (residual_conv_bias) { + op_desc.SetInput("Bias", {(*residual_conv_bias)->Name()}); + } + + for (const auto& attr : residual_conv_op->Op()->GetAttrMap()) { + op_desc.SetAttr(attr.first, attr.second); + } + + op_desc.SetAttr("fuse_residual_connection", true); + + auto fused_conv_op = graph->CreateOpNode(&op_desc); + + IR_NODE_LINK_TO(residual_conv_input, fused_conv_op); + IR_NODE_LINK_TO(residual_conv_filter, fused_conv_op); + IR_NODE_LINK_TO(projection_node, fused_conv_op); + IR_NODE_LINK_TO(fused_conv_op, residual_conv_output); + + if (residual_conv_bias) { + IR_NODE_LINK_TO((*residual_conv_bias), fused_conv_op); + } + + CorrectGraphEdges(graph, elementwise_add_out, residual_conv_output); + GraphSafeRemoveNodes( + graph, {elementwise_add_out, residual_conv_op, elementwise_add_op}); + (*fusion_stats)++; +} + +std::tuple +ResidualConnectionMKLDNNFusePass::GetNodesFromConv( + const patterns::Conv& conv_pattern, + const GraphPatternDetector::subgraph_t& subgraph) const { + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); + + return std::make_tuple(conv_op, conv_input, conv_filter, conv_output); +} + +GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( + const std::string& name_scope, + const GraphWithStats& graph_with_stats) const { + ir::Graph* graph; + int stats; + + std::tie(graph, stats) = graph_with_stats; + + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + + patterns::Conv conv_pattern{pattern, name_scope}; + auto conv_output = conv_pattern(); + + patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; + elementwise_add_pattern( + conv_output, + pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); + conv_output->AsIntermediate(); + + auto get_node_from_elementwise_add = [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) + -> std::tuple { + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + return std::make_tuple(elementwise_add_op, elementwise_add_y, + elementwise_add_out); + }; + + return ExecuteHandleOnGraph( + &gpd, graph_with_stats, + [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { + return GetNodesFromConv(conv_pattern, subgraph); + }, + get_node_from_elementwise_add); +} + +GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( + const std::string& name_scope, + const GraphWithStats& graph_with_stats) const { + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + + patterns::Conv conv_pattern{pattern, name_scope}; + auto conv_output = conv_pattern(); + + patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; + elementwise_add_pattern( + pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()), + conv_output); + conv_output->AsIntermediate(); + + auto get_node_from_elementwise_add = [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) + -> std::tuple { + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + return std::make_tuple(elementwise_add_op, elementwise_add_x, + elementwise_add_out); + }; + + return ExecuteHandleOnGraph( + &gpd, graph_with_stats, + [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { + return GetNodesFromConv(conv_pattern, subgraph); + }, + get_node_from_elementwise_add); +} + +GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( + const std::string& name_scope, + const GraphWithStats& graph_with_stats) const { + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + + patterns::Conv conv_x_pattern{pattern, name_scope}; + auto conv_x_output = conv_x_pattern(); + + patterns::Conv conv_y_pattern{pattern, name_scope}; + auto conv_y_output = conv_y_pattern(); + + patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; + elementwise_add_pattern(conv_x_output, conv_y_output); + conv_x_output->AsIntermediate(); + conv_y_output->AsIntermediate(); + + auto get_node_from_elementwise_add = [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) + -> std::tuple { + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + return std::make_tuple(elementwise_add_op, elementwise_add_out); + }; + + return ExecuteHandleOnGraph( + &gpd, graph_with_stats, + [this, + &conv_x_pattern](const GraphPatternDetector::subgraph_t& subgraph) { + return GetNodesFromConv(conv_x_pattern, subgraph); + }, + [this, + &conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) { + return GetNodesFromConv(conv_y_pattern, subgraph); + }, + get_node_from_elementwise_add); +} + +graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + auto fused_graph_with_stats = FuseConvAsY( + name_scope_, + FuseConvAsX( + name_scope_, + FuseProjectionConv(name_scope_, std::make_pair(graph.get(), 0)))); + + std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl; + AddStatis(fused_graph_with_stats.second); return graph; } } // namespace ir @@ -151,4 +435,4 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { } // namespace paddle REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, - paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass); + paddle::framework::ir::ResidualConnectionMKLDNNFusePass); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h index f4a899f1adb5e993895a40a8cfb846a67b41bb22..6629dae425ae85446fe2f6c8c172ca53f5ae8bea 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h @@ -15,24 +15,119 @@ #pragma once #include +#include +#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include + namespace paddle { namespace framework { namespace ir { -class ConvElementwiseAddMKLDNNFusePass : public FusePassBase { +using graph_ptr = std::unique_ptr; +using GraphWithStats = std::pair; + +void CorrectGraphEdges(Graph* graph, Node* from, Node* to); +bool IsReachable(ir::Graph* graph, Node* from, Node* to); +boost::optional HasBias(const Node& op, const std::string& bias_name); + +class ResidualConnectionMKLDNNFusePass : public FusePassBase { + private: + GraphWithStats FuseConvAsX(const std::string& name_scope, + const GraphWithStats& graph_with_stats) const; + GraphWithStats FuseConvAsY(const std::string& name_scope, + const GraphWithStats& graph_with_stats) const; + GraphWithStats FuseProjectionConv( + const std::string& name_scope, + const GraphWithStats& graph_with_stats) const; + + template + using GetNodeFunc = + std::function; + using IdentityConvFunc = GetNodeFunc>; + using IdentityElementwiseAddFunc = + GetNodeFunc>; + + using ProjectionConvFunc = IdentityConvFunc; + using ProjectionElementwiseAddFunc = GetNodeFunc>; + + using CanFuseFunc = std::function; + + std::tuple GetNodesFromConv( + const patterns::Conv& conv_pattern, + const GraphPatternDetector::subgraph_t& subgraph) const; + + std::tuple GetNodesFromProjectionConv( + const patterns::Conv& conv_pattern, + const GraphPatternDetector::subgraph_t& subgraph) const; + + template + GraphWithStats ExecuteHandleOnGraph(GraphPatternDetector* gpd, + const GraphWithStats& graph_with_stats, + OpFuncs&&... op_funcs) const { + ir::Graph* graph; + int stats; + + std::tie(graph, stats) = graph_with_stats; + + auto can_fuse = [this](Node* op1, Node* op2) -> bool { + return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; + }; + + auto fuse_handle = HandleType{can_fuse, std::forward(op_funcs)...}; + + (*gpd)(graph, fuse_handle); + + return std::make_pair(graph, stats + fuse_handle.get_stats()); + } + + struct IdentityFuseHandle { + IdentityFuseHandle( + const CanFuseFunc& can_fuse_func, + const IdentityConvFunc& get_node_from_conv_op, + const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op); + + void operator()(const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph); + int get_stats() const { return *fusion_stats; } + + private: + std::shared_ptr fusion_stats; + CanFuseFunc can_fuse_func; + IdentityConvFunc get_node_from_conv_op; + IdentityElementwiseAddFunc get_node_from_elementwise_add_op; + }; + + struct ProjectionFuseHandle { + ProjectionFuseHandle( + const CanFuseFunc& can_fuse_func, + const ProjectionConvFunc& get_node_from_conv_x_op, + const ProjectionConvFunc& get_node_from_conv_y_op, + const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op); + + void operator()(const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph); + int get_stats() const { return *fusion_stats; } + + private: + std::shared_ptr fusion_stats; + CanFuseFunc can_fuse_func; + ProjectionConvFunc get_node_from_conv_x_op; + ProjectionConvFunc get_node_from_conv_y_op; + ProjectionElementwiseAddFunc get_node_from_elementwise_add_op; + }; + public: - virtual ~ConvElementwiseAddMKLDNNFusePass() {} + virtual ~ResidualConnectionMKLDNNFusePass() {} protected: - std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + std::unique_ptr ApplyImpl(graph_ptr graph) const; - const std::string name_scope_{"residual_connections_fuse_pass"}; + const std::string name_scope_{"residual_connection_fuse_pass"}; }; - } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index 348a3dfc5da78e860742595a60a0b7a8b2d92243..61ba097fd8cb55e25bda1947ea97d53308c55bd3 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -40,7 +40,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, op->SetOutput(output.first, {output.second}); } -struct IsReachable { +struct TestIsReachable { using func = std::function; auto operator()(const std::unique_ptr& graph) -> func { @@ -89,7 +89,9 @@ struct IsReachable { } }; -void AssertOpsCount(const std::unique_ptr& graph) { +void AssertOpsCount(const std::unique_ptr& graph, + int expected_conv_count, + int expected_elementwise_add_count = 0) { int conv_count = 0; int elementwise_add_count = 0; @@ -101,8 +103,8 @@ void AssertOpsCount(const std::unique_ptr& graph) { ++elementwise_add_count; } } - EXPECT_EQ(conv_count, 1); - EXPECT_EQ(elementwise_add_count, 0); + EXPECT_EQ(conv_count, expected_conv_count); + EXPECT_EQ(elementwise_add_count, expected_elementwise_add_count); } ProgramDesc BuildProgramDesc(const std::vector& transient_vars, @@ -127,22 +129,13 @@ ProgramDesc BuildProgramDesc(const std::vector& transient_vars, return prog; } -} // namespace - -TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { - auto prog = - BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"}); - - SetOp(&prog, "conv2d", - {{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}}, - {"Output", "b"}); - SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"}); - SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); - std::unique_ptr graph(new ir::Graph(prog)); +void RunPassAndAssert(ProgramDesc* prog, const std::string& from, + const std::string& to, int expected_conv_num) { + std::unique_ptr graph(new ir::Graph(*prog)); - IsReachable is_reachable; - EXPECT_TRUE(is_reachable(graph)("a", "relu")); + TestIsReachable is_reachable; + EXPECT_TRUE(is_reachable(graph)(from, to)); auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); @@ -150,82 +143,87 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { graph = pass->Apply(std::move(graph)); int current_nodes_num = graph->Nodes().size(); - EXPECT_TRUE(is_reachable(graph)("a", "relu")); + EXPECT_TRUE(is_reachable(graph)(from, to)); EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num); - AssertOpsCount(graph); + AssertOpsCount(graph, expected_conv_num); } +} // namespace -TEST(ConvElementwiseAddMKLDNNFusePass, - ConvolutionWithElementwiseAddReluNoBias) { - auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); - SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}}, - {"Output", "b"}); - SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"}); - SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); - - std::unique_ptr graph(new ir::Graph(prog)); +TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) { + auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); - IsReachable is_reachable; + SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); + SetOp(&prog, "conv2d", + {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {"Output", "c"}); - EXPECT_TRUE(is_reachable(graph)("a", "relu")); + SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"}); + SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); - auto pass = - PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); - int original_nodes_num = graph->Nodes().size(); - graph = pass->Apply(std::move(graph)); - int current_nodes_num = graph->Nodes().size(); + RunPassAndAssert(&prog, "a", "relu", 1); +} - EXPECT_TRUE(is_reachable(graph)("a", "relu")); +TEST(ConvElementwiseAddMKLDNNFusePass, + ConvolutionAsYWithElementwiseAddReluNoBias) { + auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); - EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, - current_nodes_num); + SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); + SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, + {"Output", "c"}); + SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"}); + SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); - AssertOpsCount(graph); + RunPassAndAssert(&prog, "a", "relu", 1); } -TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { - auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"}); +TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) { + auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); + + SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); SetOp(&prog, "conv2d", - {{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}}, - {"Output", "b"}); - SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"}); + {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {"Output", "c"}); - std::unique_ptr graph(new ir::Graph(prog)); + SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"}); + SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); - IsReachable is_reachable; - EXPECT_TRUE(is_reachable(graph)("a", "d")); + RunPassAndAssert(&prog, "a", "relu", 1); +} - auto pass = - PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); - int original_nodes_num = graph->Nodes().size(); - graph = pass->Apply(std::move(graph)); - int current_nodes_num = graph->Nodes().size(); +TEST(ConvElementwiseAddMKLDNNFusePass, + ConvolutionAsXWithElementwiseAddReluNoBias) { + auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); - EXPECT_FALSE(is_reachable(graph)("a", "d")); + SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); + SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, + {"Output", "c"}); + SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"}); + SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); - EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, - current_nodes_num); - AssertOpsCount(graph); + RunPassAndAssert(&prog, "a", "relu", 1); } -TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { +TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) { auto prog = - BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"}); + BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"}); + SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); - SetOp(&prog, "conv2d", - {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, {"Output", "c"}); - SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"}); - SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"}); - std::unique_ptr graph(new ir::Graph(prog)); + SetOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}}, + {"Output", "e"}); - IsReachable is_reachable; + SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}}, {"Out", "f"}); + SetOp(&prog, "relu", {{"X", "f"}}, {"Out", "g"}); - EXPECT_TRUE(is_reachable(graph)("a", "f")); + std::unique_ptr graph(new ir::Graph(prog)); + + TestIsReachable is_reachable; + EXPECT_TRUE(is_reachable(graph)("a", "g")); auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); @@ -233,11 +231,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { graph = pass->Apply(std::move(graph)); int current_nodes_num = graph->Nodes().size(); - EXPECT_TRUE(is_reachable(graph)("a", "f")); + EXPECT_TRUE(is_reachable(graph)("a", "g")); + EXPECT_EQ(original_nodes_num, current_nodes_num); - EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, - current_nodes_num); - AssertOpsCount(graph); + AssertOpsCount(graph, 2, 1); } } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index b534a5509279ef7bfc5fc92ec726224e6c5ed16f..f1f971656ae6ab6bbf66c4a75dd7cf68b5848b7b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1084,16 +1084,12 @@ PDNode *patterns::Conv::operator()() { return output_var; } -PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var) { +PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr()) ->assert_is_op("elementwise_add"); - x_var->assert_is_op_input("elementwise_add", "X"); - - auto y_var = pattern->NewNode(elementwise_add_x_repr()) - ->AsInput() - ->assert_is_op_input("elementwise_add", "Y"); - + x_var->AsInput()->assert_is_op_input("elementwise_add", "X"); + y_var->AsInput()->assert_is_op_input("elementwise_add", "Y"); auto out_var = pattern->NewNode(elementwise_add_out_repr()) ->AsOutput() ->assert_is_op_output("elementwise_add", "Out"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 1c5155df7867f95fb403d51bf633084a6c400f12..c12b9503fd817757ec8d1e988be3e449fc63c6ff 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -664,7 +664,7 @@ struct ElementwiseAdd : public PatternBase { ElementwiseAdd(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "elementwise_add") {} - PDNode* operator()(PDNode* x_var); + PDNode* operator()(PDNode* x_var, PDNode* y_var); PATTERN_DECL_NODE(elementwise_add_op); PATTERN_DECL_NODE(elementwise_add_x);