Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
16b9004d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
16b9004d
编写于
5月 11, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix bug of assign value to non Parameter class member
上级
ea4836e1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
23 addition
and
2 deletion
+23
-2
mindspore/ccsrc/pipeline/parse/parse.cc
mindspore/ccsrc/pipeline/parse/parse.cc
+22
-1
tests/ut/python/pynative_mode/test_insert_grad_of.py
tests/ut/python/pynative_mode/test_insert_grad_of.py
+1
-1
未找到文件。
mindspore/ccsrc/pipeline/parse/parse.cc
浏览文件 @
16b9004d
...
...
@@ -1136,10 +1136,31 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
AnfNodePtr
target_node
=
ParseExprNode
(
block
,
targ
);
MS_EXCEPTION_IF_NULL
(
target_node
);
std
::
string
attr_name
=
targ
.
attr
(
"attr"
).
cast
<
std
::
string
>
();
std
::
string
var_name
=
"self."
;
(
void
)
var_name
.
append
(
targ
.
attr
(
"attr"
).
cast
<
std
::
string
>
()
);
(
void
)
var_name
.
append
(
attr_name
);
MS_LOG
(
DEBUG
)
<<
"assign "
<<
var_name
;
// Get targ location info for error printing
py
::
list
location
=
ast_
->
CallParserObjMethod
(
PYTHON_PARSE_GET_LOCATION
,
targ
);
if
(
location
.
size
()
<
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"List size should not be less than 2."
;
}
auto
filename
=
location
[
0
].
cast
<
std
::
string
>
();
auto
line_no
=
location
[
1
].
cast
<
int
>
();
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
if
(
!
py
::
hasattr
(
ast
()
->
obj
(),
attr_name
.
c_str
()))
{
MS_EXCEPTION
(
TypeError
)
<<
"'"
<<
var_name
<<
"' should be a Parameter, but not defined, at "
<<
filename
<<
":"
<<
line_no
;
}
auto
obj
=
ast
()
->
obj
().
attr
(
attr_name
.
c_str
());
auto
obj_type
=
obj
.
attr
(
"__class__"
).
attr
(
"__name__"
);
if
(
!
py
::
hasattr
(
obj
,
"__parameter__"
))
{
MS_EXCEPTION
(
TypeError
)
<<
"'"
<<
var_name
<<
"' should be a Parameter, but got '"
<<
py
::
str
(
obj
).
cast
<
std
::
string
>
()
<<
"' with type '"
<<
py
::
str
(
obj_type
).
cast
<
std
::
string
>
()
<<
"' at "
<<
filename
<<
":"
<<
line_no
;
}
MS_EXCEPTION_IF_NULL
(
block
);
block
->
WriteVariable
(
var_name
,
assigned_node
);
MS_LOG
(
DEBUG
)
<<
"SetState write "
<<
var_name
<<
" : "
<<
target_node
->
ToString
();
...
...
tests/ut/python/pynative_mode/test_insert_grad_of.py
浏览文件 @
16b9004d
...
...
@@ -124,9 +124,9 @@ def test_cell_assign():
class
Mul
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Mul
,
self
).
__init__
()
self
.
get_g
=
P
.
InsertGradientOf
(
self
.
save_gradient
)
self
.
matrix_w
=
mindspore
.
Parameter
(
Tensor
(
np
.
ones
([
2
,
2
],
np
.
float32
)),
name
=
"matrix_w"
)
self
.
matrix_g
=
mindspore
.
Parameter
(
Tensor
(
np
.
ones
([
2
,
2
],
np
.
float32
)),
name
=
"matrix_g"
)
self
.
get_g
=
P
.
InsertGradientOf
(
self
.
save_gradient
)
def
save_gradient
(
self
,
dout
):
self
.
matrix_g
=
dout
+
self
.
matrix_g
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录