From 4224089354eff22f0fa13e881146240c61fd83ea Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Thu, 8 Nov 2018 15:18:44 +0100 Subject: [PATCH] MKLDNN residual connections fuse pass: Maybe removed and boost::optional used where it makes sense --- .../conv_elementwise_add_mkldnn_fuse_pass.cc | 125 ++++++++++-------- .../conv_elementwise_add_mkldnn_fuse_pass.h | 44 ++---- 2 files changed, 81 insertions(+), 88 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc index 5a6d20e84..f0e9ec2ae 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -99,7 +99,7 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) { return false; } -std::pair HasBias(const Node& op, const std::string& bias_name) { +boost::optional HasBias(const Node& op, const std::string& bias_name) { auto bias_input_names = op.Op()->Inputs(); auto bias_it = bias_input_names.find(bias_name); @@ -113,11 +113,11 @@ std::pair HasBias(const Node& op, const std::string& bias_name) { [&bias_names](Node* n) -> bool { return n->Name() == bias_names[0]; }); - return std::make_pair(has_bias, *bias_names_it); + return *bias_names_it; } } - return std::make_pair(false, nullptr); + return boost::none; } ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler( @@ -125,7 +125,8 @@ ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler( const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc& get_node_from_elementwise_add_op, const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func) - : get_node_from_conv_op{get_node_from_conv_op}, + : fusion_stats{std::make_shared(0)}, + get_node_from_conv_op{get_node_from_conv_op}, get_node_from_elementwise_add_op{get_node_from_elementwise_add_op}, can_fuse_func{can_fuse_func} {} @@ -157,13 +158,10 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()( op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()}); op_desc.SetOutput("Output", {conv_output->Name()}); - bool has_bias; - Node* conv_bias; + auto conv_bias = HasBias(*conv_op, "Bias"); - std::tie(has_bias, conv_bias) = HasBias(*conv_op, "Bias"); - - if (has_bias) { - op_desc.SetInput("Bias", {conv_bias->Name()}); + if (conv_bias) { + op_desc.SetInput("Bias", {(*conv_bias)->Name()}); } for (const auto& attr : conv_op->Op()->GetAttrMap()) { @@ -179,40 +177,48 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()( IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op); IR_NODE_LINK_TO(fused_conv_op, conv_output); - if (has_bias) { - IR_NODE_LINK_TO(conv_bias, fused_conv_op); + if (conv_bias) { + IR_NODE_LINK_TO((*conv_bias), fused_conv_op); } CorrectGraphEdges(graph, elementwise_add_out, conv_output); GraphSafeRemoveNodes(graph, {elementwise_add_out, conv_op, elementwise_add_op}); + (*fusion_stats)++; +} + +std::tuple +ResidualConnectionMKLDNNFusePass::GetNodesFromConv( + const patterns::Conv& conv_pattern, + const GraphPatternDetector::subgraph_t& subgraph) const { + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); + + return std::make_tuple(conv_op, conv_input, conv_filter, conv_output); } -graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX( - const std::string& name_scope_, graph_ptr graph) const { +GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( + const std::string& name_scope, + const GraphWithStats& graph_with_stats) const { + ir::Graph* graph; + int stats; + + std::tie(graph, stats) = graph_with_stats; + GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); - patterns::Conv conv_pattern{pattern, name_scope_}; + patterns::Conv conv_pattern{pattern, name_scope}; auto conv_output = conv_pattern(); - patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; + patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; elementwise_add_pattern( conv_output, pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); conv_output->AsIntermediate(); - auto get_node_from_conv = - [&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) - -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); - - return std::make_tuple(conv_op, conv_input, conv_filter, conv_output); - }; - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { @@ -227,43 +233,29 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX( elementwise_add_out); }; - auto can_fuse = [this](Node* op1, Node* op2) -> bool { - return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; - }; - - auto fuse_handler = - FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse}; - - gpd(graph.get(), fuse_handler); - - return graph; + return ExecuteHandlerOnGraph( + &gpd, graph_with_stats, + [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { + return GetNodesFromConv(conv_pattern, subgraph); + }, + get_node_from_elementwise_add); } -graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY( - const std::string& name_scope_, graph_ptr graph) const { +GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( + const std::string& name_scope, + const GraphWithStats& graph_with_stats) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); - patterns::Conv conv_pattern{pattern, name_scope_}; + patterns::Conv conv_pattern{pattern, name_scope}; auto conv_output = conv_pattern(); - patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; + patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; elementwise_add_pattern( pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()), conv_output); conv_output->AsIntermediate(); - auto get_node_from_conv = - [&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) - -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); - - return std::make_tuple(conv_op, conv_input, conv_filter, conv_output); - }; - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { @@ -278,6 +270,24 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY( elementwise_add_out); }; + return ExecuteHandlerOnGraph( + &gpd, graph_with_stats, + [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { + return GetNodesFromConv(conv_pattern, subgraph); + }, + get_node_from_elementwise_add); +} + +GraphWithStats ResidualConnectionMKLDNNFusePass::ExecuteHandlerOnGraph( + GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats, + const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv, + const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc& + get_node_from_elementwise_add) const { + ir::Graph* graph; + int stats; + + std::tie(graph, stats) = graph_with_stats; + auto can_fuse = [this](Node* op1, Node* op2) -> bool { return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; }; @@ -285,15 +295,20 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY( auto fuse_handler = FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse}; - gpd(graph.get(), fuse_handler); + (*gpd)(graph, fuse_handler); - return graph; + return std::make_pair(graph, stats + fuse_handler.get_stats()); } graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { FusePassBase::Init(name_scope_, graph.get()); - return FuseConvAsY(name_scope_, FuseConvAsX(name_scope_, std::move(graph))); + auto fused_graph_with_stats = FuseConvAsY( + name_scope_, FuseConvAsX(name_scope_, std::make_pair(graph.get(), 0))); + + std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl; + AddStatis(fused_graph_with_stats.second); + return graph; } } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h index de4d1075e..03a23404f 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h @@ -27,43 +27,12 @@ namespace paddle { namespace framework { namespace ir { -// poor replacement for C++17 std::optional and Boost.Optional -struct InPlace {}; -InPlace in_place; - -template -class Maybe { - private: - typename std::aligned_storage::type data; - bool is_initialized{false}; - - public: - template - explicit Maybe(InPlace, Args&&... args) { - new (&data) T(std::forward(args)...); - is_initialized = true; - } - - Maybe() {} - - operator bool() { return is_initialized; } - - T& value() { return *reinterpret_cast(&data); } - - ~Maybe() { reinterpret_cast(&data)->~T(); } -}; - -template -Maybe MakeMaybe(Args&&... args) { - return Maybe(in_place, std::forward(args)...); -} - using graph_ptr = std::unique_ptr; -using GraphWithStats = std::pair>; +using GraphWithStats = std::pair; void CorrectGraphEdges(Graph* graph, Node* from, Node* to); bool IsReachable(ir::Graph* graph, Node* from, Node* to); -std::pair HasBias(const Node& op, const std::string& bias_name); +boost::optional HasBias(const Node& op, const std::string& bias_name); class ResidualConnectionMKLDNNFusePass : public FusePassBase { private: @@ -79,6 +48,15 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { using ElementwiseAddFunc = GetNodeFunc>; using CanFuseFunc = std::function; + std::tuple GetNodesFromConv( + const patterns::Conv& conv_pattern, + const GraphPatternDetector::subgraph_t& subgraph) const; + + GraphWithStats ExecuteHandlerOnGraph( + GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats, + const ConvFunc& get_node_from_conv, + const ElementwiseAddFunc& get_node_from_elementwise_add) const; + struct FuseHandler { FuseHandler(const ConvFunc& get_node_from_conv_op, const ElementwiseAddFunc& get_node_from_elementwise_add_op, -- GitLab