Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
abab21ed
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
abab21ed
编写于
8月 21, 2020
作者:
P
panyifeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add func type check for switch layer
上级
492e41a4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
30 addition
and
14 deletion
+30
-14
mindspore/ccsrc/frontend/operator/composite/composite.cc
mindspore/ccsrc/frontend/operator/composite/composite.cc
+6
-12
mindspore/core/abstract/prim_statement.cc
mindspore/core/abstract/prim_statement.cc
+2
-2
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+22
-0
未找到文件。
mindspore/ccsrc/frontend/operator/composite/composite.cc
浏览文件 @
abab21ed
...
...
@@ -989,19 +989,13 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar
// args: tuple of items, index
const
std
::
string
op_name
=
std
::
string
(
"TupleGetItemTensor"
);
abstract
::
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractTuplePtr
branches_abs
=
abstract
::
CheckArg
<
AbstractTuple
>
(
op_name
,
args_spec_list
,
0
);
AbstractBasePtrList
branches
=
branches_abs
->
elements
();
if
(
branches
.
size
()
>
0
&&
branches
[
0
]
!=
nullptr
&&
branches
[
0
]
->
isa
<
AbstractFunction
>
())
{
FuncGraphPtr
ret_graph
=
std
::
make_shared
<
FuncGraph
>
();
ret_graph
->
set_flag
(
FUNC_GRAPH_FLAG_CORE
,
true
);
AnfNodePtr
functions
=
ret_graph
->
add_parameter
();
auto
index
=
ret_graph
->
add_parameter
();
auto
ret_graph
=
std
::
make_shared
<
FuncGraph
>
();
ret_graph
->
set_flag
(
FUNC_GRAPH_FLAG_CORE
,
true
);
auto
functions
=
ret_graph
->
add_parameter
();
auto
index
=
ret_graph
->
add_parameter
();
ret_graph
->
set_output
(
ret_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimSwitchLayer
),
index
,
functions
}));
return
ret_graph
;
}
MS_LOG
(
EXCEPTION
)
<<
"TupleGetItemTensor does not support to index "
<<
branches_abs
->
ToString
()
<<
"."
;
ret_graph
->
set_output
(
ret_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimSwitchLayer
),
index
,
functions
}));
return
ret_graph
;
}
REGISTER_PYBIND_DEFINE
(
TupleAdd_
,
([](
const
py
::
module
*
m
)
{
...
...
mindspore/core/abstract/prim_statement.cc
浏览文件 @
abab21ed
...
...
@@ -114,14 +114,14 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
AbstractTuplePtr
branches_abs
=
CheckArg
<
AbstractTuple
>
(
op_name
,
args_spec_list
,
1
);
AbstractBasePtrList
branches
=
branches_abs
->
elements
();
const
size_t
maximum_layer_num
=
1000
;
if
(
branches
.
size
()
<
0
||
branches
.
size
()
>
maximum_layer_num
)
{
if
(
branches
.
size
()
<
1
||
branches
.
size
()
>
maximum_layer_num
)
{
MS_EXCEPTION
(
ValueError
)
<<
op_name
<<
" support at least 1 and at most "
<<
maximum_layer_num
<<
" but got "
<<
branches
.
size
()
<<
" branches."
;
}
for
(
size_t
i
=
0
;
i
<
branches
.
size
();
i
++
)
{
MS_EXCEPTION_IF_NULL
(
branches
[
i
]);
if
(
!
branches
[
i
]
->
isa
<
AbstractFunction
>
())
{
if
(
!
branches
[
i
]
->
isa
<
FuncGraphAbstractClosure
>
())
{
MS_EXCEPTION
(
ValueError
)
<<
op_name
<<
" requires that the 2th arg be tuple of functions, but got "
<<
branches
[
i
]
->
ToString
()
<<
" as the "
<<
i
<<
"th element."
;
}
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
abab21ed
...
...
@@ -851,3 +851,25 @@ def test_tensor_all_construct_lack_branch():
net
=
NetConditionLackBranch
()
with
pytest
.
raises
(
Exception
):
net
(
input_tensor_1
,
input_tensor_2
)
def
test_parser_switch_layer_func_primitive
():
class
FinalNet
(
nn
.
Cell
):
def
__init__
(
self
,
funcs
):
super
().
__init__
()
self
.
funcs
=
funcs
def
construct
(
self
,
i
,
input1
):
x
=
self
.
funcs
[
i
](
input1
)
return
x
func1
=
P
.
ReLU
()
func2
=
P
.
Softmax
()
funcs
=
(
func1
,
func2
)
net
=
FinalNet
(
funcs
)
input1
=
Tensor
(
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
))
i
=
Tensor
(
1
,
mstype
.
int32
)
with
pytest
.
raises
(
ValueError
):
net
(
i
,
input1
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录