From 53da846d1ec156781d31184477bae97dea6a4774 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Thu, 15 Nov 2018 16:59:36 +0100 Subject: [PATCH] MKLDNN residual connections fuse pass: initial implementation of fusion for projection pass test=develop --- .../conv_elementwise_add_mkldnn_fuse_pass.cc | 174 +++++++++++++++--- .../conv_elementwise_add_mkldnn_fuse_pass.h | 71 +++++-- 2 files changed, 206 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc index f0e9ec2aeb9..5376fc163e2 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -120,17 +120,18 @@ boost::optional HasBias(const Node& op, const std::string& bias_name) { return boost::none; } -ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler( - const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op, - const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc& - get_node_from_elementwise_add_op, - const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func) +ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle( + const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func, + const ResidualConnectionMKLDNNFusePass::IdentityConvFunc& + get_node_from_conv_op, + const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc& + get_node_from_elementwise_add_op) : fusion_stats{std::make_shared(0)}, + can_fuse_func{can_fuse_func}, get_node_from_conv_op{get_node_from_conv_op}, - get_node_from_elementwise_add_op{get_node_from_elementwise_add_op}, - can_fuse_func{can_fuse_func} {} + get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {} -void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()( +void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { Node* conv_op; Node* conv_input; @@ -187,6 +188,104 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()( (*fusion_stats)++; } +ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle( + const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func, + const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc& + get_node_from_conv_x_op, + const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc& + get_node_from_conv_y_op, + const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc& + get_node_from_elementwise_add_op) + : fusion_stats{std::make_shared(0)}, + can_fuse_func{can_fuse_func}, + get_node_from_conv_x_op{get_node_from_conv_x_op}, + get_node_from_conv_y_op{get_node_from_conv_y_op}, + get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {} + +void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( + const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + Node* conv_x_op; + Node* conv_x_input; + Node* conv_x_filter; + Node* conv_x_output; + + Node* conv_y_op; + Node* conv_y_input; + Node* conv_y_filter; + Node* conv_y_output; + + Node* elementwise_add_op; + Node* elementwise_add_out; + + std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) = + get_node_from_conv_x_op(subgraph); + std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) = + get_node_from_conv_y_op(subgraph); + std::tie(elementwise_add_op, elementwise_add_out) = + get_node_from_elementwise_add_op(subgraph); + + if (!can_fuse_func(conv_x_op, elementwise_add_op)) return; + if (!can_fuse_func(conv_y_op, elementwise_add_op)) return; + + Node* projection_node; + Node* residual_conv_op; + Node* residual_conv_input; + Node* residual_conv_filter; + Node* residual_conv_output; + + if (IsReachable(graph, conv_x_input, conv_y_output)) { + projection_node = conv_x_output; + residual_conv_op = conv_y_op; + residual_conv_input = conv_y_input; + residual_conv_filter = conv_y_filter; + residual_conv_output = conv_y_output; + } else if (IsReachable(graph, conv_y_input, conv_x_output)) { + projection_node = conv_y_output; + residual_conv_op = conv_x_op; + residual_conv_input = conv_x_input; + residual_conv_filter = conv_x_filter; + residual_conv_output = conv_x_output; + } else { + return; + } + + OpDesc op_desc; + op_desc.SetType("conv2d"); + + op_desc.SetInput("Input", {residual_conv_input->Name()}); + op_desc.SetInput("Filter", {residual_conv_filter->Name()}); + op_desc.SetInput("ResidualData", {projection_node->Name()}); + op_desc.SetOutput("Output", {residual_conv_output->Name()}); + + auto residual_conv_bias = HasBias(*residual_conv_op, "Bias"); + + if (residual_conv_bias) { + op_desc.SetInput("Bias", {(*residual_conv_bias)->Name()}); + } + + for (const auto& attr : residual_conv_op->Op()->GetAttrMap()) { + op_desc.SetAttr(attr.first, attr.second); + } + + op_desc.SetAttr("fuse_residual_connection", true); + + auto fused_conv_op = graph->CreateOpNode(&op_desc); + + IR_NODE_LINK_TO(residual_conv_input, fused_conv_op); + IR_NODE_LINK_TO(residual_conv_filter, fused_conv_op); + IR_NODE_LINK_TO(projection_node, fused_conv_op); + IR_NODE_LINK_TO(fused_conv_op, residual_conv_output); + + if (residual_conv_bias) { + IR_NODE_LINK_TO((*residual_conv_bias), fused_conv_op); + } + + CorrectGraphEdges(graph, elementwise_add_out, residual_conv_output); + GraphSafeRemoveNodes( + graph, {elementwise_add_out, residual_conv_op, elementwise_add_op}); + (*fusion_stats)++; +} + std::tuple ResidualConnectionMKLDNNFusePass::GetNodesFromConv( const patterns::Conv& conv_pattern, @@ -233,7 +332,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( elementwise_add_out); }; - return ExecuteHandlerOnGraph( + return ExecuteHandleOnGraph( &gpd, graph_with_stats, [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { return GetNodesFromConv(conv_pattern, subgraph); @@ -270,7 +369,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( elementwise_add_out); }; - return ExecuteHandlerOnGraph( + return ExecuteHandleOnGraph( &gpd, graph_with_stats, [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { return GetNodesFromConv(conv_pattern, subgraph); @@ -278,33 +377,54 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( get_node_from_elementwise_add); } -GraphWithStats ResidualConnectionMKLDNNFusePass::ExecuteHandlerOnGraph( - GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats, - const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv, - const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc& - get_node_from_elementwise_add) const { - ir::Graph* graph; - int stats; +GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( + const std::string& name_scope, + const GraphWithStats& graph_with_stats) const { + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); - std::tie(graph, stats) = graph_with_stats; + patterns::Conv conv_x_pattern{pattern, name_scope}; + auto conv_x_output = conv_x_pattern(); - auto can_fuse = [this](Node* op1, Node* op2) -> bool { - return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; - }; + patterns::Conv conv_y_pattern{pattern, name_scope}; + auto conv_y_output = conv_y_pattern(); - auto fuse_handler = - FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse}; + patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; + elementwise_add_pattern(conv_x_output, conv_y_output); + conv_x_output->AsIntermediate(); + conv_y_output->AsIntermediate(); - (*gpd)(graph, fuse_handler); + auto get_node_from_elementwise_add = [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) + -> std::tuple { + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); - return std::make_pair(graph, stats + fuse_handler.get_stats()); + return std::make_tuple(elementwise_add_op, elementwise_add_out); + }; + + return ExecuteHandleOnGraph( + &gpd, graph_with_stats, + [this, + &conv_x_pattern](const GraphPatternDetector::subgraph_t& subgraph) { + return GetNodesFromConv(conv_x_pattern, subgraph); + }, + [this, + &conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) { + return GetNodesFromConv(conv_y_pattern, subgraph); + }, + get_node_from_elementwise_add); } graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { FusePassBase::Init(name_scope_, graph.get()); - auto fused_graph_with_stats = FuseConvAsY( - name_scope_, FuseConvAsX(name_scope_, std::make_pair(graph.get(), 0))); + name_scope_, + FuseConvAsX( + name_scope_, + FuseProjectionConv(name_scope_, std::make_pair(graph.get(), 0)))); std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl; AddStatis(fused_graph_with_stats.second); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h index 03a23404f9a..6629dae425a 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h @@ -40,27 +40,73 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { const GraphWithStats& graph_with_stats) const; GraphWithStats FuseConvAsY(const std::string& name_scope, const GraphWithStats& graph_with_stats) const; + GraphWithStats FuseProjectionConv( + const std::string& name_scope, + const GraphWithStats& graph_with_stats) const; template using GetNodeFunc = std::function; - using ConvFunc = GetNodeFunc>; - using ElementwiseAddFunc = GetNodeFunc>; + using IdentityConvFunc = GetNodeFunc>; + using IdentityElementwiseAddFunc = + GetNodeFunc>; + + using ProjectionConvFunc = IdentityConvFunc; + using ProjectionElementwiseAddFunc = GetNodeFunc>; + using CanFuseFunc = std::function; std::tuple GetNodesFromConv( const patterns::Conv& conv_pattern, const GraphPatternDetector::subgraph_t& subgraph) const; - GraphWithStats ExecuteHandlerOnGraph( - GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats, - const ConvFunc& get_node_from_conv, - const ElementwiseAddFunc& get_node_from_elementwise_add) const; + std::tuple GetNodesFromProjectionConv( + const patterns::Conv& conv_pattern, + const GraphPatternDetector::subgraph_t& subgraph) const; + + template + GraphWithStats ExecuteHandleOnGraph(GraphPatternDetector* gpd, + const GraphWithStats& graph_with_stats, + OpFuncs&&... op_funcs) const { + ir::Graph* graph; + int stats; + + std::tie(graph, stats) = graph_with_stats; + + auto can_fuse = [this](Node* op1, Node* op2) -> bool { + return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; + }; + + auto fuse_handle = HandleType{can_fuse, std::forward(op_funcs)...}; + + (*gpd)(graph, fuse_handle); + + return std::make_pair(graph, stats + fuse_handle.get_stats()); + } + + struct IdentityFuseHandle { + IdentityFuseHandle( + const CanFuseFunc& can_fuse_func, + const IdentityConvFunc& get_node_from_conv_op, + const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op); + + void operator()(const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph); + int get_stats() const { return *fusion_stats; } + + private: + std::shared_ptr fusion_stats; + CanFuseFunc can_fuse_func; + IdentityConvFunc get_node_from_conv_op; + IdentityElementwiseAddFunc get_node_from_elementwise_add_op; + }; - struct FuseHandler { - FuseHandler(const ConvFunc& get_node_from_conv_op, - const ElementwiseAddFunc& get_node_from_elementwise_add_op, - const CanFuseFunc& can_fuse_func); + struct ProjectionFuseHandle { + ProjectionFuseHandle( + const CanFuseFunc& can_fuse_func, + const ProjectionConvFunc& get_node_from_conv_x_op, + const ProjectionConvFunc& get_node_from_conv_y_op, + const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op); void operator()(const GraphPatternDetector::subgraph_t& subgraph, Graph* graph); @@ -68,9 +114,10 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { private: std::shared_ptr fusion_stats; - ConvFunc get_node_from_conv_op; - ElementwiseAddFunc get_node_from_elementwise_add_op; CanFuseFunc can_fuse_func; + ProjectionConvFunc get_node_from_conv_x_op; + ProjectionConvFunc get_node_from_conv_y_op; + ProjectionElementwiseAddFunc get_node_from_elementwise_add_op; }; public: -- GitLab