Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
663d5973
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
663d5973
编写于
4月 23, 2020
作者:
C
candanzg
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
tensor assign with slice index
Signed-off-by:
N
candanzg
<
zhangshucheng@huawei.com
>
上级
9edc69af
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
432 addition
and
35 deletion
+432
-35
mindspore/_extends/parse/__init__.py
mindspore/_extends/parse/__init__.py
+2
-2
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+5
-0
mindspore/_extends/utils.py
mindspore/_extends/utils.py
+11
-0
mindspore/ccsrc/ir/value.h
mindspore/ccsrc/ir/value.h
+3
-0
mindspore/ccsrc/pipeline/parse/parse_base.h
mindspore/ccsrc/pipeline/parse/parse_base.h
+2
-0
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+7
-0
mindspore/ccsrc/utils/convert_utils.cc
mindspore/ccsrc/utils/convert_utils.cc
+8
-0
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
+117
-0
mindspore/ops/composite/multitype_ops/setitem_impl.py
mindspore/ops/composite/multitype_ops/setitem_impl.py
+172
-28
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-0
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+104
-5
未找到文件。
mindspore/_extends/parse/__init__.py
浏览文件 @
663d5973
...
...
@@ -18,7 +18,7 @@ Interfaces for parser module in c++.
from
.parser
import
(
Parser
,
create_obj_instance
,
generate_scope
,
get_bprop_method_of_class
,
get_class_instance_type
,
get_class_member_namespace_symbol
,
get_class_member_namespace_symbol
,
create_slice_obj
,
get_dataclass_attributes
,
get_dataclass_methods
,
get_module_namespace
,
get_obj_type
,
get_object_key
,
get_parse_method_of_class
,
get_scope_name
,
...
...
@@ -29,4 +29,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_object_key'
,
'get_class_instance_type'
,
'is_class_member'
,
'get_obj_type'
,
'create_obj_instance'
,
'get_module_namespace'
,
'get_class_member_namespace_symbol'
,
'Parser'
,
'get_dataclass_attributes'
,
'get_dataclass_methods'
,
'dump_obj'
,
'load_obj'
,
'get_dataclass_methods'
,
'get_scope_name'
]
'get_dataclass_methods'
,
'get_scope_name'
,
'create_slice_obj'
]
mindspore/_extends/parse/parser.py
浏览文件 @
663d5973
...
...
@@ -29,6 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype
from
mindspore.common.api
import
_MindSporeFunction
from
.namespace
import
CellNamespace
,
ClosureNamespace
,
ClassMemberNamespace
from
.resources
import
parse_object_map
,
convert_object_map
,
trope_ns
,
SYMBOL_UNDEFINE
,
NO_IMPLEMENT
from
..utils
import
Slice
# define return value
RET_SUCCESS
=
0
...
...
@@ -69,6 +70,10 @@ parse_expr_statement_white_list = (
"append"
,
)
def
create_slice_obj
(
start
,
end
,
step
):
"""Create Slice object"""
return
Slice
(
start
,
end
,
step
)
def
parse_cb
(
func
,
parse_method
=
None
):
"""Implements the function of parse."""
...
...
mindspore/_extends/utils.py
浏览文件 @
663d5973
...
...
@@ -19,6 +19,7 @@ import logging
import
os
import
inspect
from
functools
import
wraps
from
dataclasses
import
dataclass
def
cal_sha256
(
file_path
):
...
...
@@ -99,3 +100,13 @@ def cell_attr_register(fn=None, attrs=None):
if
fn
is
not
None
:
return
wrap_cell
(
fn
)
return
wrap_cell
@
dataclass
class
Slice
:
"""
Slice class
"""
start
:
int
end
:
int
step
:
int
mindspore/ccsrc/ir/value.h
浏览文件 @
663d5973
...
...
@@ -123,6 +123,9 @@ class ValueSlice : public Value {
abstract
::
AbstractBasePtr
ToAbstract
()
override
;
std
::
string
DumpText
()
const
override
{
return
ToString
();
}
ValuePtr
start
()
const
{
return
start_
;
}
ValuePtr
stop
()
const
{
return
stop_
;
}
ValuePtr
step
()
const
{
return
step_
;
}
private:
ValuePtr
start_
;
...
...
mindspore/ccsrc/pipeline/parse/parse_base.h
浏览文件 @
663d5973
...
...
@@ -79,6 +79,8 @@ const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement";
const
char
PYTHON_PARSE_GENERATE_SCOPE
[]
=
"generate_scope"
;
const
char
PYTHON_PARSE_GET_SCOPE_NAME
[]
=
"get_scope_name"
;
const
char
PYTHON_PARSE_CLASS_SLICE
[]
=
"create_slice_obj"
;
// define the common name
const
char
NAMED_PRIMITIVE_ITER
[]
=
"iter"
;
const
char
NAMED_PRIMITIVE_NEXT
[]
=
"next"
;
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
663d5973
...
...
@@ -289,6 +289,13 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic
[
"shape"
]
=
shape
;
dic
[
"dtype"
]
=
abs_base
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
abs_base
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractSlice
>
())
{
auto
arg_slice
=
dyn_cast
<
AbstractSlice
>
(
abs_base
);
std
::
vector
<
int
>
shape
;
dic
[
"shape"
]
=
shape
;
dic
[
"dtype"
]
=
arg_slice
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
arg_slice
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractTuple
>
())
{
auto
arg_tuple
=
dyn_cast
<
AbstractTuple
>
(
abs_base
);
size_t
len
=
arg_tuple
->
size
();
...
...
mindspore/ccsrc/utils/convert_utils.cc
浏览文件 @
663d5973
...
...
@@ -28,6 +28,7 @@
#include "ir/meta_tensor.h"
#include "pipeline/parse/parse.h"
#include "pipeline/parse/parse_base.h"
#include "ir/value.h"
namespace
mindspore
{
...
...
@@ -97,6 +98,13 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
i
++
;
}
ret
=
rets
;
}
else
if
(
value
->
isa
<
ValueSlice
>
())
{
auto
slice
=
value
->
cast
<
ValueSlicePtr
>
();
auto
start
=
ValuePtrToPyData
(
slice
->
start
());
auto
end
=
ValuePtrToPyData
(
slice
->
stop
());
auto
step
=
ValuePtrToPyData
(
slice
->
step
());
ret
=
parse
::
python_adapter
::
CallPyFn
(
parse
::
PYTHON_MOD_PARSE_MODULE
,
parse
::
PYTHON_PARSE_CLASS_SLICE
,
start
,
end
,
step
);
}
else
if
(
value
->
isa
<
Type
>
())
{
py
::
tuple
v
(
1
);
v
[
0
]
=
value
->
cast
<
TypePtr
>
();
...
...
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
浏览文件 @
663d5973
...
...
@@ -15,7 +15,43 @@
"""constexpr util"""
import
numpy
as
np
from
...primitive
import
constexpr
from
....common.tensor
import
Tensor
from
....common
import
dtype
as
mstype
from
...._extends.utils
import
Slice
@
constexpr
def
check_equal
(
param1
,
param2
,
msg
=
"{},{}"
):
if
param1
!=
param2
:
raise
ValueError
(
msg
.
format
(
param1
,
param2
))
return
param1
@
constexpr
def
check_tensor_setitem_index
(
index
,
element_type
=
None
):
"""Check tuple index type of tensor assignment."""
if
index
is
None
:
raise
ValueError
(
"Tensor's index cannot be None."
)
# eg. Tensor[Slice] = u
if
isinstance
(
index
,
Slice
):
return
True
# eg. Tensor[Tuple] = u
if
isinstance
(
index
,
tuple
):
if
not
index
:
raise
ValueError
(
"Tensor's index cannot be empty."
)
# eg. Tensor[Tuple(Slice...)] = u
if
not
isinstance
(
index
[
0
],
Slice
):
raise
ValueError
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
[
0
])))
return
True
# eg. Tensor[Tensor[dtype=bool]] = u
if
index
==
mstype
.
tensor
:
if
element_type
is
None
or
element_type
!=
mstype
.
bool_
:
raise
ValueError
(
"The index of tensor should be a bool type tensor.
\
{} type is not supported yet."
.
format
(
element_type
))
return
True
raise
ValueError
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
)))
@
constexpr
...
...
@@ -43,3 +79,84 @@ def error_msg(msg="", format_values=""):
"""
raise
ValueError
(
msg
.
format
(
*
format_values
))
def
slice_expand
(
input_slices
,
shape
):
"""
Convert slice to indices.
Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple.
Outputs:
(List, List, List), This is expressed as (begins, ends, strides).
"""
begin
=
[]
end
=
[]
strides
=
[]
index
=
0
slices
=
None
# Slice or Tuple(Slice...)
if
isinstance
(
input_slices
,
Slice
):
slices
=
(
input_slices
,)
elif
isinstance
(
input_slices
,
(
tuple
,
list
))
and
input_slices
and
isinstance
(
input_slices
[
0
],
Slice
):
slices
=
input_slices
else
:
raise
ValueError
(
"Tensor's index type is not supported yet."
)
for
s
in
slices
:
start
=
0
if
(
s
.
start
is
None
)
else
s
.
start
stop
=
shape
[
index
]
if
(
s
.
end
is
None
)
else
s
.
end
step
=
1
if
(
s
.
step
is
None
)
else
s
.
step
begin
.
append
(
start
)
end
.
append
(
stop
)
strides
.
append
(
step
)
index
+=
1
while
index
<
len
(
shape
):
begin
.
append
(
0
)
end
.
append
(
shape
[
index
])
strides
.
append
(
1
)
index
+=
1
return
begin
,
end
,
strides
@
constexpr
def
slice2indices
(
input_slices
,
shape
):
"""
Convert slice to indices.
Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple.
Outputs:
Tensor, the shape is (n, 1).
"""
begin
,
end
,
strides
=
slice_expand
(
input_slices
,
shape
)
np_r
=
[]
for
i
,
element
in
enumerate
(
shape
):
s
=
begin
[
i
]
if
(
begin
[
i
]
>=
0
)
else
(
element
+
begin
[
i
])
e
=
end
[
i
]
if
(
end
[
i
]
>=
0
)
else
(
element
+
end
[
i
])
np_r
.
append
(
np
.
r_
[
s
:
e
:
strides
[
i
]])
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
np_ix
=
np
.
ix_
(
*
np_r
)
ravel
=
np
.
ravel_multi_index
(
np_ix
,
shape
)
ravel
=
Tensor
(
ravel
.
reshape
(
-
1
,
1
),
dtype
=
mstype
.
int32
)
return
ravel
@
constexpr
def
check_indices
(
indices_size
,
index
):
if
indices_size
<
1
:
raise
ValueError
(
"The tensor's index is unreasonable. index:{}"
.
format
(
index
))
return
indices_size
@
constexpr
def
check_indices_value_size
(
indices_size
,
value_size
):
if
value_size
<
1
:
raise
ValueError
(
"The value assigned to tensor cannot be empty."
)
if
value_size
>
1
:
if
value_size
!=
indices_size
:
raise
ValueError
(
"The value given to tensor does not match the index size.
\
value size:{}, indics size:{}"
.
format
(
value_size
,
indices_size
))
return
value_size
mindspore/ops/composite/multitype_ops/setitem_impl.py
浏览文件 @
663d5973
...
...
@@ -138,25 +138,23 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
Outputs:
Tensor, element type and shape is same as data.
"""
result
=
None
index_dtype
=
F
.
dtype
(
index
)
index_shape
=
F
.
shape
(
index
)
is_bool
=
mult_util
.
is_same_type
(
index_dtype
,
mstype
.
bool_
)
if
not
is_bool
:
return
mult_util
.
error_msg
(
"The tensor index should be a bool type tensor. {} type tensor is not supported yet."
,
(
index_dtype
,))
data_shape
=
F
.
shape
(
data
)
if
index_shape
!=
data_shape
:
return
mult_util
.
error_msg
(
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
,
(
data_shape
,
index_shape
))
size
=
F
.
size
(
value_tensor
)
if
size
!=
1
:
return
mult_util
.
error_msg
(
"When assign value is a tensor, its size should be 1, but current size is {}."
,
(
size
,))
dtype
=
F
.
dtype
(
data
)
u_cast
=
F
.
cast
(
value_tensor
,
dtype
)
one_data
=
F
.
ones_like
(
data
)
u
=
F
.
tensor_mul
(
one_data
,
u_cast
)
return
F
.
select
(
index
,
u
,
data
)
check_result
=
mult_util
.
check_tensor_setitem_index
(
mstype
.
tensor
,
index_dtype
)
if
check_result
:
data_shape
=
F
.
shape
(
data
)
data_shape
=
mult_util
.
check_equal
(
data_shape
,
index_shape
,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
)
size
=
F
.
size
(
value_tensor
)
size
=
mult_util
.
check_equal
(
1
,
size
,
"When assign value is a tensor, its size should be {}, but current size is {}."
)
dtype
=
F
.
dtype
(
data
)
u_cast
=
F
.
cast
(
value_tensor
,
dtype
)
one_data
=
F
.
ones_like
(
data
)
u
=
F
.
tensor_mul
(
one_data
,
u_cast
)
result
=
F
.
select
(
index
,
u
,
data
)
return
result
@
setitem
.
register
(
"Tensor"
,
"Tensor"
,
"Number"
)
...
...
@@ -179,16 +177,162 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
result
=
None
index_dtype
=
F
.
dtype
(
index
)
index_shape
=
F
.
shape
(
index
)
is_bool
=
mult_util
.
is_same_type
(
index_dtype
,
mstype
.
bool_
)
if
not
is_bool
:
return
mult_util
.
error_msg
(
"The tensor index should be a bool type tensor. {} type tensor is not supported yet."
,
(
index_dtype
,))
shape
=
F
.
shape
(
data
)
if
index_shape
!=
shape
:
return
mult_util
.
error_msg
(
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
,
(
shape
,
index_shape
))
dtype
=
F
.
dtype
(
data
)
u
=
F
.
fill
(
dtype
,
shape
,
value
)
return
F
.
select
(
index
,
u
,
data
)
check_result
=
mult_util
.
check_tensor_setitem_index
(
mstype
.
tensor
,
index_dtype
)
if
check_result
:
shape
=
F
.
shape
(
data
)
shape
=
mult_util
.
check_equal
(
shape
,
index_shape
,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
)
dtype
=
F
.
dtype
(
data
)
u
=
F
.
fill
(
dtype
,
shape
,
value
)
result
=
F
.
select
(
index
,
u
,
data
)
return
result
@
setitem
.
register
(
"Tensor"
,
"Slice"
,
"Tensor"
)
def
_tensor_setitem_with_slice_v3
(
data
,
input_slice
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = U
Restraint condition: A is a Tensor
Slice like "1:3"
U is a Tensor(size=1) or Tensor(size>1)
Inputs:
data (Tensor): Assigned tensor.
input_slice (Slice): Slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return
_tensor_assgin_tensor
(
data
,
input_slice
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Tuple"
,
"Tensor"
)
def
_tensor_setitem_with_slice_v4
(
data
,
input_slice
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = U
Restraint condition: A is a Tensor
Slice like "1:3, ::, :4:-1"
U is a Tensor(size=1) or Tensor(size>1)
Inputs:
data (Tensor): Assigned tensor.
input_slice (Tuple(Slice)): Slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return
_tensor_assgin_tensor
(
data
,
input_slice
,
value
)
def
_tensor_assgin_tensor
(
data
,
input_slice
,
value
):
"""Given a tensor value assign to tensor by slice"""
# 1. condition
result
=
None
check_result
=
mult_util
.
check_tensor_setitem_index
(
input_slice
)
if
check_result
:
data_shape
=
F
.
shape
(
data
)
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices
=
mult_util
.
slice2indices
(
input_slice
,
data_shape
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
input_slice
)
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
# 2. u
value_fill
=
None
value_size
=
F
.
size
(
value
)
value_size
=
mult_util
.
check_indices_value_size
(
indices_size
,
value_size
)
if
value_size
==
1
:
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
value
=
F
.
cast
(
value
,
data_dtype
)
value_fill
=
F
.
tensor_mul
(
value_fill
,
value
)
elif
value_size
>
1
:
value_fill
=
F
.
reshape
(
value
,
(
indices_size
,))
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
# A[slice]= u -> A[B]=U -> select(B, U, A)
result
=
F
.
select
(
condition
,
u
,
data
)
return
result
@
setitem
.
register
(
"Tensor"
,
"Slice"
,
"Number"
)
def
_tensor_setitem_with_slice_v1
(
data
,
input_slice
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = u
Restraint condition: A is a Tensor.
Slice like "1:3"
u is a scalar
Inputs:
data (Tensor): Assigned tensor.
input_slice (Slice): slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return
_tensor_assgin_number
(
data
,
input_slice
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Tuple"
,
"Number"
)
def
_tensor_setitem_with_slice_v2
(
data
,
input_slice
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = u
Restraint condition: A is a Tensor.
Slice like "1:3, ::, :4:-1"
u is a scalar
Inputs:
data (Tensor): Assigned tensor.
input_slice (Tuple(Slice)): slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return
_tensor_assgin_number
(
data
,
input_slice
,
value
)
def
_tensor_assgin_number
(
data
,
input_slice
,
value
):
"""Given a scalar assign to tensor by slice"""
# 1. condition
check_result
=
mult_util
.
check_tensor_setitem_index
(
input_slice
)
result
=
None
if
check_result
:
data_shape
=
F
.
shape
(
data
)
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices
=
mult_util
.
slice2indices
(
input_slice
,
data_shape
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
input_slice
)
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
# 2. u
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
value
)
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
# A[slice]= u -> A[B]=U -> select(B, U, A)
result
=
F
.
select
(
condition
,
u
,
data
)
return
result
mindspore/ops/functional.py
浏览文件 @
663d5973
...
...
@@ -68,6 +68,7 @@ tuple_to_array = P.TupleToArray()
scalar_cast
=
P
.
ScalarCast
()
print_
=
P
.
Print
()
expand_dims
=
P
.
ExpandDims
()
scatter_nd
=
P
.
ScatterNd
()
tuple_setitem
=
Primitive
(
'tuple_setitem'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
...
...
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
663d5973
...
...
@@ -94,10 +94,101 @@ class NetWorkReduceToScalar(Cell):
return
ret
class
TensorAssignWithSliceError1
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithSliceError1
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
a
[
1
:
3
:
-
1
,::]
=
b
return
a
class
TensorAssignWithSliceError2
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithSliceError2
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
a
[
1
:
3
:
-
1
]
=
b
return
a
class
TensorAssignWithSlice2
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithSlice2
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
a
[
1
:
5
]
=
b
a
[
3
:
4
]
=
5
a
[
-
1
:
1
:
-
1
]
=
b
a
[
-
1
:
3
:
-
1
]
=
5
a
[::]
=
b
a
[::]
=
9
return
a
class
TensorAssignWithSlice
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithSlice
,
self
).
__init__
()
self
.
c
=
2
def
construct
(
self
,
a
,
b
):
a
[
1
:
3
,::]
=
b
a
[
2
:
3
:,
3
:]
=
b
a
[::]
=
b
a
[::]
=
self
.
c
a
[::,::]
=
b
a
[::,::]
=
self
.
c
a
[
2
:
3
:,
0
:,
4
:
1
:
-
1
]
=
b
a
[
2
:
3
:,
0
:,
4
:
1
:
-
1
]
=
self
.
c
z
=
a
return
z
def
test_tensor_assign_with_slice
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
net
=
TensorAssignWithSlice
()
net2
=
TensorAssignWithSlice2
()
net_e1
=
TensorAssignWithSliceError1
()
net_e2
=
TensorAssignWithSliceError2
()
a
=
np
.
arange
(
60
).
reshape
(
3
,
4
,
5
)
b
=
Tensor
([
1
])
Ta
=
Tensor
(
a
)
Tb
=
Tensor
([
1
,
3
])
Tc
=
Tensor
([])
t
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
])
net
(
Ta
,
b
)
net2
(
t
,
b
)
# Error for A[Slice] = Number
# 1. A[Slice] = Number, Slice error
with
pytest
.
raises
(
ValueError
):
net_e2
(
t
,
2
)
# Error for A[Slice] = U, U is a Tensor
# 1. A[Slice] = U, u.size is error
with
pytest
.
raises
(
ValueError
):
net2
(
t
,
Tb
)
# 2. A[Slice] = U, U is empty
with
pytest
.
raises
(
ValueError
):
net2
(
t
,
Tc
)
# 3. A[Slice] = U, U.size error
with
pytest
.
raises
(
ValueError
):
net2
(
t
,
Tb
)
# Error for A[Tuple(Slice...)] = Tensor
# 1. A[Tuple(Slice...)] = U, U is empty
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tc
)
# 2. A[Tuple(Slice...)] = U, U.size error
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tb
)
# 3. A[Tuple(Slice...)] = U, Slice error
with
pytest
.
raises
(
ValueError
):
net_e1
(
Ta
,
b
)
# Error for A[Tuple(Slice...)] = Number
# 1. A[Tuple(Slice...)] = Number, Slice error
with
pytest
.
raises
(
ValueError
):
net_e1
(
Ta
,
2
)
class
TensorAssignWithBoolTensorIndex
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex
,
self
).
__init__
()
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
self
.
t
=
Tensor
(
np
.
arange
(
6
0
).
reshape
([
3
,
4
,
5
]),
dtype
=
mstype
.
float64
)
def
construct
(
self
,
a
,
b
,
c
,
u_tensor
,
_scalar
):
a
[
c
]
=
u_scalar
...
...
@@ -119,6 +210,7 @@ class TensorAssignWithBoolTensorIndex2(Cell):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex2
,
self
).
__init__
()
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
self
.
t
=
Tensor
(
np
.
arange
(
60
).
reshape
([
3
,
4
,
5
]),
dtype
=
mstype
.
float64
)
def
construct
(
self
,
a
,
u_tensor
,
_scalar
):
a
[
a
>
8
]
=
u_tensor
...
...
@@ -139,7 +231,7 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
return
a
a
=
np
.
random
.
uniform
(
1
,
10
,
[
2
,
3
])
a
=
np
.
random
.
uniform
(
1
,
10
,[
3
,
4
,
5
])
b
=
a
>
5
c
=
a
<
3
Ta
=
Tensor
(
a
)
...
...
@@ -148,13 +240,13 @@ Tc = Tensor(c)
Td
=
Tensor
([
True
,
True
])
u_tensor
=
Tensor
([
1
])
u_tensor_error
=
Tensor
([
1
,
2
])
t_1d
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
])
u_scalar
=
5
def
test_tensor_assign_bool_index
():
net1
=
TensorAssignWithBoolTensorIndex
()
net2
=
TensorAssignWithBoolTensorIndex2
()
net1
(
Ta
,
Tb
,
Tc
,
u_tensor
,
u_scalar
)
net1
(
Ta
,
Tb
,
Tc
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Td
,
Tc
,
u_tensor
,
u_scalar
)
...
...
@@ -180,8 +272,15 @@ def test_tensor_assign_bool_index():
with
pytest
.
raises
(
AttributeError
):
net4
(
Ta
,
u_scalar
)
test_cases
=
[
(
'TensorAssignWithSlice'
,
{
'block'
:
TensorAssignWithSlice
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
}),
(
'TensorAssignWithSlice2'
,
{
'block'
:
TensorAssignWithSlice2
(),
'desc_inputs'
:
[
t_1d
,
u_tensor
],
}),
(
'TensorAssignWithBoolTensorIndex'
,
{
'block'
:
TensorAssignWithBoolTensorIndex
(),
'desc_inputs'
:
[
Ta
,
Tb
,
Tc
,
u_tensor
,
u_scalar
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录