Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
53f64c1f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
53f64c1f
编写于
8月 21, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4730 fix switch layer issues
Merge pull request !4730 from riemann_penn/fix_grad_operation_api
上级
de0a60df
f9f3cd7c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
126 addition
and
8 deletion
+126
-8
mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h
...pore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h
+2
-1
mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
...ore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
+3
-2
mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h
...src/frontend/optimizer/irpass/switch_layer_defer_inline.h
+2
-3
mindspore/ccsrc/pipeline/jit/pass.cc
mindspore/ccsrc/pipeline/jit/pass.cc
+1
-0
mindspore/core/abstract/prim_statement.cc
mindspore/core/abstract/prim_statement.cc
+9
-2
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+109
-0
未找到文件。
mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h
浏览文件 @
53f64c1f
...
...
@@ -462,7 +462,8 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
std
::
vector
<
FuncGraphPtr
>
graphs
{};
auto
graphs_cnode
=
sw
->
input
(
2
)
->
cast
<
CNodePtr
>
();
auto
&
graphs_inputs
=
graphs_cnode
->
inputs
();
if
(
IsPrimitiveCNode
(
graphs_cnode
,
prim
::
kPrimMakeTuple
)
&&
IsValueNode
<
FuncGraph
>
(
graphs_inputs
[
1
]))
{
if
(
IsPrimitiveCNode
(
graphs_cnode
,
prim
::
kPrimMakeTuple
)
&&
graphs_inputs
.
size
()
>=
2
&&
IsValueNode
<
FuncGraph
>
(
graphs_inputs
[
1
]))
{
(
void
)
std
::
transform
(
graphs_inputs
.
begin
()
+
1
,
graphs_inputs
.
end
(),
std
::
back_inserter
(
graphs
),
[](
const
AnfNodePtr
&
vnode
)
{
return
GetValueNode
<
FuncGraphPtr
>
(
vnode
);
});
}
...
...
mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
浏览文件 @
53f64c1f
...
...
@@ -89,6 +89,7 @@ class GetItemTransformACrossGraph {
ss
<<
idx
;
auto
new_fg_outer
=
TransformableClone
(
fg
,
std
::
make_shared
<
TraceTransform
>
(
ss
.
str
()));
fg
->
manager
()
->
AddFuncGraph
(
new_fg_outer
);
auto
output_outer
=
new_fg_outer
->
output
();
if
(
!
IsValueNode
<
FuncGraph
>
(
output_outer
))
{
MS_LOG
(
WARNING
)
<<
"Output of outer graph should be a func_graph"
;
...
...
@@ -486,7 +487,7 @@ class IncorporateGetitemSwitchLayerA : public AnfVisitor {
switch_layer_
=
inputs
[
0
];
(
void
)
std
::
copy
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
args_
));
}
if
(
is_in_switch_
&&
cnode
->
size
()
>
2
)
{
if
(
is_in_switch_
&&
cnode
->
size
()
>
=
2
)
{
auto
&
inputs
=
cnode
->
inputs
();
if
(
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimMakeTuple
)
&&
IsValueNode
<
FuncGraph
>
(
inputs
[
1
]))
{
(
void
)
std
::
transform
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
graphs_
),
...
...
@@ -578,7 +579,7 @@ class IncorporateGetitemSwitchLayerB : public AnfVisitor {
switch_layer_call_
=
inputs
[
0
];
(
void
)
std
::
copy
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
outer_call_args_
));
}
if
(
is_in_switch_
&&
cnode
->
size
()
>
2
)
{
if
(
is_in_switch_
&&
cnode
->
size
()
>
=
2
)
{
auto
&
inputs
=
cnode
->
inputs
();
if
(
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimMakeTuple
)
&&
IsValueNode
<
FuncGraph
>
(
inputs
[
1
]))
{
(
void
)
std
::
transform
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
graphs_
),
...
...
mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h
浏览文件 @
53f64c1f
...
...
@@ -36,10 +36,9 @@ class SwitchLayerDeferInline : public AnfVisitor {
auto
tuple
=
dyn_cast
<
abstract
::
AbstractTuple
>
(
cnode
->
inputs
()[
2
]
->
abstract
());
for
(
auto
elem
:
tuple
->
elements
())
{
auto
abstract
=
dyn_cast
<
abstract
::
FuncGraphAbstractClosure
>
(
elem
);
if
(
abstract
=
=
nullptr
)
{
return
nullptr
;
if
(
abstract
!
=
nullptr
)
{
*
(
abstract
->
func_graph
()
->
switch_layer_input
())
=
true
;
}
*
(
abstract
->
func_graph
()
->
switch_layer_input
())
=
true
;
}
return
nullptr
;
}
...
...
mindspore/ccsrc/pipeline/jit/pass.cc
浏览文件 @
53f64c1f
...
...
@@ -137,6 +137,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass
.
arithmetic_simplify2_
,
irpass
.
same_eliminate_
,
irpass
.
check_bprop_eliminate_
,
irpass
.
switch_layer_defer_inline_
,
irpass
.
replace_applicator_
,
});
opt
::
OptPassConfig
virtual_dataset
=
opt
::
OptPassConfig
({
irpass
.
virtual_dataset_eliminate_
});
...
...
mindspore/core/abstract/prim_statement.cc
浏览文件 @
53f64c1f
...
...
@@ -16,6 +16,7 @@
#include "abstract/param_validator.h"
#include "abstract/infer_functions.h"
#include "abstract/abstract_function.h"
#include "abstract/utils.h"
#include "utils/symbolic.h"
...
...
@@ -121,12 +122,18 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
for
(
size_t
i
=
0
;
i
<
branches
.
size
();
i
++
)
{
MS_EXCEPTION_IF_NULL
(
branches
[
i
]);
if
(
!
branches
[
i
]
->
isa
<
AbstractFunction
>
())
{
MS_
LOG
(
EXCEPTION
)
<<
op_name
<<
" requires that the 2th arg be tuple of functions, but got "
<<
branches
[
i
]
->
ToString
()
<<
" as the "
<<
i
<<
"th element."
;
MS_
EXCEPTION
(
ValueError
)
<<
op_name
<<
" requires that the 2th arg be tuple of functions, but got "
<<
branches
[
i
]
->
ToString
()
<<
" as the "
<<
i
<<
"th element."
;
}
}
auto
b
=
branches
[
0
];
// Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0]
// which will cancel the out of bound checking for index
if
(
branches
.
size
()
==
1
)
{
AbstractFuncAtomPtrList
func_list
{
b
->
cast
<
AbstractFuncAtomPtr
>
()};
return
std
::
make_shared
<
AbstractFuncUnion
>
(
func_list
);
}
for
(
size_t
i
=
1
;
i
<
branches
.
size
();
i
++
)
{
b
=
b
->
Join
(
branches
[
i
]);
}
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
53f64c1f
...
...
@@ -444,6 +444,86 @@ def test_index_to_switch_layer():
C
.
grad_all
(
net
)(
index
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
def
test_parser_switch_layer_switch_in_bprop
():
class
OneInputBprop
(
nn
.
Cell
):
def
__init__
(
self
,
funcs
):
super
(
OneInputBprop
,
self
).
__init__
()
self
.
op
=
P
.
ReLU
()
self
.
funcs
=
funcs
def
construct
(
self
,
i
,
x
):
return
self
.
op
(
x
)
def
bprop
(
self
,
i
,
x
,
out
,
dout
):
return
i
,
self
.
funcs
[
i
](
x
,
dout
)
class
Add
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
TensorAdd
()
def
construct
(
self
,
x
,
y
):
return
self
.
op
(
x
,
y
)
class
Mul
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
Mul
()
def
construct
(
self
,
x
,
y
):
return
self
.
op
(
x
,
y
)
func1
=
Add
()
func2
=
Mul
()
funcs
=
(
func1
,
func2
)
net
=
OneInputBprop
(
funcs
)
input1
=
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
))
grad
=
Tensor
(
np
.
random
.
randn
(
2
,
2
).
astype
(
np
.
float32
))
i
=
Tensor
(
1
,
mstype
.
int32
)
grad_net
=
C
.
grad_all_with_sens
(
net
)
grad_net
(
i
,
input1
,
grad
)
def
test_parser_switch_layer_inputs_tuple
():
class
TwoInputTupleFinalNet
(
nn
.
Cell
):
def
__init__
(
self
,
funcs
):
super
().
__init__
()
self
.
funcs
=
funcs
def
construct
(
self
,
i
,
inputa
,
inputb
):
inputs
=
(
inputa
,
inputb
)
x
=
self
.
funcs
[
i
](
inputs
)
return
x
class
Add
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
TensorAdd
()
def
construct
(
self
,
x
):
y
=
self
.
op
(
x
[
0
],
x
[
1
])
return
self
.
op
(
x
[
0
],
y
)
class
Mul
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
Mul
()
def
construct
(
self
,
x
):
y
=
self
.
op
(
x
[
0
],
x
[
1
])
return
self
.
op
(
x
[
0
],
y
)
func1
=
Add
()
func2
=
Mul
()
funcs
=
(
func1
,
func2
)
net
=
TwoInputTupleFinalNet
(
funcs
)
input1
=
Tensor
(
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
))
input2
=
Tensor
(
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
))
i
=
Tensor
(
1
,
mstype
.
int32
)
grad
=
Tensor
(
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
))
back_net
=
C
.
grad_all_with_sens
(
net
)
back_out
=
back_net
(
i
,
input1
,
input2
,
grad
)
def
test_switch_layer_with_single_prim
():
class
SwitchLayerCell
(
nn
.
Cell
):
def
__init__
(
self
):
...
...
@@ -494,6 +574,35 @@ def test_switch_layer_env_eliminate():
net2
(
x
,
i
)
def
test_switch_layer_single_layer
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
1
,
1
,
3
,
pad_mode
=
'same'
)
self
.
funs
=
(
self
.
conv
,)
def
construct
(
self
,
x
,
index
):
x
=
self
.
funs
[
index
](
x
)
return
x
class
NetGrad
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
NetGrad
,
self
).
__init__
()
self
.
grad_op
=
C
.
GradOperation
(
'grad'
,
get_by_list
=
True
,
sens_param
=
False
)
self
.
net
=
net
self
.
weights
=
ParameterTuple
(
self
.
net
.
trainable_params
())
def
construct
(
self
,
x
,
index
):
weights
=
self
.
weights
grad
=
self
.
grad_op
(
self
.
net
,
weights
)(
x
,
index
)
return
grad
net
=
Net
()
net2
=
NetGrad
(
net
)
x
=
Tensor
(
np
.
ones
((
3
,
1
,
12
,
12
)),
ms
.
float32
)
i
=
Tensor
(
1
,
ms
.
int32
)
net2
(
x
,
i
)
def
test_control_depend_check
():
with
pytest
.
raises
(
TypeError
)
as
e
:
P
.
ControlDepend
(
0.0
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录