未验证 提交 42c694df 编写于 作者: Y yeliang2258 提交者: GitHub

fix mkldnn conv add pass when the dims of res and out are not equel (#45018)

上级 4615af2c
......@@ -113,6 +113,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, residual_data, conv_output)) return;
if (HasFusedActivation(conv_op)) return;
if (HasFusedElementwiseAdd(conv_op)) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
......@@ -120,6 +121,12 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
return;
}
if (residual_data->Var()->GetShape() != conv_output->Var()->GetShape()) {
LOG(WARNING) << "conv_elementwise_add_mkldnn_fuse_pass doesn't support " -
"broadcasting";
return;
}
conv_op->Op()->SetInput("ResidualData", {residual_data->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
......
......@@ -44,6 +44,9 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
->GetAttrIfExists<std::string>("fuse_activation")
.empty());
}
static bool HasFusedElementwiseAdd(Node* conv_node) {
return conv_node->Op()->GetAttrIfExists<bool>("fuse_residual_connection");
}
const std::string name_scope_{"residual_connection_fuse_pass"};
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册