Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
31aae361
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
31aae361
编写于
4月 23, 2020
作者:
C
candanzg
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Tensor assign with integer
Signed-off-by:
N
candanzg
<
zhangshucheng@huawei.com
>
上级
496ffff3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
183 addition
and
88 deletion
+183
-88
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
+47
-26
mindspore/ops/composite/multitype_ops/setitem_impl.py
mindspore/ops/composite/multitype_ops/setitem_impl.py
+85
-61
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+51
-1
未找到文件。
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
浏览文件 @
31aae361
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""constexpr util"""
"""constexpr util"""
from
functools
import
reduce
import
numpy
as
np
import
numpy
as
np
from
...primitive
import
constexpr
from
...primitive
import
constexpr
from
....common.tensor
import
Tensor
from
....common.tensor
import
Tensor
...
@@ -23,26 +24,27 @@ from ...._extends.utils import Slice
...
@@ -23,26 +24,27 @@ from ...._extends.utils import Slice
@
constexpr
@
constexpr
def
check_equal
(
param1
,
param2
,
msg
=
"{},{}"
):
def
check_equal
(
param1
,
param2
,
msg
=
"{},{}"
):
"""Checks whether the two parameters are equal or not."""
if
param1
!=
param2
:
if
param1
!=
param2
:
raise
ValueError
(
msg
.
format
(
param1
,
param2
))
raise
ValueError
(
msg
.
format
(
param1
,
param2
))
return
param1
return
param1
@
constexpr
@
constexpr
def
check_tensor_setitem_index
(
index
,
element_type
=
None
):
def
check_tensor_setitem_index
(
index
,
element_type
=
None
):
"""Check tuple index type of tensor assignment."""
"""Check
s
tuple index type of tensor assignment."""
if
index
is
None
:
if
index
is
None
:
raise
ValueError
(
"Tensor's index cannot be None."
)
raise
ValueError
(
"Tensor's index cannot be None."
)
# eg. Tensor[Slice] = u
# eg. Tensor[Slice] = u
if
isinstance
(
index
,
Slice
):
if
isinstance
(
index
,
Slice
):
return
True
return
True
# eg. Tensor[
T
uple] = u
# eg. Tensor[
t
uple] = u
if
isinstance
(
index
,
tuple
):
if
isinstance
(
index
,
tuple
):
if
not
index
:
if
not
index
:
raise
ValueError
(
"Tensor's index cannot be empty."
)
raise
ValueError
(
"Tensor's index cannot be empty."
)
# eg. Tensor[
T
uple(Slice...)] = u
# eg. Tensor[
t
uple(Slice...)] = u
if
not
isinstance
(
index
[
0
],
Slice
):
if
isinstance
(
index
[
0
],
(
Slice
,
int
)
):
r
aise
ValueError
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
[
0
])))
r
eturn
True
r
eturn
True
r
aise
ValueError
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
[
0
])))
# eg. Tensor[Tensor[dtype=bool]] = u
# eg. Tensor[Tensor[dtype=bool]] = u
if
index
==
mstype
.
tensor
:
if
index
==
mstype
.
tensor
:
if
element_type
is
None
or
element_type
!=
mstype
.
bool_
:
if
element_type
is
None
or
element_type
!=
mstype
.
bool_
:
...
@@ -57,7 +59,7 @@ def check_tensor_setitem_index(index, element_type=None):
...
@@ -57,7 +59,7 @@ def check_tensor_setitem_index(index, element_type=None):
@
constexpr
@
constexpr
def
is_same_type
(
inst
,
type_
):
def
is_same_type
(
inst
,
type_
):
"""
"""
Check whether an object is an instance of a target type.
Check
s
whether an object is an instance of a target type.
Inputs:
Inputs:
inst (mindspore.dtype): Inspected type.
inst (mindspore.dtype): Inspected type.
...
@@ -69,34 +71,23 @@ def is_same_type(inst, type_):
...
@@ -69,34 +71,23 @@ def is_same_type(inst, type_):
return
inst
==
type_
return
inst
==
type_
@
constexpr
def
error_msg
(
msg
=
""
,
format_values
=
""
):
"""
Used to throw exception information.
Inputs:
msg (str): information content.
"""
raise
ValueError
(
msg
.
format
(
*
format_values
))
def
slice_expand
(
input_slices
,
shape
):
def
slice_expand
(
input_slices
,
shape
):
"""
"""
Convert slice to indices.
Convert
s
slice to indices.
Inputs:
Inputs:
slices (
List or Tuple(List, ...)
): Slice tuple or slice.
slices (
Union[Slice, tuple[Slice]]
): Slice tuple or slice.
shape (
T
uple): The shape of a sensor is an integer element tuple.
shape (
t
uple): The shape of a sensor is an integer element tuple.
Outputs:
Outputs:
(List, List, List)
, This is expressed as (begins, ends, strides).
tuple[list]
, This is expressed as (begins, ends, strides).
"""
"""
begin
=
[]
begin
=
[]
end
=
[]
end
=
[]
strides
=
[]
strides
=
[]
index
=
0
index
=
0
slices
=
None
slices
=
None
# Slice or
T
uple(Slice...)
# Slice or
t
uple(Slice...)
if
isinstance
(
input_slices
,
Slice
):
if
isinstance
(
input_slices
,
Slice
):
slices
=
(
input_slices
,)
slices
=
(
input_slices
,)
elif
isinstance
(
input_slices
,
(
tuple
,
list
))
and
input_slices
and
isinstance
(
input_slices
[
0
],
Slice
):
elif
isinstance
(
input_slices
,
(
tuple
,
list
))
and
input_slices
and
isinstance
(
input_slices
[
0
],
Slice
):
...
@@ -119,14 +110,15 @@ def slice_expand(input_slices, shape):
...
@@ -119,14 +110,15 @@ def slice_expand(input_slices, shape):
index
+=
1
index
+=
1
return
begin
,
end
,
strides
return
begin
,
end
,
strides
@
constexpr
@
constexpr
def
slice2indices
(
input_slices
,
shape
):
def
slice2indices
(
input_slices
,
shape
):
"""
"""
Convert slice to indices.
Convert
s
slice to indices.
Inputs:
Inputs:
slices (
List or Tuple(List, ...)
): Slice tuple or slice.
slices (
Union[Slice, tuple[Slice]]
): Slice tuple or slice.
shape (
Tuple): The shape of a s
ensor is an integer element tuple.
shape (
tuple): The shape of a t
ensor is an integer element tuple.
Outputs:
Outputs:
Tensor, the shape is (n, 1).
Tensor, the shape is (n, 1).
...
@@ -145,6 +137,7 @@ def slice2indices(input_slices, shape):
...
@@ -145,6 +137,7 @@ def slice2indices(input_slices, shape):
@
constexpr
@
constexpr
def
check_indices
(
indices_size
,
index
):
def
check_indices
(
indices_size
,
index
):
"""Checks indices whether is empty."""
if
indices_size
<
1
:
if
indices_size
<
1
:
raise
ValueError
(
"The tensor's index is unreasonable. index:{}"
.
format
(
index
))
raise
ValueError
(
"The tensor's index is unreasonable. index:{}"
.
format
(
index
))
return
indices_size
return
indices_size
...
@@ -152,6 +145,7 @@ def check_indices(indices_size, index):
...
@@ -152,6 +145,7 @@ def check_indices(indices_size, index):
@
constexpr
@
constexpr
def
check_indices_value_size
(
indices_size
,
value_size
):
def
check_indices_value_size
(
indices_size
,
value_size
):
"""Checks if the sizes are already matched."""
if
value_size
<
1
:
if
value_size
<
1
:
raise
ValueError
(
"The value assigned to tensor cannot be empty."
)
raise
ValueError
(
"The value assigned to tensor cannot be empty."
)
if
value_size
>
1
:
if
value_size
>
1
:
...
@@ -160,3 +154,30 @@ def check_indices_value_size(indices_size, value_size):
...
@@ -160,3 +154,30 @@ def check_indices_value_size(indices_size, value_size):
"The value given to tensor does not match the index size.
\
"The value given to tensor does not match the index size.
\
value size:{}, indics size:{}"
.
format
(
value_size
,
indices_size
))
value size:{}, indics size:{}"
.
format
(
value_size
,
indices_size
))
return
value_size
return
value_size
@
constexpr
def
integer_to_indices
(
index
,
shape
):
"""Converts int or tuple[int] to indices."""
size
=
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
range_
=
np
.
arange
(
size
).
reshape
(
shape
)
value
=
range_
[
index
]
value
=
value
.
reshape
(
-
1
,
1
)
return
Tensor
(
value
,
dtype
=
mstype
.
int32
)
@
constexpr
def
tuple_element_is_slice
(
indexs
):
"""Judges tuple element type."""
if
not
indexs
:
raise
ValueError
(
"Tensor's index cannot be empty."
)
if
isinstance
(
indexs
,
tuple
)
and
isinstance
(
indexs
[
0
],
Slice
):
return
True
return
False
@
constexpr
def
tuple_element_is_int
(
indexs
):
"""Judges tuple element type."""
if
not
indexs
:
raise
ValueError
(
"Tensor's index cannot be empty."
)
if
isinstance
(
indexs
,
tuple
)
and
isinstance
(
indexs
[
0
],
int
):
return
True
return
False
mindspore/ops/composite/multitype_ops/setitem_impl.py
浏览文件 @
31aae361
...
@@ -25,15 +25,14 @@ setitem = base.MultitypeFuncGraph('setitem')
...
@@ -25,15 +25,14 @@ setitem = base.MultitypeFuncGraph('setitem')
@
setitem
.
register
(
"List"
,
"Number"
,
"String"
)
@
setitem
.
register
(
"List"
,
"Number"
,
"String"
)
def
_list_setitem_with_string
(
data
,
number_index
,
value
):
def
_list_setitem_with_string
(
data
,
number_index
,
value
):
"""
"""
Assign value to list.
Assign
s
value to list.
Inputs:
Inputs:
data (list): Data of type lis.
data (list): Data of type lis.
number_index (Number): Index of data.
number_index (Number): Index of data.
value (String): Value given.
Outputs:
Outputs:
L
ist, type is same as the element type of data.
l
ist, type is same as the element type of data.
"""
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
return
F
.
list_setitem
(
data
,
number_index
,
value
)
...
@@ -41,7 +40,7 @@ def _list_setitem_with_string(data, number_index, value):
...
@@ -41,7 +40,7 @@ def _list_setitem_with_string(data, number_index, value):
@
setitem
.
register
(
"List"
,
"Number"
,
"Number"
)
@
setitem
.
register
(
"List"
,
"Number"
,
"Number"
)
def
_list_setitem_with_number
(
data
,
number_index
,
value
):
def
_list_setitem_with_number
(
data
,
number_index
,
value
):
"""
"""
Assign value to list.
Assign
s
value to list.
Inputs:
Inputs:
data (list): Data of type lis.
data (list): Data of type lis.
...
@@ -49,7 +48,7 @@ def _list_setitem_with_number(data, number_index, value):
...
@@ -49,7 +48,7 @@ def _list_setitem_with_number(data, number_index, value):
value (Number): Value given.
value (Number): Value given.
Outputs:
Outputs:
L
ist, type is same as the element type of data.
l
ist, type is same as the element type of data.
"""
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
return
F
.
list_setitem
(
data
,
number_index
,
value
)
...
@@ -57,7 +56,7 @@ def _list_setitem_with_number(data, number_index, value):
...
@@ -57,7 +56,7 @@ def _list_setitem_with_number(data, number_index, value):
@
setitem
.
register
(
"List"
,
"Number"
,
"Tensor"
)
@
setitem
.
register
(
"List"
,
"Number"
,
"Tensor"
)
def
_list_setitem_with_Tensor
(
data
,
number_index
,
value
):
def
_list_setitem_with_Tensor
(
data
,
number_index
,
value
):
"""
"""
Assign value to list.
Assign
s
value to list.
Inputs:
Inputs:
data (list): Data of type lis.
data (list): Data of type lis.
...
@@ -65,7 +64,7 @@ def _list_setitem_with_Tensor(data, number_index, value):
...
@@ -65,7 +64,7 @@ def _list_setitem_with_Tensor(data, number_index, value):
value (Tensor): Value given.
value (Tensor): Value given.
Outputs:
Outputs:
L
ist, type is same as the element type of data.
l
ist, type is same as the element type of data.
"""
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
return
F
.
list_setitem
(
data
,
number_index
,
value
)
...
@@ -73,15 +72,15 @@ def _list_setitem_with_Tensor(data, number_index, value):
...
@@ -73,15 +72,15 @@ def _list_setitem_with_Tensor(data, number_index, value):
@
setitem
.
register
(
"List"
,
"Number"
,
"List"
)
@
setitem
.
register
(
"List"
,
"Number"
,
"List"
)
def
_list_setitem_with_List
(
data
,
number_index
,
value
):
def
_list_setitem_with_List
(
data
,
number_index
,
value
):
"""
"""
Assign value to list.
Assign
s
value to list.
Inputs:
Inputs:
data (list): Data of type lis.
data (list): Data of type lis.
number_index (Number): Index of data.
number_index (Number): Index of data.
value (
L
ist): Value given.
value (
l
ist): Value given.
Outputs:
Outputs:
L
ist, type is same as the element type of data.
l
ist, type is same as the element type of data.
"""
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
return
F
.
list_setitem
(
data
,
number_index
,
value
)
...
@@ -89,15 +88,15 @@ def _list_setitem_with_List(data, number_index, value):
...
@@ -89,15 +88,15 @@ def _list_setitem_with_List(data, number_index, value):
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Tensor"
)
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Tensor"
)
def
_dict_setitem_with_tensor
(
data
,
key
,
value
):
def
_dict_setitem_with_tensor
(
data
,
key
,
value
):
"""
"""
Assign value to dictionary.
Assign
s
value to dictionary.
Inputs:
Inputs:
data (
Dictionary
): Data of type dict.
data (
dict
): Data of type dict.
key (str): Key of the data.
key (str): Key of the data.
value (Tensor): Value given.
value (Tensor): Value given.
Outputs:
Outputs:
D
ict, type is as same as the element type of data.
d
ict, type is as same as the element type of data.
"""
"""
return
F
.
dict_setitem
(
data
,
key
,
value
)
return
F
.
dict_setitem
(
data
,
key
,
value
)
...
@@ -105,15 +104,15 @@ def _dict_setitem_with_tensor(data, key, value):
...
@@ -105,15 +104,15 @@ def _dict_setitem_with_tensor(data, key, value):
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Number"
)
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Number"
)
def
_dict_setitem_with_number
(
data
,
key
,
value
):
def
_dict_setitem_with_number
(
data
,
key
,
value
):
"""
"""
Assign value to dictionary.
Assign
s
value to dictionary.
Inputs:
Inputs:
data (
Dictionary
): Data of type dict.
data (
dict
): Data of type dict.
key (str): Key of the data.
key (str): Key of the data.
value (Number): Value given.
value (Number): Value given.
Outputs:
Outputs:
D
ict, type is as same as the element type of data.
d
ict, type is as same as the element type of data.
"""
"""
return
F
.
dict_setitem
(
data
,
key
,
value
)
return
F
.
dict_setitem
(
data
,
key
,
value
)
...
@@ -219,14 +218,14 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value):
...
@@ -219,14 +218,14 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value):
Tensor assignment.
Tensor assignment.
Note:
Note:
Syntax support: A[
Slice
] = U
Syntax support: A[
tuple(Slice)] = U, and A[tuple(Number)
] = U
Restraint condition: A is a Tensor
Restraint condition: A is a Tensor
Slice like "1:3, ::, :4:-1"
Slice like "1:3, ::, :4:-1"
U is a Tensor(size=1) or Tensor(size>1)
U is a Tensor(size=1) or Tensor(size>1)
Inputs:
Inputs:
data (Tensor): Assigned tensor.
data (Tensor): Assigned tensor.
input_slice (
Tuple(Slice)
): Slice expression.
input_slice (
Union[tuple[Slice], tuple[Number]]
): Slice expression.
value (Number): Assignment value.
value (Number): Assignment value.
Outputs:
Outputs:
...
@@ -236,39 +235,43 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value):
...
@@ -236,39 +235,43 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value):
def
_tensor_assgin_tensor
(
data
,
input_slice
,
value
):
def
_tensor_assgin_tensor
(
data
,
input_slice
,
value
):
"""Given a tensor value assign to tensor by slice"""
"""Assigns a tensor value to the tensor by slice."""
# 1. condition
result
=
None
result
=
None
check_result
=
mult_util
.
check_tensor_setitem_index
(
input_slice
)
check_result
=
mult_util
.
check_tensor_setitem_index
(
input_slice
)
if
check_result
:
if
check_result
:
data_shape
=
F
.
shape
(
data
)
data_shape
=
F
.
shape
(
data
)
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices
=
mult_util
.
slice2indices
(
input_slice
,
data_shape
)
indices
=
mult_util
.
slice2indices
(
input_slice
,
data_shape
)
indices_size
=
F
.
size
(
indices
)
is_tuple_int
=
mult_util
.
tuple_element_is_int
(
input_slice
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
input_slice
)
if
is_tuple_int
:
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
indices
=
mult_util
.
integer_to_indices
(
input_slice
,
data_shape
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
result
=
_tensor_indices_tensor
(
data
,
data_shape
,
input_slice
,
indices
,
value
)
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
# 2. u
value_fill
=
None
value_size
=
F
.
size
(
value
)
value_size
=
mult_util
.
check_indices_value_size
(
indices_size
,
value_size
)
if
value_size
==
1
:
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
value
=
F
.
cast
(
value
,
data_dtype
)
value_fill
=
F
.
tensor_mul
(
value_fill
,
value
)
elif
value_size
>
1
:
value_fill
=
F
.
reshape
(
value
,
(
indices_size
,))
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
# A[slice]= u -> A[B]=U -> select(B, U, A)
result
=
F
.
select
(
condition
,
u
,
data
)
return
result
return
result
def
_tensor_indices_tensor
(
data
,
data_shape
,
index
,
indices
,
value
):
"""Assigns a tensor value to the tensor."""
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
value_fill
=
None
value_size
=
F
.
size
(
value
)
value_size
=
mult_util
.
check_indices_value_size
(
indices_size
,
value_size
)
if
value_size
==
1
:
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
value
=
F
.
cast
(
value
,
data_dtype
)
value_fill
=
F
.
tensor_mul
(
value_fill
,
value
)
elif
value_size
>
1
:
value_fill
=
F
.
reshape
(
value
,
(
indices_size
,))
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
return
F
.
select
(
condition
,
u
,
data
)
@
setitem
.
register
(
"Tensor"
,
"Slice"
,
"Number"
)
@
setitem
.
register
(
"Tensor"
,
"Slice"
,
"Number"
)
def
_tensor_setitem_with_slice_v1
(
data
,
input_slice
,
value
):
def
_tensor_setitem_with_slice_v1
(
data
,
input_slice
,
value
):
"""
"""
...
@@ -297,14 +300,14 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
...
@@ -297,14 +300,14 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
Tensor assignment.
Tensor assignment.
Note:
Note:
Syntax support: A[
Slice
] = u
Syntax support: A[
tuple(Slice)] = u, and A[tuple(Number)
] = u
Restraint condition: A is a Tensor.
Restraint condition: A is a Tensor.
Slice like "1:3, ::, :4:-1"
Slice like "1:3, ::, :4:-1"
u is a scalar
u is a scalar
Inputs:
Inputs:
data (Tensor): Assigned tensor.
data (Tensor): Assigned tensor.
input_slice (
Tuple(Slice)
): slice expression.
input_slice (
Union[tuple[Slice], tuple[Number]]
): slice expression.
value (Number): Assignment value.
value (Number): Assignment value.
Outputs:
Outputs:
...
@@ -314,25 +317,46 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
...
@@ -314,25 +317,46 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
def
_tensor_assgin_number
(
data
,
input_slice
,
value
):
def
_tensor_assgin_number
(
data
,
input_slice
,
value
):
"""Given a scalar assign to tensor by slice"""
"""Givens a scalar assign to tensor by slice"""
# 1. condition
check_result
=
mult_util
.
check_tensor_setitem_index
(
input_slice
)
check_result
=
mult_util
.
check_tensor_setitem_index
(
input_slice
)
result
=
None
result
=
None
if
check_result
:
if
check_result
:
data_shape
=
F
.
shape
(
data
)
data_shape
=
F
.
shape
(
data
)
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices
=
mult_util
.
slice2indices
(
input_slice
,
data_shape
)
indices
=
mult_util
.
slice2indices
(
input_slice
,
data_shape
)
indices_size
=
F
.
size
(
indices
)
is_tuple_int
=
mult_util
.
tuple_element_is_int
(
input_slice
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
input_slice
)
if
is_tuple_int
:
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
indices
=
mult_util
.
integer_to_indices
(
input_slice
,
data_shape
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
result
=
_tensor_indices_number
(
data
,
data_shape
,
input_slice
,
indices
,
value
)
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
# 2. u
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
value
)
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
# A[slice]= u -> A[B]=U -> select(B, U, A)
result
=
F
.
select
(
condition
,
u
,
data
)
return
result
return
result
def
_tensor_indices_number
(
data
,
data_shape
,
index
,
indices
,
value
):
"""Assigns a scalar value to the tensor."""
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
value
)
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
return
F
.
select
(
condition
,
u
,
data
)
@
setitem
.
register
(
"Tensor"
,
"Number"
,
"Number"
)
def
_tensor_setitem_with_int_v1
(
data
,
index
,
value
):
"""Syntax: A[1] = 3"""
data_shape
=
F
.
shape
(
data
)
indices
=
mult_util
.
integer_to_indices
(
index
,
data_shape
)
return
_tensor_indices_number
(
data
,
data_shape
,
index
,
indices
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Number"
,
"Tensor"
)
def
_tensor_setitem_with_int_v2
(
data
,
index
,
value
):
"""Syntax: A[1] = Tensor"""
data_shape
=
F
.
shape
(
data
)
indices
=
mult_util
.
integer_to_indices
(
index
,
data_shape
)
return
_tensor_indices_tensor
(
data
,
data_shape
,
index
,
indices
,
value
)
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
31aae361
...
@@ -138,7 +138,7 @@ class TensorAssignWithSlice(Cell):
...
@@ -138,7 +138,7 @@ class TensorAssignWithSlice(Cell):
z
=
a
z
=
a
return
z
return
z
def
test_tensor_assign
_with_slice
():
def
test_tensor_assign
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
net
=
TensorAssignWithSlice
()
net
=
TensorAssignWithSlice
()
net2
=
TensorAssignWithSlice2
()
net2
=
TensorAssignWithSlice2
()
...
@@ -147,6 +147,7 @@ def test_tensor_assign_with_slice():
...
@@ -147,6 +147,7 @@ def test_tensor_assign_with_slice():
a
=
np
.
arange
(
60
).
reshape
(
3
,
4
,
5
)
a
=
np
.
arange
(
60
).
reshape
(
3
,
4
,
5
)
b
=
Tensor
([
1
])
b
=
Tensor
([
1
])
Ta
=
Tensor
(
a
)
Ta
=
Tensor
(
a
)
Ta4d
=
Tensor
(
a
.
reshape
(
1
,
3
,
4
,
5
))
Tb
=
Tensor
([
1
,
3
])
Tb
=
Tensor
([
1
,
3
])
Tc
=
Tensor
([])
Tc
=
Tensor
([])
t
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
])
t
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
])
...
@@ -184,6 +185,47 @@ def test_tensor_assign_with_slice():
...
@@ -184,6 +185,47 @@ def test_tensor_assign_with_slice():
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
net_e1
(
Ta
,
2
)
net_e1
(
Ta
,
2
)
net
=
TensorAssignWithInteger
()
# Error for A[Number] = scalar/Tensor
# 1. A[Number] = U, U is a Tensor, u.size not match
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tb
)
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tc
)
# 2. A[Number] = U, the number index error
with
pytest
.
raises
(
IndexError
):
net
(
Ta4d
,
b
)
# Error for A[(n,m)] = scalar/Tensor
# 1. A[(n,m)] = U, U is a tensor. u.size not match
net
=
TensorAssignWithTupleInteger
()
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tc
)
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tb
)
# 2. A[(n,m)] = U, the number index error
with
pytest
.
raises
(
IndexError
):
net
(
Ta4d
,
b
)
class
TensorAssignWithInteger
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithInteger
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
a
[
1
]
=
1
a
[
0
]
=
b
return
a
class
TensorAssignWithTupleInteger
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithTupleInteger
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
a
[(
1
)]
=
1
a
[(
1
)]
=
b
a
[(
1
,
1
)]
=
b
a
[(
1
,
1
)]
=
1
return
a
class
TensorAssignWithBoolTensorIndex
(
Cell
):
class
TensorAssignWithBoolTensorIndex
(
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -273,6 +315,14 @@ def test_tensor_assign_bool_index():
...
@@ -273,6 +315,14 @@ def test_tensor_assign_bool_index():
net4
(
Ta
,
u_scalar
)
net4
(
Ta
,
u_scalar
)
test_cases
=
[
test_cases
=
[
(
'TensorAssignWithTupleInteger'
,
{
'block'
:
TensorAssignWithTupleInteger
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
}),
(
'TensorAssignWithInteger'
,
{
'block'
:
TensorAssignWithInteger
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
}),
(
'TensorAssignWithSlice'
,
{
(
'TensorAssignWithSlice'
,
{
'block'
:
TensorAssignWithSlice
(),
'block'
:
TensorAssignWithSlice
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
'desc_inputs'
:
[
Ta
,
u_tensor
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录