diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc index 973cd73e485f0f72752cc72338ae11088fe2fb2f..111e08d4fc0582a3f1502ca7ef7ec56ee2a5ddbf 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -45,17 +45,13 @@ struct Conv { ->assert_is_op("conv2d"); auto input_var = pattern->new_node(input_name()) - ->AsInput() - ->assert_is_op_input(conv_name()); + ->assert_is_op_input(conv_name(), input_name()); auto filter_var = pattern->new_node(filter_name()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input(conv_name()); + ->assert_is_op_input(conv_name(), filter_name()); auto output_var = pattern->new_node(output_name()) - ->AsOutput() - ->assert_is_op_output(conv_name()); + ->assert_is_op_output(conv_name(), output_name()); conv_op->LinksFrom({input_var, filter_var}); conv_op->LinksTo({output_var}); @@ -77,19 +73,13 @@ struct ElementwiseAdd { ->assert_is_op("elementwise_add"); auto y_var = pattern->new_node(y_name()) - ->AsInput() - ->assert_is_op_input(elementwise_add_name()); + ->assert_is_op_input(elementwise_add_name(), y_name()); - conv_output->assert_is_op_input(pattern->node_name(elementwise_add_name()), - pattern->node_name(x_name())); -// auto y_var = pattern->NewNode(y_name()) -// ->AsInput() -// ->assert_is_op_input(elementwise_add_name()); + conv_output->assert_is_op_input(elementwise_add_name(), x_name()); auto out_var = pattern->new_node(out_name()) ->AsOutput() - ->assert_is_op_output( - pattern->node_name(elementwise_add_name())); + ->assert_is_op_output(elementwise_add_name(), out_name()); elementwise_add_op->LinksFrom({y_var, conv_output}); elementwise_add_op->LinksTo({out_var}); @@ -118,16 +108,16 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); - auto pattern_ptr = std::make_shared(pattern, name_scope_); patterns::Conv conv_pattern; auto conv_output = conv_pattern(pattern_ptr)(); - conv_output->AsIntermediate(); patterns::ElementwiseAdd elementwise_add_pattern; elementwise_add_pattern(pattern_ptr)(conv_output); + conv_output->AsIntermediate(); + auto link_nodes_to = [](Node* a, Node* b) { a->outputs.push_back(b); b->inputs.push_back(a); @@ -139,7 +129,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { op_desc.SetInput("Input", {conv_input->Name()}); op_desc.SetInput("Filter", {conv_filter->Name()}); - op_desc.SetOutput("Ouput", {y->Name()}); + op_desc.SetOutput("Output", {y->Name()}); op_desc.SetAttr("fuse_sum", true); @@ -155,16 +145,17 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { }; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - auto elementwise_add_x = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.x_name()); - auto elementwise_add_y = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.y_name()); - auto elementwise_add_out = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.out_name()); - - auto conv_filter = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.filter_name()); + auto conv_op = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.conv_name()); auto conv_input = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.input_name()); + auto conv_filter = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.filter_name()); auto conv_output = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.output_name()); + auto elementwise_add_op = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.elementwise_add_name()); + auto elementwise_add_y = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.y_name()); + auto elementwise_add_out = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.out_name()); + fuse_conv(g, conv_input, conv_filter, elementwise_add_y); - remove_unused_nodes(g, {elementwise_add_x, conv_output, elementwise_add_out}); + remove_unused_nodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op}); }; gpd(graph.get(), handler); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index 62dbb1eccd36aea55622d2b00b86d4e11de58e9e..ffecf35de2af50c10e937e27c7cf802be405f067 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -16,7 +16,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, op->SetAttr("use_mkldnn", true); op->SetInput("Input", {inputs[0]}); op->SetInput("Filter", {inputs[1]}); - op->SetInput("Output", {outputs}); + op->SetOutput("Output", outputs); } else if (type == "elementwise_add") { op->SetInput("X", {inputs[0]}); op->SetInput("Y", {inputs[1]}); @@ -24,54 +24,119 @@ void SetOp(ProgramDesc* prog, const std::string& type, } } -ProgramDesc BuildProgramDesc() { - ProgramDesc prog; - for (auto& v : - std::vector({"a", "b", "c", "d", "weights", "f", "g"})) { - auto* var = prog.MutableBlock(0)->Var(v); - var->SetType(proto::VarType::LOD_TENSOR); - if (v == "weights" || v == "bias") { - var->SetPersistable(true); +TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddWithOps) { + auto build_program_desc = [&]() -> ProgramDesc { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "weights", "c", "d", "e", "f", "g"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + if (v == "weights" || v == "bias") { + var->SetPersistable(true); + } } - } - SetOp(&prog, "OP0", {"a"}, {"b"}); - SetOp(&prog, "OP1", {"c"}, {"d"}); - SetOp(&prog, "conv2d", {"d", "weights"}, {"f"}); - SetOp(&prog, "elemenwise_add", {"d", "f"}, {"g"}); + SetOp(&prog, "OP0", {"a"}, {"b"}); + SetOp(&prog, "OP1", {"c"}, {"d"}); + SetOp(&prog, "conv2d", {"b", "weights"}, {"e"}); + SetOp(&prog, "elementwise_add", {"e", "d"}, {"f"}); + SetOp(&prog, "OP3", {"f"}, {"g"}); + + return prog; + }; - return prog; + auto prog = build_program_desc(); + std::unique_ptr graph(new ir::Graph(prog)); + auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); + int original_nodes_num = graph->Nodes().size(); + graph = pass->Apply(std::move(graph)); + int current_nodes_num = graph->Nodes().size(); + + EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num); + // Assert conv_relu op in newly generated graph + int conv_count = 0; + int elementwise_add_count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "conv2d") { + ++conv_count; + } + if (node->IsOp() && node->Op()->Type() == "elementwise_add") { + ++elementwise_add_count; + } + /* + if (node->Op()->HasAttr("use_mkldnn")) { + bool use_mkldnn = boost::get(node->Op()->GetAttr("use_mkldnn")); + if (use_mkldnn) { + if (node->Op()->HasAttr("fuse_sum")) { +// bool fuse_sum = boost::get(node->Op()->GetAttr("fuse_sum")); + if (fuse_sum) { + ++conv_elementwise_add_count; + } + } + } + } + } + */ + } + EXPECT_EQ(conv_count, 1); + EXPECT_EQ(elementwise_add_count, 0); } -TEST(ConvElementwiseAddMKLDNNFusePass, basic) { - auto prog = BuildProgramDesc(); +TEST(ConvElementwiseAddMKLDNNFusePass, OnlyConvolutionElementwiseAdd) { + auto build_program_desc = [&]() -> ProgramDesc { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "weights"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + if (v == "weights" || v == "bias") { + var->SetPersistable(true); + } + } + + SetOp(&prog, "conv2d", {"a", "weights"}, {"b"}); + SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"}); + + return prog; + }; + + auto prog = build_program_desc(); std::unique_ptr graph(new ir::Graph(prog)); auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); int original_nodes_num = graph->Nodes().size(); graph = pass->Apply(std::move(graph)); int current_nodes_num = graph->Nodes().size(); - EXPECT_EQ(original_nodes_num - 2, current_nodes_num); + EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num); // Assert conv_relu op in newly generated graph - int conv_elementwise_add_count = 0; + int conv_count = 0; + int elementwise_add_count = 0; for (auto* node : graph->Nodes()) { if (node->IsOp() && node->Op()->Type() == "conv2d") { + ++conv_count; + } + if (node->IsOp() && node->Op()->Type() == "elementwise_add") { + ++elementwise_add_count; + } + /* if (node->Op()->HasAttr("use_mkldnn")) { bool use_mkldnn = boost::get(node->Op()->GetAttr("use_mkldnn")); if (use_mkldnn) { - // TODO tpatejko: it is commented because convolution does not support this attribute - if (true/*node->Op()->HasAttr("fuse_sum")*/) { + if (node->Op()->HasAttr("fuse_sum")) { // bool fuse_sum = boost::get(node->Op()->GetAttr("fuse_sum")); - if (true /*fuse_sum*/) { + if (fuse_sum) { ++conv_elementwise_add_count; } } } } } + */ } - EXPECT_EQ(conv_elementwise_add_count, 1); + EXPECT_EQ(conv_count, 1); + EXPECT_EQ(elementwise_add_count, 0); } } // namespace ir