diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc index 09c5ec59d66445bdbd5349447b125be89cb2efdf..d7df6389cfd595324e284e0da10f65213ccee80f 100644 --- a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc @@ -26,8 +26,6 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( PADDLE_ENFORCE(graph.get()); FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get()); - std::unordered_set<Node*> nodes2delete; - GraphPatternDetector gpd; auto* conv_input = gpd.mutable_pattern() ->NewNode("conv_relu_mkldnn_fuse/conv_input") @@ -42,36 +40,20 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( Graph* g) { VLOG(4) << "handle ConvReLU fuse"; GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, - conv_relu_pattern); // Filter - GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp + conv_relu_pattern); // Filter + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op - // Create an ConvReLU Node. - OpDesc desc; - std::string conv_relu_i_in = subgraph.at(conv_input)->Name(); - std::string conv_relu_w_in = conv_weight->Name(); - std::string conv_relu_b_in = conv_bias->Name(); - std::string conv_relu_out = relu_out->Name(); - desc.SetInput("Input", std::vector<std::string>({conv_relu_i_in})); - desc.SetInput("Filter", std::vector<std::string>({conv_relu_w_in})); - desc.SetInput("Bias", std::vector<std::string>({conv_relu_b_in})); - desc.SetOutput("Output", std::vector<std::string>({conv_relu_out})); - desc.SetType("conv2d"); - for (auto& attr : conv->Op()->GetAttrMap()) { - desc.SetAttr(attr.first, attr.second); - } - desc.SetAttr("fuse_relu", true); - auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied. - GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out}); + // Transform Conv node into ConvReLU node. + OpDesc* desc = conv->Op(); + desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()})); + desc->SetAttr("fuse_relu", true); + GraphSafeRemoveNodes(graph.get(), {relu, conv_out}); PADDLE_ENFORCE(subgraph.count(conv_input)); - IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node); - IR_NODE_LINK_TO(conv_weight, conv_relu_node); - IR_NODE_LINK_TO(conv_bias, conv_relu_node); - IR_NODE_LINK_TO(conv_relu_node, relu_out); + IR_NODE_LINK_TO(conv, relu_out); found_conv_relu_count++; }; diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc index 82b5fa1886098ca3b19c147c307d3f2fc3ba03d6..9dd780ec89ab991d6d99cb66fa2a9b683be2b9ca 100644 --- a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc @@ -85,16 +85,13 @@ TEST(ConvReLUFusePass, basic) { for (auto* node : graph->Nodes()) { if (node->IsOp() && node->Op()->Type() == "conv2d") { - if (node->Op()->HasAttr("use_mkldnn")) { - bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn")); - if (use_mkldnn) { - if (node->Op()->HasAttr("fuse_relu")) { - bool fuse_relu = boost::get<bool>(node->Op()->GetAttr("fuse_relu")); - if (fuse_relu) { - ++conv_relu_count; - } - } - } + auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn"))); + ASSERT_TRUE(op->HasAttr("fuse_relu")); + bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu")); + if (fuse_relu) { + ++conv_relu_count; } } } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index ef5113819696238d4e06b826bf43064b0f368dea..6d2c51b0e9bed8461f6491b84a36a3bf6663a138 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -638,11 +638,6 @@ PDNode *patterns::ConvReLU::operator()( ->AsInput() ->assert_is_persistable_var() ->assert_is_op_input("conv2d", "Filter"); - // Bias - auto *conv_bias_var = pattern->NewNode(conv_bias_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("conv2d", "Bias"); // intermediate variable, will be removed in the IR after fuse. auto *conv_out_var = pattern->NewNode(conv_out_repr()) ->AsIntermediate() @@ -653,8 +648,7 @@ PDNode *patterns::ConvReLU::operator()( ->AsOutput() ->assert_is_op_output("relu"); - conv_op->LinksFrom({conv_input, conv_weight_var, conv_bias_var}) - .LinksTo({conv_out_var}); + conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var}); return relu_out_var; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 46950ed877cda7cd83ead4e9aa9a3aaae5d5ecfa..69b486c29d8bd1102a8372f5041051c25ce19359 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -379,7 +379,7 @@ struct PatternBase { // op: conv + relu // named nodes: // conv_input, conv_weight, -// conv_bias, conv_out, conv, +// conv_out, conv, // relu_out, relu struct ConvReLU : public PatternBase { ConvReLU(PDPattern* pattern, const std::string& name_scope) @@ -392,7 +392,6 @@ struct ConvReLU : public PatternBase { PATTERN_DECL_NODE(relu); // declare variable node's name PATTERN_DECL_NODE(conv_weight); - PATTERN_DECL_NODE(conv_bias); PATTERN_DECL_NODE(conv_out); PATTERN_DECL_NODE(relu_out); };