Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
065e25e1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
065e25e1
编写于
5月 11, 2020
作者:
P
panyifeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support index to switch_layer
上级
08d86c48
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
125 addition
and
13 deletion
+125
-13
mindspore/ccsrc/operator/composite/composite.cc
mindspore/ccsrc/operator/composite/composite.cc
+27
-0
mindspore/ccsrc/operator/composite/composite.h
mindspore/ccsrc/operator/composite/composite.h
+12
-0
mindspore/ccsrc/operator/prim_statement.cc
mindspore/ccsrc/operator/prim_statement.cc
+13
-8
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+2
-2
mindspore/ops/composite/multitype_ops/getitem_impl.py
mindspore/ops/composite/multitype_ops/getitem_impl.py
+37
-0
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+34
-3
未找到文件。
mindspore/ccsrc/operator/composite/composite.cc
浏览文件 @
065e25e1
...
...
@@ -1233,6 +1233,27 @@ FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNod
return
ret_graph
;
}
FuncGraphPtr
TupleGetItemTensor
::
GenerateFuncGraph
(
const
AbstractBasePtrList
&
args_spec_list
)
{
// select indexed item
// args: tuple of items, index
const
std
::
string
op_name
=
std
::
string
(
"TupleGetItemTensor"
);
abstract
::
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractTuplePtr
branches_abs
=
abstract
::
CheckArg
<
AbstractTuple
>
(
op_name
,
args_spec_list
,
0
);
AbstractBasePtrList
branches
=
branches_abs
->
elements
();
if
(
branches
.
size
()
>
0
&&
branches
[
0
]
!=
nullptr
&&
branches
[
0
]
->
isa
<
AbstractFunction
>
())
{
FuncGraphPtr
ret_graph
=
std
::
make_shared
<
FuncGraph
>
();
ret_graph
->
set_flags
(
FUNC_GRAPH_FLAG_CORE
,
true
);
AnfNodePtr
functions
=
ret_graph
->
add_parameter
();
auto
index
=
ret_graph
->
add_parameter
();
ret_graph
->
set_output
(
ret_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimSwitchLayer
),
index
,
functions
}));
return
ret_graph
;
}
MS_LOG
(
EXCEPTION
)
<<
"TupleGetItemTensor does not support to index "
<<
branches_abs
->
ToString
()
<<
"."
;
}
REGISTER_PYBIND_DEFINE
(
TupleAdd_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
TupleAdd
,
MetaFuncGraph
,
std
::
shared_ptr
<
TupleAdd
>>
(
*
m
,
"TupleAdd_"
)
.
def
(
py
::
init
<
std
::
string
&>
());
...
...
@@ -1247,5 +1268,11 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
(
void
)
py
::
class_
<
TensorSlice
,
MetaFuncGraph
,
std
::
shared_ptr
<
TensorSlice
>>
(
*
m
,
"TensorSlice_"
)
.
def
(
py
::
init
<
std
::
string
&>
());
}));
REGISTER_PYBIND_DEFINE
(
TupleGetItemTensor_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
TupleGetItemTensor
,
MetaFuncGraph
,
std
::
shared_ptr
<
TupleGetItemTensor
>>
(
*
m
,
"TupleGetItemTensor_"
)
.
def
(
py
::
init
<
std
::
string
&>
());
}));
}
// namespace prim
}
// namespace mindspore
mindspore/ccsrc/operator/composite/composite.h
浏览文件 @
065e25e1
...
...
@@ -210,6 +210,18 @@ class TensorSlice : public MetaFuncGraph {
FuncGraphPtr
ExpandADim
(
const
FuncGraphPtr
&
ret_graph
,
const
AnfNodePtr
&
tensor_node
)
const
;
};
using
TensorSlicePtr
=
std
::
shared_ptr
<
TensorSlice
>
;
class
TupleGetItemTensor
:
public
MetaFuncGraph
{
public:
explicit
TupleGetItemTensor
(
const
std
::
string
&
name
)
:
MetaFuncGraph
(
name
)
{}
~
TupleGetItemTensor
()
override
=
default
;
MS_DECLARE_PARENT
(
TupleGetItemTensor
,
MetaFuncGraph
)
FuncGraphPtr
GenerateFuncGraph
(
const
AbstractBasePtrList
&
args_spec_list
)
override
;
friend
bool
operator
==
(
const
TupleGetItemTensor
&
lhs
,
const
TupleGetItemTensor
&
rhs
)
{
return
lhs
.
name_
==
rhs
.
name_
;
}
};
using
TupleGetItemTensorPtr
=
std
::
shared_ptr
<
TupleGetItemTensor
>
;
}
// namespace prim
}
// namespace mindspore
...
...
mindspore/ccsrc/operator/prim_statement.cc
浏览文件 @
065e25e1
...
...
@@ -129,22 +129,27 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
AbstractBasePtr
InferImplSwitchLayer
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: index, branch
if
(
args_spec_list
.
size
()
!=
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"SwitchLayer evaluator requires 2 parameters, while the input size is "
<<
args_spec_list
.
size
()
<<
"."
;
}
AbstractTuplePtr
branches_abs
=
CheckArg
<
AbstractTuple
>
(
primitive
->
name
(),
args_spec_list
,
1
);
const
std
::
string
op_name
=
primitive
->
name
();
abstract
::
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
(
void
)
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
0
);
AbstractTuplePtr
branches_abs
=
CheckArg
<
AbstractTuple
>
(
op_name
,
args_spec_list
,
1
);
AbstractBasePtrList
branches
=
branches_abs
->
elements
();
const
size_t
maximum_layer_num
=
1000
;
if
(
branches
.
size
()
<
0
||
branches
.
size
()
>
maximum_layer_num
)
{
MS_EXCEPTION
(
ValueError
)
<<
"SwitchLayer
support at least 1 and at most "
<<
maximum_layer_num
<<
" but got "
MS_EXCEPTION
(
ValueError
)
<<
op_name
<<
"
support at least 1 and at most "
<<
maximum_layer_num
<<
" but got "
<<
branches
.
size
()
<<
" branches."
;
}
MS_EXCEPTION_IF_NULL
(
branches
[
0
]);
for
(
size_t
i
=
0
;
i
<
branches
.
size
();
i
++
)
{
MS_EXCEPTION_IF_NULL
(
branches
[
i
]);
if
(
!
branches
[
i
]
->
isa
<
AbstractFunction
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" requires that the 2th arg be tuple of functions, but got "
<<
branches
[
i
]
->
ToString
()
<<
" as the "
<<
i
<<
"th element."
;
}
}
auto
b
=
branches
[
0
];
for
(
size_t
i
=
1
;
i
<
branches
.
size
();
i
++
)
{
MS_EXCEPTION_IF_NULL
(
branches
[
i
]);
b
=
b
->
Join
(
branches
[
i
]);
}
return
b
;
...
...
mindspore/ops/composite/base.py
浏览文件 @
065e25e1
...
...
@@ -18,13 +18,13 @@
"""Basic composite operations."""
from
..._c_expression
import
EnvInstance_
,
GradOperation_
,
HyperMap_
,
MultitypeFuncGraph_
,
Tail_
,
TensorSlice_
,
\
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
ZipOperation_
,
ListAppend_
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
ZipOperation_
,
ListAppend_
,
TupleGetItemTensor_
from
...common
import
dtype
as
mstype
from
...common.api
import
ms_function
from
..
import
functional
as
F
from
..
import
operations
as
P
__all__
=
[
EnvInstance_
,
TensorSlice_
,
TupleAdd_
,
TupleSlice_
,
UnpackCall_
]
__all__
=
[
EnvInstance_
,
TensorSlice_
,
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
TupleGetItemTensor_
]
def
add_flags
(
fn
,
**
flags
):
...
...
mindspore/ops/composite/multitype_ops/getitem_impl.py
浏览文件 @
065e25e1
...
...
@@ -72,6 +72,28 @@ _tensor_slice = _TensorSlice('tensor_slice')
"""_tensor_slice is an metafuncgraph object which will slice a tensor."""
class
_TupleGetItemTensor
(
base
.
TupleGetItemTensor_
):
"""
Getting item of tuple by tensor index.
Inputs:
data (tuple): A tuple of items.
index (Tensor): The index in tensor.
Outputs:
Type, is same as the element type of data.
"""
def
__init__
(
self
,
name
):
base
.
TupleGetItemTensor_
.
__init__
(
self
,
name
)
def
__call__
(
self
,
*
args
):
pass
_tuple_get_item_tensor
=
_TupleGetItemTensor
(
'tuple_get_item_tensor'
)
"""_tuple_get_item_tensor is an metafuncgraph object which will select indexed item."""
@
getitem
.
register
(
"Tuple"
,
"Number"
)
def
_tuple_getitem_by_number
(
data
,
number_index
):
"""
...
...
@@ -102,6 +124,21 @@ def _tuple_getitem_by_slice(data, slice_index):
return
_tuple_slice
(
data
,
slice_index
)
@
getitem
.
register
(
"Tuple"
,
"Tensor"
)
def
_tuple_getitem_by_tensor
(
data
,
tensor_index
):
"""
Getting item out of tuple by tensor index.
Inputs:
data (tuple): A tuple of items to index.
tensor_index (Tensor): Index to select item.
Outputs:
Type, is same as the element type of data.
"""
return
_tuple_get_item_tensor
(
data
,
tensor_index
)
@
getitem
.
register
(
"List"
,
"Number"
)
def
_list_getitem_by_number
(
data
,
number_index
):
"""
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
065e25e1
...
...
@@ -387,7 +387,38 @@ def test_switch_layer():
ret
=
F
.
switch_layer
(
index
,
self
.
layers
)(
x
)
*
self
.
z3
return
ret
index
=
Tensor
(
0
)
net
=
SwitchLayerCell
()
net
(
1
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
C
.
grad_by_list
(
net
,
ParameterTuple
(
net
.
trainable_params
()))(
0
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
C
.
grad_all
(
net
)(
0
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
net
(
index
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
C
.
grad_by_list
(
net
,
ParameterTuple
(
net
.
trainable_params
()))(
index
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
C
.
grad_all
(
net
)(
index
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
def
test_index_to_switch_layer
():
class
Layer1
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer1
,
self
).
__init__
()
self
.
z1
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z1'
)
def
construct
(
self
,
x
):
return
x
*
self
.
z1
class
Layer2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Layer2
,
self
).
__init__
()
self
.
z2
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z2'
)
def
construct
(
self
,
x
):
return
x
*
self
.
z2
class
SwitchLayerCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SwitchLayerCell
,
self
).
__init__
()
self
.
layers
=
(
Layer1
(),
Layer2
())
self
.
z3
=
Parameter
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),
name
=
'z3'
)
def
construct
(
self
,
index
,
x
):
ret
=
self
.
layers
[
index
](
x
)
*
self
.
z3
return
ret
index
=
Tensor
(
0
)
net
=
SwitchLayerCell
()
net
(
index
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
C
.
grad_by_list
(
net
,
ParameterTuple
(
net
.
trainable_params
()))(
index
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
C
.
grad_all
(
net
)(
index
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录