diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.cc index c8d92f7200f2b5adf9441c1ee01c6617cdbe7cf3..66ffa24bf122fd78b38c9d8f482d9a1e3db7cf6b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.cc @@ -28,14 +28,14 @@ namespace mindspore { namespace opt { namespace { -void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, +bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, std::vector *bn_training_reduce_outputs) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(bn_cnode); if (bn_cnode->inputs().size() != kBnInputNum) { - MS_LOG(EXCEPTION) << "BN node has wrong input size"; + MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString(); + return false; } - // All the inputs of BNTrainingReduce are from the inputs of BN std::vector bn_training_reduce_inputs = { NewValueNode(std::make_shared(kBNTrainingReduceOpName))}; bn_training_reduce_inputs.push_back(bn_cnode->input(1)); @@ -45,8 +45,9 @@ void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr & MS_EXCEPTION_IF_NULL(kernel_info); bn_training_reduce->set_kernel_info(kernel_info); std::vector bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0); - if (bn_shape_i0.size() != kShape4dDims) { - MS_LOG(EXCEPTION) << "Get shape of FusedBatchNorm fail"; + if (bn_shape_i0.size() < kShape2dDims) { + MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims; + return false; } std::vector bn_training_reduce_shape = {bn_shape_i0[1]}; auto types = {kNumberTypeFloat32, kNumberTypeFloat32}; @@ -56,6 +57,7 @@ void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr & AnfAlgo::CopyNodeAttrs(bn_cnode, bn_training_reduce); CreateMultipleOutputsOfAnfNode(graph, bn_training_reduce, kBNTrainingReduceOutputNum, bn_training_reduce_outputs); + return true; } AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, @@ -99,11 +101,15 @@ AnfNodePtr SplitFusedBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNo auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() < kBnInputNum) { - MS_LOG(EXCEPTION) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs."; + MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs."; + return nullptr; } // Create BNTrainingReduce node and get outputs of BNTrainingReduce std::vector bn_training_reduce_outputs; - CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs); + if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) { + MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split"; + return nullptr; + } if (bn_training_reduce_outputs.size() != kBN1OutputNum) { MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"; }