From ce2464fd988b3817674e566b15c7c483b976eaad Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Fri, 19 Oct 2018 13:31:32 +0200 Subject: [PATCH] MKLDNN conv + elementwise_add fusion: UT for missing bias added. UTs refactored. Some minor changes in the pass --- .../conv_elementwise_add_mkldnn_fuse_pass.cc | 5 +- ...elementwise_add_mkldnn_fuse_pass_tester.cc | 202 +++++++++--------- .../framework/ir/graph_pattern_detector.cc | 2 +- .../framework/ir/graph_pattern_detector.h | 1 - 4 files changed, 99 insertions(+), 111 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc index 10b1d636e4b..8d0035ae98b 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -68,8 +68,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { conv_output->AsIntermediate(); - auto conv_op_has_bias = [](const Node& conv_op, - const Scope& scope) -> std::pair { + auto conv_op_has_bias = [](const Node& conv_op) -> std::pair { auto bias_input_names = conv_op.Op()->Inputs(); auto bias_it = bias_input_names.find("Bias"); @@ -116,7 +115,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { bool has_bias; Node* conv_bias; - std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op, *param_scope()); + std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op); if (has_bias) { op_desc.SetInput("Bias", {conv_bias->Name()}); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index fd47b96c101..348a3dfc5da 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -22,29 +22,22 @@ namespace paddle { namespace framework { namespace ir { +namespace { constexpr int nodes_removed = 3; constexpr int nodes_added = 1; void SetOp(ProgramDesc* prog, const std::string& type, - const std::vector& inputs, - const std::vector& outputs) { + const std::vector>& inputs, + const std::pair& output) { auto op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); + op->SetAttr("use_mkldnn", true); - if (type == "conv2d") { - op->SetAttr("use_mkldnn", true); - op->SetInput("Input", {inputs[0]}); - op->SetInput("Bias", {inputs[1]}); - op->SetInput("Filter", {inputs[2]}); - op->SetOutput("Output", outputs); - } else if (type == "elementwise_add") { - op->SetInput("X", {inputs[0]}); - op->SetInput("Y", {inputs[1]}); - op->SetOutput("Out", outputs); - } else if (type == "relu" || type == "sigmoid") { - op->SetInput("X", {inputs[0]}); - op->SetOutput("Out", outputs); + for (const auto& input : inputs) { + op->SetInput(input.first, {input.second}); } + + op->SetOutput(output.first, {output.second}); } struct IsReachable { @@ -96,30 +89,59 @@ struct IsReachable { } }; -TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { - auto build_program_desc = [&]() -> ProgramDesc { - ProgramDesc prog; - for (auto& v : std::vector( - {"a", "b", "bias", "weights", "c", "d", "e", "f"})) { - auto* var = prog.MutableBlock(0)->Var(v); - var->SetType(proto::VarType::LOD_TENSOR); - if (v == "weights" || v == "bias") { - var->SetPersistable(true); - } +void AssertOpsCount(const std::unique_ptr& graph) { + int conv_count = 0; + int elementwise_add_count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "conv2d") { + ++conv_count; + } + if (node->IsOp() && node->Op()->Type() == "elementwise_add") { + ++elementwise_add_count; } + } + EXPECT_EQ(conv_count, 1); + EXPECT_EQ(elementwise_add_count, 0); +} + +ProgramDesc BuildProgramDesc(const std::vector& transient_vars, + const std::vector& persistent_vars) { + ProgramDesc prog; - SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"}); - SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"}); - SetOp(&prog, "relu", {"d"}, {"e"}); + auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* { + auto var = prog.MutableBlock(0)->Var(var_name); + var->SetType(proto::VarType::LOD_TENSOR); - return prog; + return var; }; - auto prog = build_program_desc(); + for (const auto& v : transient_vars) { + add_var_to_prog(v); + } + + for (const auto& v : persistent_vars) { + auto var = add_var_to_prog(v); + var->SetPersistable(true); + } + + return prog; +} +} // namespace + +TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { + auto prog = + BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"}); + + SetOp(&prog, "conv2d", + {{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {"Output", "b"}); + SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"}); + SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); + std::unique_ptr graph(new ir::Graph(prog)); IsReachable is_reachable; - EXPECT_TRUE(is_reachable(graph)("a", "relu")); auto pass = @@ -132,40 +154,45 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num); - // Assert conv_relu op in newly generated graph - int conv_count = 0; - int elementwise_add_count = 0; - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "conv2d") { - ++conv_count; - } - if (node->IsOp() && node->Op()->Type() == "elementwise_add") { - ++elementwise_add_count; - } - } - EXPECT_EQ(conv_count, 1); - EXPECT_EQ(elementwise_add_count, 0); + AssertOpsCount(graph); } -TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { - auto build_program_desc = [&]() -> ProgramDesc { - ProgramDesc prog; - for (auto& v : std::vector({"a", "b", "bias", "weights"})) { - auto* var = prog.MutableBlock(0)->Var(v); - var->SetType(proto::VarType::LOD_TENSOR); - if (v == "weights" || v == "bias") { - var->SetPersistable(true); - } - } +TEST(ConvElementwiseAddMKLDNNFusePass, + ConvolutionWithElementwiseAddReluNoBias) { + auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); + SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}}, + {"Output", "b"}); + SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"}); + SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"}); - SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"}); - SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"}); + std::unique_ptr graph(new ir::Graph(prog)); - return prog; - }; + IsReachable is_reachable; + + EXPECT_TRUE(is_reachable(graph)("a", "relu")); + + auto pass = + PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); + int original_nodes_num = graph->Nodes().size(); + graph = pass->Apply(std::move(graph)); + int current_nodes_num = graph->Nodes().size(); + + EXPECT_TRUE(is_reachable(graph)("a", "relu")); + + EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, + current_nodes_num); + + AssertOpsCount(graph); +} + +TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { + auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"}); + SetOp(&prog, "conv2d", + {{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {"Output", "b"}); + SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"}); - auto prog = build_program_desc(); std::unique_ptr graph(new ir::Graph(prog)); IsReachable is_reachable; @@ -181,43 +208,19 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num); - // Assert conv_relu op in newly generated graph - int conv_count = 0; - int elementwise_add_count = 0; - - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "conv2d") { - ++conv_count; - } - if (node->IsOp() && node->Op()->Type() == "elementwise_add") { - ++elementwise_add_count; - } - } - EXPECT_EQ(conv_count, 1); - EXPECT_EQ(elementwise_add_count, 0); + AssertOpsCount(graph); } TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { - auto build_program_desc = [&]() -> ProgramDesc { - ProgramDesc prog; - for (auto& v : std::vector( - {"a", "b", "bias", "weights", "c", "d", "e", "f"})) { - auto* var = prog.MutableBlock(0)->Var(v); - var->SetType(proto::VarType::LOD_TENSOR); - if (v.find("weights") || v.find("bias")) { - var->SetPersistable(true); - } - } - - SetOp(&prog, "sigmoid", {"a"}, {"b"}); - SetOp(&prog, "conv2d", {"b", "bias", "weights"}, {"c"}); - SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"}); - SetOp(&prog, "relu", {"e"}, {"f"}); - - return prog; - }; + auto prog = + BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"}); + SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); + SetOp(&prog, "conv2d", + {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {"Output", "c"}); + SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"}); + SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"}); - auto prog = build_program_desc(); std::unique_ptr graph(new ir::Graph(prog)); IsReachable is_reachable; @@ -234,20 +237,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num); - // Assert conv_relu op in newly generated graph - int conv_count = 0; - int elementwise_add_count = 0; - - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "conv2d") { - ++conv_count; - } - if (node->IsOp() && node->Op()->Type() == "elementwise_add") { - ++elementwise_add_count; - } - } - EXPECT_EQ(conv_count, 1); - EXPECT_EQ(elementwise_add_count, 0); + AssertOpsCount(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index da83bcdf375..84475251937 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1014,7 +1014,7 @@ PDNode *patterns::Conv::operator()() { ->AsOutput() ->assert_is_op_output("conv2d", "Output"); - conv_op->LinksFrom({input_var, /*bias_var,*/ filter_var}); + conv_op->LinksFrom({input_var, filter_var}); conv_op->LinksTo({output_var}); return output_var; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 8e4f4a14ab7..63189d95d71 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -617,7 +617,6 @@ struct Conv : public PatternBase { PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_input); - PATTERN_DECL_NODE(conv_bias); PATTERN_DECL_NODE(conv_filter); PATTERN_DECL_NODE(conv_residual_data); PATTERN_DECL_NODE(conv_output); -- GitLab