未验证 提交 21668cb2 编写于 作者: S Sylwester Fraczek 提交者: GitHub

add logging to fc residual fuse pass (#46760)

* add logging to fc residual fuse pass

* expand logging message to fc residual fuse pass

* Add test for fc residual not fusing with activation
上级 01bf3b92
......@@ -77,6 +77,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC(
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Fuse fc + elementwise_add as residual";
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_input, input, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_weights, weights, fc_pattern);
......@@ -89,9 +90,28 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC(
GET_IR_NODE_FROM_SUBGRAPH(
elementwise_out, elementwise_out, elementwise_pattern);
if (FindFuseOption(*fc_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, residual_data, fc_output)) return;
if (HasFusedActivation(fc_op)) return;
if (FindFuseOption(*fc_op, *elementwise_op) != FUSE_MKLDNN) {
VLOG(4) << "Skipping fusion for " << fc_op->Name() << "(" << fc_op->id()
<< ") with " << elementwise_op->Name() << "("
<< elementwise_op->id()
<< ") because not both ops have use_mkldnn";
return;
}
if (!IsReachable(g, residual_data, fc_output)) {
VLOG(4) << "Skipping fusion for " << fc_op->Name() << "(" << fc_op->id()
<< ") with " << elementwise_op->Name() << "("
<< elementwise_op->id() << ") because residual input "
<< residual_data->Name() << "(" << residual_data->id()
<< ") is not "
"reachable";
return;
}
if (HasFusedActivation(fc_op)) {
VLOG(4) << "Skipping fusion for " << fc_op->Name() << "(" << fc_op->id()
<< ") with " << elementwise_op->Name() << "("
<< elementwise_op->id() << ") because fc has activation fused";
return;
}
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
......
......@@ -139,6 +139,28 @@ TEST(FCElementwiseAddMKLDNNFusePass, NoFusion_NotResidualConnection) {
EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 2}, {"elementwise_add", 1}}));
}
TEST(FCElementwiseAddMKLDNNFusePass, NoFusion_HasActivationFused) {
auto prog =
test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
OpDesc* fc =
Create_Op_FC(&prog,
{{"Input", "b"}, {"Bias", "bias"}, {"W", "weights"}},
{{"Out", "c"}});
std::string activation{"relu"};
fc->SetAttr("activation_type", activation);
Create_Op_elementwise_add(&prog, {{"X", "c"}, {"Y", "a"}}, {{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog);
EXPECT_TRUE(test::RunPassAndAssert(
&graph, "fc_elementwise_add_mkldnn_fuse_pass", "a", "e", 0, 0));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"elementwise_add", 1}}));
}
TEST(FCElementwiseAddMKLDNNFusePass, FC_Residual_VITOCR) {
auto prog = test::BuildProgramDesc(
{"a", "b", "c", "d", "e", "f", "g", "h", "i"},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册