Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1a4abefa
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1a4abefa
编写于
5月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1385 support for multi nest switch
Merge pull request !1385 from amongo/SupportMultiSwitch
上级
2a6a3e01
334d0380
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
214 addition
and
34 deletion
+214
-34
mindspore/ccsrc/optimizer/irpass/branch_culling.cc
mindspore/ccsrc/optimizer/irpass/branch_culling.cc
+33
-23
mindspore/ccsrc/optimizer/irpass/branch_culling.h
mindspore/ccsrc/optimizer/irpass/branch_culling.h
+18
-3
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+1
-1
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+1
-1
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+80
-6
tests/ut/python/ops/test_layer_switch.py
tests/ut/python/ops/test_layer_switch.py
+81
-0
未找到文件。
mindspore/ccsrc/optimizer/irpass/branch_culling.cc
浏览文件 @
1a4abefa
...
...
@@ -52,13 +52,17 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
// Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be
// converted to switch guarded.
std
::
vector
<
std
::
pair
<
PrimitivePtr
,
std
::
vector
<
size_t
>>>
white_list
(
{{
prim
::
kPrimApplyMomentum
,
{
1
,
2
}},
{
prim
::
kPrimMomentum
,
{
2
,
3
}},
{
prim
::
kPrimStateSetItem
,
{
1
}},
{
prim
::
kPrimEnvGetItem
,
{
1
}},
{
prim
::
kPrimEnvSetItem
,
{
1
}},
{
prim
::
kPrimReduceSum
,
{
2
}},
{
prim
::
kPrimReduceMean
,
{
2
}},
{
prim
::
kPrimReduceAll
,
{
2
}},
{
prim
::
kPrimCast
,
{
2
}},
{
prim
::
kPrimTranspose
,
{
2
}},
{
prim
::
kPrimOneHot
,
{
2
}},
{
prim
::
kPrimGatherV2
,
{
3
}},
{
prim
::
kPrimReshape
,
{
2
}},
{
prim
::
kPrimAssign
,
{
1
}},
{
prim
::
kPrimAssignAdd
,
{
1
}},
{
prim
::
kPrimAssignSub
,
{
1
}},
{
prim
::
kPrimTensorSummary
,
{
1
}},
{
prim
::
kPrimImageSummary
,
{
1
}},
{
prim
::
kPrimScalarSummary
,
{
1
}},
{
prim
::
kPrimHistogramSummary
,
{
1
}}});
{{
prim
::
kPrimApplyMomentum
,
{
1
,
2
}},
{
prim
::
kPrimMomentum
,
{
2
,
3
}},
{
prim
::
kPrimStateSetItem
,
{
1
}},
{
prim
::
kPrimTupleGetItem
,
{
2
}},
{
prim
::
kPrimEnvGetItem
,
{
1
}},
{
prim
::
kPrimEnvSetItem
,
{
1
}},
{
prim
::
kPrimReduceSum
,
{
2
}},
{
prim
::
kPrimReduceMean
,
{
2
}},
{
prim
::
kPrimReduceAll
,
{
2
}},
{
prim
::
kPrimCast
,
{
2
}},
{
prim
::
kPrimTranspose
,
{
2
}},
{
prim
::
kPrimOneHot
,
{
2
}},
{
prim
::
kPrimGatherV2
,
{
3
}},
{
prim
::
kPrimReshape
,
{
2
}},
{
prim
::
kPrimAssign
,
{
1
}},
{
prim
::
kPrimAssignAdd
,
{
1
}},
{
prim
::
kPrimAssignSub
,
{
1
}},
{
prim
::
kPrimTensorSummary
,
{
1
}},
{
prim
::
kPrimImageSummary
,
{
1
}},
{
prim
::
kPrimScalarSummary
,
{
1
}},
{
prim
::
kPrimHistogramSummary
,
{
1
}}});
for
(
auto
&
item
:
white_list
)
{
auto
matched
=
std
::
any_of
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
[
&
item
,
&
node
,
&
index
](
size_t
idx
)
{
return
IsPrimitiveCNode
(
node
,
item
.
first
)
&&
idx
==
index
;
...
...
@@ -80,7 +84,8 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
using
NodeInputReplMap
=
std
::
unordered_map
<
std
::
pair
<
AnfNodePtr
,
size_t
>
,
AnfNodePtr
,
PairHasher
>
;
// replace the nodes which should be changed
void
RunSwitchNodeReplace
(
const
FuncGraphManagerPtr
&
manager
,
std
::
vector
<
std
::
pair
<
CNodePtr
,
CNodePtr
>>
nodes_changed
,
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
repl_node
,
NodeInputReplMap
repl_node_inputs
)
{
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
repl_node
,
NodeInputReplMap
repl_node_inputs
,
const
FuncGraphPtr
&
func_graph
)
{
for
(
auto
&
node_pair
:
nodes_changed
)
{
CNodePtr
old_node
=
node_pair
.
first
;
CNodePtr
new_node
=
node_pair
.
second
;
...
...
@@ -99,9 +104,11 @@ void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::p
}
for
(
auto
&
item
:
repl_node
)
{
if
(
!
manager
->
Replace
(
item
.
first
,
item
.
second
))
{
MS_LOG
(
EXCEPTION
)
<<
"TransformGraphDependNode replace node failed original:"
<<
item
.
first
->
DebugString
()
<<
" to new: "
<<
item
.
second
->
DebugString
();
if
(
IsPrimitiveCNode
(
item
.
second
,
prim
::
kPrimReturn
))
{
func_graph
->
set_output
(
item
.
second
->
cast
<
CNodePtr
>
()
->
input
(
1
));
}
else
if
(
!
manager
->
Replace
(
item
.
first
,
item
.
second
))
{
MS_LOG
(
EXCEPTION
)
<<
"TransformGraphDependNode replace node failed original:"
<<
item
.
first
->
DebugString
(
2
)
<<
" to new: "
<<
item
.
second
->
DebugString
(
2
);
}
}
}
...
...
@@ -154,7 +161,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
nodes_changed
.
emplace_back
(
node
->
cast
<
CNodePtr
>
(),
new_node
);
}
}
RunSwitchNodeReplace
(
manager
,
nodes_changed
,
repl_node
,
repl_node_inputs
);
RunSwitchNodeReplace
(
manager
,
nodes_changed
,
repl_node
,
repl_node_inputs
,
graph
);
return
graph
;
}
...
...
@@ -508,11 +515,12 @@ bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const Abstrac
AnfNodePtr
GenerateMergeNodes
(
const
AnfNodePtr
&
true_output_node
,
const
AnfNodePtr
&
false_output_node
,
const
AbstractBasePtr
&
true_graph_output_abs
,
const
AbstractBasePtr
&
false_graph_output_abs
,
const
AnfNodePtr
&
cond
)
{
const
AbstractBasePtr
&
false_graph_output_abs
,
const
FuncGraphPtr
&
switch_graph
,
const
AnfNodePtr
&
cond
)
{
MS_EXCEPTION_IF_NULL
(
true_graph_output_abs
);
MS_EXCEPTION_IF_NULL
(
false_graph_output_abs
);
MS_EXCEPTION_IF_NULL
(
cond
);
MS_EXCEPTION_IF_NULL
(
cond
->
func_graph
()
);
MS_EXCEPTION_IF_NULL
(
switch_graph
);
auto
PrimMerge
=
prim
::
GetPythonOps
(
"merge"
,
"mindspore.ops.functional"
)
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
PrimMerge
);
...
...
@@ -520,10 +528,10 @@ AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodeP
std
::
vector
<
AnfNodePtr
>
merge_nodes
;
merge_nodes
.
push_back
(
NewValueNode
(
PrimMerge
));
std
::
vector
<
AnfNodePtr
>
make_tuple_nodes
{
NewValueNode
(
prim
::
kPrimMakeTuple
),
true_output_node
,
false_output_node
};
merge_nodes
.
push_back
(
cond
->
func_graph
()
->
NewCNode
(
make_tuple_nodes
));
merge_nodes
.
push_back
(
switch_graph
->
NewCNode
(
make_tuple_nodes
));
std
::
vector
<
AnfNodePtr
>
tuple_getitem_nodes
{
NewValueNode
(
prim
::
kPrimTupleGetItem
),
cond
->
func_graph
()
->
NewCNode
(
merge_nodes
),
NewValueNode
(
MakeValue
(
0
))};
return
cond
->
func_graph
()
->
NewCNode
(
tuple_getitem_nodes
);
switch_graph
->
NewCNode
(
merge_nodes
),
NewValueNode
(
MakeValue
(
0
))};
return
switch_graph
->
NewCNode
(
tuple_getitem_nodes
);
}
else
{
abstract
::
AbstractTuplePtr
true_branch_tuple
=
true_graph_output_abs
->
cast
<
abstract
::
AbstractTuplePtr
>
();
abstract
::
AbstractTuplePtr
false_branch_tuple
=
false_graph_output_abs
->
cast
<
abstract
::
AbstractTuplePtr
>
();
...
...
@@ -533,27 +541,29 @@ AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodeP
for
(
size_t
i
=
0
;
i
<
true_branch_tuple
->
elements
().
size
();
i
++
)
{
std
::
vector
<
AnfNodePtr
>
true_getitem_nodes
{
NewValueNode
(
prim
::
kPrimTupleGetItem
),
true_output_node
,
NewValueNode
(
MakeValue
(
SizeToInt
(
i
)))};
auto
true_node
=
cond
->
func_graph
()
->
NewCNode
(
true_getitem_nodes
);
auto
true_node
=
switch_graph
->
NewCNode
(
true_getitem_nodes
);
std
::
vector
<
AnfNodePtr
>
false_getitem_nodes
{
NewValueNode
(
prim
::
kPrimTupleGetItem
),
false_output_node
,
NewValueNode
(
MakeValue
(
SizeToInt
(
i
)))};
auto
false_node
=
cond
->
func_graph
()
->
NewCNode
(
false_getitem_nodes
);
auto
false_node
=
switch_graph
->
NewCNode
(
false_getitem_nodes
);
auto
merge_node
=
GenerateMergeNodes
(
true_node
,
false_node
,
true_branch_tuple
->
elements
()[
i
],
false_branch_tuple
->
elements
()[
i
],
cond
);
false_branch_tuple
->
elements
()[
i
],
switch_graph
,
cond
);
make_tuple_nodes
.
push_back
(
merge_node
);
}
return
cond
->
func_graph
()
->
NewCNode
(
make_tuple_nodes
);
return
switch_graph
->
NewCNode
(
make_tuple_nodes
);
}
}
AnfNodePtr
TransformMergeBranches
(
const
AnfNodePtr
&
true_output_node
,
const
AnfNodePtr
&
false_output_node
,
const
AbstractBasePtr
&
true_graph_output_abs
,
const
AbstractBasePtr
&
false_graph_output_abs
,
const
AnfNodePtr
&
cond
)
{
const
AbstractBasePtr
&
false_graph_output_abs
,
const
AnfNodePtr
&
cond
,
const
FuncGraphPtr
&
switch_graph
)
{
if
(
!
GraphOutputCompatible
(
true_graph_output_abs
,
false_graph_output_abs
))
{
MS_LOG
(
EXCEPTION
)
<<
"Switch output branch not compatible, true:"
<<
true_graph_output_abs
->
ToString
()
<<
", false:"
<<
false_graph_output_abs
->
ToString
();
}
return
GenerateMergeNodes
(
true_output_node
,
false_output_node
,
true_graph_output_abs
,
false_graph_output_abs
,
cond
);
return
GenerateMergeNodes
(
true_output_node
,
false_output_node
,
true_graph_output_abs
,
false_graph_output_abs
,
switch_graph
,
cond
);
}
}
// namespace internal
}
// namespace irpass
...
...
mindspore/ccsrc/optimizer/irpass/branch_culling.h
浏览文件 @
1a4abefa
...
...
@@ -168,7 +168,8 @@ FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const
FuncGraphPtr
TransformGraphCondFalseBranchNodes
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
cond
);
AnfNodePtr
TransformMergeBranches
(
const
AnfNodePtr
&
true_output_node
,
const
AnfNodePtr
&
false_output_node
,
const
AbstractBasePtr
&
true_graph_output_abs
,
const
AbstractBasePtr
&
false_graph_output_abs
,
const
AnfNodePtr
&
cond
);
const
AbstractBasePtr
&
false_graph_output_abs
,
const
AnfNodePtr
&
cond
,
const
FuncGraphPtr
&
func_graph
);
}
// namespace internal
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
...
...
@@ -190,6 +191,20 @@ class ConvertSwitchReplacement : public AnfVisitor {
if
(
g2_
==
nullptr
||
g1_
->
output
()
==
nullptr
||
g2_
->
output
()
==
nullptr
)
{
return
nullptr
;
}
// for switch replace method, only graphs without graph inside can be replaced
for
(
auto
&
item
:
g1_
->
value_nodes
())
{
auto
value_node
=
item
.
first
;
if
(
IsValueNode
<
FuncGraph
>
(
value_node
))
{
return
nullptr
;
}
}
for
(
auto
&
item
:
g2_
->
value_nodes
())
{
auto
value_node
=
item
.
first
;
if
(
IsValueNode
<
FuncGraph
>
(
value_node
))
{
return
nullptr
;
}
}
auto
true_output
=
g1_
->
output
()
->
abstract
();
auto
false_output
=
g2_
->
output
()
->
abstract
();
...
...
@@ -200,8 +215,8 @@ class ConvertSwitchReplacement : public AnfVisitor {
auto
fg
=
node
->
func_graph
();
auto
cloned_g1
=
InlineClone
(
trans_g1
,
fg
,
params
);
auto
cloned_g2
=
InlineClone
(
trans_g2
,
fg
,
params
);
return
internal
::
TransformMergeBranches
(
cloned_g1
,
cloned_g2
,
true_output
,
false_output
,
x_
)
;
auto
nnode
=
internal
::
TransformMergeBranches
(
cloned_g1
,
cloned_g2
,
true_output
,
false_output
,
x_
,
fg
);
return
nnode
;
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
1a4abefa
...
...
@@ -162,7 +162,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
}
OptPassGroupMap
GetControlPhases
(
const
opt
::
irpass
::
OptimizeIRPassLib
&
irpass
)
{
opt
::
OptPassConfig
control_group
=
opt
::
OptPassConfig
({
irpass
.
convert_switch_replacement_
});
opt
::
OptPassConfig
control_group
=
opt
::
OptPassConfig
({
irpass
.
convert_switch_replacement_
}
,
true
);
OptPassGroupMap
map
({
{
"control_group"
,
control_group
},
{
"renormalize"
,
opt
::
OptPassConfig
::
Renormalize
()},
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
1a4abefa
...
...
@@ -346,7 +346,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
if
((
*
value
==
*
kAnyValue
))
{
auto
value_desc
=
abs_base
->
value_desc
();
MS_EXCEPTION
(
TypeError
)
<<
"Unsupported parameter "
<<
(
value_desc
.
empty
()
?
"type"
:
value_desc
)
<<
" for python primitive."
;
<<
" for python primitive."
<<
abs_base
->
ToString
()
;
}
MS_EXCEPTION
(
TypeError
)
<<
"Unsupported parameter type for python primitive, the parameter value is "
<<
value
->
ToString
();
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
1a4abefa
...
...
@@ -24,6 +24,8 @@ from mindspore.common.parameter import Parameter, ParameterTuple
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
operations
as
P
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
from
mindspore.common
import
ms_function
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
@@ -371,7 +373,8 @@ def test_switch_layer():
class
Layer1
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer1
,
self
).
__init__
()
self
.
z1
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z1'
)
self
.
z1
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z1'
)
def
construct
(
self
,
x
):
return
x
*
self
.
z1
...
...
@@ -379,7 +382,8 @@ def test_switch_layer():
class
Layer2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer2
,
self
).
__init__
()
self
.
z2
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z2'
)
self
.
z2
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z2'
)
def
construct
(
self
,
x
):
return
x
*
self
.
z2
...
...
@@ -388,7 +392,8 @@ def test_switch_layer():
def
__init__
(
self
):
super
(
SwitchLayerCell
,
self
).
__init__
()
self
.
layers
=
(
Layer1
(),
Layer2
())
self
.
z3
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z3'
)
self
.
z3
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z3'
)
def
construct
(
self
,
index
,
x
):
ret
=
F
.
switch_layer
(
index
,
self
.
layers
)(
x
)
*
self
.
z3
...
...
@@ -406,7 +411,8 @@ def test_index_to_switch_layer():
class
Layer1
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer1
,
self
).
__init__
()
self
.
z1
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z1'
)
self
.
z1
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z1'
)
def
construct
(
self
,
x
):
return
x
*
self
.
z1
...
...
@@ -414,7 +420,8 @@ def test_index_to_switch_layer():
class
Layer2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer2
,
self
).
__init__
()
self
.
z2
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z2'
)
self
.
z2
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z2'
)
def
construct
(
self
,
x
):
return
x
*
self
.
z2
...
...
@@ -423,7 +430,8 @@ def test_index_to_switch_layer():
def
__init__
(
self
):
super
(
SwitchLayerCell
,
self
).
__init__
()
self
.
layers
=
(
Layer1
(),
Layer2
())
self
.
z3
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z3'
)
self
.
z3
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z3'
)
def
construct
(
self
,
index
,
x
):
ret
=
self
.
layers
[
index
](
x
)
*
self
.
z3
...
...
@@ -444,3 +452,69 @@ def test_control_depend_check():
depend
=
P
.
ControlDepend
(
2
)
with
pytest
.
raises
(
TypeError
)
as
e
:
depend
=
P
.
ControlDepend
((
2
,))
def
test_if_nested_compile
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
auto_prefix
=
True
):
super
().
__init__
(
auto_prefix
=
auto_prefix
)
self
.
squre
=
P
.
Square
()
self
.
value
=
Tensor
(
3
,
dtype
=
ms
.
float32
)
def
construct
(
self
,
x
,
y
):
res
=
self
.
value
if
x
<=
y
:
res
=
x
+
res
res
=
y
+
res
else
:
if
x
==
y
:
res
=
self
.
squre
(
self
.
value
*
y
)
else
:
res
=
self
.
squre
(
self
.
value
)
return
res
x
=
Tensor
(
1.0
,
dtype
=
ms
.
float32
)
y
=
Tensor
(
2.0
,
dtype
=
ms
.
float32
)
net
=
Net
()
net
(
x
,
y
)
def
test_if_inside_for
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
auto_prefix
=
True
):
super
().
__init__
(
auto_prefix
=
auto_prefix
)
self
.
squre
=
P
.
Square
()
self
.
value
=
Tensor
(
3
,
dtype
=
ms
.
float32
)
self
.
count
=
4
def
construct
(
self
,
x
,
y
):
res
=
0
for
i
in
range
(
self
.
count
):
if
i
==
x
:
res
=
res
+
x
else
:
res
=
res
-
y
return
res
c1
=
Tensor
(
1
,
dtype
=
ms
.
int32
)
c2
=
Tensor
(
1
,
dtype
=
ms
.
int32
)
net
=
Net
()
out
=
net
(
c1
,
c2
)
def
test_while_in_while
():
c1
=
Tensor
(
1
,
dtype
=
ms
.
int32
)
c2
=
Tensor
(
2
,
dtype
=
ms
.
int32
)
c3
=
Tensor
(
3
,
dtype
=
ms
.
int32
)
c4
=
Tensor
(
4
,
dtype
=
ms
.
int32
)
@
ms_function
def
while_in_while
(
x
,
y
,
z
,
u
):
out
=
c4
while
x
<
y
:
z
=
c4
+
c4
while
z
<
y
:
z
=
z
+
1
out
=
out
+
1
x
=
x
+
1
out
=
out
+
3
return
out
while_in_while
(
c1
,
c2
,
c3
,
c4
)
tests/ut/python/ops/test_layer_switch.py
0 → 100644
浏览文件 @
1a4abefa
import
numpy
as
np
import
mindspore
from
mindspore
import
nn
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
class
Layer1
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer1
,
self
).
__init__
()
self
.
net
=
nn
.
Conv2d
(
3
,
1
,
3
,
pad_mode
=
'same'
)
self
.
pad
=
nn
.
Pad
(
paddings
=
((
0
,
0
),
(
0
,
2
),
(
0
,
0
),
(
0
,
0
)),
mode
=
"CONSTANT"
)
def
construct
(
self
,
x
):
y
=
self
.
net
(
x
)
return
self
.
pad
(
y
)
class
Layer2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer2
,
self
).
__init__
()
self
.
net
=
nn
.
Conv2d
(
3
,
1
,
7
,
pad_mode
=
'same'
)
self
.
pad
=
nn
.
Pad
(
paddings
=
((
0
,
0
),
(
0
,
2
),
(
0
,
0
),
(
0
,
0
)),
mode
=
"CONSTANT"
)
def
construct
(
self
,
x
):
y
=
self
.
net
(
x
)
return
self
.
pad
(
y
)
class
Layer3
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer3
,
self
).
__init__
()
self
.
net
=
nn
.
Conv2d
(
3
,
3
,
3
,
pad_mode
=
'same'
)
def
construct
(
self
,
x
):
return
self
.
net
(
x
)
class
SwitchNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SwitchNet
,
self
).
__init__
()
self
.
layer1
=
Layer1
()
self
.
layer2
=
Layer2
()
self
.
layer3
=
Layer3
()
self
.
layers
=
(
self
.
layer1
,
self
.
layer2
,
self
.
layer3
)
self
.
fill
=
P
.
Fill
()
def
construct
(
self
,
x
,
index
):
y
=
self
.
layers
[
index
](
x
)
return
y
class
MySwitchNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
MySwitchNet
,
self
).
__init__
()
self
.
layer1
=
Layer1
()
self
.
layer2
=
Layer2
()
self
.
layer3
=
Layer3
()
self
.
layers
=
(
self
.
layer1
,
self
.
layer2
,
self
.
layer3
)
self
.
fill
=
P
.
Fill
()
def
construct
(
self
,
x
,
index
):
y
=
self
.
layers
[
0
](
x
)
for
i
in
range
(
len
(
self
.
layers
)):
if
i
==
index
:
y
=
self
.
layers
[
i
](
x
)
return
y
def
test_layer_switch
():
net
=
MySwitchNet
()
x
=
Tensor
(
np
.
ones
((
3
,
3
,
24
,
24
)),
mindspore
.
float32
)
index
=
Tensor
(
0
,
dtype
=
mindspore
.
int32
)
y
=
net
(
x
,
index
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录