diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc index f6eb6aca64ece922137077d2646327a4da011e77..b9a86f7bcb84440332c1048c7ab0e101a9ce932b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc @@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_ new_addn->set_scope(origin_addn_cnode->scope()); new_addn->set_abstract(origin_addn_cnode->abstract()); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); + std::vector dyn_input_sizes{SizeToInt(offset)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn); return new_addn; } } // namespace @@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN } CNodePtr new_cnode = cnode; while (origin_input_size > inputs_divisor_) { + MS_EXCEPTION_IF_NULL(new_cnode); std::vector base_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; size_t cur_input_index = 1; - // Divide the inputs of addn by 63. - while (origin_input_size - cur_input_index + 1 > inputs_divisor_) { + // Divide the inputs of addn by inputs_divisor_. + while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); cur_input_index += inputs_divisor_; } - base_addn_inputs.push_back( - CreateNewAddn(func_graph, new_cnode, cur_input_index, origin_input_size - cur_input_index + 1)); - + for (size_t i = cur_input_index; i <= origin_input_size; i++) { + base_addn_inputs.push_back(new_cnode->input(i)); + } CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); MS_EXCEPTION_IF_NULL(base_addn); - MS_EXCEPTION_IF_NULL(new_cnode); base_addn->set_scope(new_cnode->scope()); base_addn->set_abstract(new_cnode->abstract()); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); + std::vector dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn); new_cnode = base_addn; origin_input_size = base_addn->inputs().size() - 1; } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 10ef4abf62f334be15df9350db98be09a9ad2d7e..eac901b74de8fcfabbc9f8c866df3496d35e19e2 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -149,7 +149,7 @@ constexpr auto kAttrDynInputSizes = "dyn_input_sizes"; constexpr auto kAttrSrcFormat = "src_format"; constexpr auto kAttrOutputUsedNum = "output_used_num"; constexpr auto kAttrHasBias = "has_bias"; -constexpr auto kAttrN = "N"; +constexpr auto kAttrN = "n"; constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active"; // attr value diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py index c120ac3e68e7eed4aa185224315a1b72af3a6ef9..76d7e73a800b81dd0e784c65bdca1ccea3cc8cf5 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py @@ -45,13 +45,10 @@ def test_addn_fission(tag): b = addn((input2, input3)) c = addn((input4, input5)) d = addn((input6, input7)) - e = addn((input8,)) f = addn((a, b)) g = addn((c, d)) - h = addn((e,)) i = addn((f, g)) - j = addn((h,)) - return addn((i, j)) + return addn((i, input8)) @fns def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): @@ -64,14 +61,12 @@ def test_addn_fission(tag): def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): a = addn((input0, input1, input2, input3)) b = addn((input4, input5, input6, input7)) - c = addn((input8,)) - return addn((a, b, c)) + return addn((a, b, input8)) @fns def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8): a = addn((input0, input1, input2, input3, input4, input5, input6, input7)) - b = addn((input8,)) - return addn((a, b)) + return addn((a, input8)) @fns def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8):