From 21668cb2a657faf6581071da5f3150b12a3e4589 Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Tue, 11 Oct 2022 14:34:48 +0200 Subject: [PATCH] add logging to fc residual fuse pass (#46760) * add logging to fc residual fuse pass * expand logging message to fc residual fuse pass * Add test for fc residual not fusing with activation --- .../fc_elementwise_add_mkldnn_fuse_pass.cc | 26 ++++++++++++++++--- ...elementwise_add_mkldnn_fuse_pass_tester.cc | 22 ++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc index e0de720d049..7e9a434be1c 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc @@ -77,6 +77,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC( auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + VLOG(4) << "Fuse fc + elementwise_add as residual"; GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_input, input, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_weights, weights, fc_pattern); @@ -89,9 +90,28 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC( GET_IR_NODE_FROM_SUBGRAPH( elementwise_out, elementwise_out, elementwise_pattern); - if (FindFuseOption(*fc_op, *elementwise_op) != FUSE_MKLDNN) return; - if (!IsReachable(g, residual_data, fc_output)) return; - if (HasFusedActivation(fc_op)) return; + if (FindFuseOption(*fc_op, *elementwise_op) != FUSE_MKLDNN) { + VLOG(4) << "Skipping fusion for " << fc_op->Name() << "(" << fc_op->id() + << ") with " << elementwise_op->Name() << "(" + << elementwise_op->id() + << ") because not both ops have use_mkldnn"; + return; + } + if (!IsReachable(g, residual_data, fc_output)) { + VLOG(4) << "Skipping fusion for " << fc_op->Name() << "(" << fc_op->id() + << ") with " << elementwise_op->Name() << "(" + << elementwise_op->id() << ") because residual input " + << residual_data->Name() << "(" << residual_data->id() + << ") is not " + "reachable"; + return; + } + if (HasFusedActivation(fc_op)) { + VLOG(4) << "Skipping fusion for " << fc_op->Name() << "(" << fc_op->id() + << ") with " << elementwise_op->Name() << "(" + << elementwise_op->id() << ") because fc has activation fused"; + return; + } if (!IsCompat(subgraph, g)) { LOG(WARNING) diff --git a/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc index 73193a3e904..67db173ea31 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -139,6 +139,28 @@ TEST(FCElementwiseAddMKLDNNFusePass, NoFusion_NotResidualConnection) { EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 2}, {"elementwise_add", 1}})); } +TEST(FCElementwiseAddMKLDNNFusePass, NoFusion_HasActivationFused) { + auto prog = + test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); + + test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); + OpDesc* fc = + Create_Op_FC(&prog, + {{"Input", "b"}, {"Bias", "bias"}, {"W", "weights"}}, + {{"Out", "c"}}); + std::string activation{"relu"}; + fc->SetAttr("activation_type", activation); + + Create_Op_elementwise_add(&prog, {{"X", "c"}, {"Y", "a"}}, {{"Out", "d"}}); + test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); + + Graph graph(prog); + + EXPECT_TRUE(test::RunPassAndAssert( + &graph, "fc_elementwise_add_mkldnn_fuse_pass", "a", "e", 0, 0)); + EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"elementwise_add", 1}})); +} + TEST(FCElementwiseAddMKLDNNFusePass, FC_Residual_VITOCR) { auto prog = test::BuildProgramDesc( {"a", "b", "c", "d", "e", "f", "g", "h", "i"}, -- GitLab