diff --git a/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc index e0de720d049bf182592055c7c4e456a46b880f25..7e9a434be1c4bbc5358b537f242bb8ab1124c7ba 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc @@ -77,6 +77,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC( auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + VLOG(4) << "Fuse fc + elementwise_add as residual"; GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_input, input, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_weights, weights, fc_pattern); @@ -89,9 +90,28 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC( GET_IR_NODE_FROM_SUBGRAPH( elementwise_out, elementwise_out, elementwise_pattern); - if (FindFuseOption(*fc_op, *elementwise_op) != FUSE_MKLDNN) return; - if (!IsReachable(g, residual_data, fc_output)) return; - if (HasFusedActivation(fc_op)) return; + if (FindFuseOption(*fc_op, *elementwise_op) != FUSE_MKLDNN) { + VLOG(4) << "Skipping fusion for " << fc_op->Name() << "(" << fc_op->id() + << ") with " << elementwise_op->Name() << "(" + << elementwise_op->id() + << ") because not both ops have use_mkldnn"; + return; + } + if (!IsReachable(g, residual_data, fc_output)) { + VLOG(4) << "Skipping fusion for " << fc_op->Name() << "(" << fc_op->id() + << ") with " << elementwise_op->Name() << "(" + << elementwise_op->id() << ") because residual input " + << residual_data->Name() << "(" << residual_data->id() + << ") is not " + "reachable"; + return; + } + if (HasFusedActivation(fc_op)) { + VLOG(4) << "Skipping fusion for " << fc_op->Name() << "(" << fc_op->id() + << ") with " << elementwise_op->Name() << "(" + << elementwise_op->id() << ") because fc has activation fused"; + return; + } if (!IsCompat(subgraph, g)) { LOG(WARNING) diff --git a/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc index 73193a3e9042440aa7ab5e7aabb28a7cab6823be..67db173ea31ec8b6e772a1b0a5bb1de942e42cca 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -139,6 +139,28 @@ TEST(FCElementwiseAddMKLDNNFusePass, NoFusion_NotResidualConnection) { EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 2}, {"elementwise_add", 1}})); } +TEST(FCElementwiseAddMKLDNNFusePass, NoFusion_HasActivationFused) { + auto prog = + test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); + + test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); + OpDesc* fc = + Create_Op_FC(&prog, + {{"Input", "b"}, {"Bias", "bias"}, {"W", "weights"}}, + {{"Out", "c"}}); + std::string activation{"relu"}; + fc->SetAttr("activation_type", activation); + + Create_Op_elementwise_add(&prog, {{"X", "c"}, {"Y", "a"}}, {{"Out", "d"}}); + test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); + + Graph graph(prog); + + EXPECT_TRUE(test::RunPassAndAssert( + &graph, "fc_elementwise_add_mkldnn_fuse_pass", "a", "e", 0, 0)); + EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"elementwise_add", 1}})); +} + TEST(FCElementwiseAddMKLDNNFusePass, FC_Residual_VITOCR) { auto prog = test::BuildProgramDesc( {"a", "b", "c", "d", "e", "f", "g", "h", "i"},