diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index 2eb35291803f62b419eb880f0418708c865ca8b3..a9bc746680c1637bcfa6f30a2e5a265ab30c9c03 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -113,6 +113,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv( if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return; if (!IsReachable(g, residual_data, conv_output)) return; if (HasFusedActivation(conv_op)) return; + if (HasFusedElementwiseAdd(conv_op)) return; if (!IsCompat(subgraph, g)) { LOG(WARNING) @@ -120,6 +121,12 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv( return; } + if (residual_data->Var()->GetShape() != conv_output->Var()->GetShape()) { + LOG(WARNING) << "conv_elementwise_add_mkldnn_fuse_pass doesn't support " - + "broadcasting"; + return; + } + conv_op->Op()->SetInput("ResidualData", {residual_data->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()}); conv_op->Op()->SetAttr("fuse_residual_connection", true); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h index 7c6e9927163c71518cc9062d124d3007aa51bb7b..86f65480ad1d9e18e1b01cd5800d2fa6adf43f49 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h @@ -44,6 +44,9 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ->GetAttrIfExists("fuse_activation") .empty()); } + static bool HasFusedElementwiseAdd(Node* conv_node) { + return conv_node->Op()->GetAttrIfExists("fuse_residual_connection"); + } const std::string name_scope_{"residual_connection_fuse_pass"}; };