From 4e72ab411eece7345f4ab21a142d93e2004f716e Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Fri, 19 Oct 2018 09:50:10 +0200 Subject: [PATCH] MKLDNN conv + elementwise_add fusion: fix for crash when bias is not present --- .../conv_elementwise_add_mkldnn_fuse_pass.cc | 41 +++++++++++++++++-- .../framework/ir/graph_pattern_detector.cc | 6 +-- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc index 7aad9de1be..10b1d636e4 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h" #include +#include #include "paddle/fluid/framework/ir/graph_traits.h" @@ -67,11 +68,32 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { conv_output->AsIntermediate(); + auto conv_op_has_bias = [](const Node& conv_op, + const Scope& scope) -> std::pair { + auto bias_input_names = conv_op.Op()->Inputs(); + auto bias_it = bias_input_names.find("Bias"); + + if (bias_it != std::end(bias_input_names)) { + bool has_bias = !bias_it->second.empty(); + + if (has_bias) { + auto conv_bias_names = bias_it->second; + auto conv_bias_names_it = + std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs), + [&conv_bias_names](Node* n) -> bool { + return n->Name() == conv_bias_names[0]; + }); + return std::make_pair(has_bias, *conv_bias_names_it); + } + } + + return std::make_pair(false, nullptr); + }; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, @@ -81,17 +103,25 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, elementwise_add_pattern); - if (FindFuseOption(conv_op, elementwise_add_op) != FUSE_MKLDNN) return; + if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; OpDesc op_desc; op_desc.SetType("conv2d"); op_desc.SetInput("Input", {conv_input->Name()}); - op_desc.SetInput("Bias", {conv_bias->Name()}); op_desc.SetInput("Filter", {conv_filter->Name()}); op_desc.SetInput("ResidualData", {elementwise_add_x->Name()}); op_desc.SetOutput("Output", {conv_output->Name()}); + bool has_bias; + Node* conv_bias; + + std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op, *param_scope()); + + if (has_bias) { + op_desc.SetInput("Bias", {conv_bias->Name()}); + } + for (const auto& attr : conv_op->Op()->GetAttrMap()) { op_desc.SetAttr(attr.first, attr.second); } @@ -101,11 +131,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { auto fused_conv_op = g->CreateOpNode(&op_desc); IR_NODE_LINK_TO(conv_input, fused_conv_op); - IR_NODE_LINK_TO(conv_bias, fused_conv_op); IR_NODE_LINK_TO(conv_filter, fused_conv_op); IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op); IR_NODE_LINK_TO(fused_conv_op, conv_output); + if (has_bias) { + IR_NODE_LINK_TO(conv_bias, fused_conv_op); + } + CorrectGraphEdges(g, elementwise_add_out, conv_output); GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op}); }; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 786765bff7..da83bcdf37 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1006,10 +1006,6 @@ PDNode *patterns::Conv::operator()() { ->AsInput() ->assert_is_op_input("conv2d", "Input"); - auto bias_var = pattern->NewNode(conv_bias_repr()) - ->AsInput() - ->assert_is_op_input("conv2d", "Bias"); - auto filter_var = pattern->NewNode(conv_filter_repr()) ->AsInput() ->assert_is_op_input("conv2d", "Filter"); @@ -1018,7 +1014,7 @@ PDNode *patterns::Conv::operator()() { ->AsOutput() ->assert_is_op_output("conv2d", "Output"); - conv_op->LinksFrom({input_var, bias_var, filter_var}); + conv_op->LinksFrom({input_var, /*bias_var,*/ filter_var}); conv_op->LinksTo({output_var}); return output_var; -- GitLab