diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc index 4bcf90444494bb9c19ca3ed9d3d1262cf54c2a1e..159be2ac3b4e79619660904950912bde7cc17e67 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc @@ -23,42 +23,8 @@ namespace mindspore { namespace opt { namespace { -const std::vector kOutputIndex{0, 1, 2, 3, 4}; -constexpr size_t kBatchNormLeastOutputNum = 1; constexpr size_t kBatchNormRealInputNum = 3; -bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto iter = manager->node_users().find(bn); - if (iter == manager->node_users().end()) { - return false; - } - size_t output_num = 0; - for (const auto &node_index : iter->second) { - AnfNodePtr output = node_index.first; - MS_EXCEPTION_IF_NULL(output); - if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { - continue; - } - auto tuple_getiterm_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); - auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) { - return false; - } - bn_outputs->push_back(output); - output_num++; - } - return output_num >= kBatchNormLeastOutputNum; -} - AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(bn); @@ -140,34 +106,12 @@ const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph, MS_LOG(INFO) << "is training should be true if do fusion"; return nullptr; } - std::vector bn_outputs; - if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { - MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; - return nullptr; - } AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); std::vector bn_training_reduce_outputs; CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, &bn_training_reduce_outputs); - AnfNodePtr bn_training_update_v3 = CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); - std::vector bn_training_update_v3_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_v3, kBNTrainingUpdateV3OutputNum, - &bn_training_update_v3_outputs); - if (bn_training_update_v3_outputs.size() != kBNTrainingUpdateV3OutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingUpdateV2OutputNum - << ", but it is " << bn_training_update_v3_outputs.size(); - } - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem); - size_t output_index = 0; - for (const auto &output : bn_outputs) { - (void)manager->Replace(output, bn_training_update_v3_outputs[output_index]); - output_index++; - } - // Return the new node for control depends. - return bn_training_update_v3; + return CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); } } // namespace opt } // namespace mindspore