From a49aa4dac9989039a17f3d5efedaf7ea595a37b3 Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Mon, 24 Sep 2018 11:38:04 +0200 Subject: [PATCH] make bias unnecessary for ConvRelu fuse --- paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc | 7 +++---- paddle/fluid/framework/ir/graph_pattern_detector.cc | 8 +------- paddle/fluid/framework/ir/graph_pattern_detector.h | 3 +-- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc index 644911b1b8..1f75455580 100644 --- a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc @@ -42,14 +42,13 @@ std::unique_ptr ConvReLUFusePass::ApplyImpl( Graph* g) { VLOG(4) << "handle ConvReLU fuse"; GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, - conv_relu_pattern); // Filter - GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp + conv_relu_pattern); // Filter + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op - // Create an ConvReLU Node. + // Transform Conv node into ConvReLU node. OpDesc* desc = conv->Op(); desc->SetOutput("Output", std::vector({relu_out->Name()})); desc->SetAttr("fuse_relu", true); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index ef51138196..6d2c51b0e9 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -638,11 +638,6 @@ PDNode *patterns::ConvReLU::operator()( ->AsInput() ->assert_is_persistable_var() ->assert_is_op_input("conv2d", "Filter"); - // Bias - auto *conv_bias_var = pattern->NewNode(conv_bias_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("conv2d", "Bias"); // intermediate variable, will be removed in the IR after fuse. auto *conv_out_var = pattern->NewNode(conv_out_repr()) ->AsIntermediate() @@ -653,8 +648,7 @@ PDNode *patterns::ConvReLU::operator()( ->AsOutput() ->assert_is_op_output("relu"); - conv_op->LinksFrom({conv_input, conv_weight_var, conv_bias_var}) - .LinksTo({conv_out_var}); + conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var}); return relu_out_var; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 46950ed877..69b486c29d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -379,7 +379,7 @@ struct PatternBase { // op: conv + relu // named nodes: // conv_input, conv_weight, -// conv_bias, conv_out, conv, +// conv_out, conv, // relu_out, relu struct ConvReLU : public PatternBase { ConvReLU(PDPattern* pattern, const std::string& name_scope) @@ -392,7 +392,6 @@ struct ConvReLU : public PatternBase { PATTERN_DECL_NODE(relu); // declare variable node's name PATTERN_DECL_NODE(conv_weight); - PATTERN_DECL_NODE(conv_bias); PATTERN_DECL_NODE(conv_out); PATTERN_DECL_NODE(relu_out); }; -- GitLab