diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index c537d05738529dcb885e86cbcabf4405fd75270b..2403e60df3918394e99c3284b2a417e336fc3bae 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/string/pretty_log.h" namespace paddle { namespace framework { @@ -135,157 +136,9 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { .End(); } -ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle( - const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func, - const ResidualConnectionMKLDNNFusePass::IdentityConvFunc& - get_node_from_conv_op, - const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc& - get_node_from_elementwise_add_op, - const ResidualConnectionMKLDNNFusePass* pass) - : fusion_stats{std::make_shared(0)}, - can_fuse_func{can_fuse_func}, - get_node_from_conv_op{get_node_from_conv_op}, - get_node_from_elementwise_add_op{get_node_from_elementwise_add_op}, - pass_{pass} {} - -void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( - const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { - Node* conv_op; - Node* conv_input; - Node* conv_filter; - Node* conv_output; - - Node* elementwise_add_op; - Node* elementwise_add_identity; - Node* elementwise_add_out; - - std::tie(conv_op, conv_input, conv_filter, conv_output) = - get_node_from_conv_op(subgraph); - std::tie(elementwise_add_op, elementwise_add_identity, elementwise_add_out) = - get_node_from_elementwise_add_op(subgraph); - - if (!can_fuse_func(conv_op, elementwise_add_op)) return; - - if (!IsReachable(graph, elementwise_add_identity, conv_output)) return; - - if (HasFusedActivation(conv_op)) return; - - if (!pass_->IsCompat(subgraph, graph)) { - LOG(WARNING) - << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; - return; - } - - conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); - conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); - conv_op->Op()->SetAttr("fuse_residual_connection", true); - - GraphSafeRemoveNodes(graph, {conv_output, elementwise_add_op}); - - IR_NODE_LINK_TO(elementwise_add_identity, conv_op); - IR_NODE_LINK_TO(conv_op, elementwise_add_out); - - (*fusion_stats)++; -} - -ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle( - const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func, - const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc& - get_node_from_conv_x_op, - const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc& - get_node_from_conv_y_op, - const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc& - get_node_from_elementwise_add_op, - const ResidualConnectionMKLDNNFusePass* pass) - : fusion_stats{std::make_shared(0)}, - can_fuse_func{can_fuse_func}, - get_node_from_conv_x_op{get_node_from_conv_x_op}, - get_node_from_conv_y_op{get_node_from_conv_y_op}, - get_node_from_elementwise_add_op{get_node_from_elementwise_add_op}, - pass_{pass} {} - -void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( - const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { - Node* conv_x_op; - Node* conv_x_input; - Node* conv_x_filter; - Node* conv_x_output; - - Node* conv_y_op; - Node* conv_y_input; - Node* conv_y_filter; - Node* conv_y_output; - - Node* elementwise_add_op; - Node* elementwise_add_out; - - if (!pass_->IsCompat(subgraph, graph)) { - LOG(WARNING) - << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; - return; - } - - std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) = - get_node_from_conv_x_op(subgraph); - std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) = - get_node_from_conv_y_op(subgraph); - std::tie(elementwise_add_op, elementwise_add_out) = - get_node_from_elementwise_add_op(subgraph); - - if (!can_fuse_func(conv_x_op, elementwise_add_op)) return; - if (!can_fuse_func(conv_y_op, elementwise_add_op)) return; - - Node* projection_node; - Node* residual_conv_op; - Node* residual_conv_output; - - if (IsReachable(graph, conv_x_input, conv_y_output)) { - projection_node = conv_x_output; - residual_conv_op = conv_y_op; - residual_conv_output = conv_y_output; - } else if (IsReachable(graph, conv_y_input, conv_x_output)) { - projection_node = conv_y_output; - residual_conv_op = conv_x_op; - residual_conv_output = conv_x_output; - } else { - return; - } - - if (HasFusedActivation(residual_conv_op)) return; - - residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); - residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); - - residual_conv_op->Op()->SetAttr("fuse_residual_connection", true); - - GraphSafeRemoveNodes(graph, {residual_conv_output, elementwise_add_op}); - - IR_NODE_LINK_TO(projection_node, residual_conv_op); - IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out); - - (*fusion_stats)++; -} - -std::tuple -ResidualConnectionMKLDNNFusePass::GetNodesFromConv( - const patterns::Conv& conv_pattern, - const GraphPatternDetector::subgraph_t& subgraph) const { - GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); - - return std::make_tuple(conv_op, conv_input, conv_filter, conv_output); -} - GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( const std::string& name_scope, const GraphWithStats& graph_with_stats) const { - ir::Graph* graph; - int stats; - - std::tie(graph, stats) = graph_with_stats; - GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); @@ -298,26 +151,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); conv_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) - -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); - - return std::make_tuple(elementwise_add_op, elementwise_add_y, - elementwise_add_out); - }; - - return ExecuteHandleOnGraph( - &gpd, graph_with_stats, - [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { - return GetNodesFromConv(conv_pattern, subgraph); - }, - get_node_from_elementwise_add, this); + int found_conv_as_x_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_identity, elementwise_add_y, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; + + if (!IsReachable(g, elementwise_add_identity, conv_output)) return; + + if (HasFusedActivation(conv_op)) return; + + if (!IsCompat(subgraph, g)) { + LOG(WARNING) + << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; + return; + } + + conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); + conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); + conv_op->Op()->SetAttr("fuse_residual_connection", true); + + GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op}); + + IR_NODE_LINK_TO(elementwise_add_identity, conv_op); + IR_NODE_LINK_TO(conv_op, elementwise_add_out); + + found_conv_as_x_count++; + }; + + gpd(graph_with_stats.first, handler); + if (!Has("disable_logs") || !Get("disable_logs")) { + std::stringstream msg_ss; + msg_ss << "--- Fused " << found_conv_as_x_count + << " conv (as x) + elementwise_add patterns"; + paddle::string::PrettyLogDetail(msg_ss.str().c_str()); + } + + return std::make_pair(graph_with_stats.first, + found_conv_as_x_count + graph_with_stats.second); } GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( @@ -335,26 +218,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( conv_output); conv_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) - -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); - - return std::make_tuple(elementwise_add_op, elementwise_add_x, - elementwise_add_out); - }; - - return ExecuteHandleOnGraph( - &gpd, graph_with_stats, - [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { - return GetNodesFromConv(conv_pattern, subgraph); - }, - get_node_from_elementwise_add, this); + int found_conv_as_y_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; + + if (!IsReachable(g, elementwise_add_x, conv_output)) return; + + if (HasFusedActivation(conv_op)) return; + + if (!IsCompat(subgraph, g)) { + LOG(WARNING) + << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; + return; + } + + conv_op->Op()->SetInput("ResidualData", {elementwise_add_x->Name()}); + conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); + conv_op->Op()->SetAttr("fuse_residual_connection", true); + + GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op}); + + IR_NODE_LINK_TO(elementwise_add_x, conv_op); + IR_NODE_LINK_TO(conv_op, elementwise_add_out); + + found_conv_as_y_count++; + }; + + gpd(graph_with_stats.first, handler); + if (!Has("disable_logs") || !Get("disable_logs")) { + std::stringstream msg_ss; + msg_ss << "--- Fused " << found_conv_as_y_count + << " conv (as y) + elementwise_add patterns"; + paddle::string::PrettyLogDetail(msg_ss.str().c_str()); + } + + return std::make_pair(graph_with_stats.first, + found_conv_as_y_count + graph_with_stats.second); } GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( @@ -374,39 +287,84 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( conv_x_output->AsIntermediate(); conv_y_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) - -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); - - return std::make_tuple(elementwise_add_op, elementwise_add_out); - }; - - return ExecuteHandleOnGraph( - &gpd, graph_with_stats, - [this, - &conv_x_pattern](const GraphPatternDetector::subgraph_t& subgraph) { - return GetNodesFromConv(conv_x_pattern, subgraph); - }, - [this, - &conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) { - return GetNodesFromConv(conv_y_pattern, subgraph); - }, - get_node_from_elementwise_add, this); + int found_projection_conv_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(conv_x_op, conv_op, conv_x_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_x_input, conv_input, conv_x_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_x_filter, conv_filter, conv_x_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_x_output, conv_output, conv_x_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(conv_y_op, conv_op, conv_y_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_y_input, conv_input, conv_y_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_y_filter, conv_filter, conv_y_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_y_output, conv_output, conv_y_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + if (!IsCompat(subgraph, g)) { + LOG(WARNING) + << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; + return; + } + + if (FindFuseOption(*conv_x_op, *elementwise_add_op) != FUSE_MKLDNN) return; + if (FindFuseOption(*conv_y_op, *elementwise_add_op) != FUSE_MKLDNN) return; + + Node* projection_node; + Node* residual_conv_op; + Node* residual_conv_output; + if (IsReachable(g, conv_x_input, conv_y_output)) { + projection_node = conv_x_output; + residual_conv_op = conv_y_op; + residual_conv_output = conv_y_output; + } else if (IsReachable(g, conv_y_input, conv_x_output)) { + projection_node = conv_y_output; + residual_conv_op = conv_x_op; + residual_conv_output = conv_x_output; + } else { + return; + } + + if (HasFusedActivation(residual_conv_op)) return; + + residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); + residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); + + residual_conv_op->Op()->SetAttr("fuse_residual_connection", true); + + GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_add_op}); + + IR_NODE_LINK_TO(projection_node, residual_conv_op); + IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out); + + found_projection_conv_count++; + }; + + gpd(graph_with_stats.first, handler); + if (!Has("disable_logs") || !Get("disable_logs")) { + std::stringstream msg_ss; + msg_ss << "--- Fused " << found_projection_conv_count + << " projection conv (as y) + elementwise_add patterns"; + paddle::string::PrettyLogDetail(msg_ss.str().c_str()); + } + + return std::make_pair(graph_with_stats.first, + found_projection_conv_count + graph_with_stats.second); } -void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { +void ResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init(name_scope_, graph); - auto fused_graph_with_stats = FuseConvAsY( - name_scope_, - FuseConvAsX(name_scope_, - FuseProjectionConv(name_scope_, std::make_pair(graph, 0)))); + auto graph_with_stats = + FuseProjectionConv(name_scope_, std::make_pair(graph, 0)); + graph_with_stats = FuseConvAsX(name_scope_, graph_with_stats); + graph_with_stats = FuseConvAsY(name_scope_, graph_with_stats); - LOG(INFO) << "Fused graph " << fused_graph_with_stats.second << "\n"; - AddStatis(fused_graph_with_stats.second); + AddStatis(graph_with_stats.second); } } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h index c83335da2f629c128fcf4819b2645ab1ef5eae42..c4351b382187d9062a059d013ddb237520645b6d 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h @@ -28,19 +28,9 @@ namespace paddle { namespace framework { namespace ir { -class Graph; -class GraphPatternDetector; -class Node; -namespace patterns { -struct Conv; -} // namespace patterns - -using graph_ptr = ir::Graph*; using GraphWithStats = std::pair; -void CorrectGraphEdges(Graph* graph, Node* from, Node* to); bool IsReachable(ir::Graph* graph, Node* from, Node* to); -paddle::optional HasBias(const Node& op, const std::string& bias_name); class ResidualConnectionMKLDNNFusePass : public FusePassBase { private: @@ -52,91 +42,13 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { const std::string& name_scope, const GraphWithStats& graph_with_stats) const; - template - using GetNodeFunc = - std::function; - using IdentityConvFunc = GetNodeFunc>; - using IdentityElementwiseAddFunc = - GetNodeFunc>; - - using ProjectionConvFunc = IdentityConvFunc; - using ProjectionElementwiseAddFunc = GetNodeFunc>; - - using CanFuseFunc = std::function; - - std::tuple GetNodesFromConv( - const patterns::Conv& conv_pattern, - const GraphPatternDetector::subgraph_t& subgraph) const; - - std::tuple GetNodesFromProjectionConv( - const patterns::Conv& conv_pattern, - const GraphPatternDetector::subgraph_t& subgraph) const; - - template - GraphWithStats ExecuteHandleOnGraph(GraphPatternDetector* gpd, - const GraphWithStats& graph_with_stats, - OpFuncs&&... op_funcs) const { - ir::Graph* graph; - int stats; - - std::tie(graph, stats) = graph_with_stats; - - auto can_fuse = [this](Node* op1, Node* op2) -> bool { - return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; - }; - auto fuse_handle = HandleType{can_fuse, std::forward(op_funcs)...}; - - (*gpd)(graph, fuse_handle); - - return std::make_pair(graph, stats + fuse_handle.get_stats()); - } - - struct IdentityFuseHandle { - IdentityFuseHandle( - const CanFuseFunc& can_fuse_func, - const IdentityConvFunc& get_node_from_conv_op, - const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op, - const ResidualConnectionMKLDNNFusePass* pass); - - void operator()(const GraphPatternDetector::subgraph_t& subgraph, - Graph* graph); - int get_stats() const { return *fusion_stats; } - - private: - std::shared_ptr fusion_stats; - CanFuseFunc can_fuse_func; - IdentityConvFunc get_node_from_conv_op; - IdentityElementwiseAddFunc get_node_from_elementwise_add_op; - const ResidualConnectionMKLDNNFusePass* pass_; - }; - - struct ProjectionFuseHandle { - ProjectionFuseHandle( - const CanFuseFunc& can_fuse_func, - const ProjectionConvFunc& get_node_from_conv_x_op, - const ProjectionConvFunc& get_node_from_conv_y_op, - const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op, - const ResidualConnectionMKLDNNFusePass* pass); - - void operator()(const GraphPatternDetector::subgraph_t& subgraph, - Graph* graph); - int get_stats() const { return *fusion_stats; } - - private: - std::shared_ptr fusion_stats; - CanFuseFunc can_fuse_func; - ProjectionConvFunc get_node_from_conv_x_op; - ProjectionConvFunc get_node_from_conv_y_op; - ProjectionElementwiseAddFunc get_node_from_elementwise_add_op; - const ResidualConnectionMKLDNNFusePass* pass_; - }; - public: ResidualConnectionMKLDNNFusePass(); virtual ~ResidualConnectionMKLDNNFusePass() {} protected: - void ApplyImpl(graph_ptr graph) const; + void ApplyImpl(ir::Graph* graph) const; + static bool HasFusedActivation(Node* conv_node) { return !(conv_node->Op() ->GetAttrIfExists("fuse_activation")