diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc index 129c6e1f59918db56d677c4c30507491b7aa7b66..715fed7d79376428549ecba29596ac6e441192d9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc @@ -27,6 +27,20 @@ namespace mindspore { namespace opt { +namespace { +constexpr size_t kEltwiseInputSize = 2; +constexpr size_t kEltwiseOutputSize = 2; +bool CheckEltwiseInputAndOutputSize(const AnfNodePtr &node) { + if (AnfAlgo::GetInputTensorNum(node) == kEltwiseInputSize) { + return true; + } + if (AnfAlgo::GetOutputTensorNum(node) == kEltwiseOutputSize) { + return true; + } + return false; +} +} // namespace + void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { @@ -74,8 +88,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) { auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); }