提交 274f6f01 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1133 support tensor get value by tensor index

Merge pull request !1133 from zhangbuxue/support_tensor_get_value_by_tensor_index
......@@ -1172,6 +1172,12 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, co
return 1;
}
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) {
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
return ret_graph;
}
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tensor
// args: tensor, slice or slice tuple
......@@ -1229,12 +1235,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
return ret_graph;
}
FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const {
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
return ret_graph;
}
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// select indexed item
// args: tuple of items, index
......
......@@ -206,8 +206,6 @@ class TensorSlice : public MetaFuncGraph {
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
};
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
......
......@@ -101,6 +101,7 @@ const char kNameReLU6[] = "ReLU6";
const char kNameReLU6Grad[] = "ReLU6Grad";
const char kNameElu[] = "Elu";
const char kNameEluGrad[] = "EluGrad";
const char kNameScatterUpdate[] = "ScatterUpdate";
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
const char kNameScatterMax[] = "ScatterMax";
const char kNameNMSWithMask[] = "NMSWithMask";
......@@ -256,6 +257,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
{string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)},
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
{string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)},
......
......@@ -515,6 +515,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};
// ScatterUpdate
INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ScatterUpdate) = {{0, OUTPUT_DESC(var)}};
// ScatterNdUpdate
INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
......
......@@ -132,6 +132,8 @@ DECLARE_OP_ADAPTER(ZerosLike)
DECLARE_OP_USE_OUTPUT(ZerosLike)
DECLARE_OP_ADAPTER(OnesLike)
DECLARE_OP_USE_OUTPUT(OnesLike)
DECLARE_OP_ADAPTER(ScatterUpdate)
DECLARE_OP_USE_OUTPUT(ScatterUpdate)
DECLARE_OP_ADAPTER(ScatterNdUpdate)
DECLARE_OP_USE_OUTPUT(ScatterNdUpdate)
DECLARE_OP_ADAPTER(ScatterMax)
......
......@@ -179,14 +179,15 @@ from .bounding_box_encode import _bounding_box_encode_tbe
from .check_valid import _check_valid_tbe
from .iou import _iou_tbe
from .arg_max import _arg_max_tbe
from .nms_with_mask import nms_with_mask_op_info
from .random_choice_with_mask import random_choice_with_mask_op_info
from .sgd import sgd_op_info
from .lars_update import lars_update_op_info
from .nms_with_mask import _nms_with_mask_tbe
from .random_choice_with_mask import _random_choice_with_mask_tbe
from .sgd import _sgd_tbe
from .lars_update import _lars_update_tbe
from .bn_training_update_v2 import _bn_training_update_v2_tbe
from .square_sum_all import square_sum_all_op_info
from .square_sum_all import _square_sum_all_tbe
from .pack import _pack_tbe
from .unpack import _unpack_tbe
from .scatter_update import _scatter_update_tbe
from .prelu import _prelu_tbe
from .prelu_grad import _prelu_grad_tbe
from .binary_cross_entropy import _binary_cross_entropy_tbe
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterUpdate op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_update_op_info = TBERegOp("ScatterUpdate") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_update.so") \
.compute_cost(10) \
.kernel_name("scatter_update") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(1, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(scatter_update_op_info)
def _scatter_update_tbe():
"""ScatterUpdate TBE register"""
return
......@@ -14,6 +14,6 @@
# ============================================================================
"""ops utils."""
from .utils import _get_broadcast_shape, _get_concat_offset
from .utils import get_broadcast_shape, get_concat_offset
__all__ = ['_get_broadcast_shape', '_get_concat_offset']
__all__ = ['get_broadcast_shape', 'get_concat_offset']
......@@ -19,7 +19,8 @@ from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
def _get_broadcast_shape(x_shape, y_shape, prim_name):
def get_broadcast_shape(x_shape, y_shape, prim_name):
"""
Doing broadcast between tensor x and tensor y.
......@@ -37,7 +38,7 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
Examples:
>>> x_shape = [1, 2, 3]
>>> y_shape = [1, 2]
>>> broadcast_shape = _get_broadcast_shape(x_shape, y_shape)
>>> broadcast_shape = get_broadcast_shape(x_shape, y_shape)
"""
if x_shape == y_shape:
return x_shape
......@@ -54,15 +55,14 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
elif x_shape[i] == y_shape[i]:
broadcast_shape_back.append(x_shape[i])
else:
raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format(
prim_name, x_shape, y_shape))
raise ValueError(f"For '{prim_name}', the x_shape {x_shape} and y_shape {y_shape} can not broadcast.")
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
broadcast_shape = broadcast_shape_front + broadcast_shape_back
broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
return broadcast_shape
def _get_concat_offset(x_shp, x_type, axis, prim_name):
def get_concat_offset(x_shp, x_type, axis, prim_name):
"""for concat and concatoffset check args and compute offset"""
validator.check_value_type("shape", x_shp, [tuple], prim_name)
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
......@@ -73,7 +73,7 @@ def _get_concat_offset(x_shp, x_type, axis, prim_name):
if axis < 0:
axis = axis + rank_base
all_shp = x_shp[0][axis]
offset = [0,]
offset = [0]
for i in range(1, len(x_shp)):
v = x_shp[i]
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
......
......@@ -14,13 +14,36 @@
# ============================================================================
"""constexpr util"""
from functools import reduce
import numpy as np
from ...primitive import constexpr
from ....common.tensor import Tensor
from ....common import dtype as mstype
from ...._extends.utils import Slice, Ellipsis_
from ....ops import _utils as op_utils
from ...composite import base
from .... import log as logger
from ... import functional as F
from ... import operations as P
hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)
ALL_TENSOR = 0
NO_TENSOR = 1
CONTAIN_TENSOR = 2
ALL_SCALAR = 3
INT_ = 0
BOOL_ = 1
UNSUPPORTED_DTYPE = 2
TENSOR_SETITEM = "tensor setitem"
TENSOR_GETITEM = "tensor getitem"
SET_ITEM_BY_ONE_TENSOR = 0
SET_ITEM_BY_TUPLE_OF_TENSOR = 1
@constexpr
def check_equal(param1, param2, msg="{},{}"):
......@@ -55,7 +78,7 @@ def check_tensor_setitem_index(index, element_type=None):
return True
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
# eg. Tensor[Tensor[dtype=bool]] = u
if index == mstype.tensor:
if isinstance(index, mstype.tensor_type):
if element_type is None or element_type != mstype.bool_:
raise TypeError(
"The index of tensor should be a bool type tensor. "
......@@ -172,6 +195,7 @@ def slice2indices(input_slices, shape):
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
return ravel
@constexpr
def check_indices(indices_size, index):
"""Checks indices whether is empty."""
......@@ -192,6 +216,7 @@ def check_indices_value_size(indices_size, value_size):
" value size:{}, indics size:{}".format(value_size, indices_size))
return value_size
@constexpr
def integer_to_indices(index, shape):
"""Converts int or tuple[int] to indices."""
......@@ -201,6 +226,7 @@ def integer_to_indices(index, shape):
value = value.reshape(-1, 1)
return Tensor(value, dtype=mstype.int32)
@constexpr
def tuple_element_is_slice(indexs):
"""Judges tuple element type."""
......@@ -213,6 +239,7 @@ def tuple_element_is_slice(indexs):
return True
return False
@constexpr
def tuple_element_is_int(indexs):
"""Judges tuple element type."""
......@@ -224,3 +251,237 @@ def tuple_element_is_int(indexs):
return False
return True
return False
@constexpr
def tuple_elements_type(types):
"""Judges the type of all elements of the tuple."""
tensors_number = 0
for ele in types:
if isinstance(ele, mstype.tensor_type):
tensors_number += 1
if tensors_number == len(types):
return ALL_TENSOR
if tensors_number == 0:
return NO_TENSOR
return CONTAIN_TENSOR
@constexpr
def check_value_elements(data_dtype, types):
"""Judges the type of all elements of the tuple."""
tensors_number = 0
scalars_number = 0
for i, ele in enumerate(types):
if isinstance(ele, mstype.tensor_type):
ele_dtype = ele.element_type()
if data_dtype == ele_dtype:
tensors_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.")
elif mstype.issubclass_(ele, data_dtype):
scalars_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "
f"value tuple is not consistent with origin tensor data type '{data_dtype}'.")
if tensors_number == len(types):
return ALL_TENSOR
if scalars_number == len(types):
return ALL_SCALAR
raise TypeError(f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")
@constexpr
def get_index_tensor_dtype(dtype):
"""Check a tuple of tensor data type."""
if dtype == mstype.int32:
return INT_
if dtype == mstype.bool_:
return BOOL_
raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
@constexpr
def check_index_tensors_dtype(dtypes, op_name):
"""Check a tuple of tensor data type."""
if op_name == TENSOR_GETITEM:
valid_dtypes = (mstype.int32, mstype.int64)
elif op_name == TENSOR_SETITEM:
valid_dtypes = (mstype.int32,)
else:
raise ValueError("Unsupported operation.")
for ele in dtypes:
if ele in valid_dtypes and ele == dtypes[0]:
continue
raise TypeError(f"For '{op_name}', the index tensors data type must be same, "
f"and should be one of the following: {valid_dtypes}, but got {dtypes}.")
return True
@constexpr
def check_tensor_dtype_valid(dtype, valid_dtypes):
"""Check a tensor data type."""
if dtype in valid_dtypes:
return True
raise TypeError(f"The index tensor data type must be one of "
f"the following: {valid_dtypes}, but got {dtype}.")
@constexpr
def check_tensors_dtype_same(x_dtype, y_dtype, op_name):
"""Check tensors data type same."""
if x_dtype == y_dtype:
return True
raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' "
f"is not consistent with origin tensor data type {x_dtype}.")
@constexpr
def broadcast_shapes(shapes, op_name):
"""Broadcasts a tuple of tensor."""
broadcast_shape = shapes[0]
for i, shape in enumerate(shapes):
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
return tuple(broadcast_shape)
@constexpr
def check_two_shapes_need_broadcast(shape_x, shape_y):
"""Check two shapes need broadcast."""
error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape "
f"{shape_y} could not broadcast the required updates shape {shape_x}.")
if len(shape_y) > len(shape_x):
raise error
for i in range(-len(shape_y), 0):
if shape_y[i] > shape_x[i]:
raise error
if shape_y[i] < shape_x[i] and shape_y[i] != 1:
raise error
if shape_y == shape_x:
return False
return True
@constexpr
def compute_multiples(origin_shape, broadcast_shape):
"""Compute multiples between broadcast_shape with origin_shape."""
len_gap = len(broadcast_shape) - len(origin_shape)
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
def tile(broadcast_shape, x):
multiples = compute_multiples(F.shape(x), broadcast_shape)
return F.tile(x, multiples)
@constexpr
def check_shapes_same(value_shapes, op_name):
"""Check if the shapes in the tuple are consistent."""
for i, shape in enumerate(value_shapes):
if shape != value_shapes[0]:
raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple "
f"is not same as the first tensor shape.")
return True
@constexpr
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
"""Convert a scalar to a tensor."""
if op_type == SET_ITEM_BY_ONE_TENSOR:
updates_shape = indices_shape + data_shape[1:]
else:
updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
if isinstance(value, mstype.dtype_to_pytype(data_dtype)):
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'"
f" is not consistent with tensor data type {data_dtype}.")
@constexpr
def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type):
"""Convert a tuple of scalar to a tensor."""
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
if len(value) != updates_shape[-1]:
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple "
f"does not meet the requirements: {updates_shape[-1]}.")
array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype))
reps = compute_multiples(updates_shape[-1:], updates_shape)
return Tensor(np.tile(array, reps))
@constexpr
def generate_updates_shape(data_shape, index_shape, op_type):
"""Generate updates shape for 'tensor setitem'."""
if op_type == SET_ITEM_BY_ONE_TENSOR:
updates_shape = index_shape + data_shape[1:]
else:
updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:]
return updates_shape
@constexpr
def check_number_of_index_tensor(data_shape, tuple_len, op_name):
"""Check if the number of index tensor exceeds the dimension of the operated tensor."""
if tuple_len <= len(data_shape):
return True
raise IndexError(f"For '{op_name}', the number {tuple_len} of index tensor "
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
if check_index_tensor_number:
dtype_tuple = hyper_map(F.dtype, tuple_index)
check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name)
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = broadcast_shapes(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
return indices
def generate_updates_from_scalar(data, indices, value, op_type):
"""Generate an updates tensor from a scalar."""
data_shape = F.shape(data)
indices_shape = F.shape(indices)
data_dtype = F.dtype(data)
return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
def generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple."""
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = check_value_elements(data_dtype, value_types)
if value_elements_type == ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM)
if shapes_same:
value = F.pack(value)
return generate_updates_from_tensor(data, index, value, op_type)
data_shape = F.shape(data)
index_shape = F.shape(index)
return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
def generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor."""
data_shape = F.shape(data)
index_shape = F.shape(index)
value_shape = F.shape(value)
data_dtype = F.dtype(data)
value_dtype = F.dtype(value)
updates_shape = value_shape
check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM)
if check_dtype_same:
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape)
if need_broadcast:
return tile(updates_shape, value)
return value
......@@ -15,9 +15,10 @@
"""Implementation for getitem."""
from ...composite import base
from . import _utils as multi_utils
from ..import base
from ... import functional as F
from ....common import dtype as mstype
getitem = base.MultitypeFuncGraph('getitem')
"""
......@@ -214,19 +215,45 @@ def _tensor_getitem_by_slice(data, slice_index):
return _tensor_slice(data, slice_index)
@getitem.register("Tensor", "Tensor")
def _tensor_getitem_by_tensor(data, tensor_index):
"""
Getting item of tensor by slice.
Inputs:
data (Tensor): A tensor.
tensor_index (Tensor): An index expressed by tensor.
Outputs:
Tensor, element type is same as the element type of data.
"""
check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64))
result = None
if check_dtypes:
result = F.gather(data, tensor_index, 0)
return result
@getitem.register("Tensor", "Tuple")
def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
def _tensor_getitem_by_tuple(data, tuple_index):
"""
Getting item of tensor by slice tuple.
Inputs:
data (Tensor): A tensor.
slice_tuple_index (tuple): Index in tuple.
tuple_index (tuple): Index in tuple.
Outputs:
Tensor, element type is same as the element type of data.
"""
return _tensor_slice(data, slice_tuple_index)
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_elements_type(index_types)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_slice(data, tuple_index)
if index_elements_type == multi_utils.ALL_TENSOR:
result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return result
@getitem.register("Tensor", "Ellipsis")
......@@ -242,3 +269,10 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
Tensor, same as data.
"""
return _tensor_slice(data, ellipsis_index)
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result
......@@ -18,10 +18,11 @@
from ...composite import base
from ....common import dtype as mstype
from ... import functional as F
from . import _multitype_ops_util as mult_util
from . import _utils as multi_utils
setitem = base.MultitypeFuncGraph('setitem')
@setitem.register("List", "Number", "String")
def _list_setitem_with_string(data, number_index, value):
"""
......@@ -118,7 +119,7 @@ def _dict_setitem_with_number(data, key, value):
@setitem.register("Tensor", "Tensor", "Tensor")
def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
"""
Tensor assignment.
......@@ -137,27 +138,15 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
Outputs:
Tensor, element type and shape is same as data.
"""
result = None
index_dtype = F.dtype(index)
index_shape = F.shape(index)
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
if check_result:
data_shape = F.shape(data)
data_shape = mult_util.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value_tensor)
size = mult_util.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value_tensor, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
result = F.select(index, u, data)
return result
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == multi_utils.INT_:
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
@setitem.register("Tensor", "Tensor", "Number")
def _tensor_setitem_by_tensor_v2(data, index, value):
def _tensor_setitem_by_tensor_with_number(data, index, value):
"""
Tensor assignment.
......@@ -171,143 +160,167 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value_tensor (Number): Assignment value.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
result = None
index_dtype = F.dtype(index)
index_shape = F.shape(index)
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
if check_result:
shape = F.shape(data)
shape = mult_util.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
result = F.select(index, u, data)
return result
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == multi_utils.BOOL_:
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
@setitem.register("Tensor", "Slice", "Tensor")
def _tensor_setitem_with_slice_v3(data, input_slice, value):
@setitem.register("Tensor", "Tuple", "Number")
def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = U
Restraint condition: A is a Tensor
Slice like "1:3"
U is a Tensor(size=1) or Tensor(size>1)
Syntax support: A[B, C, D] = u.
Restraint condition: 1) A is a Tensor, and B, C, D are index.
2) u is a scalar.
Inputs:
data (Tensor): Assigned tensor.
input_slice (Slice): Slice expression.
index (Tuple): An index tuple.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_tensor(data, input_slice, value)
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_elements_type(index_types)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_assgin_number(data, tuple_index, value)
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_scalar(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
@setitem.register("Tensor", "Tuple", "Tensor")
def _tensor_setitem_with_slice_v4(data, input_slice, value):
def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
"""
Tensor assignment.
Note:
Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U
Restraint condition: A is a Tensor
Slice like "1:3, ::, :4:-1"
U is a Tensor(size=1) or Tensor(size>1)
Syntax support: A[B, C, D] = U.
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
2) U is a Tensor.
Inputs:
data (Tensor): Assigned tensor.
input_slice (Union[tuple[Slice], tuple[Number]]): Slice expression.
value (Number): Assignment value.
index (Tuple): An index tuple.
value (Tensor): Assignment tensor, should has the same data type as 'data'.
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_tensor(data, input_slice, value)
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_elements_type(index_types)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_assgin_tensor(data, tuple_index, value)
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_tensor(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
def _tensor_assgin_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
@setitem.register("Tensor", "Tuple", "Tuple")
def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
"""
Tensor assignment.
Note:
Syntax support: A[B, C, D] = U.
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
2) A B and C could be broadcast.
3) U is a Tensor.
Inputs:
data (Tensor): Assigned tensor.
index (Tuple): A tuple of tensor, these tensor could be broadcast.
value (Tensor): Assignment tensor, should has the same data type as 'data'.
Outputs:
Tensor, element type and shape is same as data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_elements_type(index_types)
result = None
check_result = mult_util.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = mult_util.slice2indices(input_slice, data_shape)
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = mult_util.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_tuple(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
@setitem.register("Tensor", "Tensor", "Tuple")
def _tensor_setitem_by_tensor_v2(data, index, value):
"""
Tensor assignment.
value_size = mult_util.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value (Tuple): Assignment value.
@setitem.register("Tensor", "Slice", "Number")
def _tensor_setitem_with_slice_v1(data, input_slice, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64))
result = None
if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
return result
@setitem.register("Tensor", "Slice", "Tensor")
def _tensor_setitem_with_slice_v3(data, input_slice, value):
"""
Tensor assignment.
Note:
Syntax support: A[Slice] = u
Restraint condition: A is a Tensor.
Syntax support: A[Slice] = U
Restraint condition: A is a Tensor
Slice like "1:3"
u is a scalar
U is a Tensor(size=1) or Tensor(size>1)
Inputs:
data (Tensor): Assigned tensor.
input_slice (Slice): slice expression.
input_slice (Slice): Slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_number(data, input_slice, value)
return _tensor_assgin_tensor(data, input_slice, value)
@setitem.register("Tensor", "Tuple", "Number")
def _tensor_setitem_with_slice_v2(data, input_slice, value):
@setitem.register("Tensor", "Slice", "Number")
def _tensor_setitem_with_slice_v1(data, input_slice, value):
"""
Tensor assignment.
Note:
Syntax support: A[tuple(Slice)] = u, and A[tuple(Number)] = u
Syntax support: A[Slice] = u
Restraint condition: A is a Tensor.
Slice like "1:3, ::, :4:-1"
Slice like "1:3"
u is a scalar
Inputs:
data (Tensor): Assigned tensor.
input_slice (Union[tuple[Slice], tuple[Number]]): slice expression.
input_slice (Slice): slice expression.
value (Number): Assignment value.
Outputs:
......@@ -318,39 +331,23 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
def _tensor_assgin_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice"""
check_result = mult_util.check_tensor_setitem_index(input_slice)
check_result = multi_utils.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
indices = mult_util.slice2indices(input_slice, data_shape)
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
indices = multi_utils.slice2indices(input_slice, data_shape)
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = mult_util.integer_to_indices(input_slice, data_shape)
indices = multi_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
@setitem.register("Tensor", "Number", "Number")
def _tensor_setitem_with_int_v1(data, index, value):
"""Syntax: A[1] = 3"""
data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape)
indices = multi_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)
......@@ -358,7 +355,7 @@ def _tensor_setitem_with_int_v1(data, index, value):
def _tensor_setitem_with_int_v2(data, index, value):
"""Syntax: A[1] = Tensor"""
data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape)
indices = multi_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
......@@ -379,7 +376,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
......@@ -389,3 +386,108 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result
def _tensor_assgin_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = multi_utils.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = multi_utils.slice2indices(input_slice, data_shape)
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = multi_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = multi_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
value_size = multi_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = multi_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = multi_utils.generate_updates_from_tuple(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
result = F.scatter_update(data, index, updates)
return result
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
updates = multi_utils.generate_updates_from_scalar(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape = F.shape(index)
shape = F.shape(data)
shape = multi_utils.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
return F.select(index, u, data)
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor."""
updates = multi_utils.generate_updates_from_tensor(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index)
data_shape = F.shape(data)
data_shape = multi_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value)
size = multi_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
result = F.select(index, u, data)
return result
......@@ -31,6 +31,7 @@ dtype = P.DType()
issubclass_ = P.IsSubClass()
isinstance_ = P.IsInstance()
fill = P.Fill()
tile = P.Tile()
select = P.Select()
size = P.Size()
ones_like = P.OnesLike()
......@@ -70,6 +71,12 @@ scalar_cast = P.ScalarCast()
print_ = P.Print()
expand_dims = P.ExpandDims()
scatter_nd = P.ScatterNd()
gather = P.GatherV2()
gather_nd = P.GatherNd()
scatter_update = P.ScatterUpdate()
scatter_nd_update = P.ScatterNdUpdate()
pack = P.Pack()
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem')
......
......@@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterMax,
SameTypeShape, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile,
......@@ -193,6 +193,7 @@ __all__ = [
'Pad',
'MirrorPad',
'GatherNd',
'ScatterUpdate',
'ScatterNdUpdate',
'Floor',
'NMSWithMask',
......
......@@ -19,7 +19,7 @@ from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import Validator as validator, Rel
from .._utils import _get_concat_offset
from .._utils import get_concat_offset
from ...common import dtype as mstype
......@@ -136,7 +136,7 @@ class ConcatOffset(PrimitiveWithInfer):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name)
offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name)
self.add_prim_attr('T', x_type[0].element_type())
offset_values = []
for i in range(len(x_shp)):
......
......@@ -24,16 +24,15 @@ import itertools
import numbers
import numpy as np
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ..operations.math_ops import _infer_shape_reduce
from .._utils import _get_concat_offset
from .._utils import get_concat_offset
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
validator.check_value_type('axis', axis, [int, tuple], prim_name)
......@@ -931,7 +930,7 @@ class InvertPermutation(PrimitiveWithInfer):
z = [x_value[i] for i in range(len(x_value))]
z.sort()
y = [None]*len(x_value)
y = [None] * len(x_value)
for i, value in enumerate(x_value):
validator.check_value_type("input[%d]" % i, value, [int], self.name)
validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name)
......@@ -1111,6 +1110,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
>>> input_x = Tensor(np.random.rand(5))
>>> index, output = P.ArgMinWithValue()(input_x)
"""
@prim_attr_register
def __init__(self, axis=0, keep_dims=False):
"""init ArgMinWithValue"""
......@@ -1352,7 +1352,7 @@ class Concat(PrimitiveWithInfer):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name)
_, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
self.add_prim_attr('T', x_type[0].element_type())
self.add_prim_attr('inputNums', len(x_shp))
ret_shp = x_shp[0].copy()
......@@ -1376,15 +1376,13 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
if axis < 0:
axis = axis + rank_base + 1
for i in range(1, N):
v = x_shape[i]
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base, Rel.EQ, prim_name)
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError)
for j in range(rank_base):
if v[j] != x_shape[0][j]:
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
if x_shape[i] != x_shape[0]:
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
out_shape.insert(axis, N)
return out_shape
class Pack(PrimitiveWithInfer):
r"""
Packs a list of tensors in specified axis.
......@@ -1831,7 +1829,7 @@ class DiagPart(PrimitiveWithInfer):
return x_type
def infer_shape(self, x_shape):
if len(x_shape)%2 != 0 or \
if len(x_shape) % 2 != 0 or \
not x_shape:
raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
f"with shapes {x_shape}")
......@@ -2004,6 +2002,49 @@ class GatherNd(PrimitiveWithInfer):
return x_dtype
class ScatterUpdate(PrimitiveWithInfer):
"""
Update tensor value by using input indices and value.
Using given values to update tensor value, along with the input indices.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: True.
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].
Outputs:
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update)
"""
@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterNdUpdate"""
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
if indices_shape + x_shape[1:] != value_shape:
raise ValueError('Input value are not match with input indices.')
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
class ScatterNdUpdate(PrimitiveWithInfer):
"""
Update tensor value by using input indices and value.
......@@ -2028,11 +2069,6 @@ class ScatterNdUpdate(PrimitiveWithInfer):
>>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update)
"""
__mindspore_signature__ = (
('input_x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
)
@prim_attr_register
def __init__(self, use_locking=True):
......@@ -2142,10 +2178,10 @@ class SpaceToDepth(PrimitiveWithInfer):
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
out_shape = copy.deepcopy(x_shape)
for i in range(2):
if out_shape[i+2] % self.block_size != 0:
raise ValueError(f'For \'{self.name}\' input shape[{i+2}] {out_shape[i+2]} should be '
if out_shape[i + 2] % self.block_size != 0:
raise ValueError(f'For \'{self.name}\' input shape[{i + 2}] {out_shape[i + 2]} should be '
f'fully divided by block_size {self.block_size}')
out_shape[i+2] //= self.block_size
out_shape[i + 2] //= self.block_size
out_shape[1] *= self.block_size * self.block_size
return out_shape
......@@ -2199,9 +2235,10 @@ class DepthToSpace(PrimitiveWithInfer):
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
out_shape = copy.deepcopy(x_shape)
for i in range(2):
out_shape[i+2] *= self.block_size
out_shape[i + 2] *= self.block_size
validator.check_integer('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size),
validator.check_integer('x_shape[1] % (block_size*block_size)',
x_shape[1] % (self.block_size * self.block_size),
0, Rel.EQ, self.name)
out_shape[1] //= self.block_size * self.block_size
return out_shape
......@@ -2251,6 +2288,7 @@ class SpaceToBatch(PrimitiveWithInfer):
[[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
"""
@prim_attr_register
def __init__(self, block_size, paddings):
"""Init SpaceToBatch"""
......@@ -2271,12 +2309,12 @@ class SpaceToBatch(PrimitiveWithInfer):
validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
out_shape = copy.deepcopy(x_shape)
for i in range(2):
padded = out_shape[i+2] + self.paddings[i][0] + \
padded = out_shape[i + 2] + self.paddings[i][0] + \
self.paddings[i][1]
if padded % self.block_size != 0:
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
f'block_size {self.block_size}')
out_shape[i+2] = padded // self.block_size
out_shape[i + 2] = padded // self.block_size
out_shape[0] *= self.block_size * self.block_size
return out_shape
......@@ -2319,6 +2357,7 @@ class BatchToSpace(PrimitiveWithInfer):
[[[[1., 2.], [3., 4.]]]]
"""
@prim_attr_register
def __init__(self, block_size, crops):
"""Init BatchToSpace"""
......@@ -2339,10 +2378,10 @@ class BatchToSpace(PrimitiveWithInfer):
validator.check('rank of input_x', len(x_shape), '', 4)
out_shape = copy.deepcopy(x_shape)
for i in range(2):
x_block_prod = out_shape[i+2] * self.block_size
x_block_prod = out_shape[i + 2] * self.block_size
crops_sum = self.crops[i][0] + self.crops[i][1]
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
out_shape[i+2] = x_block_prod - crops_sum
out_shape[i + 2] = x_block_prod - crops_sum
block_size_prod = self.block_size * self.block_size
if out_shape[0] % block_size_prod != 0:
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
......
......@@ -24,7 +24,7 @@ from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.tensor import Tensor
from .._utils import _get_broadcast_shape
from .._utils import get_broadcast_shape
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
......@@ -75,7 +75,7 @@ class _BinaryOp(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, y_shape):
return _get_broadcast_shape(x_shape, y_shape, self.name)
return get_broadcast_shape(x_shape, y_shape, self.name)
class _MathBinaryOp(_BinaryOp):
......
......@@ -27,9 +27,15 @@ class IdentityEC(IExectorComponent):
def __call__(self):
result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id]
group = self.function[keyword.group] + '-' + self.inputs[keyword.group]
return {
ret = {
keyword.id: result_id,
keyword.group: group,
keyword.desc_inputs: self.inputs[keyword.desc_inputs],
keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
}
print("buxue------------------------------------------------")
print("inputs")
print(ret[keyword.desc_inputs])
print("outputs")
print(ret[keyword.result])
return ret
......@@ -1307,7 +1307,7 @@ raise_set = [
('ScatterNdUpdate', {
'block': (P.ScatterNdUpdate(), {'exception': TypeError}),
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
Tensor(np.ones((2, 2), np.int32)),
Tensor(np.ones((2, 2), np.float32)),
Tensor(np.ones((2,), np.float32))),
'desc_bprop': [[2, 3]]}),
('Pack', {
......
......@@ -16,13 +16,14 @@
import numpy as np
import pytest
from mindspore import Tensor
from mindspore import Tensor, Parameter
from mindspore import context
from mindspore import dtype as mstype
from mindspore.nn import Cell
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
class NetWorkSlicePositive(Cell):
......@@ -145,6 +146,160 @@ class TensorAssignWithSlice(Cell):
return z
class TensorIndexByOneTensor(Cell):
def __init__(self):
super(TensorIndexByOneTensor, self).__init__()
self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32)
def construct(self, x, index):
ret = x[index] + self.const
return ret
class TensorIndexByTwoTensors(Cell):
def __init__(self):
super(TensorIndexByTwoTensors, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32)
def construct(self, x, index_0, index_1):
ret = x[index_0, index_1] + self.const
return ret
class TensorIndexByThreeTensors(Cell):
def __init__(self):
super(TensorIndexByThreeTensors, self).__init__()
self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
def construct(self, x, index_0, index_1, index_2):
ret = x[index_0, index_1, index_2] + self.const
return ret
class TensorSetItemByOneTensorWithNumber(Cell):
def __init__(self, value):
super(TensorSetItemByOneTensorWithNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index):
self.param[index] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByOneTensorWithTensor(Cell):
def __init__(self):
super(TensorSetItemByOneTensorWithTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index, value):
self.param[index] = value
ret = self.param + self.const
return ret
class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
def __init__(self, value):
super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index):
self.param[index] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
def __init__(self):
super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x")
def construct(self, index, value_0, value_1, value_2):
self.param[index] = (value_0, value_1, value_2)
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithNumber(Cell):
def __init__(self, value):
super(TensorSetItemByTensorsWithNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[index_0, index_1, index_2] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithTensor(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, value):
self.param[index_0, index_1, index_2] = value
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithTensorNumberError(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, index_3, value):
self.param[index_0, index_1, index_2, index_3] = value
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithTupleOfNumber(Cell):
def __init__(self, value):
super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[index_0, index_1, index_2] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithTupleOfTensor(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, value_0, value_1):
self.param[index_0, index_1, index_2] = (value_0, value_1)
ret = self.param + self.const
return ret
def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice()
......@@ -441,15 +596,206 @@ test_cases = [
'block': NetWorkSliceEllipsis(),
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
}),
('TensorIndexByOneTensor', {
'block': TensorIndexByOneTensor(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
}),
('TensorIndexByTwoTensors', {
'block': TensorIndexByTwoTensors(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
}),
('TensorIndexByThreeTensors', {
'block': TensorIndexByThreeTensors(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithNumber', {
'block': TensorSetItemByOneTensorWithNumber(value=0.0),
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTensor', {
'block': TensorSetItemByOneTensorWithTensor(),
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
Tensor(np.zeros((4, 7, 8)), mstype.float32)],
}),
('TensorSetItemByOneTensorWithTupleOfNumber', {
'block': TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)),
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTupleOfTensor', {
'block': TensorSetItemByOneTensorWithTupleOfTensor(),
'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
Tensor(np.zeros((8,), np.float32)),
Tensor(np.ones((8,), np.float32)),
Tensor(np.ones((8,), np.float32) * 2)],
}),
('TensorSetItemByTensorsWithNumber', {
'block': TensorSetItemByTensorsWithNumber(value=0.0),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTensor', {
'block': TensorSetItemByTensorsWithTensor(),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTupleOfNumber', {
'block': TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4)),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTupleOfTensor', {
'block': TensorSetItemByTensorsWithTupleOfTensor(),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)) * 2, mstype.float32)],
})
]
raise_error_set = [
('TensorIndexByOneTensorDtypeError', {
'block': (TensorIndexByOneTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
}),
('TensorIndexByTwoTensorsShapeError', {
'block': (TensorIndexByTwoTensors(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
}),
('TensorIndexByTwoTensorsDtypeError', {
'block': (TensorIndexByTwoTensors(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
}),
('TensorIndexByThreeTensorsShapeError', {
'block': (TensorIndexByThreeTensors(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
}),
('TensorIndexByThreeTensorsDtypeError', {
'block': (TensorIndexByThreeTensors(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithNumberTypeError', {
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTensorShapeError', {
'block': (TensorSetItemByOneTensorWithTensor(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
Tensor(np.zeros((6, 7, 8)), mstype.float32)],
}),
('TensorSetItemByOneTensorWithTensorDtypeError', {
'block': (TensorSetItemByOneTensorWithTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
Tensor(np.zeros((6, 7, 8)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTupleOfNumberTypeError', {
'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0, 1, 2, 3, 4, 5, 6, 7)), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTupleOfNumberNumberError', {
'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2)), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTupleOfTensorDtyeError', {
'block': (TensorSetItemByOneTensorWithTupleOfTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
Tensor(np.zeros((8,), np.int32)),
Tensor(np.ones((8,), np.int32)),
Tensor(np.ones((8,), np.float32) * 2)],
}),
('TensorSetItemByTensorsWithNumberTypeError', {
'block': (TensorSetItemByTensorsWithNumber(value=0), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTensorShapeError', {
'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((2, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTensorTypeError', {
'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTensorNumberError', {
'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((2, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTupleOfNumberNumberError', {
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTupleOfTensorNumberError', {
'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTupleOfTensorTypeError', {
'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)), mstype.int32),
Tensor(np.ones((4, 5)) * 2, mstype.int32)],
})
]
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_compile():
context.set_context(mode=context.GRAPH_MODE)
def test_exec():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
return test_cases
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
def test_check_exception():
return raise_error_set
def test_tensor_slice_reduce_out_of_bounds_neg():
class NetWork(Cell):
def __init__(self):
......
......@@ -26,7 +26,7 @@ from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops._grad.grad_math_ops import binop_grad_common
from mindspore.ops._utils import _get_broadcast_shape
from mindspore.ops._utils import get_broadcast_shape
from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
......@@ -54,7 +54,7 @@ class MockSub(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, y_shape):
return _get_broadcast_shape(x_shape, y_shape)
return get_broadcast_shape(x_shape, y_shape)
def infer_dtype(self, x_dtype, y_dtype):
return x_dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册