Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
850171a3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
850171a3
编写于
5月 22, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Restrict tensor getitem or setitem not support mixed tensor.
上级
b06c8028
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
61 addition
and
28 deletion
+61
-28
mindspore/ops/composite/multitype_ops/_utils.py
mindspore/ops/composite/multitype_ops/_utils.py
+2
-2
mindspore/ops/composite/multitype_ops/getitem_impl.py
mindspore/ops/composite/multitype_ops/getitem_impl.py
+1
-1
mindspore/ops/composite/multitype_ops/setitem_impl.py
mindspore/ops/composite/multitype_ops/setitem_impl.py
+3
-3
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+55
-22
未找到文件。
mindspore/ops/composite/multitype_ops/_utils.py
浏览文件 @
850171a3
...
...
@@ -254,7 +254,7 @@ def tuple_element_is_int(indexs):
@
constexpr
def
tuple_
elements_type
(
types
):
def
tuple_
index_elements_type
(
types
,
op_name
):
"""Judges the type of all elements of the tuple."""
tensors_number
=
0
for
ele
in
types
:
...
...
@@ -264,7 +264,7 @@ def tuple_elements_type(types):
return
ALL_TENSOR
if
tensors_number
==
0
:
return
NO_TENSOR
r
eturn
CONTAIN_TENSOR
r
aise
IndexError
(
f
"For '
{
op_name
}
', the index does not support mixed tensor."
)
@
constexpr
...
...
mindspore/ops/composite/multitype_ops/getitem_impl.py
浏览文件 @
850171a3
...
...
@@ -247,7 +247,7 @@ def _tensor_getitem_by_tuple(data, tuple_index):
Tensor, element type is same as the element type of data.
"""
index_types
=
multi_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
multi_utils
.
tuple_
elements_type
(
index_types
)
index_elements_type
=
multi_utils
.
tuple_
index_elements_type
(
index_types
,
multi_utils
.
TENSOR_GETITEM
)
result
=
None
if
index_elements_type
==
multi_utils
.
NO_TENSOR
:
result
=
_tensor_slice
(
data
,
tuple_index
)
...
...
mindspore/ops/composite/multitype_ops/setitem_impl.py
浏览文件 @
850171a3
...
...
@@ -191,7 +191,7 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
Tensor, element type and shape is same as data.
"""
index_types
=
multi_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
multi_utils
.
tuple_
elements_type
(
index_types
)
index_elements_type
=
multi_utils
.
tuple_
index_elements_type
(
index_types
,
multi_utils
.
TENSOR_SETITEM
)
result
=
None
if
index_elements_type
==
multi_utils
.
NO_TENSOR
:
result
=
_tensor_assgin_number
(
data
,
tuple_index
,
value
)
...
...
@@ -222,7 +222,7 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
Tensor, element type and shape is same as data.
"""
index_types
=
multi_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
multi_utils
.
tuple_
elements_type
(
index_types
)
index_elements_type
=
multi_utils
.
tuple_
index_elements_type
(
index_types
,
multi_utils
.
TENSOR_SETITEM
)
result
=
None
if
index_elements_type
==
multi_utils
.
NO_TENSOR
:
result
=
_tensor_assgin_tensor
(
data
,
tuple_index
,
value
)
...
...
@@ -254,7 +254,7 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
Tensor, element type and shape is same as data.
"""
index_types
=
multi_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
multi_utils
.
tuple_
elements_type
(
index_types
)
index_elements_type
=
multi_utils
.
tuple_
index_elements_type
(
index_types
,
multi_utils
.
TENSOR_SETITEM
)
result
=
None
if
index_elements_type
==
multi_utils
.
ALL_TENSOR
:
indices
=
multi_utils
.
generate_indeices_from_tuple_of_tensor
(
data
,
tuple_index
,
multi_utils
.
TENSOR_SETITEM
)
...
...
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
850171a3
...
...
@@ -146,9 +146,9 @@ class TensorAssignWithSlice(Cell):
return
z
class
Tensor
Index
ByOneTensor
(
Cell
):
class
Tensor
GetItem
ByOneTensor
(
Cell
):
def
__init__
(
self
):
super
(
Tensor
Index
ByOneTensor
,
self
).
__init__
()
super
(
Tensor
GetItem
ByOneTensor
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
5
,
4
,
7
,
8
)),
mstype
.
int32
)
def
construct
(
self
,
x
,
index
):
...
...
@@ -156,9 +156,9 @@ class TensorIndexByOneTensor(Cell):
return
ret
class
Tensor
Index
ByTwoTensors
(
Cell
):
class
Tensor
GetItem
ByTwoTensors
(
Cell
):
def
__init__
(
self
):
super
(
Tensor
Index
ByTwoTensors
,
self
).
__init__
()
super
(
Tensor
GetItem
ByTwoTensors
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
3
,
4
,
5
,
8
)),
mstype
.
int32
)
def
construct
(
self
,
x
,
index_0
,
index_1
):
...
...
@@ -166,9 +166,9 @@ class TensorIndexByTwoTensors(Cell):
return
ret
class
Tensor
Index
ByThreeTensors
(
Cell
):
class
Tensor
GetItem
ByThreeTensors
(
Cell
):
def
__init__
(
self
):
super
(
Tensor
Index
ByThreeTensors
,
self
).
__init__
()
super
(
Tensor
GetItem
ByThreeTensors
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
5
,
3
,
4
,
5
)),
mstype
.
int32
)
def
construct
(
self
,
x
,
index_0
,
index_1
,
index_2
):
...
...
@@ -176,6 +176,15 @@ class TensorIndexByThreeTensors(Cell):
return
ret
class
TensorGetItemByMixedTensors
(
Cell
):
def
__init__
(
self
):
super
(
TensorGetItemByMixedTensors
,
self
).
__init__
()
def
construct
(
self
,
x
,
index_0
,
index_1
):
ret
=
x
[
index_0
,
index_1
,
0
:
6
]
return
ret
class
TensorSetItemByOneTensorWithNumber
(
Cell
):
def
__init__
(
self
,
value
):
super
(
TensorSetItemByOneTensorWithNumber
,
self
).
__init__
()
...
...
@@ -300,6 +309,19 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
return
ret
class
TensorSetItemByMixedTensors
(
Cell
):
def
__init__
(
self
):
super
(
TensorSetItemByMixedTensors
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
7
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
float32
),
name
=
"x"
)
self
.
value
=
99.0
def
construct
(
self
,
index_0
,
index_1
):
self
.
param
[
index_0
,
index_1
,
0
:
6
]
=
self
.
value
ret
=
self
.
param
+
self
.
const
return
ret
def
test_tensor_assign
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
net
=
TensorAssignWithSlice
()
...
...
@@ -596,19 +618,19 @@ test_cases = [
'block'
:
NetWorkSliceEllipsis
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
7
,
8
,
9
],
np
.
int32
))],
}),
(
'Tensor
Index
ByOneTensor'
,
{
'block'
:
Tensor
Index
ByOneTensor
(),
(
'Tensor
GetItem
ByOneTensor'
,
{
'block'
:
Tensor
GetItem
ByOneTensor
(),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
5
,
4
)),
mstype
.
int32
)],
}),
(
'Tensor
Index
ByTwoTensors'
,
{
'block'
:
Tensor
Index
ByTwoTensors
(),
(
'Tensor
GetItem
ByTwoTensors'
,
{
'block'
:
Tensor
GetItem
ByTwoTensors
(),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
)],
}),
(
'Tensor
Index
ByThreeTensors'
,
{
'block'
:
Tensor
Index
ByThreeTensors
(),
(
'Tensor
GetItem
ByThreeTensors'
,
{
'block'
:
Tensor
GetItem
ByThreeTensors
(),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
...
...
@@ -665,37 +687,43 @@ test_cases = [
]
raise_error_set
=
[
(
'Tensor
Index
ByOneTensorDtypeError'
,
{
'block'
:
(
Tensor
Index
ByOneTensor
(),
{
'exception'
:
TypeError
}),
(
'Tensor
GetItem
ByOneTensorDtypeError'
,
{
'block'
:
(
Tensor
GetItem
ByOneTensor
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
5
,
4
)),
mstype
.
int8
)],
}),
(
'Tensor
Index
ByTwoTensorsShapeError'
,
{
'block'
:
(
Tensor
Index
ByTwoTensors
(),
{
'exception'
:
ValueError
}),
(
'Tensor
GetItem
ByTwoTensorsShapeError'
,
{
'block'
:
(
Tensor
GetItem
ByTwoTensors
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
2
,
3
,
5
)),
mstype
.
int32
)],
}),
(
'Tensor
Index
ByTwoTensorsDtypeError'
,
{
'block'
:
(
Tensor
Index
ByTwoTensors
(),
{
'exception'
:
TypeError
}),
(
'Tensor
GetItem
ByTwoTensorsDtypeError'
,
{
'block'
:
(
Tensor
GetItem
ByTwoTensors
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
float32
)],
}),
(
'Tensor
Index
ByThreeTensorsShapeError'
,
{
'block'
:
(
Tensor
Index
ByThreeTensors
(),
{
'exception'
:
ValueError
}),
(
'Tensor
GetItem
ByThreeTensorsShapeError'
,
{
'block'
:
(
Tensor
GetItem
ByThreeTensors
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
2
,
4
,
5
)),
mstype
.
int32
)],
}),
(
'Tensor
Index
ByThreeTensorsDtypeError'
,
{
'block'
:
(
Tensor
Index
ByThreeTensors
(),
{
'exception'
:
TypeError
}),
(
'Tensor
GetItem
ByThreeTensorsDtypeError'
,
{
'block'
:
(
Tensor
GetItem
ByThreeTensors
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
3
,
4
,
5
)),
mstype
.
int64
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorGetItemByMixedTensors'
,
{
'block'
:
(
TensorGetItemByMixedTensors
(),
{
'exception'
:
IndexError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
3
,
4
,
5
)),
mstype
.
int64
)],
}),
(
'TensorSetItemByOneTensorWithNumberTypeError'
,
{
'block'
:
(
TensorSetItemByOneTensorWithNumber
(
value
=
0
),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
4
,
size
=
(
5
,
4
)),
mstype
.
int32
)],
...
...
@@ -781,6 +809,11 @@ raise_error_set = [
Tensor
(
np
.
zeros
((
4
,
5
)),
mstype
.
float32
),
Tensor
(
np
.
ones
((
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
ones
((
4
,
5
))
*
2
,
mstype
.
int32
)],
}),
(
'TensorSetItemByMixedTensors'
,
{
'block'
:
(
TensorSetItemByMixedTensors
(),
{
'exception'
:
IndexError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
)],
})
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录