提交 11a4b35c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!472 Fix inputs size and attr for AddN fission pass

Merge pull request !472 from YuJianfeng/master
...@@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_ ...@@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_
new_addn->set_scope(origin_addn_cnode->scope()); new_addn->set_scope(origin_addn_cnode->scope());
new_addn->set_abstract(origin_addn_cnode->abstract()); new_addn->set_abstract(origin_addn_cnode->abstract());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn);
std::vector<int> dyn_input_sizes{SizeToInt(offset)};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn);
return new_addn; return new_addn;
} }
} // namespace } // namespace
...@@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN ...@@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN
} }
CNodePtr new_cnode = cnode; CNodePtr new_cnode = cnode;
while (origin_input_size > inputs_divisor_) { while (origin_input_size > inputs_divisor_) {
MS_EXCEPTION_IF_NULL(new_cnode);
std::vector<AnfNodePtr> base_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))}; std::vector<AnfNodePtr> base_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
size_t cur_input_index = 1; size_t cur_input_index = 1;
// Divide the inputs of addn by 63. // Divide the inputs of addn by inputs_divisor_.
while (origin_input_size - cur_input_index + 1 > inputs_divisor_) { while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_));
cur_input_index += inputs_divisor_; cur_input_index += inputs_divisor_;
} }
base_addn_inputs.push_back( for (size_t i = cur_input_index; i <= origin_input_size; i++) {
CreateNewAddn(func_graph, new_cnode, cur_input_index, origin_input_size - cur_input_index + 1)); base_addn_inputs.push_back(new_cnode->input(i));
}
CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs);
MS_EXCEPTION_IF_NULL(base_addn); MS_EXCEPTION_IF_NULL(base_addn);
MS_EXCEPTION_IF_NULL(new_cnode);
base_addn->set_scope(new_cnode->scope()); base_addn->set_scope(new_cnode->scope());
base_addn->set_abstract(new_cnode->abstract()); base_addn->set_abstract(new_cnode->abstract());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn);
std::vector<int> dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn);
new_cnode = base_addn; new_cnode = base_addn;
origin_input_size = base_addn->inputs().size() - 1; origin_input_size = base_addn->inputs().size() - 1;
} }
......
...@@ -149,7 +149,7 @@ constexpr auto kAttrDynInputSizes = "dyn_input_sizes"; ...@@ -149,7 +149,7 @@ constexpr auto kAttrDynInputSizes = "dyn_input_sizes";
constexpr auto kAttrSrcFormat = "src_format"; constexpr auto kAttrSrcFormat = "src_format";
constexpr auto kAttrOutputUsedNum = "output_used_num"; constexpr auto kAttrOutputUsedNum = "output_used_num";
constexpr auto kAttrHasBias = "has_bias"; constexpr auto kAttrHasBias = "has_bias";
constexpr auto kAttrN = "N"; constexpr auto kAttrN = "n";
constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active"; constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active";
// attr value // attr value
......
...@@ -45,13 +45,10 @@ def test_addn_fission(tag): ...@@ -45,13 +45,10 @@ def test_addn_fission(tag):
b = addn((input2, input3)) b = addn((input2, input3))
c = addn((input4, input5)) c = addn((input4, input5))
d = addn((input6, input7)) d = addn((input6, input7))
e = addn((input8,))
f = addn((a, b)) f = addn((a, b))
g = addn((c, d)) g = addn((c, d))
h = addn((e,))
i = addn((f, g)) i = addn((f, g))
j = addn((h,)) return addn((i, input8))
return addn((i, j))
@fns @fns
def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8):
...@@ -64,14 +61,12 @@ def test_addn_fission(tag): ...@@ -64,14 +61,12 @@ def test_addn_fission(tag):
def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8):
a = addn((input0, input1, input2, input3)) a = addn((input0, input1, input2, input3))
b = addn((input4, input5, input6, input7)) b = addn((input4, input5, input6, input7))
c = addn((input8,)) return addn((a, b, input8))
return addn((a, b, c))
@fns @fns
def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8): def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8):
a = addn((input0, input1, input2, input3, input4, input5, input6, input7)) a = addn((input0, input1, input2, input3, input4, input5, input6, input7))
b = addn((input8,)) return addn((a, input8))
return addn((a, b))
@fns @fns
def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8): def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册