Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
28dabf03
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
28dabf03
编写于
8月 05, 2020
作者:
K
kingfo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix grad flag update issue in pynative
上级
57fd31b2
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
21 addition
and
11 deletion
+21
-11
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
...re/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
+6
-0
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+1
-0
mindspore/nn/cell.py
mindspore/nn/cell.py
+7
-7
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+3
-3
tests/ut/cpp/optimizer/lib_test.cc
tests/ut/cpp/optimizer/lib_test.cc
+3
-0
tests/ut/python/pynative_mode/test_hook.py
tests/ut/python/pynative_mode/test_hook.py
+1
-1
未找到文件。
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
浏览文件 @
28dabf03
...
...
@@ -20,6 +20,9 @@ namespace mindspore {
namespace
opt
{
namespace
irpass
{
AnfNodePtr
ArithmeticSimplify
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
if
(
MsContext
::
GetInstance
()
->
execution_mode
()
==
kPynativeMode
)
{
return
nullptr
;
}
PatternNode
x
,
y
,
z
,
xs
;
PConstant
one_
(
node
,
false
,
1
);
PConstant
one_scalar_
(
node
,
false
,
1
,
true
);
...
...
@@ -68,6 +71,9 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
}
AnfNodePtr
ArithmeticSimplify2
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
if
(
MsContext
::
GetInstance
()
->
execution_mode
()
==
kPynativeMode
)
{
return
nullptr
;
}
PatternNode
x
,
y
;
PConstant
zero_
(
node
,
false
,
0
);
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
28dabf03
...
...
@@ -1223,6 +1223,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
}
MS_LOG
(
DEBUG
)
<<
"Clear"
;
grad_flag_
=
false
;
top_g_
=
nullptr
;
df_builder_
=
nullptr
;
curr_g_
=
nullptr
;
...
...
mindspore/nn/cell.py
浏览文件 @
28dabf03
...
...
@@ -84,16 +84,16 @@ class Cell:
self
.
_backward_hook
=
None
self
.
enable_hook
=
False
self
.
_bprop_debug
=
False
self
.
_
is
_run
=
False
self
.
_
already
_run
=
False
self
.
cell_type
=
None
@
property
def
is
_run
(
self
):
return
self
.
_
is
_run
def
already
_run
(
self
):
return
self
.
_
already
_run
@
is
_run
.
setter
def
is
_run
(
self
,
value
):
self
.
_
is
_run
=
value
@
already
_run
.
setter
def
already
_run
(
self
,
value
):
self
.
_
already
_run
=
value
@
property
def
create_time
(
self
):
...
...
@@ -260,7 +260,7 @@ class Cell:
_pynative_exec
.
end_graph
(
self
,
output
,
*
inputs
)
for
i
,
cell
in
enumerate
(
self
.
cells
()):
cell
.
set_grad
(
orign_grad
[
i
])
self
.
_
is
_run
=
True
self
.
_
already
_run
=
True
return
output
def
__setattr__
(
self
,
name
,
value
):
...
...
mindspore/ops/composite/base.py
浏览文件 @
28dabf03
...
...
@@ -129,14 +129,14 @@ class GradOperation(GradOperation_):
output
=
fn
(
*
args
)
_pynative_exec
.
end_graph
(
fn
,
output
,
*
args
)
else
:
if
fn
.
is
_run
and
not
fn
.
requires_grad
:
if
fn
.
already
_run
and
not
fn
.
requires_grad
:
raise
ValueError
(
"obj must set_grad."
)
if
not
fn
.
is
_run
:
if
not
fn
.
already
_run
:
self
.
need_forward
=
True
print
(
"already has forward run before grad by user"
)
if
self
.
need_forward
:
fn
.
set_grad
()
fn
(
*
args
)
fn
.
already_run
=
False
def
__call__
(
self
,
fn
,
weights
=
None
):
grad_
=
GradOperation
(
'grad'
,
self
.
get_all
,
self
.
get_by_list
,
self
.
sens_param
)
...
...
tests/ut/cpp/optimizer/lib_test.cc
浏览文件 @
28dabf03
...
...
@@ -40,6 +40,9 @@ class TestOptLib : public UT::Common {
void
SetUp
()
{
UT
::
InitPythonPath
();
parse
::
data_converter
::
ClearObjectCache
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_execution_mode
(
kGraphMode
);
}
FuncGraphPtr
RunTransform
(
FuncGraphPtr
gbefore
,
const
SubstitutionList
&
transform
)
{
equiv_node
.
clear
();
...
...
tests/ut/python/pynative_mode/test_hook.py
浏览文件 @
28dabf03
...
...
@@ -152,7 +152,7 @@ def test_hook():
assert
cell_hook_done
assert
var_hook_done
assert
cell_bprop_done
print
(
loss_output
.
asnumpy
()
.
shape
)
print
(
loss_output
.
asnumpy
())
bprop_debug
=
False
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录