Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_44025039
mindspore
提交
b812c1a1
M
mindspore
项目概览
weixin_44025039
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b812c1a1
编写于
7月 23, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support call super when class define in test_case.
上级
684ff4f4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
57 addition
and
22 deletion
+57
-22
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+19
-19
mindspore/nn/cell.py
mindspore/nn/cell.py
+3
-0
tests/ut/python/pipeline/parse/test_super.py
tests/ut/python/pipeline/parse/test_super.py
+35
-3
未找到文件。
mindspore/_extends/parse/parser.py
浏览文件 @
b812c1a1
...
...
@@ -459,27 +459,27 @@ class Parser:
logger
.
debug
(
"ops info = %r"
,
ops_info
)
return
ops_info
def
analyze_super
(
self
,
father_class
_node
,
subclass_instance
):
def
analyze_super
(
self
,
class_type
_node
,
subclass_instance
):
"""Analyze super and return a class instance."""
father_class
=
None
if
father_class_node
is
None
:
father_class
=
type
(
subclass_instance
)
if
isinstance
(
father_class_node
,
ast
.
Name
):
father_class_name
=
getattr
(
father_class_node
,
'id'
)
father_class
=
self
.
global_namespace
[
father_class_name
]
if
isinstance
(
father_class_node
,
ast
.
Attribute
):
value
=
getattr
(
father_class_node
,
'value'
)
attr
=
getattr
(
father_class_node
,
'attr'
)
module_name
=
getattr
(
value
,
'id'
)
father_class_module
=
self
.
global_namespace
[
module_name
]
father_class
=
getattr
(
father_class_module
,
attr
)
if
father_class
is
None
:
raise
ValueError
(
"When call 'super', the father class is None."
)
if
not
isinstance
(
subclass_instance
,
father_class
):
sub_class
=
type
(
subclass_instance
)
if
class_type_node
is
None
:
return
super
(
sub_class
,
subclass_instance
)
if
isinstance
(
class_type_node
,
ast
.
Name
):
class_name
=
getattr
(
class_type_node
,
'id'
)
elif
isinstance
(
class_type_node
,
ast
.
Attribute
):
class_name
=
getattr
(
class_type_node
,
'attr'
)
else
:
raise
ValueError
(
f
"When call 'super', the first arg should be a class type, "
f
"but got
{
class_type_node
.
__class__
.
__name__
}
."
)
target_father_class
=
None
for
class_element
in
sub_class
.
mro
():
if
class_element
.
__name__
==
class_name
:
target_father_class
=
class_element
break
if
target_father_class
is
None
:
raise
ValueError
(
"When call 'super', the second arg should be an instance of first arg."
)
target_class_instance
=
super
(
father_class
,
subclass_instance
)
return
target_class_instance
return
super
(
target_father_class
,
subclass_instance
)
def
get_location
(
self
,
node
):
"""
...
...
mindspore/nn/cell.py
浏览文件 @
b812c1a1
...
...
@@ -58,6 +58,7 @@ class Cell:
>>> def construct(self, x):
>>> return self.relu(x)
"""
def
__init__
(
self
,
auto_prefix
=
True
,
flags
=
None
):
self
.
_params
=
OrderedDict
()
self
.
_cells
=
OrderedDict
()
...
...
@@ -888,6 +889,7 @@ class Cell:
for
param
in
params
:
param
.
set_param_ps
(
init_in_server
)
class
GraphKernel
(
Cell
):
"""
Base class for GraphKernel.
...
...
@@ -904,6 +906,7 @@ class GraphKernel(Cell):
>>> def construct(self, x):
>>> return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
"""
def
__init__
(
self
,
auto_prefix
=
True
,
pips
=
None
):
super
(
GraphKernel
,
self
).
__init__
(
auto_prefix
,
pips
)
class_name
=
self
.
__class__
.
__name__
...
...
tests/ut/python/pipeline/parse/test_super.py
浏览文件 @
b812c1a1
...
...
@@ -92,7 +92,7 @@ class Net(nn.Cell):
def
test_single_super
():
single_net
=
SingleSubNet
(
2
,
3
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
y
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
single_net
(
x
,
y
)
...
...
@@ -100,7 +100,7 @@ def test_single_super():
def
test_mul_super
():
mul_net
=
MulSubNet
(
2
,
3
,
4
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
y
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
mul_net
(
x
,
y
)
...
...
@@ -108,9 +108,41 @@ def test_mul_super():
def
test_super_cell
():
net
=
Net
(
2
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
y
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
with
pytest
.
raises
(
RuntimeError
)
as
er
:
net
(
x
,
y
)
assert
"Unsupported syntax 'Raise'"
in
str
(
er
.
value
)
def
test_single_super_in
():
class
FatherNetIn
(
nn
.
Cell
):
def
__init__
(
self
,
x
):
super
(
FatherNetIn
,
self
).
__init__
(
x
)
self
.
x
=
x
def
construct
(
self
,
x
,
y
):
return
self
.
x
*
x
def
test_father
(
self
,
x
):
return
self
.
x
+
x
class
SingleSubNetIN
(
FatherNetIn
):
def
__init__
(
self
,
x
,
z
):
super
(
SingleSubNetIN
,
self
).
__init__
(
x
)
self
.
z
=
z
def
construct
(
self
,
x
,
y
):
ret_father_construct
=
super
().
construct
(
x
,
y
)
ret_father_test
=
super
(
SingleSubNetIN
,
self
).
test_father
(
x
)
ret_father_x
=
super
(
SingleSubNetIN
,
self
).
x
ret_sub_z
=
self
.
z
return
ret_father_construct
,
ret_father_test
,
ret_father_x
,
ret_sub_z
single_net_in
=
SingleSubNetIN
(
2
,
3
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
y
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
single_net_in
(
x
,
y
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录