From 42c694df3ae82e29f2fb16b64e93ab3badfde5b3 Mon Sep 17 00:00:00 2001 From: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> Date: Tue, 9 Aug 2022 19:32:39 +0800 Subject: [PATCH] fix mkldnn conv add pass when the dims of res and out are not equel (#45018) --- .../ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc | 7 +++++++ .../ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h | 3 +++ 2 files changed, 10 insertions(+) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index 2eb35291803..a9bc746680c 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -113,6 +113,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv( if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return; if (!IsReachable(g, residual_data, conv_output)) return; if (HasFusedActivation(conv_op)) return; + if (HasFusedElementwiseAdd(conv_op)) return; if (!IsCompat(subgraph, g)) { LOG(WARNING) @@ -120,6 +121,12 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv( return; } + if (residual_data->Var()->GetShape() != conv_output->Var()->GetShape()) { + LOG(WARNING) << "conv_elementwise_add_mkldnn_fuse_pass doesn't support " - + "broadcasting"; + return; + } + conv_op->Op()->SetInput("ResidualData", {residual_data->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()}); conv_op->Op()->SetAttr("fuse_residual_connection", true); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h index 7c6e9927163..86f65480ad1 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h @@ -44,6 +44,9 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ->GetAttrIfExists("fuse_activation") .empty()); } + static bool HasFusedElementwiseAdd(Node* conv_node) { + return conv_node->Op()->GetAttrIfExists("fuse_residual_connection"); + } const std::string name_scope_{"residual_connection_fuse_pass"}; }; -- GitLab