Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7a367af9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7a367af9
编写于
4月 06, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 06, 2020
浏览文件
操作
浏览文件
下载
差异文件
!135 fix grad missing due to indirect dependent free morphism
Merge pull request !135 from penn/fix_free_morphism_error
上级
32017f6d
1fb776fe
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
54 addition
and
17 deletion
+54
-17
mindspore/ccsrc/optimizer/ad/dfunctor.cc
mindspore/ccsrc/optimizer/ad/dfunctor.cc
+29
-14
mindspore/ccsrc/optimizer/ad/dfunctor.h
mindspore/ccsrc/optimizer/ad/dfunctor.h
+1
-0
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+1
-1
tests/ut/python/pynative_mode/test_cell_bprop.py
tests/ut/python/pynative_mode/test_cell_bprop.py
+1
-2
tests/ut/python/pynative_mode/test_framstruct.py
tests/ut/python/pynative_mode/test_framstruct.py
+22
-0
未找到文件。
mindspore/ccsrc/optimizer/ad/dfunctor.cc
浏览文件 @
7a367af9
...
...
@@ -185,19 +185,32 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
return
node_adjoint
;
}
bool
DFunctor
::
IsFreeMorphism
(
const
AnfNodePtr
&
node
)
{
// Do not care about non-CNode
if
(
!
node
->
isa
<
CNode
>
())
{
return
false
;
}
// Do not care about kPrimReturn
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimReturn
))
{
return
false
;
}
auto
&
users
=
primal_graph_
->
manager
()
->
node_users
()[
node
];
// Do not care about isolated morphisms
if
(
users
.
empty
())
{
return
false
;
}
// Not free if it's used by some node in primal_graph
bool
nonfree
=
std
::
any_of
(
std
::
begin
(
users
),
std
::
end
(
users
),
[
&
](
const
auto
&
kv
)
{
auto
&
user
=
kv
.
first
;
return
user
->
func_graph
()
==
primal_graph_
;
});
return
!
nonfree
;
}
void
DFunctor
::
MapFreeMorphism
()
{
// Handle cnode not attached to output, that might be refered in other functions.
for
(
auto
&
node
:
primal_graph_
->
nodes
())
{
auto
adjoint
=
FindAdjoint
(
node
);
if
(
adjoint
!=
nullptr
)
{
continue
;
}
if
(
!
node
->
isa
<
CNode
>
())
{
MS_LOG
(
DEBUG
)
<<
"MapFreeMorphism noncnode not mapped after MapMorphism "
<<
node
->
ToString
()
<<
" "
<<
node
->
type_name
()
<<
"."
;
continue
;
}
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimReturn
))
{
if
(
!
IsFreeMorphism
(
node
))
{
continue
;
}
MS_LOG
(
DEBUG
)
<<
"MapFreeMorphism map nonoutput cnode after MapMorphism "
<<
node
->
ToString
()
<<
"."
;
...
...
@@ -256,9 +269,10 @@ void DFunctor::MapMorphism() {
// Set stop_gradient before MapMorphism.
BroadCastStopFlag
();
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
MapFreeMorphism
();
// Handle morphism from output.
(
void
)
MapMorphism
(
primal_graph_
->
output
());
MapFreeMorphism
();
// Construct K for primal_graph_
auto
output_adjoint
=
anfnode_to_adjoin_
.
find
(
primal_graph_
->
output
());
...
...
@@ -298,9 +312,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
const
size_t
param_diff
=
1
;
if
(
bprop_graph
->
output
()
->
isa
<
CNode
>
()
&&
bprop_graph
->
output
()
->
cast
<
CNodePtr
>
()
->
size
()
+
param_diff
!=
bprop_graph
->
parameters
().
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"User defined Cell bprop "
<<
primal
->
ToString
()
<<
" in scope "
<<
primal
->
output
()
->
scope
()
->
name
()
<<
" output must be a tuple and output number should be the same with inputs."
;
// It does not matter with the final tangents, just a tip for debugging
MS_LOG
(
DEBUG
)
<<
"User defined Cell bprop "
<<
primal
->
ToString
()
<<
" in scope "
<<
primal
->
output
()
->
scope
()
->
name
()
<<
" output must be a tuple and output number should be the same with inputs."
;
}
resources_
->
manager
()
->
AddFuncGraph
(
bprop_graph
);
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.h
浏览文件 @
7a367af9
...
...
@@ -61,6 +61,7 @@ class DFunctor {
private:
// Map one morphism.
AdjointPtr
MapMorphism
(
const
AnfNodePtr
&
morph
);
bool
IsFreeMorphism
(
const
AnfNodePtr
&
node
);
// Map morphism that's not attached to output.
void
MapFreeMorphism
();
void
BackPropagateFv
(
const
AnfNodePtr
&
fv
,
const
AnfNodePtr
&
din
);
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
7a367af9
...
...
@@ -111,7 +111,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
irpass
.
replace_applicator_
,
});
opt
::
OptPassConfig
virtual_dataset
=
opt
::
OptPassConfig
({
irpass
.
virtual_dataset_eliminate_
});
opt
::
OptPassConfig
grad
=
opt
::
OptPassConfig
({
irpass
.
inline_
,
irpass
.
expand_jprim_
},
true
);
opt
::
OptPassConfig
grad
=
opt
::
OptPassConfig
({
irpass
.
expand_jprim_
},
true
);
OptPassGroupMap
map_a
({{
"a_1"
,
a_1
},
{
"a_2"
,
a_2
},
...
...
tests/ut/python/pynative_mode/test_cell_bprop.py
浏览文件 @
7a367af9
...
...
@@ -304,5 +304,4 @@ class MulAddWithWrongOutputNum(nn.Cell):
def
test_grad_mul_add_with_wrong_output_num
():
mul_add
=
MulAddWithWrongOutputNum
()
with
pytest
.
raises
(
RuntimeError
):
C
.
grad_all
(
mul_add
)(
1
,
2
)
C
.
grad_all
(
mul_add
)(
1
,
2
)
tests/ut/python/pynative_mode/test_framstruct.py
浏览文件 @
7a367af9
...
...
@@ -15,6 +15,7 @@
""" test_framstruct """
import
pytest
import
numpy
as
np
import
mindspore
as
ms
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore.ops
import
composite
as
C
...
...
@@ -706,3 +707,24 @@ def grad_refactor_14(a, b):
return
inner1
(
b
)
+
inner2
(
a
)
+
inner3
(
a
)
def
test_grad_refactor_14
():
assert
C
.
grad_all
(
grad_refactor_14
)(
2
,
3
)
==
(
3
,
9
)
class
IfDeferInline
(
nn
.
Cell
):
def
__init__
(
self
,
mul_size
):
super
().
__init__
()
self
.
mul_weight
=
Tensor
(
np
.
full
(
mul_size
,
0.6
,
dtype
=
np
.
float32
))
self
.
mul
=
P
.
Mul
()
def
construct
(
self
,
inputs
):
x
=
self
.
mul
(
inputs
,
self
.
mul_weight
)
if
True
:
x
=
x
return
x
def
test_grad_if_defer_inline
():
""" test_grad_if_defer_inline """
network
=
IfDeferInline
([
128
,
96
])
network
.
add_flags
(
defer_inline
=
False
)
inp
=
Tensor
(
np
.
ones
([
128
,
96
]).
astype
(
np
.
float32
))
grads
=
C
.
grad_all
(
network
)(
inp
)
assert
grads
==
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录