Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
11a4b35c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
11a4b35c
编写于
4月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!472 Fix inputs size and attr for AddN fission pass
Merge pull request !472 from YuJianfeng/master
上级
2f1a037f
bc2df2c9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
14 addition
and
15 deletion
+14
-15
mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc
...pore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc
+10
-6
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+1
-1
tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py
...ython_input/gtest_input/pre_activate/addn_fission_test.py
+3
-8
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc
浏览文件 @
11a4b35c
...
...
@@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_
new_addn
->
set_scope
(
origin_addn_cnode
->
scope
());
new_addn
->
set_abstract
(
origin_addn_cnode
->
abstract
());
AnfAlgo
::
SetNodeAttr
(
kAttrN
,
MakeValue
(
SizeToInt
(
offset
)),
new_addn
);
std
::
vector
<
int
>
dyn_input_sizes
{
SizeToInt
(
offset
)};
AnfAlgo
::
SetNodeAttr
(
kAttrDynInputSizes
,
MakeValue
(
dyn_input_sizes
),
new_addn
);
return
new_addn
;
}
}
// namespace
...
...
@@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN
}
CNodePtr
new_cnode
=
cnode
;
while
(
origin_input_size
>
inputs_divisor_
)
{
MS_EXCEPTION_IF_NULL
(
new_cnode
);
std
::
vector
<
AnfNodePtr
>
base_addn_inputs
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimAddN
->
name
()))};
size_t
cur_input_index
=
1
;
// Divide the inputs of addn by
63
.
while
(
origin_input_size
-
cur_input_index
+
1
>
inputs_divisor_
)
{
// Divide the inputs of addn by
inputs_divisor_
.
while
(
origin_input_size
-
cur_input_index
+
1
>
=
inputs_divisor_
)
{
base_addn_inputs
.
push_back
(
CreateNewAddn
(
func_graph
,
new_cnode
,
cur_input_index
,
inputs_divisor_
));
cur_input_index
+=
inputs_divisor_
;
}
base_addn_inputs
.
push_back
(
CreateNewAddn
(
func_graph
,
new_cnode
,
cur_input_index
,
origin_input_size
-
cur_input_index
+
1
));
for
(
size_t
i
=
cur_input_index
;
i
<=
origin_input_size
;
i
++
)
{
base_addn_inputs
.
push_back
(
new_cnode
->
input
(
i
));
}
CNodePtr
base_addn
=
func_graph
->
NewCNode
(
base_addn_inputs
);
MS_EXCEPTION_IF_NULL
(
base_addn
);
MS_EXCEPTION_IF_NULL
(
new_cnode
);
base_addn
->
set_scope
(
new_cnode
->
scope
());
base_addn
->
set_abstract
(
new_cnode
->
abstract
());
AnfAlgo
::
SetNodeAttr
(
kAttrN
,
MakeValue
(
SizeToInt
(
base_addn_inputs
.
size
()
-
1
)),
base_addn
);
std
::
vector
<
int
>
dyn_input_sizes
{
SizeToInt
(
base_addn_inputs
.
size
()
-
1
)};
AnfAlgo
::
SetNodeAttr
(
kAttrDynInputSizes
,
MakeValue
(
dyn_input_sizes
),
base_addn
);
new_cnode
=
base_addn
;
origin_input_size
=
base_addn
->
inputs
().
size
()
-
1
;
}
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
11a4b35c
...
...
@@ -149,7 +149,7 @@ constexpr auto kAttrDynInputSizes = "dyn_input_sizes";
constexpr
auto
kAttrSrcFormat
=
"src_format"
;
constexpr
auto
kAttrOutputUsedNum
=
"output_used_num"
;
constexpr
auto
kAttrHasBias
=
"has_bias"
;
constexpr
auto
kAttrN
=
"
N
"
;
constexpr
auto
kAttrN
=
"
n
"
;
constexpr
auto
kAttrLabelForInsertStreamActive
=
"label_for_insert_stream_active"
;
// attr value
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py
浏览文件 @
11a4b35c
...
...
@@ -45,13 +45,10 @@ def test_addn_fission(tag):
b
=
addn
((
input2
,
input3
))
c
=
addn
((
input4
,
input5
))
d
=
addn
((
input6
,
input7
))
e
=
addn
((
input8
,))
f
=
addn
((
a
,
b
))
g
=
addn
((
c
,
d
))
h
=
addn
((
e
,))
i
=
addn
((
f
,
g
))
j
=
addn
((
h
,))
return
addn
((
i
,
j
))
return
addn
((
i
,
input8
))
@
fns
def
after_divided_by_3
(
input0
,
input1
,
input2
,
input3
,
input4
,
input5
,
input6
,
input7
,
input8
):
...
...
@@ -64,14 +61,12 @@ def test_addn_fission(tag):
def
after_divided_by_4
(
input0
,
input1
,
input2
,
input3
,
input4
,
input5
,
input6
,
input7
,
input8
):
a
=
addn
((
input0
,
input1
,
input2
,
input3
))
b
=
addn
((
input4
,
input5
,
input6
,
input7
))
c
=
addn
((
input8
,))
return
addn
((
a
,
b
,
c
))
return
addn
((
a
,
b
,
input8
))
@
fns
def
after_divided_by_8
(
input0
,
input1
,
input2
,
input3
,
input4
,
input5
,
input6
,
input7
,
input8
):
a
=
addn
((
input0
,
input1
,
input2
,
input3
,
input4
,
input5
,
input6
,
input7
))
b
=
addn
((
input8
,))
return
addn
((
a
,
b
))
return
addn
((
a
,
input8
))
@
fns
def
after_divided_by_9
(
input0
,
input1
,
input2
,
input3
,
input4
,
input5
,
input6
,
input7
,
input8
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录