diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc index 09c5ec59d66445bdbd5349447b125be89cb2efdf..644911b1b8ada4c9d24cc9bcefce2f9f86bdf920 100644 --- a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc @@ -50,28 +50,13 @@ std::unique_ptr ConvReLUFusePass::ApplyImpl( GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op // Create an ConvReLU Node. - OpDesc desc; - std::string conv_relu_i_in = subgraph.at(conv_input)->Name(); - std::string conv_relu_w_in = conv_weight->Name(); - std::string conv_relu_b_in = conv_bias->Name(); - std::string conv_relu_out = relu_out->Name(); - desc.SetInput("Input", std::vector({conv_relu_i_in})); - desc.SetInput("Filter", std::vector({conv_relu_w_in})); - desc.SetInput("Bias", std::vector({conv_relu_b_in})); - desc.SetOutput("Output", std::vector({conv_relu_out})); - desc.SetType("conv2d"); - for (auto& attr : conv->Op()->GetAttrMap()) { - desc.SetAttr(attr.first, attr.second); - } - desc.SetAttr("fuse_relu", true); - auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied. - GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out}); + OpDesc* desc = conv->Op(); + desc->SetOutput("Output", std::vector({relu_out->Name()})); + desc->SetAttr("fuse_relu", true); + GraphSafeRemoveNodes(graph.get(), {relu, conv_out}); PADDLE_ENFORCE(subgraph.count(conv_input)); - IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node); - IR_NODE_LINK_TO(conv_weight, conv_relu_node); - IR_NODE_LINK_TO(conv_bias, conv_relu_node); - IR_NODE_LINK_TO(conv_relu_node, relu_out); + IR_NODE_LINK_TO(conv, relu_out); found_conv_relu_count++; };