Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e490618d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e490618d
编写于
5月 12, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support tensor get value by tensor index
support tensor set value by tensor index
上级
ca74e624
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
1251 addition
and
179 deletion
+1251
-179
mindspore/ccsrc/operator/composite/composite.cc
mindspore/ccsrc/operator/composite/composite.cc
+6
-6
mindspore/ccsrc/operator/composite/composite.h
mindspore/ccsrc/operator/composite/composite.h
+0
-2
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+2
-0
mindspore/ccsrc/transform/op_declare.cc
mindspore/ccsrc/transform/op_declare.cc
+5
-0
mindspore/ccsrc/transform/op_declare.h
mindspore/ccsrc/transform/op_declare.h
+2
-0
mindspore/ops/_op_impl/tbe/__init__.py
mindspore/ops/_op_impl/tbe/__init__.py
+6
-5
mindspore/ops/_op_impl/tbe/scatter_update.py
mindspore/ops/_op_impl/tbe/scatter_update.py
+42
-0
mindspore/ops/_utils/__init__.py
mindspore/ops/_utils/__init__.py
+2
-2
mindspore/ops/_utils/utils.py
mindspore/ops/_utils/utils.py
+7
-7
mindspore/ops/composite/multitype_ops/_utils.py
mindspore/ops/composite/multitype_ops/_utils.py
+487
-0
mindspore/ops/composite/multitype_ops/getitem_impl.py
mindspore/ops/composite/multitype_ops/getitem_impl.py
+39
-5
mindspore/ops/composite/multitype_ops/setitem_impl.py
mindspore/ops/composite/multitype_ops/setitem_impl.py
+216
-114
mindspore/ops/functional.py
mindspore/ops/functional.py
+7
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+2
-1
mindspore/ops/operations/_grad_ops.py
mindspore/ops/operations/_grad_ops.py
+2
-2
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+64
-25
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+2
-2
tests/mindspore_test_framework/components/executor/exec_forward.py
...dspore_test_framework/components/executor/exec_forward.py
+7
-1
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+1
-1
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+350
-4
tests/ut/python/optimizer/test_debug_location.py
tests/ut/python/optimizer/test_debug_location.py
+2
-2
未找到文件。
mindspore/ccsrc/operator/composite/composite.cc
浏览文件 @
e490618d
...
...
@@ -1172,6 +1172,12 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, co
return
1
;
}
FuncGraphPtr
ExpandADim
(
const
FuncGraphPtr
&
ret_graph
,
const
AnfNodePtr
&
tensor_node
)
{
auto
PrimExpandDims
=
GetPythonOps
(
"expand_dims"
,
"mindspore.ops.functional"
);
ret_graph
->
set_output
(
NewCNode
({
NewValueNode
(
PrimExpandDims
),
tensor_node
,
NewValueNode
(
0
)},
ret_graph
));
return
ret_graph
;
}
FuncGraphPtr
TensorSlice
::
GenerateFuncGraph
(
const
AbstractBasePtrList
&
args_spec_list
)
{
// slice a tensor
// args: tensor, slice or slice tuple
...
...
@@ -1229,12 +1235,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
return
ret_graph
;
}
FuncGraphPtr
TensorSlice
::
ExpandADim
(
const
FuncGraphPtr
&
ret_graph
,
const
AnfNodePtr
&
tensor_node
)
const
{
auto
PrimExpandDims
=
GetPythonOps
(
"expand_dims"
,
"mindspore.ops.functional"
);
ret_graph
->
set_output
(
NewCNode
({
NewValueNode
(
PrimExpandDims
),
tensor_node
,
NewValueNode
(
0
)},
ret_graph
));
return
ret_graph
;
}
FuncGraphPtr
TupleGetItemTensor
::
GenerateFuncGraph
(
const
AbstractBasePtrList
&
args_spec_list
)
{
// select indexed item
// args: tuple of items, index
...
...
mindspore/ccsrc/operator/composite/composite.h
浏览文件 @
e490618d
...
...
@@ -206,8 +206,6 @@ class TensorSlice : public MetaFuncGraph {
MS_DECLARE_PARENT
(
TensorSlice
,
MetaFuncGraph
)
FuncGraphPtr
GenerateFuncGraph
(
const
AbstractBasePtrList
&
args_spec_list
)
override
;
friend
bool
operator
==
(
const
TensorSlice
&
lhs
,
const
TensorSlice
&
rhs
)
{
return
lhs
.
name_
==
rhs
.
name_
;
}
FuncGraphPtr
ExpandADim
(
const
FuncGraphPtr
&
ret_graph
,
const
AnfNodePtr
&
tensor_node
)
const
;
};
using
TensorSlicePtr
=
std
::
shared_ptr
<
TensorSlice
>
;
...
...
mindspore/ccsrc/transform/convert.cc
浏览文件 @
e490618d
...
...
@@ -101,6 +101,7 @@ const char kNameReLU6[] = "ReLU6";
const
char
kNameReLU6Grad
[]
=
"ReLU6Grad"
;
const
char
kNameElu
[]
=
"Elu"
;
const
char
kNameEluGrad
[]
=
"EluGrad"
;
const
char
kNameScatterUpdate
[]
=
"ScatterUpdate"
;
const
char
kNameScatterNdUpdate
[]
=
"ScatterNdUpdate"
;
const
char
kNameScatterMax
[]
=
"ScatterMax"
;
const
char
kNameNMSWithMask
[]
=
"NMSWithMask"
;
...
...
@@ -256,6 +257,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{
string
(
kNameResizeBilinear
),
ADPT_DESC
(
ResizeBilinearV2D
)},
{
string
(
kNameZerosLike
),
ADPT_DESC
(
ZerosLike
)},
{
string
(
kNameOnesLike
),
ADPT_DESC
(
OnesLike
)},
{
string
(
kNameScatterUpdate
),
ADPT_DESC
(
ScatterUpdate
)},
{
string
(
kNameScatterNdUpdate
),
ADPT_DESC
(
ScatterNdUpdate
)},
{
string
(
kNameScatterMax
),
ADPT_DESC
(
ScatterMax
)},
{
string
(
kNameNMSWithMask
),
ADPT_DESC
(
NMSWithMask
)},
...
...
mindspore/ccsrc/transform/op_declare.cc
浏览文件 @
e490618d
...
...
@@ -515,6 +515,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}};
ATTR_MAP
(
Unpack
)
=
{{
"axis"
,
ATTR_DESC
(
axis
,
AnyTraits
<
int
>
())},
{
"num"
,
ATTR_DESC
(
num
,
AnyTraits
<
int
>
())}};
DYN_OUTPUT_MAP
(
Unpack
)
=
{{
0
,
DYN_OUTPUT_DESC
(
y
)}};
// ScatterUpdate
INPUT_MAP
(
ScatterUpdate
)
=
{{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
indices
)},
{
3
,
INPUT_DESC
(
updates
)}};
ATTR_MAP
(
ScatterUpdate
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ScatterUpdate
)
=
{{
0
,
OUTPUT_DESC
(
var
)}};
// ScatterNdUpdate
INPUT_MAP
(
ScatterNdUpdate
)
=
{{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
indices
)},
{
3
,
INPUT_DESC
(
updates
)}};
ATTR_MAP
(
ScatterNdUpdate
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
...
...
mindspore/ccsrc/transform/op_declare.h
浏览文件 @
e490618d
...
...
@@ -132,6 +132,8 @@ DECLARE_OP_ADAPTER(ZerosLike)
DECLARE_OP_USE_OUTPUT
(
ZerosLike
)
DECLARE_OP_ADAPTER
(
OnesLike
)
DECLARE_OP_USE_OUTPUT
(
OnesLike
)
DECLARE_OP_ADAPTER
(
ScatterUpdate
)
DECLARE_OP_USE_OUTPUT
(
ScatterUpdate
)
DECLARE_OP_ADAPTER
(
ScatterNdUpdate
)
DECLARE_OP_USE_OUTPUT
(
ScatterNdUpdate
)
DECLARE_OP_ADAPTER
(
ScatterMax
)
...
...
mindspore/ops/_op_impl/tbe/__init__.py
浏览文件 @
e490618d
...
...
@@ -178,13 +178,14 @@ from .bounding_box_encode import _bounding_box_encode_tbe
from
.check_valid
import
_check_valid_tbe
from
.iou
import
_iou_tbe
from
.arg_max
import
_arg_max_tbe
from
.nms_with_mask
import
nms_with_mask_op_info
from
.random_choice_with_mask
import
random_choice_with_mask_op_info
from
.sgd
import
sgd_op_info
from
.lars_update
import
lars_update_op_info
from
.nms_with_mask
import
_nms_with_mask_tbe
from
.random_choice_with_mask
import
_random_choice_with_mask_tbe
from
.sgd
import
_sgd_tbe
from
.lars_update
import
_lars_update_tbe
from
.bn_training_update_v2
import
_bn_training_update_v2_tbe
from
.square_sum_all
import
square_sum_all_op_info
from
.square_sum_all
import
_square_sum_all_tbe
from
.pack
import
_pack_tbe
from
.unpack
import
_unpack_tbe
from
.scatter_update
import
_scatter_update_tbe
from
.prelu
import
_prelu_tbe
from
.prelu_grad
import
_prelu_grad_tbe
mindspore/ops/_op_impl/tbe/scatter_update.py
0 → 100644
浏览文件 @
e490618d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterUpdate op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
scatter_update_op_info
=
TBERegOp
(
"ScatterUpdate"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"scatter_update.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"scatter_update"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"use_locking"
,
"optional"
,
"bool"
,
"all"
)
\
.
input
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"indices"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"updates"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
I32_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
I32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I32_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
I32_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
,)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
I32_Default
,
DataType
.
BOOL_Default
,
DataType
.
BOOL_Default
)
\
.
get_op_info
()
@
op_info_register
(
scatter_update_op_info
)
def
_scatter_update_tbe
():
"""ScatterUpdate TBE register"""
return
mindspore/ops/_utils/__init__.py
浏览文件 @
e490618d
...
...
@@ -14,6 +14,6 @@
# ============================================================================
"""ops utils."""
from
.utils
import
_get_broadcast_shape
,
_
get_concat_offset
from
.utils
import
get_broadcast_shape
,
get_concat_offset
__all__
=
[
'
_get_broadcast_shape'
,
'_
get_concat_offset'
]
__all__
=
[
'
get_broadcast_shape'
,
'
get_concat_offset'
]
mindspore/ops/_utils/utils.py
浏览文件 @
e490618d
...
...
@@ -19,7 +19,8 @@ from ..._checkparam import Validator as validator
from
..._checkparam
import
Rel
from
...common
import
dtype
as
mstype
def
_get_broadcast_shape
(
x_shape
,
y_shape
,
prim_name
):
def
get_broadcast_shape
(
x_shape
,
y_shape
,
prim_name
):
"""
Doing broadcast between tensor x and tensor y.
...
...
@@ -37,7 +38,7 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
Examples:
>>> x_shape = [1, 2, 3]
>>> y_shape = [1, 2]
>>> broadcast_shape =
_
get_broadcast_shape(x_shape, y_shape)
>>> broadcast_shape = get_broadcast_shape(x_shape, y_shape)
"""
if
x_shape
==
y_shape
:
return
x_shape
...
...
@@ -54,15 +55,14 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
elif
x_shape
[
i
]
==
y_shape
[
i
]:
broadcast_shape_back
.
append
(
x_shape
[
i
])
else
:
raise
ValueError
(
"For '{}' the x_shape {} and y_shape {} can not broadcast."
.
format
(
prim_name
,
x_shape
,
y_shape
))
raise
ValueError
(
f
"For '
{
prim_name
}
', the x_shape
{
x_shape
}
and y_shape
{
y_shape
}
can not broadcast."
)
broadcast_shape_front
=
y_shape
[
0
:
y_len
-
length
]
if
length
==
x_len
else
x_shape
[
0
:
x_len
-
length
]
broadcast_shape
=
broadcast_shape_front
+
broadcast_shape_back
broadcast_shape
=
list
(
broadcast_shape_front
)
+
broadcast_shape_back
return
broadcast_shape
def
_
get_concat_offset
(
x_shp
,
x_type
,
axis
,
prim_name
):
def
get_concat_offset
(
x_shp
,
x_type
,
axis
,
prim_name
):
"""for concat and concatoffset check args and compute offset"""
validator
.
check_value_type
(
"shape"
,
x_shp
,
[
tuple
],
prim_name
)
validator
.
check_integer
(
"input_x rank"
,
len
(
x_shp
),
0
,
Rel
.
GT
,
prim_name
)
...
...
@@ -73,7 +73,7 @@ def _get_concat_offset(x_shp, x_type, axis, prim_name):
if
axis
<
0
:
axis
=
axis
+
rank_base
all_shp
=
x_shp
[
0
][
axis
]
offset
=
[
0
,
]
offset
=
[
0
]
for
i
in
range
(
1
,
len
(
x_shp
)):
v
=
x_shp
[
i
]
validator
.
check
(
'len of x_shp[%d]'
%
i
,
len
(
v
),
'len of x_shp[0]'
,
len
(
x_shp
[
0
]),
Rel
.
EQ
,
prim_name
)
...
...
mindspore/ops/composite/multitype_ops/_
multitype_ops_util
.py
→
mindspore/ops/composite/multitype_ops/_
utils
.py
浏览文件 @
e490618d
...
...
@@ -14,13 +14,36 @@
# ============================================================================
"""constexpr util"""
from
functools
import
reduce
import
numpy
as
np
from
...primitive
import
constexpr
from
....common.tensor
import
Tensor
from
....common
import
dtype
as
mstype
from
...._extends.utils
import
Slice
,
Ellipsis_
from
....ops
import
_utils
as
op_utils
from
...composite
import
base
from
....
import
log
as
logger
from
...
import
functional
as
F
from
...
import
operations
as
P
hyper_map
=
base
.
HyperMap
()
pack
=
P
.
Pack
(
axis
=-
1
)
ALL_TENSOR
=
0
NO_TENSOR
=
1
CONTAIN_TENSOR
=
2
ALL_SCALAR
=
3
INT_
=
0
BOOL_
=
1
UNSUPPORTED_DTYPE
=
2
TENSOR_SETITEM
=
"tensor setitem"
TENSOR_GETITEM
=
"tensor getitem"
SET_ITEM_BY_ONE_TENSOR
=
0
SET_ITEM_BY_TUPLE_OF_TENSOR
=
1
@
constexpr
def
check_equal
(
param1
,
param2
,
msg
=
"{},{}"
):
...
...
@@ -55,7 +78,7 @@ def check_tensor_setitem_index(index, element_type=None):
return
True
raise
IndexError
(
"Index of type '{}' is not supported yet."
.
format
(
type
(
index
[
0
])))
# eg. Tensor[Tensor[dtype=bool]] = u
if
i
ndex
==
mstype
.
tensor
:
if
i
sinstance
(
index
,
mstype
.
tensor_type
)
:
if
element_type
is
None
or
element_type
!=
mstype
.
bool_
:
raise
TypeError
(
"The index of tensor should be a bool type tensor. "
...
...
@@ -172,6 +195,7 @@ def slice2indices(input_slices, shape):
ravel
=
Tensor
(
ravel
.
reshape
(
-
1
,
1
),
dtype
=
mstype
.
int32
)
return
ravel
@
constexpr
def
check_indices
(
indices_size
,
index
):
"""Checks indices whether is empty."""
...
...
@@ -192,6 +216,7 @@ def check_indices_value_size(indices_size, value_size):
" value size:{}, indics size:{}"
.
format
(
value_size
,
indices_size
))
return
value_size
@
constexpr
def
integer_to_indices
(
index
,
shape
):
"""Converts int or tuple[int] to indices."""
...
...
@@ -201,6 +226,7 @@ def integer_to_indices(index, shape):
value
=
value
.
reshape
(
-
1
,
1
)
return
Tensor
(
value
,
dtype
=
mstype
.
int32
)
@
constexpr
def
tuple_element_is_slice
(
indexs
):
"""Judges tuple element type."""
...
...
@@ -213,6 +239,7 @@ def tuple_element_is_slice(indexs):
return
True
return
False
@
constexpr
def
tuple_element_is_int
(
indexs
):
"""Judges tuple element type."""
...
...
@@ -224,3 +251,237 @@ def tuple_element_is_int(indexs):
return
False
return
True
return
False
@
constexpr
def
tuple_elements_type
(
types
):
"""Judges the type of all elements of the tuple."""
tensors_number
=
0
for
ele
in
types
:
if
isinstance
(
ele
,
mstype
.
tensor_type
):
tensors_number
+=
1
if
tensors_number
==
len
(
types
):
return
ALL_TENSOR
if
tensors_number
==
0
:
return
NO_TENSOR
return
CONTAIN_TENSOR
@
constexpr
def
check_value_elements
(
data_dtype
,
types
):
"""Judges the type of all elements of the tuple."""
tensors_number
=
0
scalars_number
=
0
for
i
,
ele
in
enumerate
(
types
):
if
isinstance
(
ele
,
mstype
.
tensor_type
):
ele_dtype
=
ele
.
element_type
()
if
data_dtype
==
ele_dtype
:
tensors_number
+=
1
else
:
raise
TypeError
(
f
"For '
{
TENSOR_SETITEM
}
', the data type of
{
i
}
th tensor '
{
ele_dtype
}
' "
f
"in value tuple is not consistent with origin tensor data type '
{
data_dtype
}
'."
)
elif
mstype
.
issubclass_
(
ele
,
data_dtype
):
scalars_number
+=
1
else
:
raise
TypeError
(
f
"For '
{
TENSOR_SETITEM
}
', the
{
i
}
th element type '
{
ele
}
' in "
f
"value tuple is not consistent with origin tensor data type '
{
data_dtype
}
'."
)
if
tensors_number
==
len
(
types
):
return
ALL_TENSOR
if
scalars_number
==
len
(
types
):
return
ALL_SCALAR
raise
TypeError
(
f
"For '
{
TENSOR_SETITEM
}
', the value does not support scalar and tensor mixing, but got
{
types
}
."
)
@
constexpr
def
get_index_tensor_dtype
(
dtype
):
"""Check a tuple of tensor data type."""
if
dtype
==
mstype
.
int32
:
return
INT_
if
dtype
==
mstype
.
bool_
:
return
BOOL_
raise
TypeError
(
f
"For '
{
TENSOR_SETITEM
}
', the index tensor data type '
{
dtype
}
' is not supported."
)
@
constexpr
def
check_index_tensors_dtype
(
dtypes
,
op_name
):
"""Check a tuple of tensor data type."""
if
op_name
==
TENSOR_GETITEM
:
valid_dtypes
=
(
mstype
.
int32
,
mstype
.
int64
)
elif
op_name
==
TENSOR_SETITEM
:
valid_dtypes
=
(
mstype
.
int32
,)
else
:
raise
ValueError
(
"Unsupported operation."
)
for
ele
in
dtypes
:
if
ele
in
valid_dtypes
and
ele
==
dtypes
[
0
]:
continue
raise
TypeError
(
f
"For '
{
op_name
}
', the index tensors data type must be same, "
f
"and should be one of the following:
{
valid_dtypes
}
, but got
{
dtypes
}
."
)
return
True
@
constexpr
def
check_tensor_dtype_valid
(
dtype
,
valid_dtypes
):
"""Check a tensor data type."""
if
dtype
in
valid_dtypes
:
return
True
raise
TypeError
(
f
"The index tensor data type must be one of "
f
"the following:
{
valid_dtypes
}
, but got
{
dtype
}
."
)
@
constexpr
def
check_tensors_dtype_same
(
x_dtype
,
y_dtype
,
op_name
):
"""Check tensors data type same."""
if
x_dtype
==
y_dtype
:
return
True
raise
TypeError
(
f
"For '
{
op_name
}
', the value data type '
{
y_dtype
}
' "
f
"is not consistent with origin tensor data type
{
x_dtype
}
."
)
@
constexpr
def
broadcast_shapes
(
shapes
,
op_name
):
"""Broadcasts a tuple of tensor."""
broadcast_shape
=
shapes
[
0
]
for
i
,
shape
in
enumerate
(
shapes
):
logger
.
debug
(
f
"Broadcasts the
{
i
}
th tensor, the shape is
{
shape
}
."
)
broadcast_shape
=
op_utils
.
get_broadcast_shape
(
broadcast_shape
,
shape
,
op_name
)
return
tuple
(
broadcast_shape
)
@
constexpr
def
check_two_shapes_need_broadcast
(
shape_x
,
shape_y
):
"""Check two shapes need broadcast."""
error
=
ValueError
(
f
"For 'tensor setitem with tensor', the value tensor shape "
f
"
{
shape_y
}
could not broadcast the required updates shape
{
shape_x
}
."
)
if
len
(
shape_y
)
>
len
(
shape_x
):
raise
error
for
i
in
range
(
-
len
(
shape_y
),
0
):
if
shape_y
[
i
]
>
shape_x
[
i
]:
raise
error
if
shape_y
[
i
]
<
shape_x
[
i
]
and
shape_y
[
i
]
!=
1
:
raise
error
if
shape_y
==
shape_x
:
return
False
return
True
@
constexpr
def
compute_multiples
(
origin_shape
,
broadcast_shape
):
"""Compute multiples between broadcast_shape with origin_shape."""
len_gap
=
len
(
broadcast_shape
)
-
len
(
origin_shape
)
return
broadcast_shape
[
0
:
len_gap
]
+
tuple
(
map
(
lambda
x
,
y
:
x
//
y
,
broadcast_shape
[
len_gap
:],
origin_shape
))
def
tile
(
broadcast_shape
,
x
):
multiples
=
compute_multiples
(
F
.
shape
(
x
),
broadcast_shape
)
return
F
.
tile
(
x
,
multiples
)
@
constexpr
def
check_shapes_same
(
value_shapes
,
op_name
):
"""Check if the shapes in the tuple are consistent."""
for
i
,
shape
in
enumerate
(
value_shapes
):
if
shape
!=
value_shapes
[
0
]:
raise
ValueError
(
f
"For '
{
op_name
}
', the
{
i
}
th tensor shape in value tuple "
f
"is not same as the first tensor shape."
)
return
True
@
constexpr
def
convert_scalar_to_tensor
(
data_shape
,
data_dtype
,
indices_shape
,
value
,
op_type
):
"""Convert a scalar to a tensor."""
if
op_type
==
SET_ITEM_BY_ONE_TENSOR
:
updates_shape
=
indices_shape
+
data_shape
[
1
:]
else
:
updates_shape
=
indices_shape
[:
-
1
]
+
data_shape
[
indices_shape
[
-
1
]:]
if
isinstance
(
value
,
mstype
.
dtype_to_pytype
(
data_dtype
)):
return
Tensor
(
np
.
full
(
updates_shape
,
value
),
dtype
=
data_dtype
)
raise
TypeError
(
f
"For '
{
TENSOR_SETITEM
}
', the value type '
{
value
.
__class__
.
__name__
}
'"
f
" is not consistent with tensor data type
{
data_dtype
}
."
)
@
constexpr
def
convert_tuple_of_scalar_to_tensor
(
data_shape
,
data_dtype
,
index_shape
,
value
,
op_type
):
"""Convert a tuple of scalar to a tensor."""
updates_shape
=
generate_updates_shape
(
data_shape
,
index_shape
,
op_type
)
if
len
(
value
)
!=
updates_shape
[
-
1
]:
raise
ValueError
(
f
"For '
{
TENSOR_SETITEM
}
', the number of elements :
{
len
(
value
)
}
in the updates tuple "
f
"does not meet the requirements:
{
updates_shape
[
-
1
]
}
."
)
array
=
np
.
array
(
value
,
dtype
=
mstype
.
dtype_to_nptype
(
data_dtype
))
reps
=
compute_multiples
(
updates_shape
[
-
1
:],
updates_shape
)
return
Tensor
(
np
.
tile
(
array
,
reps
))
@
constexpr
def
generate_updates_shape
(
data_shape
,
index_shape
,
op_type
):
"""Generate updates shape for 'tensor setitem'."""
if
op_type
==
SET_ITEM_BY_ONE_TENSOR
:
updates_shape
=
index_shape
+
data_shape
[
1
:]
else
:
updates_shape
=
index_shape
[:
-
1
]
+
data_shape
[
index_shape
[
-
1
]:]
return
updates_shape
@
constexpr
def
check_number_of_index_tensor
(
data_shape
,
tuple_len
,
op_name
):
"""Check if the number of index tensor exceeds the dimension of the operated tensor."""
if
tuple_len
<=
len
(
data_shape
):
return
True
raise
IndexError
(
f
"For '
{
op_name
}
', the number
{
tuple_len
}
of index tensor "
f
"is greater than the dimension
{
len
(
data_shape
)
}
of the operated tensor."
)
def
generate_indeices_from_tuple_of_tensor
(
data
,
tuple_index
,
op_name
):
"""Generate an indices tensor from a tuple of tensor."""
indices
=
None
check_index_tensor_number
=
check_number_of_index_tensor
(
F
.
shape
(
data
),
len
(
tuple_index
),
op_name
)
if
check_index_tensor_number
:
dtype_tuple
=
hyper_map
(
F
.
dtype
,
tuple_index
)
check_dtypes
=
check_index_tensors_dtype
(
dtype_tuple
,
op_name
)
if
check_dtypes
:
shape_tuple
=
hyper_map
(
F
.
shape
,
tuple_index
)
broadcast_shape
=
broadcast_shapes
(
shape_tuple
,
op_name
)
broadcast_tensors
=
hyper_map
(
F
.
partial
(
tile
,
broadcast_shape
),
tuple_index
)
indices
=
pack
(
broadcast_tensors
)
return
indices
def
generate_updates_from_scalar
(
data
,
indices
,
value
,
op_type
):
"""Generate an updates tensor from a scalar."""
data_shape
=
F
.
shape
(
data
)
indices_shape
=
F
.
shape
(
indices
)
data_dtype
=
F
.
dtype
(
data
)
return
convert_scalar_to_tensor
(
data_shape
,
data_dtype
,
indices_shape
,
value
,
op_type
)
def
generate_updates_from_tuple
(
data
,
index
,
value
,
op_type
):
"""Generate an updates tensor from a tuple."""
value_types
=
hyper_map
(
F
.
typeof
,
value
)
data_dtype
=
F
.
dtype
(
data
)
value_elements_type
=
check_value_elements
(
data_dtype
,
value_types
)
if
value_elements_type
==
ALL_TENSOR
:
value_shapes
=
hyper_map
(
F
.
shape
,
value
)
shapes_same
=
check_shapes_same
(
value_shapes
,
TENSOR_SETITEM
)
if
shapes_same
:
value
=
F
.
pack
(
value
)
return
generate_updates_from_tensor
(
data
,
index
,
value
,
op_type
)
data_shape
=
F
.
shape
(
data
)
index_shape
=
F
.
shape
(
index
)
return
convert_tuple_of_scalar_to_tensor
(
data_shape
,
data_dtype
,
index_shape
,
value
,
op_type
)
def
generate_updates_from_tensor
(
data
,
index
,
value
,
op_type
):
"""Generate an updates tensor from a tensor."""
data_shape
=
F
.
shape
(
data
)
index_shape
=
F
.
shape
(
index
)
value_shape
=
F
.
shape
(
value
)
data_dtype
=
F
.
dtype
(
data
)
value_dtype
=
F
.
dtype
(
value
)
updates_shape
=
value_shape
check_dtype_same
=
check_tensors_dtype_same
(
data_dtype
,
value_dtype
,
TENSOR_SETITEM
)
if
check_dtype_same
:
updates_shape
=
generate_updates_shape
(
data_shape
,
index_shape
,
op_type
)
need_broadcast
=
check_two_shapes_need_broadcast
(
updates_shape
,
value_shape
)
if
need_broadcast
:
return
tile
(
updates_shape
,
value
)
return
value
mindspore/ops/composite/multitype_ops/getitem_impl.py
浏览文件 @
e490618d
...
...
@@ -15,9 +15,10 @@
"""Implementation for getitem."""
from
...composite
import
base
from
.
import
_utils
as
multi_utils
from
..
import
base
from
...
import
functional
as
F
from
....common
import
dtype
as
mstype
getitem
=
base
.
MultitypeFuncGraph
(
'getitem'
)
"""
...
...
@@ -214,19 +215,45 @@ def _tensor_getitem_by_slice(data, slice_index):
return
_tensor_slice
(
data
,
slice_index
)
@
getitem
.
register
(
"Tensor"
,
"Tensor"
)
def
_tensor_getitem_by_tensor
(
data
,
tensor_index
):
"""
Getting item of tensor by slice.
Inputs:
data (Tensor): A tensor.
tensor_index (Tensor): An index expressed by tensor.
Outputs:
Tensor, element type is same as the element type of data.
"""
check_dtypes
=
multi_utils
.
check_tensor_dtype_valid
(
F
.
dtype
(
tensor_index
),
(
mstype
.
int32
,
mstype
.
int64
))
result
=
None
if
check_dtypes
:
result
=
F
.
gather
(
data
,
tensor_index
,
0
)
return
result
@
getitem
.
register
(
"Tensor"
,
"Tuple"
)
def
_tensor_getitem_by_
slice_tuple
(
data
,
slice_
tuple_index
):
def
_tensor_getitem_by_
tuple
(
data
,
tuple_index
):
"""
Getting item of tensor by slice tuple.
Inputs:
data (Tensor): A tensor.
slice_
tuple_index (tuple): Index in tuple.
tuple_index (tuple): Index in tuple.
Outputs:
Tensor, element type is same as the element type of data.
"""
return
_tensor_slice
(
data
,
slice_tuple_index
)
index_types
=
multi_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
multi_utils
.
tuple_elements_type
(
index_types
)
result
=
None
if
index_elements_type
==
multi_utils
.
NO_TENSOR
:
result
=
_tensor_slice
(
data
,
tuple_index
)
if
index_elements_type
==
multi_utils
.
ALL_TENSOR
:
result
=
_tensor_getitem_by_tuple_of_tensor
(
data
,
tuple_index
)
return
result
@
getitem
.
register
(
"Tensor"
,
"Ellipsis"
)
...
...
@@ -242,3 +269,10 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
Tensor, same as data.
"""
return
_tensor_slice
(
data
,
ellipsis_index
)
def
_tensor_getitem_by_tuple_of_tensor
(
data
,
tuple_index
):
"""Tensor getitem by a tuple of tensor."""
indices
=
multi_utils
.
generate_indeices_from_tuple_of_tensor
(
data
,
tuple_index
,
multi_utils
.
TENSOR_GETITEM
)
result
=
F
.
gather_nd
(
data
,
indices
)
return
result
mindspore/ops/composite/multitype_ops/setitem_impl.py
浏览文件 @
e490618d
...
...
@@ -18,10 +18,11 @@
from
...composite
import
base
from
....common
import
dtype
as
mstype
from
...
import
functional
as
F
from
.
import
_
multitype_ops_util
as
mult_util
from
.
import
_
utils
as
multi_utils
setitem
=
base
.
MultitypeFuncGraph
(
'setitem'
)
@
setitem
.
register
(
"List"
,
"Number"
,
"String"
)
def
_list_setitem_with_string
(
data
,
number_index
,
value
):
"""
...
...
@@ -118,7 +119,7 @@ def _dict_setitem_with_number(data, key, value):
@
setitem
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_tensor_setitem_by_tensor_
v1
(
data
,
index
,
value_tensor
):
def
_tensor_setitem_by_tensor_
with_tensor
(
data
,
index
,
value_tensor
):
"""
Tensor assignment.
...
...
@@ -137,27 +138,15 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
Outputs:
Tensor, element type and shape is same as data.
"""
result
=
None
index_dtype
=
F
.
dtype
(
index
)
index_shape
=
F
.
shape
(
index
)
check_result
=
mult_util
.
check_tensor_setitem_index
(
mstype
.
tensor
,
index_dtype
)
if
check_result
:
data_shape
=
F
.
shape
(
data
)
data_shape
=
mult_util
.
check_equal
(
data_shape
,
index_shape
,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
)
size
=
F
.
size
(
value_tensor
)
size
=
mult_util
.
check_equal
(
1
,
size
,
"When assign value is a tensor, its size should be {}, but current size is {}."
)
dtype
=
F
.
dtype
(
data
)
u_cast
=
F
.
cast
(
value_tensor
,
dtype
)
one_data
=
F
.
ones_like
(
data
)
u
=
F
.
tensor_mul
(
one_data
,
u_cast
)
result
=
F
.
select
(
index
,
u
,
data
)
return
result
tensor_dtype
=
multi_utils
.
get_index_tensor_dtype
(
index_dtype
)
if
tensor_dtype
==
multi_utils
.
INT_
:
return
_tensor_setitem_by_int_tensor_with_tensor
(
data
,
index
,
value_tensor
)
return
_tensor_setitem_by_bool_tensor_with_tensor
(
data
,
index
,
value_tensor
)
@
setitem
.
register
(
"Tensor"
,
"Tensor"
,
"Number"
)
def
_tensor_setitem_by_tensor_
v2
(
data
,
index
,
value
):
def
_tensor_setitem_by_tensor_
with_number
(
data
,
index
,
value
):
"""
Tensor assignment.
...
...
@@ -171,143 +160,167 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value
_tensor
(Number): Assignment value.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
result
=
None
index_dtype
=
F
.
dtype
(
index
)
index_shape
=
F
.
shape
(
index
)
check_result
=
mult_util
.
check_tensor_setitem_index
(
mstype
.
tensor
,
index_dtype
)
if
check_result
:
shape
=
F
.
shape
(
data
)
shape
=
mult_util
.
check_equal
(
shape
,
index_shape
,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
)
dtype
=
F
.
dtype
(
data
)
u
=
F
.
fill
(
dtype
,
shape
,
value
)
result
=
F
.
select
(
index
,
u
,
data
)
return
result
tensor_dtype
=
multi_utils
.
get_index_tensor_dtype
(
index_dtype
)
if
tensor_dtype
==
multi_utils
.
BOOL_
:
return
_tensor_setitem_by_bool_tensor_with_scalar
(
data
,
index
,
value
)
return
_tensor_setitem_by_int_tensor_with_scalar
(
data
,
index
,
value
)
@
setitem
.
register
(
"Tensor"
,
"
Slice"
,
"Tenso
r"
)
def
_tensor_setitem_
with_slice_v3
(
data
,
input_slice
,
value
):
@
setitem
.
register
(
"Tensor"
,
"
Tuple"
,
"Numbe
r"
)
def
_tensor_setitem_
by_tuple_with_number
(
data
,
tuple_index
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = U
Restraint condition: A is a Tensor
Slice like "1:3"
U is a Tensor(size=1) or Tensor(size>1)
Syntax support: A[B, C, D] = u.
Restraint condition: 1) A is a Tensor, and B, C, D are index.
2) u is a scalar.
Inputs:
data (Tensor): Assigned tensor.
in
put_slice (Slice): Slice expression
.
in
dex (Tuple): An index tuple
.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return
_tensor_assgin_tensor
(
data
,
input_slice
,
value
)
index_types
=
multi_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
multi_utils
.
tuple_elements_type
(
index_types
)
result
=
None
if
index_elements_type
==
multi_utils
.
NO_TENSOR
:
result
=
_tensor_assgin_number
(
data
,
tuple_index
,
value
)
if
index_elements_type
==
multi_utils
.
ALL_TENSOR
:
indices
=
multi_utils
.
generate_indeices_from_tuple_of_tensor
(
data
,
tuple_index
,
multi_utils
.
TENSOR_SETITEM
)
updates
=
multi_utils
.
generate_updates_from_scalar
(
data
,
indices
,
value
,
multi_utils
.
SET_ITEM_BY_TUPLE_OF_TENSOR
)
result
=
F
.
scatter_nd_update
(
data
,
indices
,
updates
)
return
result
@
setitem
.
register
(
"Tensor"
,
"Tuple"
,
"Tensor"
)
def
_tensor_setitem_
with_slice_v4
(
data
,
input_slice
,
value
):
def
_tensor_setitem_
by_tuple_with_tensor
(
data
,
tuple_index
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U
Restraint condition: A is a Tensor
Slice like "1:3, ::, :4:-1"
U is a Tensor(size=1) or Tensor(size>1)
Syntax support: A[B, C, D] = U.
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
2) U is a Tensor.
Inputs:
data (Tensor): Assigned tensor.
in
put_slice (Union[tuple[Slice], tuple[Number]]): Slice expression
.
value (
Number): Assignment value
.
in
dex (Tuple): An index tuple
.
value (
Tensor): Assignment tensor, should has the same data type as 'data'
.
Outputs:
Tensor, element type and shape is same as data.
"""
return
_tensor_assgin_tensor
(
data
,
input_slice
,
value
)
index_types
=
multi_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
multi_utils
.
tuple_elements_type
(
index_types
)
result
=
None
if
index_elements_type
==
multi_utils
.
NO_TENSOR
:
result
=
_tensor_assgin_tensor
(
data
,
tuple_index
,
value
)
if
index_elements_type
==
multi_utils
.
ALL_TENSOR
:
indices
=
multi_utils
.
generate_indeices_from_tuple_of_tensor
(
data
,
tuple_index
,
multi_utils
.
TENSOR_SETITEM
)
updates
=
multi_utils
.
generate_updates_from_tensor
(
data
,
indices
,
value
,
multi_utils
.
SET_ITEM_BY_TUPLE_OF_TENSOR
)
result
=
F
.
scatter_nd_update
(
data
,
indices
,
updates
)
return
result
def
_tensor_assgin_tensor
(
data
,
input_slice
,
value
):
"""Assigns a tensor value to the tensor by slice."""
@
setitem
.
register
(
"Tensor"
,
"Tuple"
,
"Tuple"
)
def
_tensor_setitem_by_tuple_with_tuple
(
data
,
tuple_index
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[B, C, D] = U.
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
2) A B and C could be broadcast.
3) U is a Tensor.
Inputs:
data (Tensor): Assigned tensor.
index (Tuple): A tuple of tensor, these tensor could be broadcast.
value (Tensor): Assignment tensor, should has the same data type as 'data'.
Outputs:
Tensor, element type and shape is same as data.
"""
index_types
=
multi_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
multi_utils
.
tuple_elements_type
(
index_types
)
result
=
None
check_result
=
mult_util
.
check_tensor_setitem_index
(
input_slice
)
if
check_result
:
data_shape
=
F
.
shape
(
data
)
indices
=
mult_util
.
slice2indices
(
input_slice
,
data_shape
)
is_tuple_int
=
mult_util
.
tuple_element_is_int
(
input_slice
)
if
is_tuple_int
:
indices
=
mult_util
.
integer_to_indices
(
input_slice
,
data_shape
)
result
=
_tensor_indices_tensor
(
data
,
data_shape
,
input_slice
,
indices
,
value
)
if
index_elements_type
==
multi_utils
.
ALL_TENSOR
:
indices
=
multi_utils
.
generate_indeices_from_tuple_of_tensor
(
data
,
tuple_index
,
multi_utils
.
TENSOR_SETITEM
)
updates
=
multi_utils
.
generate_updates_from_tuple
(
data
,
indices
,
value
,
multi_utils
.
SET_ITEM_BY_TUPLE_OF_TENSOR
)
result
=
F
.
scatter_nd_update
(
data
,
indices
,
updates
)
return
result
def
_tensor_indices_tensor
(
data
,
data_shape
,
index
,
indices
,
value
):
"""Assigns a tensor value to the tensor."""
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
mstype
.
int32
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
condition
=
F
.
cast
(
condition
,
mstype
.
bool_
)
value_fill
=
None
value_size
=
F
.
size
(
value
)
@
setitem
.
register
(
"Tensor"
,
"Tensor"
,
"Tuple"
)
def
_tensor_setitem_by_tensor_v2
(
data
,
index
,
value
):
"""
Tensor assignment.
value_size
=
mult_util
.
check_indices_value_size
(
indices_size
,
value_size
)
if
value_size
==
1
:
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
value
=
F
.
cast
(
value
,
data_dtype
)
value_fill
=
F
.
tensor_mul
(
value_fill
,
value
)
elif
value_size
>
1
:
value_fill
=
F
.
reshape
(
value
,
(
indices_size
,))
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
return
F
.
select
(
condition
,
u
,
data
)
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value (Tuple): Assignment value.
@
setitem
.
register
(
"Tensor"
,
"Slice"
,
"Number"
)
def
_tensor_setitem_with_slice_v1
(
data
,
input_slice
,
value
):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype
=
F
.
dtype
(
index
)
check_dtype
=
multi_utils
.
check_tensor_dtype_valid
(
index_dtype
,
(
mstype
.
int32
,
mstype
.
int64
))
result
=
None
if
check_dtype
:
result
=
_tensor_setitem_by_tensor_with_tuple
(
data
,
index
,
value
)
return
result
@
setitem
.
register
(
"Tensor"
,
"Slice"
,
"Tensor"
)
def
_tensor_setitem_with_slice_v3
(
data
,
input_slice
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] =
u
Restraint condition: A is a Tensor
.
Syntax support: A[Slice] =
U
Restraint condition: A is a Tensor
Slice like "1:3"
u is a scalar
U is a Tensor(size=1) or Tensor(size>1)
Inputs:
data (Tensor): Assigned tensor.
input_slice (Slice):
s
lice expression.
input_slice (Slice):
S
lice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return
_tensor_assgin_
numbe
r
(
data
,
input_slice
,
value
)
return
_tensor_assgin_
tenso
r
(
data
,
input_slice
,
value
)
@
setitem
.
register
(
"Tensor"
,
"
Tupl
e"
,
"Number"
)
def
_tensor_setitem_with_slice_v
2
(
data
,
input_slice
,
value
):
@
setitem
.
register
(
"Tensor"
,
"
Slic
e"
,
"Number"
)
def
_tensor_setitem_with_slice_v
1
(
data
,
input_slice
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[
tuple(Slice)] = u, and A[tuple(Number)
] = u
Syntax support: A[
Slice
] = u
Restraint condition: A is a Tensor.
Slice like "1:3
, ::, :4:-1
"
Slice like "1:3"
u is a scalar
Inputs:
data (Tensor): Assigned tensor.
input_slice (
Union[tuple[Slice], tuple[Number]]
): slice expression.
input_slice (
Slice
): slice expression.
value (Number): Assignment value.
Outputs:
...
...
@@ -318,39 +331,23 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
def
_tensor_assgin_number
(
data
,
input_slice
,
value
):
"""Givens a scalar assign to tensor by slice"""
check_result
=
mult
_util
.
check_tensor_setitem_index
(
input_slice
)
check_result
=
mult
i_utils
.
check_tensor_setitem_index
(
input_slice
)
result
=
None
if
check_result
:
data_shape
=
F
.
shape
(
data
)
indices
=
mult
_util
.
slice2indices
(
input_slice
,
data_shape
)
is_tuple_int
=
mult
_util
.
tuple_element_is_int
(
input_slice
)
indices
=
mult
i_utils
.
slice2indices
(
input_slice
,
data_shape
)
is_tuple_int
=
mult
i_utils
.
tuple_element_is_int
(
input_slice
)
if
is_tuple_int
:
indices
=
mult
_util
.
integer_to_indices
(
input_slice
,
data_shape
)
indices
=
mult
i_utils
.
integer_to_indices
(
input_slice
,
data_shape
)
result
=
_tensor_indices_number
(
data
,
data_shape
,
input_slice
,
indices
,
value
)
return
result
def
_tensor_indices_number
(
data
,
data_shape
,
index
,
indices
,
value
):
"""Assigns a scalar value to the tensor."""
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
mult_util
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
mstype
.
int32
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
condition
=
F
.
cast
(
condition
,
mstype
.
bool_
)
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
value
)
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
return
F
.
select
(
condition
,
u
,
data
)
@
setitem
.
register
(
"Tensor"
,
"Number"
,
"Number"
)
def
_tensor_setitem_with_int_v1
(
data
,
index
,
value
):
"""Syntax: A[1] = 3"""
data_shape
=
F
.
shape
(
data
)
indices
=
mult
_util
.
integer_to_indices
(
index
,
data_shape
)
indices
=
mult
i_utils
.
integer_to_indices
(
index
,
data_shape
)
return
_tensor_indices_number
(
data
,
data_shape
,
index
,
indices
,
value
)
...
...
@@ -358,7 +355,7 @@ def _tensor_setitem_with_int_v1(data, index, value):
def
_tensor_setitem_with_int_v2
(
data
,
index
,
value
):
"""Syntax: A[1] = Tensor"""
data_shape
=
F
.
shape
(
data
)
indices
=
mult
_util
.
integer_to_indices
(
index
,
data_shape
)
indices
=
mult
i_utils
.
integer_to_indices
(
index
,
data_shape
)
return
_tensor_indices_tensor
(
data
,
data_shape
,
index
,
indices
,
value
)
...
...
@@ -379,7 +376,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
data_size
=
F
.
size
(
data
)
value_shape
=
F
.
shape
(
value
)
value_size
=
F
.
size
(
value
)
check_result
=
mult
_util
.
check_ellipsis_shape_size
(
data_shape
,
value_shape
,
data_size
,
value_size
)
check_result
=
mult
i_utils
.
check_ellipsis_shape_size
(
data_shape
,
value_shape
,
data_size
,
value_size
)
if
check_result
:
if
data_size
==
value_size
:
result
=
F
.
reshape
(
value
,
data_shape
)
...
...
@@ -389,3 +386,108 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
param2
=
F
.
cast
(
value
,
data_dtype
)
result
=
F
.
tensor_mul
(
param1
,
param2
)
return
result
def
_tensor_assgin_tensor
(
data
,
input_slice
,
value
):
"""Assigns a tensor value to the tensor by slice."""
result
=
None
check_result
=
multi_utils
.
check_tensor_setitem_index
(
input_slice
)
if
check_result
:
data_shape
=
F
.
shape
(
data
)
indices
=
multi_utils
.
slice2indices
(
input_slice
,
data_shape
)
is_tuple_int
=
multi_utils
.
tuple_element_is_int
(
input_slice
)
if
is_tuple_int
:
indices
=
multi_utils
.
integer_to_indices
(
input_slice
,
data_shape
)
result
=
_tensor_indices_tensor
(
data
,
data_shape
,
input_slice
,
indices
,
value
)
return
result
def
_tensor_indices_tensor
(
data
,
data_shape
,
index
,
indices
,
value
):
"""Assigns a tensor value to the tensor."""
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
multi_utils
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
mstype
.
int32
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
condition
=
F
.
cast
(
condition
,
mstype
.
bool_
)
value_fill
=
None
value_size
=
F
.
size
(
value
)
value_size
=
multi_utils
.
check_indices_value_size
(
indices_size
,
value_size
)
if
value_size
==
1
:
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
value
=
F
.
cast
(
value
,
data_dtype
)
value_fill
=
F
.
tensor_mul
(
value_fill
,
value
)
elif
value_size
>
1
:
value_fill
=
F
.
reshape
(
value
,
(
indices_size
,))
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
return
F
.
select
(
condition
,
u
,
data
)
def
_tensor_indices_number
(
data
,
data_shape
,
index
,
indices
,
value
):
"""Assigns a scalar value to the tensor."""
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
multi_utils
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
mstype
.
int32
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
condition
=
F
.
cast
(
condition
,
mstype
.
bool_
)
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
value
)
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
return
F
.
select
(
condition
,
u
,
data
)
def
_tensor_setitem_by_tensor_with_tuple
(
data
,
index
,
value
):
"""Set a tensor item by a tensor with a tuple."""
updates
=
multi_utils
.
generate_updates_from_tuple
(
data
,
index
,
value
,
multi_utils
.
SET_ITEM_BY_ONE_TENSOR
)
result
=
F
.
scatter_update
(
data
,
index
,
updates
)
return
result
def
_tensor_setitem_by_int_tensor_with_scalar
(
data
,
index
,
value
):
"""Set a tensor item by a int tensor with a scalar."""
updates
=
multi_utils
.
generate_updates_from_scalar
(
data
,
index
,
value
,
multi_utils
.
SET_ITEM_BY_ONE_TENSOR
)
return
F
.
scatter_update
(
data
,
index
,
updates
)
def
_tensor_setitem_by_bool_tensor_with_scalar
(
data
,
index
,
value
):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape
=
F
.
shape
(
index
)
shape
=
F
.
shape
(
data
)
shape
=
multi_utils
.
check_equal
(
shape
,
index_shape
,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
)
dtype
=
F
.
dtype
(
data
)
u
=
F
.
fill
(
dtype
,
shape
,
value
)
return
F
.
select
(
index
,
u
,
data
)
def
_tensor_setitem_by_int_tensor_with_tensor
(
data
,
index
,
value
):
"""Set a tensor item by a int tensor with a tensor."""
updates
=
multi_utils
.
generate_updates_from_tensor
(
data
,
index
,
value
,
multi_utils
.
SET_ITEM_BY_ONE_TENSOR
)
return
F
.
scatter_update
(
data
,
index
,
updates
)
def
_tensor_setitem_by_bool_tensor_with_tensor
(
data
,
index
,
value
):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape
=
F
.
shape
(
index
)
data_shape
=
F
.
shape
(
data
)
data_shape
=
multi_utils
.
check_equal
(
data_shape
,
index_shape
,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
)
size
=
F
.
size
(
value
)
size
=
multi_utils
.
check_equal
(
1
,
size
,
"When assign value is a tensor, its size should be {}, but current size is {}."
)
dtype
=
F
.
dtype
(
data
)
u_cast
=
F
.
cast
(
value
,
dtype
)
one_data
=
F
.
ones_like
(
data
)
u
=
F
.
tensor_mul
(
one_data
,
u_cast
)
result
=
F
.
select
(
index
,
u
,
data
)
return
result
mindspore/ops/functional.py
浏览文件 @
e490618d
...
...
@@ -31,6 +31,7 @@ dtype = P.DType()
issubclass_
=
P
.
IsSubClass
()
isinstance_
=
P
.
IsInstance
()
fill
=
P
.
Fill
()
tile
=
P
.
Tile
()
select
=
P
.
Select
()
size
=
P
.
Size
()
ones_like
=
P
.
OnesLike
()
...
...
@@ -70,6 +71,12 @@ scalar_cast = P.ScalarCast()
print_
=
P
.
Print
()
expand_dims
=
P
.
ExpandDims
()
scatter_nd
=
P
.
ScatterNd
()
gather
=
P
.
GatherV2
()
gather_nd
=
P
.
GatherNd
()
scatter_update
=
P
.
ScatterUpdate
()
scatter_nd_update
=
P
.
ScatterNdUpdate
()
pack
=
P
.
Pack
()
tuple_setitem
=
Primitive
(
'tuple_setitem'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
e490618d
...
...
@@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Fill
,
GatherNd
,
GatherV2
,
InvertPermutation
,
IsInstance
,
IsSubClass
,
ArgMaxWithValue
,
OnesLike
,
ZerosLike
,
Rank
,
Reshape
,
ResizeNearestNeighbor
,
ArgMinWithValue
,
SameTypeShape
,
ScatterMax
,
SameTypeShape
,
ScatterMax
,
ScatterUpdate
,
ScalarToArray
,
ScalarToTensor
,
ScatterNd
,
ScatterNdUpdate
,
Select
,
Shape
,
Size
,
Slice
,
Split
,
Squeeze
,
StridedSlice
,
Tile
,
...
...
@@ -193,6 +193,7 @@ __all__ = [
'Pad'
,
'MirrorPad'
,
'GatherNd'
,
'ScatterUpdate'
,
'ScatterNdUpdate'
,
'Floor'
,
'NMSWithMask'
,
...
...
mindspore/ops/operations/_grad_ops.py
浏览文件 @
e490618d
...
...
@@ -19,7 +19,7 @@ from ..._c_expression import signature_rw as sig_rw
from
..._c_expression
import
signature_kind
as
sig_kind
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
from
..._checkparam
import
Validator
as
validator
,
Rel
from
.._utils
import
_
get_concat_offset
from
.._utils
import
get_concat_offset
from
...common
import
dtype
as
mstype
...
...
@@ -136,7 +136,7 @@ class ConcatOffset(PrimitiveWithInfer):
axis
=
self
.
axis
x_shp
=
input_x
[
'shape'
]
x_type
=
input_x
[
'dtype'
]
offset
,
_
,
axis
=
_
get_concat_offset
(
x_shp
,
x_type
,
axis
,
self
.
name
)
offset
,
_
,
axis
=
get_concat_offset
(
x_shp
,
x_type
,
axis
,
self
.
name
)
self
.
add_prim_attr
(
'T'
,
x_type
[
0
].
element_type
())
offset_values
=
[]
for
i
in
range
(
len
(
x_shp
)):
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
e490618d
...
...
@@ -24,16 +24,15 @@ import itertools
import
numbers
import
numpy
as
np
from
..._c_expression
import
signature_rw
as
sig_rw
from
..._c_expression
import
signature_kind
as
sig_kind
from
..._checkparam
import
Validator
as
validator
from
..._checkparam
import
Rel
from
...common
import
dtype
as
mstype
from
...common.tensor
import
Tensor
from
..operations.math_ops
import
_infer_shape_reduce
from
.._utils
import
_
get_concat_offset
from
.._utils
import
get_concat_offset
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
def
_check_infer_attr_reduce
(
axis
,
keep_dims
,
prim_name
):
validator
.
check_value_type
(
'keep_dims'
,
keep_dims
,
[
bool
],
prim_name
)
validator
.
check_value_type
(
'axis'
,
axis
,
[
int
,
tuple
],
prim_name
)
...
...
@@ -931,7 +930,7 @@ class InvertPermutation(PrimitiveWithInfer):
z
=
[
x_value
[
i
]
for
i
in
range
(
len
(
x_value
))]
z
.
sort
()
y
=
[
None
]
*
len
(
x_value
)
y
=
[
None
]
*
len
(
x_value
)
for
i
,
value
in
enumerate
(
x_value
):
validator
.
check_value_type
(
"input[%d]"
%
i
,
value
,
[
int
],
self
.
name
)
validator
.
check
(
f
'value'
,
z
[
i
],
f
'index'
,
i
,
Rel
.
EQ
,
self
.
name
)
...
...
@@ -1111,6 +1110,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
>>> input_x = Tensor(np.random.rand(5))
>>> index, output = P.ArgMinWithValue()(input_x)
"""
@
prim_attr_register
def
__init__
(
self
,
axis
=
0
,
keep_dims
=
False
):
"""init ArgMinWithValue"""
...
...
@@ -1352,7 +1352,7 @@ class Concat(PrimitiveWithInfer):
axis
=
self
.
axis
x_shp
=
input_x
[
'shape'
]
x_type
=
input_x
[
'dtype'
]
_
,
all_shp
,
_
=
_
get_concat_offset
(
x_shp
,
x_type
,
axis
,
self
.
name
)
_
,
all_shp
,
_
=
get_concat_offset
(
x_shp
,
x_type
,
axis
,
self
.
name
)
self
.
add_prim_attr
(
'T'
,
x_type
[
0
].
element_type
())
self
.
add_prim_attr
(
'inputNums'
,
len
(
x_shp
))
ret_shp
=
x_shp
[
0
].
copy
()
...
...
@@ -1376,15 +1376,13 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
if
axis
<
0
:
axis
=
axis
+
rank_base
+
1
for
i
in
range
(
1
,
N
):
v
=
x_shape
[
i
]
validator
.
check
(
'len of x_shape[%d]'
%
i
,
len
(
v
),
'len of rank_base'
,
rank_base
,
Rel
.
EQ
,
prim_name
)
validator
.
check
(
'x_type[%d]'
%
i
,
x_type
[
i
],
'base'
,
x_type
[
0
],
Rel
.
EQ
,
prim_name
,
TypeError
)
for
j
in
range
(
rank_base
):
if
v
[
j
]
!=
x_shape
[
0
][
j
]:
raise
ValueError
(
f
"For
\'
{
prim_name
}
\'
element
{
i
}
shape in input can not pack with first element"
)
if
x_shape
[
i
]
!=
x_shape
[
0
]:
raise
ValueError
(
f
"For
\'
{
prim_name
}
\'
element
{
i
}
shape in input can not pack with first element"
)
out_shape
.
insert
(
axis
,
N
)
return
out_shape
class
Pack
(
PrimitiveWithInfer
):
r
"""
Packs a list of tensors in specified axis.
...
...
@@ -1831,7 +1829,7 @@ class DiagPart(PrimitiveWithInfer):
return
x_type
def
infer_shape
(
self
,
x_shape
):
if
len
(
x_shape
)
%
2
!=
0
or
\
if
len
(
x_shape
)
%
2
!=
0
or
\
not
x_shape
:
raise
ValueError
(
f
"For
\'
{
self
.
name
}
\'
input rank must be non-zero and even, but got rank
{
len
(
x_shape
)
}
, "
f
"with shapes
{
x_shape
}
"
)
...
...
@@ -2004,6 +2002,49 @@ class GatherNd(PrimitiveWithInfer):
return
x_dtype
class
ScatterUpdate
(
PrimitiveWithInfer
):
"""
Update tensor value by using input indices and value.
Using given values to update tensor value, along with the input indices.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: True.
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].
Outputs:
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update)
"""
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
True
):
"""Init ScatterNdUpdate"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'indices'
,
'value'
],
outputs
=
[
'y'
])
def
infer_shape
(
self
,
x_shape
,
indices_shape
,
value_shape
):
if
indices_shape
+
x_shape
[
1
:]
!=
value_shape
:
raise
ValueError
(
'Input value are not match with input indices.'
)
return
x_shape
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
,
value_dtype
):
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
mstype
.
int_type
,
self
.
name
)
args
=
{
"x"
:
x_dtype
,
"value"
:
value_dtype
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
return
x_dtype
class
ScatterNdUpdate
(
PrimitiveWithInfer
):
"""
Update tensor value by using input indices and value.
...
...
@@ -2028,11 +2069,6 @@ class ScatterNdUpdate(PrimitiveWithInfer):
>>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update)
"""
__mindspore_signature__
=
(
(
'input_x'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
),
(
'indices'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
),
(
'value'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
)
)
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
True
):
...
...
@@ -2142,10 +2178,10 @@ class SpaceToDepth(PrimitiveWithInfer):
validator
.
check
(
'x dimension'
,
len
(
x_shape
),
''
,
4
,
Rel
.
EQ
)
out_shape
=
copy
.
deepcopy
(
x_shape
)
for
i
in
range
(
2
):
if
out_shape
[
i
+
2
]
%
self
.
block_size
!=
0
:
raise
ValueError
(
f
'For
\'
{
self
.
name
}
\'
input shape[
{
i
+
2
}
]
{
out_shape
[
i
+
2
]
}
should be '
if
out_shape
[
i
+
2
]
%
self
.
block_size
!=
0
:
raise
ValueError
(
f
'For
\'
{
self
.
name
}
\'
input shape[
{
i
+
2
}
]
{
out_shape
[
i
+
2
]
}
should be '
f
'fully divided by block_size
{
self
.
block_size
}
'
)
out_shape
[
i
+
2
]
//=
self
.
block_size
out_shape
[
i
+
2
]
//=
self
.
block_size
out_shape
[
1
]
*=
self
.
block_size
*
self
.
block_size
return
out_shape
...
...
@@ -2199,9 +2235,10 @@ class DepthToSpace(PrimitiveWithInfer):
validator
.
check
(
'x dimension'
,
len
(
x_shape
),
''
,
4
,
Rel
.
EQ
)
out_shape
=
copy
.
deepcopy
(
x_shape
)
for
i
in
range
(
2
):
out_shape
[
i
+
2
]
*=
self
.
block_size
out_shape
[
i
+
2
]
*=
self
.
block_size
validator
.
check_integer
(
'x_shape[1] % (block_size*block_size)'
,
x_shape
[
1
]
%
(
self
.
block_size
*
self
.
block_size
),
validator
.
check_integer
(
'x_shape[1] % (block_size*block_size)'
,
x_shape
[
1
]
%
(
self
.
block_size
*
self
.
block_size
),
0
,
Rel
.
EQ
,
self
.
name
)
out_shape
[
1
]
//=
self
.
block_size
*
self
.
block_size
return
out_shape
...
...
@@ -2251,6 +2288,7 @@ class SpaceToBatch(PrimitiveWithInfer):
[[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
"""
@
prim_attr_register
def
__init__
(
self
,
block_size
,
paddings
):
"""Init SpaceToBatch"""
...
...
@@ -2271,12 +2309,12 @@ class SpaceToBatch(PrimitiveWithInfer):
validator
.
check_integer
(
'rank of input_x'
,
len
(
x_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
out_shape
=
copy
.
deepcopy
(
x_shape
)
for
i
in
range
(
2
):
padded
=
out_shape
[
i
+
2
]
+
self
.
paddings
[
i
][
0
]
+
\
padded
=
out_shape
[
i
+
2
]
+
self
.
paddings
[
i
][
0
]
+
\
self
.
paddings
[
i
][
1
]
if
padded
%
self
.
block_size
!=
0
:
raise
ValueError
(
f
'For
\'
{
self
.
name
}
\'
padded[
{
i
}
]
{
padded
}
should be divisible by '
f
'block_size
{
self
.
block_size
}
'
)
out_shape
[
i
+
2
]
=
padded
//
self
.
block_size
out_shape
[
i
+
2
]
=
padded
//
self
.
block_size
out_shape
[
0
]
*=
self
.
block_size
*
self
.
block_size
return
out_shape
...
...
@@ -2319,6 +2357,7 @@ class BatchToSpace(PrimitiveWithInfer):
[[[[1., 2.], [3., 4.]]]]
"""
@
prim_attr_register
def
__init__
(
self
,
block_size
,
crops
):
"""Init BatchToSpace"""
...
...
@@ -2339,10 +2378,10 @@ class BatchToSpace(PrimitiveWithInfer):
validator
.
check
(
'rank of input_x'
,
len
(
x_shape
),
''
,
4
)
out_shape
=
copy
.
deepcopy
(
x_shape
)
for
i
in
range
(
2
):
x_block_prod
=
out_shape
[
i
+
2
]
*
self
.
block_size
x_block_prod
=
out_shape
[
i
+
2
]
*
self
.
block_size
crops_sum
=
self
.
crops
[
i
][
0
]
+
self
.
crops
[
i
][
1
]
validator
.
check
(
"x block shape prod"
,
x_block_prod
,
'crops sum'
,
crops_sum
,
Rel
.
GT
,
self
.
name
)
out_shape
[
i
+
2
]
=
x_block_prod
-
crops_sum
out_shape
[
i
+
2
]
=
x_block_prod
-
crops_sum
block_size_prod
=
self
.
block_size
*
self
.
block_size
if
out_shape
[
0
]
%
block_size_prod
!=
0
:
raise
ValueError
(
f
'For
\'
{
self
.
name
}
\'
input_x dimension 0
{
out_shape
[
0
]
}
should be divisible by '
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
e490618d
...
...
@@ -24,7 +24,7 @@ from ..._checkparam import Validator as validator
from
..._checkparam
import
Rel
from
...common
import
dtype
as
mstype
from
...common.tensor
import
Tensor
from
.._utils
import
_
get_broadcast_shape
from
.._utils
import
get_broadcast_shape
from
..primitive
import
PrimitiveWithInfer
,
prim_attr_register
,
_run_op
...
...
@@ -75,7 +75,7 @@ class _BinaryOp(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'y'
],
outputs
=
[
'output'
])
def
infer_shape
(
self
,
x_shape
,
y_shape
):
return
_
get_broadcast_shape
(
x_shape
,
y_shape
,
self
.
name
)
return
get_broadcast_shape
(
x_shape
,
y_shape
,
self
.
name
)
class
_MathBinaryOp
(
_BinaryOp
):
...
...
tests/mindspore_test_framework/components/executor/exec_forward.py
浏览文件 @
e490618d
...
...
@@ -27,9 +27,15 @@ class IdentityEC(IExectorComponent):
def
__call__
(
self
):
result_id
=
self
.
function
[
keyword
.
id
]
+
'-'
+
self
.
inputs
[
keyword
.
id
]
group
=
self
.
function
[
keyword
.
group
]
+
'-'
+
self
.
inputs
[
keyword
.
group
]
ret
urn
{
ret
=
{
keyword
.
id
:
result_id
,
keyword
.
group
:
group
,
keyword
.
desc_inputs
:
self
.
inputs
[
keyword
.
desc_inputs
],
keyword
.
result
:
self
.
function
[
keyword
.
block
](
*
self
.
inputs
[
keyword
.
desc_inputs
])
}
print
(
"buxue------------------------------------------------"
)
print
(
"inputs"
)
print
(
ret
[
keyword
.
desc_inputs
])
print
(
"outputs"
)
print
(
ret
[
keyword
.
result
])
return
ret
tests/ut/python/ops/test_ops.py
浏览文件 @
e490618d
...
...
@@ -1297,7 +1297,7 @@ raise_set = [
(
'ScatterNdUpdate'
,
{
'block'
:
(
P
.
ScatterNdUpdate
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
(
Tensor
(
np
.
ones
((
2
,
3
),
np
.
float32
)),
Tensor
(
np
.
ones
((
2
,
2
),
np
.
in
t32
)),
Tensor
(
np
.
ones
((
2
,
2
),
np
.
floa
t32
)),
Tensor
(
np
.
ones
((
2
,),
np
.
float32
))),
'desc_bprop'
:
[[
2
,
3
]]}),
(
'Pack'
,
{
...
...
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
e490618d
...
...
@@ -16,13 +16,14 @@
import
numpy
as
np
import
pytest
from
mindspore
import
Tensor
from
mindspore
import
Tensor
,
Parameter
from
mindspore
import
context
from
mindspore
import
dtype
as
mstype
from
mindspore.nn
import
Cell
from
....mindspore_test_framework.mindspore_test
import
mindspore_test
from
....mindspore_test_framework.pipeline.forward.compile_forward
\
import
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
import
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
,
\
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
class
NetWorkSlicePositive
(
Cell
):
...
...
@@ -145,6 +146,160 @@ class TensorAssignWithSlice(Cell):
return
z
class
TensorIndexByOneTensor
(
Cell
):
def
__init__
(
self
):
super
(
TensorIndexByOneTensor
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
5
,
4
,
7
,
8
)),
mstype
.
int32
)
def
construct
(
self
,
x
,
index
):
ret
=
x
[
index
]
+
self
.
const
return
ret
class
TensorIndexByTwoTensors
(
Cell
):
def
__init__
(
self
):
super
(
TensorIndexByTwoTensors
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
3
,
4
,
5
,
8
)),
mstype
.
int32
)
def
construct
(
self
,
x
,
index_0
,
index_1
):
ret
=
x
[
index_0
,
index_1
]
+
self
.
const
return
ret
class
TensorIndexByThreeTensors
(
Cell
):
def
__init__
(
self
):
super
(
TensorIndexByThreeTensors
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
5
,
3
,
4
,
5
)),
mstype
.
int32
)
def
construct
(
self
,
x
,
index_0
,
index_1
,
index_2
):
ret
=
x
[
index_0
,
index_1
,
index_2
]
+
self
.
const
return
ret
class
TensorSetItemByOneTensorWithNumber
(
Cell
):
def
__init__
(
self
,
value
):
super
(
TensorSetItemByOneTensorWithNumber
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
7
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
float32
),
name
=
"x"
)
self
.
value
=
value
def
construct
(
self
,
index
):
self
.
param
[
index
]
=
self
.
value
ret
=
self
.
param
+
self
.
const
return
ret
class
TensorSetItemByOneTensorWithTensor
(
Cell
):
def
__init__
(
self
):
super
(
TensorSetItemByOneTensorWithTensor
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
7
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
float32
),
name
=
"x"
)
def
construct
(
self
,
index
,
value
):
self
.
param
[
index
]
=
value
ret
=
self
.
param
+
self
.
const
return
ret
class
TensorSetItemByOneTensorWithTupleOfNumber
(
Cell
):
def
__init__
(
self
,
value
):
super
(
TensorSetItemByOneTensorWithTupleOfNumber
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
7
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
float32
),
name
=
"x"
)
self
.
value
=
value
def
construct
(
self
,
index
):
self
.
param
[
index
]
=
self
.
value
ret
=
self
.
param
+
self
.
const
return
ret
class
TensorSetItemByOneTensorWithTupleOfTensor
(
Cell
):
def
__init__
(
self
):
super
(
TensorSetItemByOneTensorWithTupleOfTensor
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
3
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
3
*
8
).
reshape
((
6
,
3
,
8
)),
mstype
.
float32
),
name
=
"x"
)
def
construct
(
self
,
index
,
value_0
,
value_1
,
value_2
):
self
.
param
[
index
]
=
(
value_0
,
value_1
,
value_2
)
ret
=
self
.
param
+
self
.
const
return
ret
class
TensorSetItemByTensorsWithNumber
(
Cell
):
def
__init__
(
self
,
value
):
super
(
TensorSetItemByTensorsWithNumber
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
7
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
float32
),
name
=
"x"
)
self
.
value
=
value
def
construct
(
self
,
index_0
,
index_1
,
index_2
):
self
.
param
[
index_0
,
index_1
,
index_2
]
=
self
.
value
ret
=
self
.
param
+
self
.
const
return
ret
class
TensorSetItemByTensorsWithTensor
(
Cell
):
def
__init__
(
self
):
super
(
TensorSetItemByTensorsWithTensor
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
7
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
float32
),
name
=
"x"
)
def
construct
(
self
,
index_0
,
index_1
,
index_2
,
value
):
self
.
param
[
index_0
,
index_1
,
index_2
]
=
value
ret
=
self
.
param
+
self
.
const
return
ret
class
TensorSetItemByTensorsWithTensorNumberError
(
Cell
):
def
__init__
(
self
):
super
(
TensorSetItemByTensorsWithTensorNumberError
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
7
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
float32
),
name
=
"x"
)
def
construct
(
self
,
index_0
,
index_1
,
index_2
,
index_3
,
value
):
self
.
param
[
index_0
,
index_1
,
index_2
,
index_3
]
=
value
ret
=
self
.
param
+
self
.
const
return
ret
class
TensorSetItemByTensorsWithTupleOfNumber
(
Cell
):
def
__init__
(
self
,
value
):
super
(
TensorSetItemByTensorsWithTupleOfNumber
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
7
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
float32
),
name
=
"x"
)
self
.
value
=
value
def
construct
(
self
,
index_0
,
index_1
,
index_2
):
self
.
param
[
index_0
,
index_1
,
index_2
]
=
self
.
value
ret
=
self
.
param
+
self
.
const
return
ret
class
TensorSetItemByTensorsWithTupleOfTensor
(
Cell
):
def
__init__
(
self
):
super
(
TensorSetItemByTensorsWithTupleOfTensor
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
7
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
float32
),
name
=
"x"
)
def
construct
(
self
,
index_0
,
index_1
,
index_2
,
value_0
,
value_1
,
value_2
):
self
.
param
[
index_0
,
index_1
,
index_2
]
=
(
value_0
,
value_1
,
value_2
)
ret
=
self
.
param
+
self
.
const
return
ret
class
TensorSetItemByTensorsWithTupleOfTensorNumberError
(
Cell
):
def
__init__
(
self
):
super
(
TensorSetItemByTensorsWithTupleOfTensorNumberError
,
self
).
__init__
()
self
.
const
=
Tensor
(
np
.
ones
((
6
,
7
,
8
)),
mstype
.
float32
)
self
.
param
=
Parameter
(
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
float32
),
name
=
"x"
)
def
construct
(
self
,
index_0
,
index_1
,
index_2
,
value_0
,
value_1
):
self
.
param
[
index_0
,
index_1
,
index_2
]
=
(
value_0
,
value_1
)
ret
=
self
.
param
+
self
.
const
return
ret
def
test_tensor_assign
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
net
=
TensorAssignWithSlice
()
...
...
@@ -441,15 +596,206 @@ test_cases = [
'block'
:
NetWorkSliceEllipsis
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
7
,
8
,
9
],
np
.
int32
))],
}),
(
'TensorIndexByOneTensor'
,
{
'block'
:
TensorIndexByOneTensor
(),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
5
,
4
)),
mstype
.
int32
)],
}),
(
'TensorIndexByTwoTensors'
,
{
'block'
:
TensorIndexByTwoTensors
(),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorIndexByThreeTensors'
,
{
'block'
:
TensorIndexByThreeTensors
(),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByOneTensorWithNumber'
,
{
'block'
:
TensorSetItemByOneTensorWithNumber
(
value
=
0.0
),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
4
,
size
=
(
5
,
4
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByOneTensorWithTensor'
,
{
'block'
:
TensorSetItemByOneTensorWithTensor
(),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
3
,
size
=
(
5
,
4
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
4
,
7
,
8
)),
mstype
.
float32
)],
}),
(
'TensorSetItemByOneTensorWithTupleOfNumber'
,
{
'block'
:
TensorSetItemByOneTensorWithTupleOfNumber
(
value
=
(
0.0
,
1.1
,
2.2
,
3.3
,
4.4
,
5.5
,
6.6
,
7.7
)),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
5
,
size
=
(
5
,
4
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByOneTensorWithTupleOfTensor'
,
{
'block'
:
TensorSetItemByOneTensorWithTupleOfTensor
(),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
5
,
4
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
8
,),
np
.
float32
)),
Tensor
(
np
.
ones
((
8
,),
np
.
float32
)),
Tensor
(
np
.
ones
((
8
,),
np
.
float32
)
*
2
)],
}),
(
'TensorSetItemByTensorsWithNumber'
,
{
'block'
:
TensorSetItemByTensorsWithNumber
(
value
=
0.0
),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByTensorsWithTensor'
,
{
'block'
:
TensorSetItemByTensorsWithTensor
(),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
4
,
5
)),
mstype
.
float32
)],
}),
(
'TensorSetItemByTensorsWithTupleOfNumber'
,
{
'block'
:
TensorSetItemByTensorsWithTupleOfNumber
(
value
=
(
0.0
,
1.1
,
2.2
,
3.3
,
4.4
)),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByTensorsWithTupleOfTensor'
,
{
'block'
:
TensorSetItemByTensorsWithTupleOfTensor
(),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
4
,
5
)),
mstype
.
float32
),
Tensor
(
np
.
ones
((
4
,
5
)),
mstype
.
float32
),
Tensor
(
np
.
ones
((
4
,
5
))
*
2
,
mstype
.
float32
)],
})
]
raise_error_set
=
[
(
'TensorIndexByOneTensorDtypeError'
,
{
'block'
:
(
TensorIndexByOneTensor
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
5
,
4
)),
mstype
.
int8
)],
}),
(
'TensorIndexByTwoTensorsShapeError'
,
{
'block'
:
(
TensorIndexByTwoTensors
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
2
,
3
,
5
)),
mstype
.
int32
)],
}),
(
'TensorIndexByTwoTensorsDtypeError'
,
{
'block'
:
(
TensorIndexByTwoTensors
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
float32
)],
}),
(
'TensorIndexByThreeTensorsShapeError'
,
{
'block'
:
(
TensorIndexByThreeTensors
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
2
,
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorIndexByThreeTensorsDtypeError'
,
{
'block'
:
(
TensorIndexByThreeTensors
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
arange
(
6
*
7
*
8
).
reshape
((
6
,
7
,
8
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
3
,
4
,
5
)),
mstype
.
int64
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByOneTensorWithNumberTypeError'
,
{
'block'
:
(
TensorSetItemByOneTensorWithNumber
(
value
=
0
),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
4
,
size
=
(
5
,
4
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByOneTensorWithTensorShapeError'
,
{
'block'
:
(
TensorSetItemByOneTensorWithTensor
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
3
,
size
=
(
5
,
4
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
6
,
7
,
8
)),
mstype
.
float32
)],
}),
(
'TensorSetItemByOneTensorWithTensorDtypeError'
,
{
'block'
:
(
TensorSetItemByOneTensorWithTensor
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
3
,
size
=
(
5
,
4
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
6
,
7
,
8
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByOneTensorWithTupleOfNumberTypeError'
,
{
'block'
:
(
TensorSetItemByOneTensorWithTupleOfNumber
(
value
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
)),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
5
,
size
=
(
5
,
4
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByOneTensorWithTupleOfNumberNumberError'
,
{
'block'
:
(
TensorSetItemByOneTensorWithTupleOfNumber
(
value
=
(
0.0
,
1.1
,
2.2
)),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
5
,
size
=
(
5
,
4
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByOneTensorWithTupleOfTensorDtyeError'
,
{
'block'
:
(
TensorSetItemByOneTensorWithTupleOfTensor
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
5
,
4
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
8
,),
np
.
int32
)),
Tensor
(
np
.
ones
((
8
,),
np
.
int32
)),
Tensor
(
np
.
ones
((
8
,),
np
.
float32
)
*
2
)],
}),
(
'TensorSetItemByTensorsWithNumberTypeError'
,
{
'block'
:
(
TensorSetItemByTensorsWithNumber
(
value
=
0
),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByTensorsWithTensorShapeError'
,
{
'block'
:
(
TensorSetItemByTensorsWithTensor
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
2
,
5
)),
mstype
.
float32
)],
}),
(
'TensorSetItemByTensorsWithTensorTypeError'
,
{
'block'
:
(
TensorSetItemByTensorsWithTensor
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByTensorsWithTensorNumberError'
,
{
'block'
:
(
TensorSetItemByTensorsWithTensorNumberError
(),
{
'exception'
:
IndexError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
1
,
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
2
,
5
)),
mstype
.
float32
)],
}),
(
'TensorSetItemByTensorsWithTupleOfNumberTypeError'
,
{
'block'
:
(
TensorSetItemByTensorsWithTupleOfNumber
(
value
=
(
0
,
1
,
2
,
3
,
4
)),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByTensorsWithTupleOfNumberNumberError'
,
{
'block'
:
(
TensorSetItemByTensorsWithTupleOfNumber
(
value
=
(
0.0
,
1.0
,
2.0
,
3.0
)),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
)],
}),
(
'TensorSetItemByTensorsWithTupleOfTensorNumberError'
,
{
'block'
:
(
TensorSetItemByTensorsWithTupleOfTensorNumberError
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
4
,
5
)),
mstype
.
float32
),
Tensor
(
np
.
ones
((
4
,
5
)),
mstype
.
float32
)],
}),
(
'TensorSetItemByTensorsWithTupleOfTensorTypeError'
,
{
'block'
:
(
TensorSetItemByTensorsWithTupleOfTensor
(),
{
'exception'
:
TypeError
}),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randint
(
6
,
size
=
(
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
7
,
size
=
(
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
random
.
randint
(
8
,
size
=
(
5
,
3
,
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
zeros
((
4
,
5
)),
mstype
.
float32
),
Tensor
(
np
.
ones
((
4
,
5
)),
mstype
.
int32
),
Tensor
(
np
.
ones
((
4
,
5
))
*
2
,
mstype
.
int32
)],
})
]
@
mindspore_test
(
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
)
def
test_
compile
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
def
test_
exec
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
return
test_cases
@
mindspore_test
(
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
)
def
test_check_exception
():
return
raise_error_set
def
test_tensor_slice_reduce_out_of_bounds_neg
():
class
NetWork
(
Cell
):
def
__init__
(
self
):
...
...
tests/ut/python/optimizer/test_debug_location.py
浏览文件 @
e490618d
...
...
@@ -26,7 +26,7 @@ from mindspore.ops import functional as F
from
mindspore.ops
import
operations
as
P
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore.ops._grad.grad_math_ops
import
binop_grad_common
from
mindspore.ops._utils
import
_
get_broadcast_shape
from
mindspore.ops._utils
import
get_broadcast_shape
from
mindspore.ops.primitive
import
PrimitiveWithInfer
,
prim_attr_register
from
mindspore.train.loss_scale_manager
import
DynamicLossScaleManager
...
...
@@ -54,7 +54,7 @@ class MockSub(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'y'
],
outputs
=
[
'output'
])
def
infer_shape
(
self
,
x_shape
,
y_shape
):
return
_
get_broadcast_shape
(
x_shape
,
y_shape
)
return
get_broadcast_shape
(
x_shape
,
y_shape
)
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
x_dtype
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录