Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
30fc1bd0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
30fc1bd0
编写于
7月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2747 refactor StridedSlice op
Merge pull request !2747 from zhangbuxue/refactor_the_StridedSlice_op
上级
7b5b4837
be771aa9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
238 addition
and
85 deletion
+238
-85
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+162
-58
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+73
-22
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+3
-5
未找到文件。
mindspore/ops/operations/array_ops.py
浏览文件 @
30fc1bd0
...
...
@@ -37,6 +37,7 @@ from ..._c_expression import signature_kind as sig_kind
from
..._c_expression
import
signature_dtype
as
sig_dtype
from
..._c_expression
import
typing
def
_check_infer_attr_reduce
(
axis
,
keep_dims
,
prim_name
):
validator
.
check_value_type
(
'keep_dims'
,
keep_dims
,
[
bool
],
prim_name
)
validator
.
check_value_type
(
'axis'
,
axis
,
[
int
,
tuple
],
prim_name
)
...
...
@@ -193,7 +194,7 @@ class Cast(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'dst_type'
],
outputs
=
[
'output'
])
def
check_elim
(
self
,
x
,
dtype
):
if
isinstance
(
x
,
(
Tensor
,
numbers
.
Number
,
Parameter
)):
if
isinstance
(
x
,
(
Tensor
,
numbers
.
Number
,
Parameter
)):
if
isinstance
(
x
,
Tensor
)
and
x
.
dtype
==
dtype
:
return
(
True
,
x
)
if
isinstance
(
x
,
numbers
.
Number
):
...
...
@@ -987,10 +988,10 @@ class InvertPermutation(PrimitiveWithInfer):
z
.
sort
()
for
i
in
range
(
1
,
len
(
z
)):
if
z
[
i
-
1
]
==
z
[
i
]:
if
z
[
i
-
1
]
==
z
[
i
]:
raise
ValueError
(
f
"For
{
self
.
name
}
,
{
z
[
i
]
}
is duplicated in the input."
)
validator
.
check
(
f
'value min'
,
min
(
x_value
),
''
,
0
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
f
'value max'
,
max
(
x_value
),
''
,
len
(
x_value
)
-
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
f
'value max'
,
max
(
x_value
),
''
,
len
(
x_value
)
-
1
,
Rel
.
EQ
,
self
.
name
)
y
=
[
None
]
*
len
(
x_value
)
for
i
,
value
in
enumerate
(
x_value
):
...
...
@@ -1693,6 +1694,57 @@ class Select(PrimitiveWithInfer):
return
None
def
_compute_slicing_length
(
begin
,
end
,
stride
,
x_shape
,
i
):
"""Compute the length of the slicing."""
if
i
>=
len
(
x_shape
):
raise
ValueError
(
f
"For 'StridedSlice', When their is no new axis, the index length must be less or "
f
"equal than the dim of x."
)
x_dim
=
x_shape
[
i
]
if
stride
>
0
:
# When slicing forward, convert begin and end to positive numbers.
if
begin
>=
x_dim
or
end
<
-
x_dim
:
# When slicing forward, if begin >= x_dim or end < -x_dim, the length of the slicing is 0.
slicing_length
=
0
else
:
if
-
x_dim
<=
begin
<
0
:
begin
+=
x_dim
if
begin
<
-
x_dim
:
# When slicing forward, if begin < -x_dim, set begin = 0, which means start from the 0th element.
begin
=
0
if
-
x_dim
<=
end
<
0
:
end
+=
x_dim
if
end
>
x_dim
:
# When slicing forward, if end > x_dim, set end = x_dims, which means slice to the last element.
end
=
x_dim
if
begin
>=
end
:
# When slicing forward, if begin >= end, the length of the slicing is 0.
slicing_length
=
0
else
:
slicing_length
=
1
+
(
end
-
1
-
begin
)
//
stride
else
:
# When slicing backward, convert begin and end to negative numbers.
if
begin
<
-
x_dim
or
end
>=
x_dim
:
# When slicing backward, if begin < -x_dim or end >= x_dim, the length of the slicing is 0.
slicing_length
=
0
else
:
if
0
<=
begin
<
x_dim
:
begin
+=
-
x_dim
if
begin
>=
x_dim
:
# When slicing backward, if begin >= x_dim, set begin = -1, which means start from the last element.
begin
=
-
1
if
0
<
end
<
x_dim
:
end
+=
-
x_dim
if
end
<
-
x_dim
-
1
:
# When slicing backward, if end < -x_dim - 1, set end = -x_dim - 1, which means
# slicing to the 0th element.
end
=
-
x_dim
-
1
if
begin
<=
end
:
# When slicing backward, if begin <= end, the length of the slicing is 0.
slicing_length
=
0
else
:
slicing_length
=
1
+
(
end
+
1
-
begin
)
//
stride
return
slicing_length
class
StridedSlice
(
PrimitiveWithInfer
):
r
"""
...
...
@@ -1756,13 +1808,15 @@ class StridedSlice(PrimitiveWithInfer):
ellipsis_mask
=
0
,
new_axis_mask
=
0
,
shrink_axis_mask
=
0
):
"""
i
nit StrideSlice"""
"""
I
nit StrideSlice"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'begin'
,
'end'
,
'strides'
],
outputs
=
[
'output'
])
validator
.
check_value_type
(
'begin_mask'
,
begin_mask
,
[
int
],
self
.
name
)
validator
.
check_value_type
(
'end_mask'
,
end_mask
,
[
int
],
self
.
name
)
validator
.
check_value_type
(
'ellipsis_mask'
,
ellipsis_mask
,
[
int
],
self
.
name
)
validator
.
check_value_type
(
'new_axis_mask'
,
new_axis_mask
,
[
int
],
self
.
name
)
validator
.
check_value_type
(
'shrink_axis_mask'
,
shrink_axis_mask
,
[
int
],
self
.
name
)
validator
.
check_integer
(
'begin_mask'
,
begin_mask
,
0
,
Rel
.
GE
,
self
.
name
)
validator
.
check_integer
(
'end_mask'
,
end_mask
,
0
,
Rel
.
GE
,
self
.
name
)
validator
.
check_integer
(
'ellipsis_mask'
,
ellipsis_mask
,
0
,
Rel
.
GE
,
self
.
name
)
if
len
(
tuple
(
filter
(
lambda
x
:
x
==
'1'
,
bin
(
ellipsis_mask
)[
-
1
:
1
:
-
1
])))
>
1
:
raise
ValueError
(
f
"For '
{
self
.
name
}
', only support one ellipsis in the index, but got
{
end_mask
}
."
)
validator
.
check_integer
(
'new_axis_mask'
,
new_axis_mask
,
0
,
Rel
.
GE
,
self
.
name
)
validator
.
check_integer
(
'shrink_axis_mask'
,
shrink_axis_mask
,
0
,
Rel
.
GE
,
self
.
name
)
def
__infer__
(
self
,
x
,
begin
,
end
,
strides
):
begin_v
,
end_v
,
strides_v
=
begin
[
'value'
],
end
[
'value'
],
strides
[
'value'
]
...
...
@@ -1770,58 +1824,103 @@ class StridedSlice(PrimitiveWithInfer):
validator
.
check_value_type
(
"end"
,
end_v
,
[
tuple
],
self
.
name
)
validator
.
check_value_type
(
"strides"
,
strides_v
,
[
tuple
],
self
.
name
)
x_shape
=
x
[
'shape'
]
x_shp_len
=
len
(
x_shape
)
if
len
(
begin_v
)
!=
x_shp_len
or
len
(
end_v
)
!=
x_shp_len
or
len
(
strides_v
)
!=
x_shp_len
:
raise
ValueError
(
f
"For
\'
{
self
.
name
}
\'
the length of begin index
{
begin_v
}
, end index
{
end_v
}
and "
f
"strides
{
strides_v
}
must be equal to the dims(
{
x_shp_len
}
) of input."
)
if
tuple
(
filter
(
lambda
x
:
not
isinstance
(
x
,
int
),
begin_v
+
end_v
+
strides_v
)):
raise
ValueError
(
f
"For
{
self
.
name
}
, both the begins, ends, and strides must be a tuple of int, "
f
"but got begins:
{
begin_v
}
, ends:
{
end_v
}
, strides:
{
strides_v
}
."
)
ret_shape
=
[]
append_dimensions
=
[]
shrink_pos
=
bin
(
self
.
shrink_axis_mask
)[::
-
1
]
new_pos
=
bin
(
self
.
new_axis_mask
)[::
-
1
]
for
i
in
range
(
x_shp_len
):
# After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'
if
i
<
(
len
(
new_pos
)
-
2
)
and
new_pos
[
i
]
==
'1'
:
ret_shape
.
append
(
1
)
append_dimensions
.
append
(
x_shape
[
x_shp_len
-
1
-
len
(
append_dimensions
)])
continue
if
i
<
(
len
(
shrink_pos
)
-
2
)
and
shrink_pos
[
i
]
==
'1'
:
validator
.
check_integer
(
f
'begin[
{
i
}
]'
,
begin_v
[
i
],
-
x_shape
[
i
],
Rel
.
GE
,
self
.
name
)
validator
.
check_integer
(
f
'begin[
{
i
}
]'
,
begin_v
[
i
],
x_shape
[
i
],
Rel
.
LT
,
self
.
name
)
continue
begin_idx
=
begin_v
[
i
]
end_idx
=
end_v
[
i
]
strides_idx
=
strides_v
[
i
]
if
self
.
begin_mask
:
begin_idx
=
0
if
self
.
end_mask
:
end_idx
=
x_shape
[
i
]
validator
.
check_integer
(
f
'begin[
{
i
}
]'
,
begin_idx
,
x_shape
[
i
],
Rel
.
LE
,
self
.
name
)
validator
.
check_integer
(
f
'end[
{
i
}
]'
,
end_idx
,
x_shape
[
i
],
Rel
.
LE
,
self
.
name
)
validator
.
check_integer
(
f
'strides[
{
i
}
]'
,
strides_idx
,
0
,
Rel
.
NE
,
self
.
name
)
if
strides_idx
>
0
:
# If sliced forward , end_idx >= begin_idx
validator
.
check
(
f
'begin[
{
i
}
]'
,
begin_idx
,
f
'end[
{
i
}
]'
,
end_idx
,
Rel
.
LE
)
if
begin_idx
<
0
<
end_idx
:
# Turn negative begin_idx into positive values
begin_idx
=
x_shape
[
i
]
+
begin_idx
num_elems
=
(
end_idx
-
begin_idx
+
strides_idx
-
1
)
//
strides_idx
else
:
# If sliced backwards, end_idx <= begin_idx
validator
.
check
(
f
'begin[
{
i
}
]'
,
begin_idx
,
f
'end[
{
i
}
]'
,
end_idx
,
Rel
.
GE
)
if
end_idx
<
0
<
begin_idx
:
# Turn negative end_idx into positive values
end_idx
=
x_shape
[
i
]
+
end_idx
num_elems
=
(
end_idx
-
begin_idx
+
strides_idx
+
1
)
//
strides_idx
ret_shape
.
append
(
num_elems
)
if
append_dimensions
:
ret_shape
+=
append_dimensions
[::
-
1
]
if
tuple
(
filter
(
lambda
x
:
x
==
0
,
strides_v
)):
raise
ValueError
(
f
"For '
{
self
.
name
}
', the strides cannot contain 0, but got strides:
{
strides_v
}
."
)
if
len
(
end_v
)
!=
len
(
begin_v
)
or
len
(
strides_v
)
!=
len
(
begin_v
):
raise
ValueError
(
f
"For '
{
self
.
name
}
' the length of begin index:
{
begin_v
}
, end index:
{
end_v
}
and "
f
"strides:
{
strides_v
}
must be equal."
)
ret_shape
=
self
.
_compute_slicing_shape
(
x
[
'shape'
],
begin_v
,
end_v
,
strides_v
)
value
=
None
if
all
(
ret_shape
)
else
Tensor
(
np
.
array
([]).
reshape
(
ret_shape
),
x
[
'dtype'
].
element_type
())
return
{
'shape'
:
ret_shape
,
'dtype'
:
x
[
'dtype'
],
'value'
:
None
}
'value'
:
value
}
def
_compute_slicing_shape
(
self
,
x_shape
,
begin_v
,
end_v
,
strides_v
):
"""Compute the shape of the slicing."""
x_rank
=
len
(
x_shape
)
slice_len
=
len
(
begin_v
)
# After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'.
begin_pos
=
bin
(
self
.
begin_mask
)[
-
1
:
1
:
-
1
]
end_pos
=
bin
(
self
.
end_mask
)[
-
1
:
1
:
-
1
]
ellipsis_pos
=
bin
(
self
.
ellipsis_mask
)[
-
1
:
1
:
-
1
]
new_axis_pos
=
bin
(
self
.
new_axis_mask
)[
-
1
:
1
:
-
1
]
shrink_axis_pos
=
bin
(
self
.
shrink_axis_mask
)[
-
1
:
1
:
-
1
]
ret_shape
=
[]
i
,
j
=
0
,
0
has_ellipsis
=
False
while
i
<
x_rank
or
j
<
slice_len
:
if
j
<
slice_len
:
begin
,
end
,
stride
=
begin_v
[
j
],
end_v
[
j
],
strides_v
[
j
]
if
j
<
len
(
ellipsis_pos
)
and
ellipsis_pos
[
j
]
==
'1'
:
# When there is ellipsis, the latter part of the ellipsis will be processed separately.
has_ellipsis
=
True
break
if
j
<
len
(
begin_pos
)
and
begin_pos
[
j
]
==
'1'
:
begin
=
-
1
if
strides_v
[
j
]
<
0
else
0
if
j
<
len
(
end_pos
)
and
end_pos
[
j
]
==
'1'
:
end
=
-
(
x_shape
[
i
]
+
1
)
if
strides_v
[
j
]
<
0
else
x_shape
[
i
]
if
j
<
len
(
new_axis_pos
)
and
new_axis_pos
[
j
]
==
'1'
:
ret_shape
.
append
(
1
)
j
+=
1
continue
if
j
<
len
(
shrink_axis_pos
)
and
shrink_axis_pos
[
j
]
==
'1'
:
if
(
not
-
x_shape
[
i
]
<=
begin
<
x_shape
[
i
])
or
stride
<
0
:
raise
ValueError
(
f
"For
{
self
.
name
}
, when shrink axis, the stride cannot be negative number, "
f
"and begin should be in [-
{
x_shape
[
i
]
}
,
{
x_shape
[
i
]
}
), "
f
"but got stride:
{
stride
}
, begin:
{
begin
}
."
)
j
+=
1
i
+=
1
continue
else
:
begin
,
end
,
stride
=
0
,
x_shape
[
i
],
1
slicing_length
=
_compute_slicing_length
(
begin
,
end
,
stride
,
x_shape
,
i
)
ret_shape
.
append
(
slicing_length
)
i
+=
1
j
+=
1
if
has_ellipsis
:
# When there is ellipsis, handle the second half of the ellipsis split.
ellipsis_occupied_dims
=
x_rank
-
i
-
(
slice_len
-
(
j
+
1
))
+
\
len
(
tuple
(
filter
(
lambda
x
:
x
==
'1'
,
new_axis_pos
[
j
+
1
:
slice_len
])))
ret_shape
.
extend
(
x_shape
[
i
:
i
+
ellipsis_occupied_dims
])
j
+=
1
i
+=
ellipsis_occupied_dims
while
i
<
x_rank
or
j
<
slice_len
:
begin
,
end
,
stride
=
begin_v
[
j
],
end_v
[
j
],
strides_v
[
j
]
if
j
<
len
(
begin_pos
)
and
begin_pos
[
j
]
==
'1'
:
begin
=
-
1
if
strides_v
[
j
]
<
0
else
0
if
j
<
len
(
end_pos
)
and
end_pos
[
j
]
==
'1'
:
end
=
-
(
x_shape
[
i
]
+
1
)
if
strides_v
[
j
]
<
0
else
x_shape
[
i
]
if
j
<
len
(
new_axis_pos
)
and
new_axis_pos
[
j
]
==
'1'
:
ret_shape
.
append
(
1
)
j
+=
1
continue
if
j
<
len
(
shrink_axis_pos
)
and
shrink_axis_pos
[
j
]
==
'1'
:
if
(
not
-
x_shape
[
i
]
<=
begin
<
x_shape
[
i
])
or
stride
<
0
:
raise
ValueError
(
f
"For
{
self
.
name
}
, when shrink axis, the stride cannot be negative number, "
f
"and begin should be in [-
{
x_shape
[
i
]
}
,
{
x_shape
[
i
]
}
), "
f
"but got stride:
{
stride
}
, begin:
{
begin
}
."
)
j
+=
1
i
+=
1
continue
slicing_length
=
_compute_slicing_length
(
begin
,
end
,
stride
,
x_shape
,
i
)
ret_shape
.
append
(
slicing_length
)
i
+=
1
j
+=
1
return
ret_shape
class
Diag
(
PrimitiveWithInfer
):
...
...
@@ -2102,6 +2201,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
>>> op = P.TensorScatterUpdate()
>>> output = op(input_x, indices, update)
"""
@
prim_attr_register
def
__init__
(
self
):
"""Init TensorScatterUpdate"""
...
...
@@ -2153,6 +2253,7 @@ class ScatterUpdate(PrimitiveWithInfer):
(
'indices'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T1
),
(
'value'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
)
)
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
True
):
"""Init ScatterUpdate"""
...
...
@@ -2201,6 +2302,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
(
'indices'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T1
),
(
'value'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
)
)
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
True
):
"""Init ScatterNdUpdate"""
...
...
@@ -2220,6 +2322,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
return
x_dtype
def
_check_scatter_shape
(
x_shape
,
indices_shape
,
updates_shape
,
prim_name
):
if
updates_shape
and
updates_shape
!=
indices_shape
+
x_shape
[
1
:]:
raise
ValueError
(
f
"For '
{
prim_name
}
', the shape of updates should be [] or "
...
...
@@ -2912,6 +3015,7 @@ class InplaceUpdate(PrimitiveWithInfer):
[ 4. 5.]
[ 6. 7.]]]
"""
@
prim_attr_register
def
__init__
(
self
,
indices
):
"""Init InplaceUpdate"""
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
30fc1bd0
...
...
@@ -35,25 +35,6 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import
pipeline_for_compile_grad_ge_graph_for_case_by_case_config
def
test_tensor_scatter_update
():
class
TensorScatterUpdateNet
(
nn
.
Cell
):
"""TensorScatterUpdate net definition"""
def
__init__
(
self
):
super
(
TensorScatterUpdateNet
,
self
).
__init__
()
self
.
tensor_scatter_update
=
P
.
TensorScatterUpdate
()
def
construct
(
self
,
x
,
i
,
u
):
out
=
self
.
tensor_scatter_update
(
x
,
i
,
u
)
return
out
net
=
TensorScatterUpdateNet
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
x
=
Tensor
(
np
.
arange
(
3
*
4
*
5
).
reshape
((
3
,
4
,
5
)),
mstype
.
float32
)
indices
=
Tensor
(
np
.
array
([[
0
,
0
],
[
1
,
1
]],
np
.
int32
))
updates
=
Tensor
(
np
.
ones
([
2
,
5
],
np
.
float32
))
net
(
x
,
indices
,
updates
)
class
InputBackward
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
InputBackward
,
self
).
__init__
()
...
...
@@ -446,6 +427,7 @@ class SparseApplyAdagradNet(nn.Cell):
out
=
self
.
sparse_apply_adagrad
(
self
.
var
,
self
.
accum
,
grad
,
indices
)
return
out
class
ApplyRMSNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
ApplyRMSNet
,
self
).
__init__
()
...
...
@@ -496,6 +478,60 @@ class NormalNet(nn.Cell):
return
out
class
StridedSliceNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
StridedSliceNet
,
self
).
__init__
()
self
.
begins
=
(
1
,
2
,
3
,
2
,
1
)
self
.
ends
=
(
5
,
6
,
7
,
8
,
9
)
self
.
strides
=
(
1
,
2
,
3
,
2
,
1
)
self
.
strided_slice_0
=
P
.
StridedSlice
(
begin_mask
=
3
,
end_mask
=
5
,
ellipsis_mask
=
4
,
shrink_axis_mask
=
2
,
new_axis_mask
=
8
)
self
.
strided_slice_1
=
P
.
StridedSlice
(
begin_mask
=
5
,
end_mask
=
2
,
ellipsis_mask
=
2
,
shrink_axis_mask
=
6
,
new_axis_mask
=
10
)
self
.
strided_slice_2
=
P
.
StridedSlice
(
begin_mask
=
3
,
end_mask
=
3
,
ellipsis_mask
=
4
,
shrink_axis_mask
=
5
,
new_axis_mask
=
13
)
self
.
strided_slice_3
=
P
.
StridedSlice
(
begin_mask
=
0
,
end_mask
=
0
,
ellipsis_mask
=
4
,
shrink_axis_mask
=
12
,
new_axis_mask
=
15
)
self
.
const_0
=
Tensor
(
np
.
ones
([
6
,
8
,
9
,
1
,
8
],
np
.
float32
))
self
.
const_1
=
Tensor
(
np
.
ones
([
5
,
7
,
8
,
1
,
8
],
np
.
float32
))
self
.
const_2
=
Tensor
(
np
.
ones
([
1
,
3
,
7
,
8
,
9
,
1
,
8
],
np
.
float32
))
self
.
const_3
=
Tensor
(
np
.
ones
([
1
,
1
,
6
,
7
,
8
,
9
,
1
,
8
],
np
.
float32
))
def
construct
(
self
,
x
):
out_0
=
self
.
strided_slice_0
(
x
,
self
.
begins
,
self
.
ends
,
self
.
strides
)
+
self
.
const_0
out_1
=
self
.
strided_slice_1
(
x
,
self
.
begins
,
self
.
ends
,
self
.
strides
)
+
self
.
const_1
out_2
=
self
.
strided_slice_2
(
x
,
self
.
begins
,
self
.
ends
,
self
.
strides
)
+
self
.
const_2
out_3
=
self
.
strided_slice_3
(
x
,
self
.
begins
,
self
.
ends
,
self
.
strides
)
+
self
.
const_3
return
out_0
,
out_1
,
out_2
,
out_3
def
test_strided_slice_const
():
class
StridedSLiceConstNet
(
nn
.
Cell
):
"""StridedSLiceConstNet net definition"""
def
__init__
(
self
):
super
(
StridedSLiceConstNet
,
self
).
__init__
()
self
.
begins
=
(
0
,
2
,
-
5
,
2
,
1
)
self
.
ends
=
(
0
,
6
,
9
,
8
,
9
)
self
.
strides
=
(
1
,
2
,
1
,
2
,
1
)
self
.
strided_slice
=
P
.
StridedSlice
(
begin_mask
=
2
,
end_mask
=
6
,
ellipsis_mask
=
4
,
shrink_axis_mask
=
6
,
new_axis_mask
=
18
)
def
construct
(
self
,
x
):
out
=
self
.
strided_slice
(
x
,
self
.
begins
,
self
.
ends
,
self
.
strides
)
return
out
net
=
StridedSLiceConstNet
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
x
=
Tensor
(
np
.
ones
([
6
,
7
,
8
,
9
,
10
]),
mstype
.
float32
)
ret
=
net
(
x
)
assert
ret
.
shape
==
(
0
,
1
,
7
,
8
,
9
,
3
,
1
)
assert
(
ret
.
asnumpy
()
==
np
.
array
([],
np
.
float32
).
reshape
([
0
,
1
,
7
,
8
,
9
,
3
,
1
])).
all
()
test_case_math_ops
=
[
(
'BitwiseAnd'
,
{
'block'
:
P
.
BitwiseAnd
(),
...
...
@@ -1366,6 +1402,10 @@ test_case_nn_ops = [
'desc_inputs'
:
[
Tensor
(
np
.
array
([[
128
,
32
,
32
,
64
],
[
128
,
32
,
32
,
64
]]).
astype
(
np
.
float16
))],
'desc_bprop'
:
[
Tensor
(
np
.
array
([[
128
,
32
,
32
,
64
],
[
128
,
32
,
32
,
64
]]).
astype
(
np
.
float16
))],
'skip'
:
[
'backward'
]}),
(
'StridedSliceNet'
,
{
'block'
:
StridedSliceNet
(),
'desc_inputs'
:
[[
6
,
7
,
8
,
9
,
10
]],
'skip'
:
[
'backward'
]}),
(
'OneHot'
,
{
'block'
:
P
.
OneHot
(),
'desc_const'
:
[
3
,
Tensor
(
1.0
,
mstype
.
float32
),
Tensor
(
0.0
,
mstype
.
float32
)],
...
...
@@ -1763,7 +1803,7 @@ test_case_other_ops = [
'desc_bprop'
:
[([
3
,
3
],
{
'dtype'
:
np
.
int32
})]}),
(
'TensorScatterUpdate'
,
{
'block'
:
P
.
TensorScatterUpdate
(),
'desc_inputs'
:
(
Tensor
(
np
.
arange
(
3
*
4
*
5
).
reshape
((
3
,
4
,
5
)),
mstype
.
float32
),
'desc_inputs'
:
(
Tensor
(
np
.
arange
(
3
*
4
*
5
).
reshape
((
3
,
4
,
5
)),
mstype
.
float32
),
Tensor
(
np
.
array
([[
0
,
1
],
[
1
,
2
]],
np
.
int32
)),
Tensor
(
np
.
ones
([
2
,
5
],
np
.
float32
)
*
99
)),
'desc_bprop'
:
[([
3
,
4
,
5
],
{
'dtype'
:
np
.
float32
})]}),
...
...
@@ -1930,11 +1970,10 @@ test_case_other_ops = [
]
test_case_quant_ops
=
[
(
'AscendQuant_1'
,
{
'block'
:
inner
.
AscendQuant
(
0.5
,
0.0
,
False
,
"Round"
),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
rand
(
1
,
2
,
4
,
4
),
mstype
.
float32
)],
'desc_inputs'
:
[
Tensor
(
np
.
random
.
rand
(
1
,
2
,
4
,
4
),
mstype
.
float32
)],
'skip'
:
[
'backward'
]}),
(
'AscendQuant_2'
,
{
'block'
:
inner
.
AscendQuant
(
80.0
,
10.0
,
True
,
"Round"
),
...
...
@@ -2027,6 +2066,18 @@ raise_set = [
'block'
:
(
nn
.
SSIM
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
((
1
,
3
,
8
,
8
)),
mstype
.
float32
),
Tensor
(
np
.
ones
((
1
,
3
,
8
,
8
)),
mstype
.
float32
)]}),
(
'StridedSlice_0'
,
{
'block'
:
(
P
.
StridedSlice
(),
{
'exception'
:
ValueError
}),
'desc_const'
:
[(
1
,
2.2
,
3
),
(
3
,
4
,
5
),
(
1
,
1
,
1
)],
'desc_inputs'
:
[[
4
,
5
,
6
,
7
]]}),
(
'StridedSlice_1'
,
{
'block'
:
(
P
.
StridedSlice
(),
{
'exception'
:
ValueError
}),
'desc_const'
:
[(
1
,
2
,
3
),
(
3
,
4
,
5
),
(
1
,
1
)],
'desc_inputs'
:
[[
4
,
5
,
6
,
7
]]}),
(
'StridedSlice_2'
,
{
'block'
:
(
P
.
StridedSlice
(),
{
'exception'
:
ValueError
}),
'desc_const'
:
[(
1
,
2
,
3
),
(
3
,
4
,
5
),
(
1
,
1
,
0
)],
'desc_inputs'
:
[[
4
,
5
,
6
,
7
]]}),
]
...
...
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
30fc1bd0
...
...
@@ -25,6 +25,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
,
\
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
class
NetWorkSlicePositive
(
Cell
):
def
__init__
(
self
):
super
(
NetWorkSlicePositive
,
self
).
__init__
()
...
...
@@ -1159,10 +1160,8 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
input_tensor
=
Tensor
(
np
.
ones
([
6
,
8
,
10
],
np
.
int32
))
net
=
NetWork
()
with
pytest
.
raises
(
ValueError
)
as
ex
:
with
pytest
.
raises
(
ValueError
):
net
(
input_tensor
)
assert
"For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`"
in
str
(
ex
.
value
)
def
test_tensor_slice_reduce_out_of_bounds_positive
():
...
...
@@ -1177,6 +1176,5 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
input_tensor
=
Tensor
(
np
.
ones
([
6
,
8
,
10
],
np
.
int32
))
net
=
NetWork
()
with
pytest
.
raises
(
ValueError
)
as
ex
:
with
pytest
.
raises
(
ValueError
):
net
(
input_tensor
)
assert
"For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`"
in
str
(
ex
.
value
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录