diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc index e470960ee178cbaa3ef47d77f9f600a4e2e471db..5a6d20e8478a79b8df8e2028b1d1c004d15f9513 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -99,10 +99,9 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) { return false; } -std::pair ResidualConnectionMKLDNNFusePass::HasBias( - const Node& op) const { +std::pair HasBias(const Node& op, const std::string& bias_name) { auto bias_input_names = op.Op()->Inputs(); - auto bias_it = bias_input_names.find("Bias"); + auto bias_it = bias_input_names.find(bias_name); if (bias_it != std::end(bias_input_names)) { bool has_bias = !bias_it->second.empty(); @@ -121,6 +120,74 @@ std::pair ResidualConnectionMKLDNNFusePass::HasBias( return std::make_pair(false, nullptr); } +ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler( + const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op, + const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc& + get_node_from_elementwise_add_op, + const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func) + : get_node_from_conv_op{get_node_from_conv_op}, + get_node_from_elementwise_add_op{get_node_from_elementwise_add_op}, + can_fuse_func{can_fuse_func} {} + +void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()( + const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + Node* conv_op; + Node* conv_input; + Node* conv_filter; + Node* conv_output; + + Node* elementwise_add_op; + Node* elementwise_add_identity; + Node* elementwise_add_out; + + std::tie(conv_op, conv_input, conv_filter, conv_output) = + get_node_from_conv_op(subgraph); + std::tie(elementwise_add_op, elementwise_add_identity, elementwise_add_out) = + get_node_from_elementwise_add_op(subgraph); + + if (!can_fuse_func(conv_op, elementwise_add_op)) return; + + if (!IsReachable(graph, elementwise_add_identity, conv_output)) return; + + OpDesc op_desc; + op_desc.SetType("conv2d"); + + op_desc.SetInput("Input", {conv_input->Name()}); + op_desc.SetInput("Filter", {conv_filter->Name()}); + op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()}); + op_desc.SetOutput("Output", {conv_output->Name()}); + + bool has_bias; + Node* conv_bias; + + std::tie(has_bias, conv_bias) = HasBias(*conv_op, "Bias"); + + if (has_bias) { + op_desc.SetInput("Bias", {conv_bias->Name()}); + } + + for (const auto& attr : conv_op->Op()->GetAttrMap()) { + op_desc.SetAttr(attr.first, attr.second); + } + + op_desc.SetAttr("fuse_residual_connection", true); + + auto fused_conv_op = graph->CreateOpNode(&op_desc); + + IR_NODE_LINK_TO(conv_input, fused_conv_op); + IR_NODE_LINK_TO(conv_filter, fused_conv_op); + IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op); + IR_NODE_LINK_TO(fused_conv_op, conv_output); + + if (has_bias) { + IR_NODE_LINK_TO(conv_bias, fused_conv_op); + } + + CorrectGraphEdges(graph, elementwise_add_out, conv_output); + GraphSafeRemoveNodes(graph, + {elementwise_add_out, conv_op, elementwise_add_op}); +} + graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX( const std::string& name_scope_, graph_ptr graph) const { GraphPatternDetector gpd; @@ -135,8 +202,8 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX( pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); conv_output->AsIntermediate(); - auto get_node_from_conv = [](const patterns::Conv& conv_pattern, - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_conv = + [&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); @@ -146,8 +213,7 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX( return std::make_tuple(conv_op, conv_input, conv_filter, conv_output); }; - auto get_node_from_elementwise_add = []( - const patterns::ElementwiseAdd& elementwise_add_pattern, + auto get_node_from_elementwise_add = [&elementwise_add_pattern]( const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, @@ -161,10 +227,14 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX( elementwise_add_out); }; - auto handler = - GenerateFuseHandler(conv_pattern, elementwise_add_pattern, - get_node_from_conv, get_node_from_elementwise_add); - gpd(graph.get(), handler); + auto can_fuse = [this](Node* op1, Node* op2) -> bool { + return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; + }; + + auto fuse_handler = + FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse}; + + gpd(graph.get(), fuse_handler); return graph; } @@ -183,8 +253,8 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY( conv_output); conv_output->AsIntermediate(); - auto get_node_from_conv = [](const patterns::Conv& conv_pattern, - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_conv = + [&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); @@ -194,8 +264,7 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY( return std::make_tuple(conv_op, conv_input, conv_filter, conv_output); }; - auto get_node_from_elementwise_add = []( - const patterns::ElementwiseAdd& elementwise_add_pattern, + auto get_node_from_elementwise_add = [&elementwise_add_pattern]( const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, @@ -209,10 +278,14 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY( elementwise_add_out); }; - auto handler = - GenerateFuseHandler(conv_pattern, elementwise_add_pattern, - get_node_from_conv, get_node_from_elementwise_add); - gpd(graph.get(), handler); + auto can_fuse = [this](Node* op1, Node* op2) -> bool { + return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; + }; + + auto fuse_handler = + FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse}; + + gpd(graph.get(), fuse_handler); return graph; } diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h index 7dfff3c2d3b28d41a1b37a29e9c1427f7b0b1fa9..b614b5c5230ccf25cf3bfc118addf25001cea1f5 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" @@ -28,24 +29,32 @@ using graph_ptr = std::unique_ptr; void CorrectGraphEdges(Graph* graph, Node* from, Node* to); bool IsReachable(ir::Graph* graph, Node* from, Node* to); - -using handler_func = std::function; +std::pair HasBias(const Node& op, const std::string& bias_name); class ResidualConnectionMKLDNNFusePass : public FusePassBase { private: graph_ptr FuseConvAsX(const std::string& name_scope_, graph_ptr graph) const; graph_ptr FuseConvAsY(const std::string& name_scope_, graph_ptr graph) const; - std::pair HasBias(const Node& op) const; + template + using GetNodeFunc = + std::function; + using ConvFunc = GetNodeFunc>; + using ElementwiseAddFunc = GetNodeFunc>; + using CanFuseFunc = std::function; + + struct FuseHandler { + FuseHandler(const ConvFunc& get_node_from_conv_op, + const ElementwiseAddFunc& get_node_from_elementwise_add_op, + const CanFuseFunc& can_fuse_func); + + ConvFunc get_node_from_conv_op; + ElementwiseAddFunc get_node_from_elementwise_add_op; + CanFuseFunc can_fuse_func; - template - HANDLER_FUNC GenerateFuseHandler( - const patterns::Conv& conv_pattern, - const patterns::ElementwiseAdd& elementwise_add_pattern, - CONV_FUNC get_node_from_conv_op, - ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const; + void operator()(const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph); + }; public: virtual ~ResidualConnectionMKLDNNFusePass() {} @@ -55,74 +64,6 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { const std::string name_scope_{"residual_connection_fuse_pass"}; }; - -template -HANDLER_FUNC ResidualConnectionMKLDNNFusePass::GenerateFuseHandler( - const patterns::Conv& conv_pattern, - const patterns::ElementwiseAdd& elementwise_add_pattern, - CONV_FUNC get_node_from_conv_op, - ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const { - return [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { - Node* conv_op; - Node* conv_input; - Node* conv_filter; - Node* conv_output; - - Node* elementwise_add_op; - Node* elementwise_add_identity; - Node* elementwise_add_out; - - std::tie(conv_op, conv_input, conv_filter, conv_output) = - get_node_from_conv_op(conv_pattern, subgraph); - std::tie(elementwise_add_op, elementwise_add_identity, - elementwise_add_out) = - get_node_from_elementwise_add_op(elementwise_add_pattern, subgraph); - - if (this->FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) - return; - - if (!IsReachable(graph, elementwise_add_identity, conv_output)) return; - - OpDesc op_desc; - op_desc.SetType("conv2d"); - - op_desc.SetInput("Input", {conv_input->Name()}); - op_desc.SetInput("Filter", {conv_filter->Name()}); - op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()}); - op_desc.SetOutput("Output", {conv_output->Name()}); - - bool has_bias; - Node* conv_bias; - - std::tie(has_bias, conv_bias) = this->HasBias(*conv_op); - - if (has_bias) { - op_desc.SetInput("Bias", {conv_bias->Name()}); - } - - for (const auto& attr : conv_op->Op()->GetAttrMap()) { - op_desc.SetAttr(attr.first, attr.second); - } - - op_desc.SetAttr("fuse_residual_connection", true); - - auto fused_conv_op = graph->CreateOpNode(&op_desc); - - IR_NODE_LINK_TO(conv_input, fused_conv_op); - IR_NODE_LINK_TO(conv_filter, fused_conv_op); - IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op); - IR_NODE_LINK_TO(fused_conv_op, conv_output); - - if (has_bias) { - IR_NODE_LINK_TO(conv_bias, fused_conv_op); - } - - CorrectGraphEdges(graph, elementwise_add_out, conv_output); - GraphSafeRemoveNodes(graph, - {elementwise_add_out, conv_op, elementwise_add_op}); - }; -} } // namespace ir } // namespace framework } // namespace paddle