Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b7596e1f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b7596e1f
编写于
4月 24, 2020
作者:
P
panyifeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add switch_case primitive
上级
5a03bd80
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
138 addition
and
2 deletion
+138
-2
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+1
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+1
-0
mindspore/ccsrc/operator/prim_statement.cc
mindspore/ccsrc/operator/prim_statement.cc
+24
-0
mindspore/ccsrc/optimizer/ad/dfunctor.cc
mindspore/ccsrc/optimizer/ad/dfunctor.cc
+49
-2
mindspore/ccsrc/optimizer/ad/dfunctor.h
mindspore/ccsrc/optimizer/ad/dfunctor.h
+4
-0
mindspore/ccsrc/optimizer/ad/kprim.cc
mindspore/ccsrc/optimizer/ad/kprim.cc
+16
-0
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+1
-0
mindspore/ccsrc/pipeline/static_analysis/prim.h
mindspore/ccsrc/pipeline/static_analysis/prim.h
+2
-0
mindspore/ops/_grad/grad_implementations.py
mindspore/ops/_grad/grad_implementations.py
+6
-0
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-0
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+33
-0
未找到文件。
mindspore/ccsrc/operator/ops.cc
浏览文件 @
b7596e1f
...
...
@@ -59,6 +59,7 @@ const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
// Statements
const
PrimitivePtr
kPrimSwitch
=
std
::
make_shared
<
Primitive
>
(
"switch"
);
const
PrimitivePtr
kPrimSwitchLayer
=
std
::
make_shared
<
Primitive
>
(
"switch_layer"
);
const
PrimitivePtr
kPrimReturn
=
std
::
make_shared
<
Primitive
>
(
"return"
);
const
PrimitivePtr
kPrimAssign
=
std
::
make_shared
<
Primitive
>
(
"Assign"
);
const
PrimitivePtr
kPrimAssignAdd
=
std
::
make_shared
<
Primitive
>
(
"AssignAdd"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
b7596e1f
...
...
@@ -65,6 +65,7 @@ extern const PrimitivePtr kPrimHasType;
// Statements
extern
const
PrimitivePtr
kPrimSwitch
;
extern
const
PrimitivePtr
kPrimSwitchLayer
;
extern
const
PrimitivePtr
kPrimReturn
;
extern
const
PrimitivePtr
kPrimAssign
;
extern
const
PrimitivePtr
kPrimAssignAdd
;
...
...
mindspore/ccsrc/operator/prim_statement.cc
浏览文件 @
b7596e1f
...
...
@@ -126,6 +126,30 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
MS_LOG
(
EXCEPTION
)
<<
"Invalid condition value for switch "
<<
cond
->
ToString
();
}
AbstractBasePtr
InferImplSwitchLayer
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: index, branch
if
(
args_spec_list
.
size
()
!=
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"SwitchLayer evaluator requires 2 parameters, while the input size is "
<<
args_spec_list
.
size
()
<<
"."
;
}
AbstractTuplePtr
branches_abs
=
CheckArg
<
AbstractTuple
>
(
primitive
->
name
(),
args_spec_list
,
1
);
AbstractBasePtrList
branches
=
branches_abs
->
elements
();
const
size_t
maximum_layer_num
=
1000
;
if
(
branches
.
size
()
<
0
||
branches
.
size
()
>
maximum_layer_num
)
{
MS_EXCEPTION
(
ValueError
)
<<
"SwitchLayer support at least 1 and at most "
<<
maximum_layer_num
<<
" but got "
<<
branches
.
size
()
<<
" branches."
;
}
MS_EXCEPTION_IF_NULL
(
branches
[
0
]);
auto
b
=
branches
[
0
];
for
(
size_t
i
=
1
;
i
<
branches
.
size
();
i
++
)
{
MS_EXCEPTION_IF_NULL
(
branches
[
i
]);
b
=
b
->
Join
(
branches
[
i
]);
}
return
b
;
}
std
::
vector
<
ValuePtr
>
GetSupportedTargetValue
()
{
std
::
vector
<
ValuePtr
>
list
=
{
kNone
,
MakeValue
(
false
),
MakeValue
(
true
)};
return
list
;
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.cc
浏览文件 @
b7596e1f
...
...
@@ -38,6 +38,7 @@ namespace mindspore {
namespace
ad
{
std
::
unordered_map
<
FuncGraphPtr
,
DFunctorPtr
>
DFunctor
::
func_graph_to_functor_
;
std
::
unordered_map
<
AnfNodePtr
,
AdjointPtr
>
DFunctor
::
anfnode_to_adjoin_definition_
;
FuncGraphSet
DFunctor
::
scope_
;
DFunctor
::
DFunctor
(
const
FuncGraphPtr
&
primal_graph
,
const
pipeline
::
ResourceBasePtr
&
resources
)
:
primal_graph_
(
primal_graph
),
resources_
(
resources
),
need_cut_
(
false
),
is_top_
(
false
)
{
...
...
@@ -55,11 +56,15 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
void
DFunctor
::
Init
(
const
DFunctorPtr
&
functor
,
bool
is_top
)
{
func_graph_to_functor_
[
primal_graph_
]
=
functor
;
is_top_
=
is_top
;
if
(
is_top
)
{
scope_
=
primal_graph_
->
scope
();
}
}
void
DFunctor
::
Clear
()
{
func_graph_to_functor_
.
clear
();
anfnode_to_adjoin_definition_
.
clear
();
scope_
.
clear
();
}
void
DFunctor
::
BackPropagateFv
(
const
AnfNodePtr
&
fv
,
const
AnfNodePtr
&
din
)
{
...
...
@@ -95,11 +100,48 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
fv_adjoint
->
second
->
AccumulateDout
(
dfv
);
}
void
DFunctor
::
BackPropagateSwitchLayer
(
const
CNodePtr
&
cnode_morph
,
const
CNodePtr
&
env
)
{
// Take switch_layer as a set of candidate functions.
auto
input
=
cnode_morph
->
input
(
2
);
if
(
!
IsPrimitiveCNode
(
input
,
prim
::
kPrimMakeTuple
))
{
MS_LOG
(
EXCEPTION
)
<<
"The 2th input of switch_layer expect a tuple of graphs, but got "
<<
input
->
ToString
()
<<
"."
;
}
auto
tuple_graphs
=
input
->
cast
<
CNodePtr
>
();
for
(
size_t
i
=
1
;
i
<
tuple_graphs
->
size
();
++
i
)
{
auto
graph
=
tuple_graphs
->
input
(
i
);
if
(
!
IsValueNode
<
FuncGraph
>
(
graph
))
{
MS_LOG
(
EXCEPTION
)
<<
"The 2th input of switch_layer expect a tuple of graphs, but got "
<<
graph
->
ToString
()
<<
" as the "
<<
i
<<
"th element."
;
}
auto
func_graph
=
GetValueNode
<
FuncGraphPtr
>
(
graph
);
auto
functor
=
func_graph_to_functor_
.
find
(
func_graph
);
if
(
functor
==
func_graph_to_functor_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"BackPropagateSwitchLayer failed functor for subgraph does not exist input["
<<
i
<<
"] "
<<
func_graph
->
ToString
()
<<
"."
;
}
// Consider direct and indirect fvs.
for
(
auto
fv
:
func_graph
->
free_variables_nodes
())
{
BackPropagateFv
(
fv
,
env
);
}
for
(
auto
indirect_fv
:
functor
->
second
->
anfnode_to_adjoin_indirect_fv_
)
{
MS_LOG
(
DEBUG
)
<<
"BackPropagateSwitchLayer backprop indirect fv "
<<
func_graph
->
ToString
()
<<
" "
<<
indirect_fv
.
first
->
ToString
()
<<
"."
;
BackPropagateFv
(
indirect_fv
.
first
,
env
);
}
}
}
void
DFunctor
::
BackPropagate
(
const
CNodePtr
&
cnode_morph
,
const
CNodePtr
&
k_app
,
const
AdjointPtr
&
node_adjoint
)
{
auto
bprop
=
k_graph_
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
k_app
,
NewValueNode
(
1
)});
// Call with delimited continuation dout.
auto
bprop_app
=
tape_
->
NewCNode
({
bprop
,
node_adjoint
->
dout
()});
node_adjoint
->
RegisterDoutUser
(
bprop_app
,
1
);
// Special case for switch_layer
if
(
IsPrimitiveCNode
(
cnode_morph
,
prim
::
kPrimSwitchLayer
))
{
auto
din
=
tape_
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
bprop_app
,
NewValueNode
(
0
)});
BackPropagateSwitchLayer
(
cnode_morph
,
din
);
return
;
}
for
(
size_t
i
=
0
;
i
<
cnode_morph
->
size
();
i
++
)
{
auto
din
=
tape_
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
bprop_app
,
NewValueNode
(
SizeToInt
(
i
))});
auto
input
=
cnode_morph
->
input
(
i
);
...
...
@@ -402,6 +444,11 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
return
primal
;
}
bool
DFunctor
::
IsInScope
(
const
AnfNodePtr
&
node
)
{
return
std
::
any_of
(
scope_
.
begin
(),
scope_
.
end
(),
[
&
](
const
FuncGraphPtr
&
graph
)
{
return
node
->
func_graph
()
==
graph
;
});
}
void
DFunctor
::
MapFvObject
()
{
// Map free variable.
const
auto
&
free_variables_nodes
=
primal_graph_
->
free_variables_nodes
();
...
...
@@ -414,8 +461,8 @@ void DFunctor::MapFvObject() {
if
(
parent_adjoint
!=
nullptr
)
{
adjoint
=
std
::
make_shared
<
Adjoint
>
(
node
,
parent_adjoint
->
k
(),
tape_
);
}
else
{
if
(
is_top_
)
{
//
Top graph for ad
, add adjoint for free variables.
if
(
is_top_
||
node
->
isa
<
Parameter
>
()
||
!
IsInScope
(
node
)
)
{
//
Out of ad scope
, add adjoint for free variables.
adjoint
=
std
::
make_shared
<
Adjoint
>
(
node
,
node
,
tape_
);
UpdateAdjoint
(
adjoint
);
}
else
{
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.h
浏览文件 @
b7596e1f
...
...
@@ -62,9 +62,11 @@ class DFunctor {
// Map one morphism.
AdjointPtr
MapMorphism
(
const
AnfNodePtr
&
morph
);
bool
IsFreeMorphism
(
const
AnfNodePtr
&
node
);
bool
IsInScope
(
const
AnfNodePtr
&
node
);
// Map morphism that's not attached to output.
void
MapFreeMorphism
();
void
BackPropagateFv
(
const
AnfNodePtr
&
fv
,
const
AnfNodePtr
&
din
);
void
BackPropagateSwitchLayer
(
const
CNodePtr
&
cnode_morph
,
const
CNodePtr
&
env
);
void
BackPropagate
(
const
CNodePtr
&
cnode_morph
,
const
CNodePtr
&
k_app
,
const
AdjointPtr
&
node_adjoint
);
AnfNodePtr
AttachFvDoutToTape
(
const
AnfNodePtr
&
grad_fv
);
AnfNodePtr
AttachIndirectFvDoutToTape
(
const
AnfNodePtr
&
grad_fv
);
...
...
@@ -101,6 +103,7 @@ class DFunctor {
bool
is_top_
;
static
std
::
unordered_map
<
FuncGraphPtr
,
std
::
shared_ptr
<
DFunctor
>>
func_graph_to_functor_
;
static
std
::
unordered_map
<
AnfNodePtr
,
AdjointPtr
>
anfnode_to_adjoin_definition_
;
static
FuncGraphSet
scope_
;
};
// D Functor's rules to map primitive object.
...
...
@@ -120,6 +123,7 @@ class KPrim {
private:
FuncGraphPtr
GetBprop
(
const
PrimitivePtr
&
prim
);
FuncGraphPtr
GetFprop
(
const
PrimitivePtr
&
prim
);
FuncGraphPtr
FakeBprop
(
const
ValueNodePtr
&
value_node
,
const
pipeline
::
ResourceBasePtr
&
resources
);
// Given a bprop rule, do the K mapping.
template
<
typename
T
>
...
...
mindspore/ccsrc/optimizer/ad/kprim.cc
浏览文件 @
b7596e1f
...
...
@@ -62,6 +62,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
return
func_graph
;
}
FuncGraphPtr
KPrim
::
GetFprop
(
const
PrimitivePtr
&
prim
)
{
static
const
std
::
string
ad_module
=
"mindspore.ops._grad.grad_implementations"
;
std
::
string
func_name
=
"_fprop_"
+
prim
->
name
();
py
::
function
fn
=
parse
::
python_adapter
::
GetPyFn
(
ad_module
,
func_name
);
auto
func_graph
=
parse
::
ParsePythonCode
(
fn
);
MS_EXCEPTION_IF_NULL
(
func_graph
);
return
BasicClone
(
func_graph
);
}
MetaFuncGraphPtr
KPrim
::
KMetaFuncGraph
(
const
PrimitivePtr
&
prim
)
{
MS_EXCEPTION_IF_NULL
(
prim
);
...
...
@@ -92,6 +101,13 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
return
iter
->
second
;
}
if
(
prim
->
Hash
()
==
prim
::
kPrimSwitchLayer
->
Hash
()
&&
prim
->
name
()
==
"switch_layer"
)
{
auto
fprop
=
GetFprop
(
prim
);
fprop
->
transforms
().
emplace
(
"primal"
,
FuncGraphTransform
(
prim
::
kPrimSwitchLayer
));
bprop_registry_
[
prim
::
kPrimSwitchLayer
]
=
fprop
;
return
fprop
;
}
if
(
prim
->
name
()
==
"make_tuple"
)
{
return
nullptr
;
}
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
b7596e1f
...
...
@@ -50,6 +50,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{
prim
::
kPrimHasType
,
{
InferImplHasType
,
false
}},
{
prim
::
kPrimDot
,
{
InferImplDot
,
true
}},
{
prim
::
kPrimSwitch
,
{
InferImplSwitch
,
true
}},
{
prim
::
kPrimSwitchLayer
,
{
InferImplSwitchLayer
,
true
}},
{
prim
::
kPrimIs_
,
{
InferImplIs_
,
true
}},
{
prim
::
kPrimIsNot
,
{
InferImplIsNot
,
true
}},
{
prim
::
kPrimInDict
,
{
InferImplInDict
,
true
}},
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.h
浏览文件 @
b7596e1f
...
...
@@ -174,6 +174,8 @@ AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &prim
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSwitch
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSwitchLayer
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIs_
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIsNot
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
...
...
mindspore/ops/_grad/grad_implementations.py
浏览文件 @
b7596e1f
...
...
@@ -242,3 +242,9 @@ def bprop_switch(cond, tb, fb, out, dout):
"""Backpropagator for primitive `switch`."""
return
C
.
zeros_like
(
cond
),
F
.
switch
(
cond
,
dout
,
C
.
zeros_like
(
tb
)),
\
F
.
switch
(
cond
,
C
.
zeros_like
(
fb
),
dout
)
def
_fprop_switch_layer
(
index
,
layers
):
"""Backpropagator for primitive `switch_layer`."""
def
_bprop_switch_layer
(
dout
):
return
dout
,
C
.
zeros_like
(
index
),
()
return
F
.
switch_layer
(
index
,
layers
),
_bprop_switch_layer
mindspore/ops/functional.py
浏览文件 @
b7596e1f
...
...
@@ -135,6 +135,7 @@ env_getitem = Primitive('env_getitem')
env_add
=
Primitive
(
'env_add'
)
J
=
Primitive
(
'J'
)
switch
=
Primitive
(
'switch'
)
switch_layer
=
Primitive
(
'switch_layer'
)
# for sum bprop
reduced_shape
=
Primitive
(
"reduced_shape"
)
# shape_mul:input mush be shape multiply elemts in tuple(shape)
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
b7596e1f
...
...
@@ -19,6 +19,9 @@ from mindspore import nn
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
functional
as
F
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
@@ -358,3 +361,33 @@ def test_if_compile_true():
def
test_if_compile_false
():
output
=
if_compile_test
(
8
,
3
)
print
(
"test_if_compile_false:"
,
output
)
def
test_switch_layer
():
class
Layer1
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer1
,
self
).
__init__
()
self
.
z1
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z1'
)
def
construct
(
self
,
x
):
return
x
*
self
.
z1
class
Layer2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer2
,
self
).
__init__
()
self
.
z2
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z2'
)
def
construct
(
self
,
x
):
return
x
*
self
.
z2
class
SwitchLayerCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SwitchLayerCell
,
self
).
__init__
()
self
.
layers
=
(
Layer1
(),
Layer2
())
self
.
z3
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z3'
)
def
construct
(
self
,
index
,
x
):
ret
=
F
.
switch_layer
(
index
,
self
.
layers
)(
x
)
*
self
.
z3
return
ret
net
=
SwitchLayerCell
()
net
(
1
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
C
.
grad_by_list
(
net
,
ParameterTuple
(
net
.
trainable_params
()))(
0
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
C
.
grad_all
(
net
)(
0
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录