Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
31aae361
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
31aae361
编写于
4月 23, 2020
作者:
C
candanzg
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Tensor assign with integer
Signed-off-by:
N
candanzg
<
zhangshucheng@huawei.com
>
上级
496ffff3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
183 addition
and
88 deletion
+183
-88
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
+47
-26
mindspore/ops/composite/multitype_ops/setitem_impl.py
mindspore/ops/composite/multitype_ops/setitem_impl.py
+85
-61
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+51
-1
未找到文件。
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
浏览文件 @
31aae361
...
...
@@ -15,6 +15,7 @@
"""constexpr util"""
from
functools
import
reduce
import
numpy
as
np
from
...primitive
import
constexpr
from
....common.tensor
import
Tensor
...
...
@@ -23,26 +24,27 @@ from ...._extends.utils import Slice
@
constexpr
def
check_equal
(
param1
,
param2
,
msg
=
"{},{}"
):
"""Checks whether the two parameters are equal or not."""
if
param1
!=
param2
:
raise
ValueError
(
msg
.
format
(
param1
,
param2
))
return
param1
@
constexpr
def
check_tensor_setitem_index
(
index
,
element_type
=
None
):
"""Check tuple index type of tensor assignment."""
"""Check
s
tuple index type of tensor assignment."""
if
index
is
None
:
raise
ValueError
(
"Tensor's index cannot be None."
)
# eg. Tensor[Slice] = u
if
isinstance
(
index
,
Slice
):
return
True
# eg. Tensor[
T
uple] = u
# eg. Tensor[
t
uple] = u
if
isinstance
(
index
,
tuple
):
if
not
index
:
raise
ValueError
(
"Tensor's index cannot be empty."
)
# eg. Tensor[
T
uple(Slice...)] = u
if
not
isinstance
(
index
[
0
],
Slice
):
r
aise
ValueError
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
[
0
])))
r
eturn
True
# eg. Tensor[
t
uple(Slice...)] = u
if
isinstance
(
index
[
0
],
(
Slice
,
int
)
):
r
eturn
True
r
aise
ValueError
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
[
0
])))
# eg. Tensor[Tensor[dtype=bool]] = u
if
index
==
mstype
.
tensor
:
if
element_type
is
None
or
element_type
!=
mstype
.
bool_
:
...
...
@@ -57,7 +59,7 @@ def check_tensor_setitem_index(index, element_type=None):
@
constexpr
def
is_same_type
(
inst
,
type_
):
"""
Check whether an object is an instance of a target type.
Check
s
whether an object is an instance of a target type.
Inputs:
inst (mindspore.dtype): Inspected type.
...
...
@@ -69,34 +71,23 @@ def is_same_type(inst, type_):
return
inst
==
type_
@
constexpr
def
error_msg
(
msg
=
""
,
format_values
=
""
):
"""
Used to throw exception information.
Inputs:
msg (str): information content.
"""
raise
ValueError
(
msg
.
format
(
*
format_values
))
def
slice_expand
(
input_slices
,
shape
):
"""
Convert slice to indices.
Convert
s
slice to indices.
Inputs:
slices (
List or Tuple(List, ...)
): Slice tuple or slice.
shape (
T
uple): The shape of a sensor is an integer element tuple.
slices (
Union[Slice, tuple[Slice]]
): Slice tuple or slice.
shape (
t
uple): The shape of a sensor is an integer element tuple.
Outputs:
(List, List, List)
, This is expressed as (begins, ends, strides).
tuple[list]
, This is expressed as (begins, ends, strides).
"""
begin
=
[]
end
=
[]
strides
=
[]
index
=
0
slices
=
None
# Slice or
T
uple(Slice...)
# Slice or
t
uple(Slice...)
if
isinstance
(
input_slices
,
Slice
):
slices
=
(
input_slices
,)
elif
isinstance
(
input_slices
,
(
tuple
,
list
))
and
input_slices
and
isinstance
(
input_slices
[
0
],
Slice
):
...
...
@@ -119,14 +110,15 @@ def slice_expand(input_slices, shape):
index
+=
1
return
begin
,
end
,
strides
@
constexpr
def
slice2indices
(
input_slices
,
shape
):
"""
Convert slice to indices.
Convert
s
slice to indices.
Inputs:
slices (
List or Tuple(List, ...)
): Slice tuple or slice.
shape (
Tuple): The shape of a s
ensor is an integer element tuple.
slices (
Union[Slice, tuple[Slice]]
): Slice tuple or slice.
shape (
tuple): The shape of a t
ensor is an integer element tuple.
Outputs:
Tensor, the shape is (n, 1).
...
...
@@ -145,6 +137,7 @@ def slice2indices(input_slices, shape):
@
constexpr
def
check_indices
(
indices_size
,
index
):
"""Checks indices whether is empty."""
if
indices_size
<
1
:
raise
ValueError
(
"The tensor's index is unreasonable. index:{}"
.
format
(
index
))
return
indices_size
...
...
@@ -152,6 +145,7 @@ def check_indices(indices_size, index):
@
constexpr
def
check_indices_value_size
(
indices_size
,
value_size
):
"""Checks if the sizes are already matched."""
if
value_size
<
1
:
raise
ValueError
(
"The value assigned to tensor cannot be empty."
)
if
value_size
>
1
:
...
...
@@ -160,3 +154,30 @@ def check_indices_value_size(indices_size, value_size):
"The value given to tensor does not match the index size.
\
value size:{}, indics size:{}"
.
format
(
value_size
,
indices_size
))
return
value_size
@
constexpr
def
integer_to_indices
(
index
,
shape
):
"""Converts int or tuple[int] to indices."""
size
=
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
range_
=
np
.
arange
(
size
).
reshape
(
shape
)
value
=
range_
[
index
]
value
=
value
.
reshape
(
-
1
,
1
)
return
Tensor
(
value
,
dtype
=
mstype
.
int32
)
@
constexpr
def
tuple_element_is_slice
(
indexs
):
"""Judges tuple element type."""
if
not
indexs
:
raise
ValueError
(
"Tensor's index cannot be empty."
)
if
isinstance
(
indexs
,
tuple
)
and
isinstance
(
indexs
[
0
],
Slice
):
return
True
return
False
@
constexpr
def
tuple_element_is_int
(
indexs
):
"""Judges tuple element type."""
if
not
indexs
:
raise
ValueError
(
"Tensor's index cannot be empty."
)
if
isinstance
(
indexs
,
tuple
)
and
isinstance
(
indexs
[
0
],
int
):
return
True
return
False
mindspore/ops/composite/multitype_ops/setitem_impl.py
浏览文件 @
31aae361
...
...
@@ -25,15 +25,14 @@ setitem = base.MultitypeFuncGraph('setitem')
@
setitem
.
register
(
"List"
,
"Number"
,
"String"
)
def
_list_setitem_with_string
(
data
,
number_index
,
value
):
"""
Assign value to list.
Assign
s
value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (String): Value given.
Outputs:
L
ist, type is same as the element type of data.
l
ist, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
...
...
@@ -41,7 +40,7 @@ def _list_setitem_with_string(data, number_index, value):
@
setitem
.
register
(
"List"
,
"Number"
,
"Number"
)
def
_list_setitem_with_number
(
data
,
number_index
,
value
):
"""
Assign value to list.
Assign
s
value to list.
Inputs:
data (list): Data of type lis.
...
...
@@ -49,7 +48,7 @@ def _list_setitem_with_number(data, number_index, value):
value (Number): Value given.
Outputs:
L
ist, type is same as the element type of data.
l
ist, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
...
...
@@ -57,7 +56,7 @@ def _list_setitem_with_number(data, number_index, value):
@
setitem
.
register
(
"List"
,
"Number"
,
"Tensor"
)
def
_list_setitem_with_Tensor
(
data
,
number_index
,
value
):
"""
Assign value to list.
Assign
s
value to list.
Inputs:
data (list): Data of type lis.
...
...
@@ -65,7 +64,7 @@ def _list_setitem_with_Tensor(data, number_index, value):
value (Tensor): Value given.
Outputs:
L
ist, type is same as the element type of data.
l
ist, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
...
...
@@ -73,15 +72,15 @@ def _list_setitem_with_Tensor(data, number_index, value):
@
setitem
.
register
(
"List"
,
"Number"
,
"List"
)
def
_list_setitem_with_List
(
data
,
number_index
,
value
):
"""
Assign value to list.
Assign
s
value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (
L
ist): Value given.
value (
l
ist): Value given.
Outputs:
L
ist, type is same as the element type of data.
l
ist, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
...
...
@@ -89,15 +88,15 @@ def _list_setitem_with_List(data, number_index, value):
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Tensor"
)
def
_dict_setitem_with_tensor
(
data
,
key
,
value
):
"""
Assign value to dictionary.
Assign
s
value to dictionary.
Inputs:
data (
Dictionary
): Data of type dict.
data (
dict
): Data of type dict.
key (str): Key of the data.
value (Tensor): Value given.
Outputs:
D
ict, type is as same as the element type of data.
d
ict, type is as same as the element type of data.
"""
return
F
.
dict_setitem
(
data
,
key
,
value
)
...
...
@@ -105,15 +104,15 @@ def _dict_setitem_with_tensor(data, key, value):
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Number"
)
def
_dict_setitem_with_number
(
data
,
key
,
value
):
"""
Assign value to dictionary.
Assign
s
value to dictionary.
Inputs:
data (
Dictionary
): Data of type dict.
data (
dict
): Data of type dict.
key (str): Key of the data.
value (Number): Value given.
Outputs:
D
ict, type is as same as the element type of data.
d
ict, type is as same as the element type of data.
"""
return
F
.
dict_setitem
(
data
,
key
,
value
)
...
...
@@ -219,14 +218,14 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value):
Tensor assignment.
Note:
Syntax support: A[
Slice
] = U
Syntax support: A[
tuple(Slice)] = U, and A[tuple(Number)
] = U
Restraint condition: A is a Tensor
Slice like "1:3, ::, :4:-1"
U is a Tensor(size=1) or Tensor(size>1)
Inputs:
data (Tensor): Assigned tensor.
input_slice (
Tuple(Slice)
): Slice expression.
input_slice (
Union[tuple[Slice], tuple[Number]]
): Slice expression.
value (Number): Assignment value.
Outputs:
...
...
@@ -236,39 +235,43 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value):
def
_tensor_assgin_tensor
(
data
,
input_slice
,
value
):
"""Given a tensor value assign to tensor by slice"""
# 1. condition
"""Assigns a tensor value to the tensor by slice."""
result
=
None
check_result
=
mult_util
.
check_tensor_setitem_index
(
input_slice
)
if
check_result
:
data_shape
=
F
.
shape
(
data
)
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices
=
mult_util
.
slice2indices
(
input_slice
,
data_shape
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
input_slice
)
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
# 2. u
value_fill
=
None
value_size
=
F
.
size
(
value
)
value_size
=
mult_util
.
check_indices_value_size
(
indices_size
,
value_size
)
if
value_size
==
1
:
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
value
=
F
.
cast
(
value
,
data_dtype
)
value_fill
=
F
.
tensor_mul
(
value_fill
,
value
)
elif
value_size
>
1
:
value_fill
=
F
.
reshape
(
value
,
(
indices_size
,))
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
# A[slice]= u -> A[B]=U -> select(B, U, A)
result
=
F
.
select
(
condition
,
u
,
data
)
is_tuple_int
=
mult_util
.
tuple_element_is_int
(
input_slice
)
if
is_tuple_int
:
indices
=
mult_util
.
integer_to_indices
(
input_slice
,
data_shape
)
result
=
_tensor_indices_tensor
(
data
,
data_shape
,
input_slice
,
indices
,
value
)
return
result
def
_tensor_indices_tensor
(
data
,
data_shape
,
index
,
indices
,
value
):
"""Assigns a tensor value to the tensor."""
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
value_fill
=
None
value_size
=
F
.
size
(
value
)
value_size
=
mult_util
.
check_indices_value_size
(
indices_size
,
value_size
)
if
value_size
==
1
:
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
value
=
F
.
cast
(
value
,
data_dtype
)
value_fill
=
F
.
tensor_mul
(
value_fill
,
value
)
elif
value_size
>
1
:
value_fill
=
F
.
reshape
(
value
,
(
indices_size
,))
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
return
F
.
select
(
condition
,
u
,
data
)
@
setitem
.
register
(
"Tensor"
,
"Slice"
,
"Number"
)
def
_tensor_setitem_with_slice_v1
(
data
,
input_slice
,
value
):
"""
...
...
@@ -297,14 +300,14 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
Tensor assignment.
Note:
Syntax support: A[
Slice
] = u
Syntax support: A[
tuple(Slice)] = u, and A[tuple(Number)
] = u
Restraint condition: A is a Tensor.
Slice like "1:3, ::, :4:-1"
u is a scalar
Inputs:
data (Tensor): Assigned tensor.
input_slice (
Tuple(Slice)
): slice expression.
input_slice (
Union[tuple[Slice], tuple[Number]]
): slice expression.
value (Number): Assignment value.
Outputs:
...
...
@@ -314,25 +317,46 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
def
_tensor_assgin_number
(
data
,
input_slice
,
value
):
"""Given a scalar assign to tensor by slice"""
# 1. condition
"""Givens a scalar assign to tensor by slice"""
check_result
=
mult_util
.
check_tensor_setitem_index
(
input_slice
)
result
=
None
if
check_result
:
data_shape
=
F
.
shape
(
data
)
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices
=
mult_util
.
slice2indices
(
input_slice
,
data_shape
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
input_slice
)
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
# 2. u
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
value
)
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
# A[slice]= u -> A[B]=U -> select(B, U, A)
result
=
F
.
select
(
condition
,
u
,
data
)
is_tuple_int
=
mult_util
.
tuple_element_is_int
(
input_slice
)
if
is_tuple_int
:
indices
=
mult_util
.
integer_to_indices
(
input_slice
,
data_shape
)
result
=
_tensor_indices_number
(
data
,
data_shape
,
input_slice
,
indices
,
value
)
return
result
def
_tensor_indices_number
(
data
,
data_shape
,
index
,
indices
,
value
):
"""Assigns a scalar value to the tensor."""
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
value
)
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
return
F
.
select
(
condition
,
u
,
data
)
@
setitem
.
register
(
"Tensor"
,
"Number"
,
"Number"
)
def
_tensor_setitem_with_int_v1
(
data
,
index
,
value
):
"""Syntax: A[1] = 3"""
data_shape
=
F
.
shape
(
data
)
indices
=
mult_util
.
integer_to_indices
(
index
,
data_shape
)
return
_tensor_indices_number
(
data
,
data_shape
,
index
,
indices
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Number"
,
"Tensor"
)
def
_tensor_setitem_with_int_v2
(
data
,
index
,
value
):
"""Syntax: A[1] = Tensor"""
data_shape
=
F
.
shape
(
data
)
indices
=
mult_util
.
integer_to_indices
(
index
,
data_shape
)
return
_tensor_indices_tensor
(
data
,
data_shape
,
index
,
indices
,
value
)
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
31aae361
...
...
@@ -138,7 +138,7 @@ class TensorAssignWithSlice(Cell):
z
=
a
return
z
def
test_tensor_assign
_with_slice
():
def
test_tensor_assign
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
net
=
TensorAssignWithSlice
()
net2
=
TensorAssignWithSlice2
()
...
...
@@ -147,6 +147,7 @@ def test_tensor_assign_with_slice():
a
=
np
.
arange
(
60
).
reshape
(
3
,
4
,
5
)
b
=
Tensor
([
1
])
Ta
=
Tensor
(
a
)
Ta4d
=
Tensor
(
a
.
reshape
(
1
,
3
,
4
,
5
))
Tb
=
Tensor
([
1
,
3
])
Tc
=
Tensor
([])
t
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
])
...
...
@@ -184,6 +185,47 @@ def test_tensor_assign_with_slice():
with
pytest
.
raises
(
ValueError
):
net_e1
(
Ta
,
2
)
net
=
TensorAssignWithInteger
()
# Error for A[Number] = scalar/Tensor
# 1. A[Number] = U, U is a Tensor, u.size not match
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tb
)
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tc
)
# 2. A[Number] = U, the number index error
with
pytest
.
raises
(
IndexError
):
net
(
Ta4d
,
b
)
# Error for A[(n,m)] = scalar/Tensor
# 1. A[(n,m)] = U, U is a tensor. u.size not match
net
=
TensorAssignWithTupleInteger
()
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tc
)
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tb
)
# 2. A[(n,m)] = U, the number index error
with
pytest
.
raises
(
IndexError
):
net
(
Ta4d
,
b
)
class
TensorAssignWithInteger
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithInteger
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
a
[
1
]
=
1
a
[
0
]
=
b
return
a
class
TensorAssignWithTupleInteger
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithTupleInteger
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
a
[(
1
)]
=
1
a
[(
1
)]
=
b
a
[(
1
,
1
)]
=
b
a
[(
1
,
1
)]
=
1
return
a
class
TensorAssignWithBoolTensorIndex
(
Cell
):
def
__init__
(
self
):
...
...
@@ -273,6 +315,14 @@ def test_tensor_assign_bool_index():
net4
(
Ta
,
u_scalar
)
test_cases
=
[
(
'TensorAssignWithTupleInteger'
,
{
'block'
:
TensorAssignWithTupleInteger
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
}),
(
'TensorAssignWithInteger'
,
{
'block'
:
TensorAssignWithInteger
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
}),
(
'TensorAssignWithSlice'
,
{
'block'
:
TensorAssignWithSlice
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录