提交 bc2df2c9 编写于 作者: Y YuJianfeng

Fix inputs size and attr for AddN fission pass

上级 ae7556ff
......@@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_
new_addn->set_scope(origin_addn_cnode->scope());
new_addn->set_abstract(origin_addn_cnode->abstract());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn);
std::vector<int> dyn_input_sizes{SizeToInt(offset)};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn);
return new_addn;
}
} // namespace
......@@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN
}
CNodePtr new_cnode = cnode;
while (origin_input_size > inputs_divisor_) {
MS_EXCEPTION_IF_NULL(new_cnode);
std::vector<AnfNodePtr> base_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
size_t cur_input_index = 1;
// Divide the inputs of addn by 63.
while (origin_input_size - cur_input_index + 1 > inputs_divisor_) {
// Divide the inputs of addn by inputs_divisor_.
while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_));
cur_input_index += inputs_divisor_;
}
base_addn_inputs.push_back(
CreateNewAddn(func_graph, new_cnode, cur_input_index, origin_input_size - cur_input_index + 1));
for (size_t i = cur_input_index; i <= origin_input_size; i++) {
base_addn_inputs.push_back(new_cnode->input(i));
}
CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs);
MS_EXCEPTION_IF_NULL(base_addn);
MS_EXCEPTION_IF_NULL(new_cnode);
base_addn->set_scope(new_cnode->scope());
base_addn->set_abstract(new_cnode->abstract());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn);
std::vector<int> dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn);
new_cnode = base_addn;
origin_input_size = base_addn->inputs().size() - 1;
}
......
......@@ -149,7 +149,7 @@ constexpr auto kAttrDynInputSizes = "dyn_input_sizes";
constexpr auto kAttrSrcFormat = "src_format";
constexpr auto kAttrOutputUsedNum = "output_used_num";
constexpr auto kAttrHasBias = "has_bias";
constexpr auto kAttrN = "N";
constexpr auto kAttrN = "n";
constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active";
// attr value
......
......@@ -45,13 +45,10 @@ def test_addn_fission(tag):
b = addn((input2, input3))
c = addn((input4, input5))
d = addn((input6, input7))
e = addn((input8,))
f = addn((a, b))
g = addn((c, d))
h = addn((e,))
i = addn((f, g))
j = addn((h,))
return addn((i, j))
return addn((i, input8))
@fns
def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8):
......@@ -64,14 +61,12 @@ def test_addn_fission(tag):
def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8):
a = addn((input0, input1, input2, input3))
b = addn((input4, input5, input6, input7))
c = addn((input8,))
return addn((a, b, c))
return addn((a, b, input8))
@fns
def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8):
a = addn((input0, input1, input2, input3, input4, input5, input6, input7))
b = addn((input8,))
return addn((a, b))
return addn((a, input8))
@fns
def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册