Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e886a318
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e886a318
编写于
4月 26, 2020
作者:
C
candanzg
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
tensor assign with ellpsis
Signed-off-by:
N
candanzg
<
zhangshucheng@huawei.com
>
上级
decc8404
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
254 addition
and
86 deletion
+254
-86
mindspore/_extends/parse/__init__.py
mindspore/_extends/parse/__init__.py
+2
-2
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+6
-1
mindspore/_extends/utils.py
mindspore/_extends/utils.py
+7
-0
mindspore/ccsrc/pipeline/parse/parse_base.h
mindspore/ccsrc/pipeline/parse/parse_base.h
+1
-0
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+6
-0
mindspore/ccsrc/utils/convert_utils.cc
mindspore/ccsrc/utils/convert_utils.cc
+2
-0
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
+62
-19
mindspore/ops/composite/multitype_ops/setitem_impl.py
mindspore/ops/composite/multitype_ops/setitem_impl.py
+33
-4
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+135
-60
未找到文件。
mindspore/_extends/parse/__init__.py
浏览文件 @
e886a318
...
...
@@ -22,11 +22,11 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes
,
get_dataclass_methods
,
get_module_namespace
,
get_obj_type
,
get_object_key
,
get_parse_method_of_class
,
get_scope_name
,
is_class_member
,
parse_cb
,
resolve_symbol
)
is_class_member
,
parse_cb
,
resolve_symbol
,
create_ellipsis_obj
)
from
.serialize
import
*
__all__
=
[
'parse_cb'
,
'get_parse_method_of_class'
,
'get_bprop_method_of_class'
,
'resolve_symbol'
,
'get_object_key'
,
'get_class_instance_type'
,
'is_class_member'
,
'get_obj_type'
,
'create_obj_instance'
,
'get_module_namespace'
,
'get_class_member_namespace_symbol'
,
'Parser'
,
'get_dataclass_attributes'
,
'get_dataclass_methods'
,
'dump_obj'
,
'load_obj'
,
'get_dataclass_methods'
,
'get_scope_name'
,
'create_slice_obj'
]
'get_dataclass_methods'
,
'get_scope_name'
,
'create_slice_obj'
,
'create_ellipsis_obj'
]
mindspore/_extends/parse/parser.py
浏览文件 @
e886a318
...
...
@@ -29,7 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype
from
mindspore.common.api
import
_MindSporeFunction
from
.namespace
import
CellNamespace
,
ClosureNamespace
,
ClassMemberNamespace
from
.resources
import
parse_object_map
,
convert_object_map
,
trope_ns
,
SYMBOL_UNDEFINE
,
NO_IMPLEMENT
from
..utils
import
Slice
from
..utils
import
Slice
,
Ellipsis_
# define return value
RET_SUCCESS
=
0
...
...
@@ -70,6 +70,11 @@ parse_expr_statement_white_list = (
"append"
,
)
def
create_ellipsis_obj
():
"""Create Slice object"""
return
Ellipsis_
()
def
create_slice_obj
(
start
,
end
,
step
):
"""Create Slice object"""
return
Slice
(
start
,
end
,
step
)
...
...
mindspore/_extends/utils.py
浏览文件 @
e886a318
...
...
@@ -110,3 +110,10 @@ class Slice:
start
:
int
end
:
int
step
:
int
@
dataclass
class
Ellipsis_
:
"""
Ellipsis class
"""
mindspore/ccsrc/pipeline/parse/parse_base.h
浏览文件 @
e886a318
...
...
@@ -80,6 +80,7 @@ const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
const
char
PYTHON_PARSE_GET_SCOPE_NAME
[]
=
"get_scope_name"
;
const
char
PYTHON_PARSE_CLASS_SLICE
[]
=
"create_slice_obj"
;
const
char
PYTHON_PARSE_CLASS_ELLIPSIS
[]
=
"create_ellipsis_obj"
;
// define the common name
const
char
NAMED_PRIMITIVE_ITER
[]
=
"iter"
;
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
e886a318
...
...
@@ -298,6 +298,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
}
else
if
(
abs_base
->
isa
<
AbstractRef
>
())
{
auto
value
=
abs_base
->
cast
<
AbstractRefPtr
>
()
->
ref
();
dic
=
ConvertAbstractToPython
(
value
);
}
else
if
(
abs_base
->
isa
<
AbstractEllipsis
>
())
{
auto
arg_slice
=
dyn_cast
<
AbstractEllipsis
>
(
abs_base
);
std
::
vector
<
int
>
shape
;
dic
[
"shape"
]
=
shape
;
dic
[
"dtype"
]
=
arg_slice
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
arg_slice
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractTuple
>
())
{
auto
arg_tuple
=
dyn_cast
<
AbstractTuple
>
(
abs_base
);
size_t
len
=
arg_tuple
->
size
();
...
...
mindspore/ccsrc/utils/convert_utils.cc
浏览文件 @
e886a318
...
...
@@ -98,6 +98,8 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
i
++
;
}
ret
=
rets
;
}
else
if
(
value
->
isa
<
EllipsisObj
>
())
{
ret
=
parse
::
python_adapter
::
CallPyFn
(
parse
::
PYTHON_MOD_PARSE_MODULE
,
parse
::
PYTHON_PARSE_CLASS_ELLIPSIS
);
}
else
if
(
value
->
isa
<
ValueSlice
>
())
{
auto
slice
=
value
->
cast
<
ValueSlicePtr
>
();
auto
start
=
ValuePtrToPyData
(
slice
->
start
());
...
...
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
浏览文件 @
e886a318
...
...
@@ -20,7 +20,7 @@ import numpy as np
from
...primitive
import
constexpr
from
....common.tensor
import
Tensor
from
....common
import
dtype
as
mstype
from
...._extends.utils
import
Slice
from
...._extends.utils
import
Slice
,
Ellipsis_
@
constexpr
def
check_equal
(
param1
,
param2
,
msg
=
"{},{}"
):
...
...
@@ -29,31 +29,40 @@ def check_equal(param1, param2, msg="{},{}"):
raise
ValueError
(
msg
.
format
(
param1
,
param2
))
return
param1
@
constexpr
def
check_ellipsis_shape_size
(
data_shape
,
value_shape
,
data_size
,
value_size
):
"""Checks the shape and size of the sensor and value."""
if
data_shape
==
value_shape
or
data_size
==
value_size
or
value_size
==
1
:
return
True
raise
ValueError
(
"The value(shape={}), can not assign to tensor(shape={})."
.
format
(
value_shape
,
data_shape
))
@
constexpr
def
check_tensor_setitem_index
(
index
,
element_type
=
None
):
"""Checks tuple index type of tensor assignment."""
if
index
is
None
:
raise
Value
Error
(
"Tensor's index cannot be None."
)
raise
Index
Error
(
"Tensor's index cannot be None."
)
# eg. Tensor[Slice] = u
if
isinstance
(
index
,
Slice
):
return
True
# eg. Tensor[tuple] = u
if
isinstance
(
index
,
tuple
):
if
not
index
:
raise
Value
Error
(
"Tensor's index cannot be empty."
)
raise
Index
Error
(
"Tensor's index cannot be empty."
)
# eg. Tensor[tuple(Slice...)] = u
if
isinstance
(
index
[
0
],
(
Slice
,
int
)):
if
isinstance
(
index
[
0
],
(
Slice
,
Ellipsis_
,
int
)):
return
True
raise
Value
Error
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
[
0
])))
raise
Index
Error
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
[
0
])))
# eg. Tensor[Tensor[dtype=bool]] = u
if
index
==
mstype
.
tensor
:
if
element_type
is
None
or
element_type
!=
mstype
.
bool_
:
raise
Valu
eError
(
"The index of tensor should be a bool type tensor.
\
{} type is not supported yet."
.
format
(
element_type
))
raise
Typ
eError
(
"The index of tensor should be a bool type tensor.
"
"
{} type is not supported yet."
.
format
(
element_type
))
return
True
raise
Value
Error
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
)))
raise
Index
Error
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
)))
@
constexpr
...
...
@@ -90,10 +99,18 @@ def slice_expand(input_slices, shape):
# Slice or tuple(Slice...)
if
isinstance
(
input_slices
,
Slice
):
slices
=
(
input_slices
,)
elif
isinstance
(
input_slices
,
(
tuple
,
list
))
and
input_slices
and
isinstance
(
input_slices
[
0
],
Slice
):
slices
=
input_slices
elif
isinstance
(
input_slices
,
(
tuple
,
list
))
and
input_slices
and
isinstance
(
input_slices
[
0
],
(
Slice
,
Ellipsis_
)):
is_have_ellipsis
=
False
for
_
,
element
in
enumerate
(
input_slices
):
if
isinstance
(
element
,
Ellipsis_
):
is_have_ellipsis
=
True
break
if
is_have_ellipsis
:
slices
=
ellipsis2slice
(
input_slices
,
shape
)
else
:
slices
=
input_slices
else
:
raise
Value
Error
(
"Tensor's index type is not supported yet."
)
raise
Index
Error
(
"Tensor's index type is not supported yet."
)
for
s
in
slices
:
start
=
0
if
(
s
.
start
is
None
)
else
s
.
start
...
...
@@ -111,6 +128,26 @@ def slice_expand(input_slices, shape):
return
begin
,
end
,
strides
def
ellipsis2slice
(
input_
,
shape
):
"""Converts ellipsis to slice."""
input_slice
=
input_
result
=
[]
if
isinstance
(
input_
,
Ellipsis_
):
input_slice
=
(
input_
,)
ell_count
=
0
for
_
,
element
in
enumerate
(
input_slice
):
if
not
isinstance
(
element
,
Ellipsis_
):
result
.
append
(
element
)
continue
ell_count
+=
1
if
ell_count
>
1
:
raise
IndexError
(
"There cannot be more than one ellisis (...) in the index of the tensor, "
"but it is currently {}"
.
format
(
input_slice
))
for
_
in
range
(
len
(
shape
)
-
len
(
input_slice
)
+
1
):
result
.
append
(
Slice
(
None
,
None
,
None
))
return
tuple
(
result
)
@
constexpr
def
slice2indices
(
input_slices
,
shape
):
"""
...
...
@@ -139,7 +176,7 @@ def slice2indices(input_slices, shape):
def
check_indices
(
indices_size
,
index
):
"""Checks indices whether is empty."""
if
indices_size
<
1
:
raise
Value
Error
(
"The tensor's index is unreasonable. index:{}"
.
format
(
index
))
raise
Index
Error
(
"The tensor's index is unreasonable. index:{}"
.
format
(
index
))
return
indices_size
...
...
@@ -151,8 +188,8 @@ def check_indices_value_size(indices_size, value_size):
if
value_size
>
1
:
if
value_size
!=
indices_size
:
raise
ValueError
(
"The value given to tensor does not match the index size
.
\
value size:{}, indics size:{}"
.
format
(
value_size
,
indices_size
))
"The value given to tensor does not match the index size
,"
"
value size:{}, indics size:{}"
.
format
(
value_size
,
indices_size
))
return
value_size
@
constexpr
...
...
@@ -168,8 +205,11 @@ def integer_to_indices(index, shape):
def
tuple_element_is_slice
(
indexs
):
"""Judges tuple element type."""
if
not
indexs
:
raise
ValueError
(
"Tensor's index cannot be empty."
)
if
isinstance
(
indexs
,
tuple
)
and
isinstance
(
indexs
[
0
],
Slice
):
raise
IndexError
(
"Tensor's index cannot be empty."
)
if
isinstance
(
indexs
,
tuple
):
for
_
,
ele
in
enumerate
(
indexs
):
if
not
isinstance
(
ele
,
Slice
):
return
False
return
True
return
False
...
...
@@ -177,7 +217,10 @@ def tuple_element_is_slice(indexs):
def
tuple_element_is_int
(
indexs
):
"""Judges tuple element type."""
if
not
indexs
:
raise
ValueError
(
"Tensor's index cannot be empty."
)
if
isinstance
(
indexs
,
tuple
)
and
isinstance
(
indexs
[
0
],
int
):
raise
IndexError
(
"Tensor's index cannot be empty."
)
if
isinstance
(
indexs
,
tuple
):
for
_
,
ele
in
enumerate
(
indexs
):
if
not
isinstance
(
ele
,
int
):
return
False
return
True
return
False
mindspore/ops/composite/multitype_ops/setitem_impl.py
浏览文件 @
e886a318
...
...
@@ -254,10 +254,10 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
update
=
F
.
fill
(
mstype
.
int32
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
condition
=
F
.
cast
(
condition
,
mstype
.
bool_
)
value_fill
=
None
value_size
=
F
.
size
(
value
)
...
...
@@ -336,10 +336,10 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
update
=
F
.
fill
(
mstype
.
int32
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition_1d
=
F
.
cast
(
condition_1d
,
mstype
.
bool_
)
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
condition
=
F
.
cast
(
condition
,
mstype
.
bool_
)
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
value
)
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
...
...
@@ -360,3 +360,32 @@ def _tensor_setitem_with_int_v2(data, index, value):
data_shape
=
F
.
shape
(
data
)
indices
=
mult_util
.
integer_to_indices
(
index
,
data_shape
)
return
_tensor_indices_tensor
(
data
,
data_shape
,
index
,
indices
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Ellipsis"
,
"Number"
)
def
_tensor_setitem_with_ellipsis_v1
(
data
,
index
,
value
):
"""Syntax: A[...] = number."""
data_shape
=
F
.
shape
(
data
)
data_dtype
=
F
.
dtype
(
data
)
return
F
.
fill
(
data_dtype
,
data_shape
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Ellipsis"
,
"Tensor"
)
def
_tensor_setitem_with_ellipsis_v2
(
data
,
index
,
value
):
"""Syntax: A[...] = Tensor."""
result
=
None
data_shape
=
F
.
shape
(
data
)
data_dtype
=
F
.
dtype
(
data
)
data_size
=
F
.
size
(
data
)
value_shape
=
F
.
shape
(
value
)
value_size
=
F
.
size
(
value
)
check_result
=
mult_util
.
check_ellipsis_shape_size
(
data_shape
,
value_shape
,
data_size
,
value_size
)
if
check_result
:
if
data_size
==
value_size
:
result
=
F
.
reshape
(
value
,
data_shape
)
result
=
F
.
cast
(
result
,
data_dtype
)
elif
value_size
==
1
:
param1
=
F
.
fill
(
data_dtype
,
data_shape
,
1
)
param2
=
F
.
cast
(
value
,
data_dtype
)
result
=
F
.
tensor_mul
(
param1
,
param2
)
return
result
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
e886a318
...
...
@@ -103,6 +103,7 @@ class TensorAssignWithSliceError1(Cell):
a
[
1
:
3
:
-
1
,::]
=
b
return
a
class
TensorAssignWithSliceError2
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithSliceError2
,
self
).
__init__
()
...
...
@@ -110,24 +111,29 @@ class TensorAssignWithSliceError2(Cell):
def
construct
(
self
,
a
,
b
):
a
[
1
:
3
:
-
1
]
=
b
return
a
class
TensorAssignWithSlice2
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithSlice2
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
def
construct
(
self
,
a
,
b
,
ck
):
a
[
1
:
5
]
=
b
a
[
3
:
4
]
=
5
a
[
-
1
:
1
:
-
1
]
=
b
a
[
-
1
:
3
:
-
1
]
=
5
a
[::]
=
b
a
[::]
=
9
return
a
z
=
a
+
ck
return
z
class
TensorAssignWithSlice
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithSlice
,
self
).
__init__
()
self
.
c
=
2
def
construct
(
self
,
a
,
b
):
def
construct
(
self
,
a
,
b
,
ck
):
a
[
1
:
3
,::]
=
b
a
[
2
:
3
:,
3
:]
=
b
a
[::]
=
b
...
...
@@ -136,9 +142,10 @@ class TensorAssignWithSlice(Cell):
a
[::,::]
=
self
.
c
a
[
2
:
3
:,
0
:,
4
:
1
:
-
1
]
=
b
a
[
2
:
3
:,
0
:,
4
:
1
:
-
1
]
=
self
.
c
z
=
a
z
=
a
+
ck
return
z
def
test_tensor_assign
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
net
=
TensorAssignWithSlice
()
...
...
@@ -146,95 +153,145 @@ def test_tensor_assign():
net_e1
=
TensorAssignWithSliceError1
()
net_e2
=
TensorAssignWithSliceError2
()
a
=
np
.
arange
(
60
).
reshape
(
3
,
4
,
5
)
b
=
Tensor
([
1
])
Ta
=
Tensor
(
a
)
Ta4d
=
Tensor
(
a
.
reshape
(
1
,
3
,
4
,
5
))
Tb
=
Tensor
([
1
,
3
])
Tc
=
Tensor
([])
t
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
])
net
(
Ta
,
b
)
net2
(
t
,
b
)
ck
=
np
.
arange
(
60
).
reshape
(
3
,
4
,
5
)
b
=
Tensor
([
1
],
dtype
=
mstype
.
float32
)
Ta
=
Tensor
(
a
,
dtype
=
mstype
.
float32
)
Tck
=
Tensor
(
ck
,
dtype
=
mstype
.
float32
)
Ta4d
=
Tensor
(
a
.
reshape
(
1
,
3
,
4
,
5
),
dtype
=
mstype
.
float32
)
Ta4d_ck
=
Tensor
(
ck
.
reshape
(
1
,
3
,
4
,
5
),
dtype
=
mstype
.
float32
)
Tb
=
Tensor
([
1
,
3
],
dtype
=
mstype
.
float32
)
Tc
=
Tensor
([],
dtype
=
mstype
.
float32
)
t
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
dtype
=
mstype
.
float32
)
tck
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
dtype
=
mstype
.
float32
)
net
(
Ta
,
b
,
Tck
)
net2
(
t
,
b
,
tck
)
# Error for A[Slice] = Number
# 1. A[Slice] = Number, Slice error
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
Index
Error
):
net_e2
(
t
,
2
)
# Error for A[Slice] = U, U is a Tensor
# 1. A[Slice] = U, u.size is error
with
pytest
.
raises
(
ValueError
):
net2
(
t
,
Tb
)
net2
(
t
,
Tb
,
tck
)
# 2. A[Slice] = U, U is empty
with
pytest
.
raises
(
ValueError
):
net2
(
t
,
Tc
)
net2
(
t
,
Tc
,
tck
)
# 3. A[Slice] = U, U.size error
with
pytest
.
raises
(
ValueError
):
net2
(
t
,
Tb
)
net2
(
t
,
Tb
,
tck
)
# Error for A[Tuple(Slice...)] = Tensor
# 1. A[Tuple(Slice...)] = U, U is empty
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tc
)
net
(
Ta
,
Tc
,
Tck
)
# 2. A[Tuple(Slice...)] = U, U.size error
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tb
)
net
(
Ta
,
Tb
,
Tck
)
# 3. A[Tuple(Slice...)] = U, Slice error
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
Index
Error
):
net_e1
(
Ta
,
b
)
# Error for A[Tuple(Slice...)] = Number
# 1. A[Tuple(Slice...)] = Number, Slice error
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
Index
Error
):
net_e1
(
Ta
,
2
)
net
=
TensorAssignWithInteger
()
# Error for A[Number] = scalar/Tensor
# 1. A[Number] = U, U is a Tensor, u.size not match
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tb
)
net
(
Ta
,
Tb
,
Tck
)
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tc
)
net
(
Ta
,
Tc
,
Tck
)
# 2. A[Number] = U, the number index error
with
pytest
.
raises
(
IndexError
):
net
(
Ta4d
,
b
)
net
(
Ta4d
,
b
,
Ta4d_ck
)
# Error for A[(n,m)] = scalar/Tensor
# 1. A[(n,m)] = U, U is a tensor. u.size not match
net
=
TensorAssignWithTupleInteger
()
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tc
)
net
(
Ta
,
Tc
,
Tck
)
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tb
)
net
(
Ta
,
Tb
,
Tck
)
# 2. A[(n,m)] = U, the number index error
with
pytest
.
raises
(
IndexError
):
net
(
Ta4d
,
b
)
net
(
Ta4d
,
b
,
Ta4d_ck
)
#Error for A[...] = U or A[1:, ...] = u
#1. A[...] = scalar/tensor
net
=
TensorAssignWithEllipsis
()
net
(
Ta
,
Ta4d
)
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tc
)
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tb
)
#2. A[::, 1:, ...] = scalar/tensor
net
=
TensorAssignWithTupleEllipsis
()
net
(
Ta
,
b
)
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tc
)
with
pytest
.
raises
(
ValueError
):
net
(
Ta
,
Tb
)
class
TensorAssignWithTupleEllipsis2
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithTupleEllipsis2
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
a
[
1
:,
...,
::]
=
b
return
a
class
TensorAssignWithTupleEllipsis
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithTupleEllipsis
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
a
[:
2
,
...]
=
1
a
[
1
:,
...]
=
b
return
a
class
TensorAssignWithEllipsis
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithEllipsis
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
a
[...]
=
1
a
[...]
=
b
return
a
class
TensorAssignWithInteger
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithInteger
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
def
construct
(
self
,
a
,
b
,
ck
):
a
[
1
]
=
1
a
[
0
]
=
b
return
a
z
=
a
+
ck
return
z
class
TensorAssignWithTupleInteger
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithTupleInteger
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
):
def
construct
(
self
,
a
,
b
,
ck
):
a
[(
1
)]
=
1
a
[(
1
)]
=
b
a
[(
1
,
1
)]
=
b
a
[(
1
,
1
)]
=
1
return
a
z
=
a
+
ck
return
z
class
TensorAssignWithBoolTensorIndex
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex
,
self
).
__init__
()
self
.
t
=
Tensor
(
np
.
arange
(
60
).
reshape
([
3
,
4
,
5
]),
dtype
=
mstype
.
float64
)
self
.
t
=
Tensor
(
np
.
arange
(
60
).
reshape
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
self
.
u_scalar
=
5
def
construct
(
self
,
a
,
b
,
c
,
u_tensor
,
_scalar
):
a
[
c
]
=
u_scalar
def
construct
(
self
,
a
,
b
,
c
,
u_tensor
):
a
[
c
]
=
self
.
u_scalar
a
[
b
]
=
u_tensor
z
=
a
+
self
.
t
return
z
...
...
@@ -252,15 +309,16 @@ class TensorAssignWithBoolTensorIndexError(Cell):
class
TensorAssignWithBoolTensorIndex2
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex2
,
self
).
__init__
()
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
self
.
t
=
Tensor
(
np
.
arange
(
60
).
reshape
([
3
,
4
,
5
]),
dtype
=
mstype
.
float64
)
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float32
)
self
.
t
=
Tensor
(
np
.
arange
(
60
).
reshape
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
self
.
u_scalar
=
5
def
construct
(
self
,
a
,
u_tensor
,
_scalar
):
def
construct
(
self
,
a
,
u_tensor
):
a
[
a
>
8
]
=
u_tensor
a
[
a
>=
6
]
=
u_scalar
a
[
a
<
3
]
=
u_scalar
a
[
a
>=
6
]
=
self
.
u_scalar
a
[
a
<
3
]
=
self
.
u_scalar
a
[
a
<=
5
]
=
u_tensor
a
[
a
==
5
]
=
u_scalar
a
[
a
==
5
]
=
self
.
u_scalar
z
=
a
+
self
.
t
return
z
...
...
@@ -274,36 +332,41 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
return
a
a
=
np
.
random
.
uniform
(
1
,
10
,[
3
,
4
,
5
])
a
=
np
.
arange
(
60
).
reshape
(
3
,
4
,
5
)
ck
=
np
.
arange
(
60
).
reshape
(
3
,
4
,
5
)
a4
=
np
.
arange
(
60
).
reshape
(
3
,
2
,
2
,
5
)
b
=
a
>
5
c
=
a
<
3
Ta
=
Tensor
(
a
)
Ta
=
Tensor
(
a
,
dtype
=
mstype
.
float32
)
Tck
=
Tensor
(
ck
,
dtype
=
mstype
.
float32
)
Ta4
=
Tensor
(
a4
,
dtype
=
mstype
.
float32
)
Tb
=
Tensor
(
b
)
Tc
=
Tensor
(
c
)
Td
=
Tensor
([
True
,
True
])
u_tensor
=
Tensor
([
1
])
u_tensor_error
=
Tensor
([
1
,
2
])
t_1d
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
])
u_tensor
=
Tensor
([
1
],
dtype
=
mstype
.
float32
)
u_tensor_error
=
Tensor
([
1
,
2
],
dtype
=
mstype
.
float32
)
t_1d
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
dtype
=
mstype
.
float32
)
tck_1d
=
Tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
dtype
=
mstype
.
float32
)
u_scalar
=
5
def
test_tensor_assign_bool_index
():
net1
=
TensorAssignWithBoolTensorIndex
()
net2
=
TensorAssignWithBoolTensorIndex2
()
net1
(
Ta
,
Tb
,
Tc
,
u_tensor
,
u_scalar
)
net1
(
Ta
,
Tb
,
Tc
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Td
,
Tc
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
u_tensor
,
Tc
,
u_tensor
,
u_scalar
)
net1
(
Ta
,
Tb
,
Tc
,
u_tensor
)
net1
(
Ta
,
Tb
,
Tc
,
u_tensor
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Tb
,
Td
,
u_tensor
,
u_scalar
)
net1
(
Ta
,
Td
,
Tc
,
u_tensor
)
with
pytest
.
raises
(
TypeError
):
net1
(
Ta
,
u_tensor
,
Tc
,
u_tensor
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Tb
,
Ta
,
u_tensor
,
u_scalar
)
net1
(
Ta
,
Tb
,
Td
,
u_tensor
)
with
pytest
.
raises
(
TypeError
):
net1
(
Ta
,
Tb
,
Ta
,
u_tensor
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Tb
,
Tc
,
u_tensor_error
,
u_scalar
)
net1
(
Ta
,
Tb
,
Tc
,
u_tensor_error
)
# net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
with
pytest
.
raises
(
ValueError
):
net2
(
Ta
,
u_tensor_error
,
u_scalar
)
net2
(
Ta
,
u_tensor_error
)
net3
=
TensorAssignWithBoolTensorIndexError
()
with
pytest
.
raises
(
AttributeError
):
net3
(
Ta
,
Tb
,
Tc
,
u_tensor
)
...
...
@@ -316,29 +379,41 @@ def test_tensor_assign_bool_index():
net4
(
Ta
,
u_scalar
)
test_cases
=
[
(
'TensorAssignWithTupleEllipsis2'
,
{
'block'
:
TensorAssignWithTupleEllipsis2
(),
'desc_inputs'
:
[
Ta4
,
u_tensor
],
}),
(
'TensorAssignWithTupleEllipsis'
,
{
'block'
:
TensorAssignWithTupleEllipsis
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
}),
(
'TensorAssignWithEllipsis'
,
{
'block'
:
TensorAssignWithEllipsis
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
}),
(
'TensorAssignWithTupleInteger'
,
{
'block'
:
TensorAssignWithTupleInteger
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
'desc_inputs'
:
[
Ta
,
u_tensor
,
Tck
],
}),
(
'TensorAssignWithInteger'
,
{
'block'
:
TensorAssignWithInteger
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
'desc_inputs'
:
[
Ta
,
u_tensor
,
Tck
],
}),
(
'TensorAssignWithSlice'
,
{
'block'
:
TensorAssignWithSlice
(),
'desc_inputs'
:
[
Ta
,
u_tensor
],
'desc_inputs'
:
[
Ta
,
u_tensor
,
Tck
],
}),
(
'TensorAssignWithSlice2'
,
{
'block'
:
TensorAssignWithSlice2
(),
'desc_inputs'
:
[
t_1d
,
u_tensor
],
'desc_inputs'
:
[
t_1d
,
u_tensor
,
tck_1d
],
}),
(
'TensorAssignWithBoolTensorIndex'
,
{
'block'
:
TensorAssignWithBoolTensorIndex
(),
'desc_inputs'
:
[
Ta
,
Tb
,
Tc
,
u_tensor
,
u_scalar
],
'desc_inputs'
:
[
Ta
,
Tb
,
Tc
,
u_tensor
],
}),
(
'TensorAssignWithBoolTensorIndex2'
,
{
'block'
:
TensorAssignWithBoolTensorIndex2
(),
'desc_inputs'
:
[
Ta
,
u_tensor
,
u_scalar
],
'desc_inputs'
:
[
Ta
,
u_tensor
],
}),
(
'SlicePositive'
,
{
'block'
:
NetWorkSlicePositive
(),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录