提交 37cc6e26 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3175 add ScatterNdAdd ScatterNdSub ScatterNonAliasingAdd ops

Merge pull request !3175 from fangzehua/scatter_add_vm
......@@ -37,6 +37,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"re_lu6", "relu6"},
{"re_lu6_grad", "relu6_grad"},
{"re_lu", "relu"},
{"reverse_v2", "reverse_v2_d"},
{"re_luv2", "relu_v2"},
{"p_re_lu", "prelu"},
{"p_re_lu_grad", "prelu_grad"},
......
......@@ -377,6 +377,18 @@ def get_bprop_pack(self):
return bprop
@bprop_getters.register(P.ReverseV2)
def get_bprop_reverse_v2(self):
"""Generate bprop for ReverseV2"""
axis = self.axis
def bprop(x, out, dout):
reverse_grad = P.ReverseV2(axis)
dx = reverse_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(P.Unpack)
def get_bprop_unpack(self):
"""Generate bprop for Unpack"""
......@@ -495,6 +507,16 @@ def get_bprop_scatter_nd_update(self):
return bprop
@bprop_getters.register(P.ScatterNonAliasingAdd)
def get_bprop_scatter_non_aliasing_add_update(self):
"""Generate bprop for ScatterNonAliasingAdd"""
op = P.GatherNd()
def bprop(x, indices, update, out, dout):
return dout, zeros_like(indices), op(dout, indices)
return bprop
@bprop_getters.register(P.TensorScatterUpdate)
def get_bprop_tensor_scatter_update(self):
"""Generate bprop for TensorScatterUpdate"""
......@@ -509,6 +531,7 @@ def get_bprop_tensor_scatter_update(self):
return bprop
@bprop_getters.register(P.ScatterMax)
def get_bprop_scatter_max(self):
"""Generate bprop for ScatterMax"""
......
......@@ -81,6 +81,9 @@ from .sub import _sub_tbe
from .reduce_mean_d import _reduce_mean_d_tbe
from .scatter_nd import _scatter_nd_tbe
from .scatter_nd_d import _scatter_nd_d_tbe
from .scatter_nd_add import _scatter_nd_add_tbe
from .scatter_nd_sub import _scatter_nd_sub_tbe
from .scatter_non_aliasing_add import _scatter_non_aliasing_add_tbe
from .reduce_mean import _reduce_mean_tbe
from .tile import _tile_tbe
from .atomic_addr_clean import _atomic_addr_clean_tbe
......@@ -93,6 +96,8 @@ from .bn_training_update_grad import _bn_training_update_grad_tbe
from .bn_infer import _bn_infer_tbe
from .bn_infer_grad import _bn_infer_grad_tbe
from .reciprocal import _reciprocal_tbe
from .reverse_v2_d import _reverse_v2_d_tbe
from .rint import _rint_tbe
from .strided_slice_d import _strided_slice_d_tbe
from .strided_slice_grad_d import _strided_slice_grad_d_tbe
from .split_d import _split_d_tbe
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReverseV2D op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
reverse_v2_d_op_info = TBERegOp("ReverseV2") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("reverse_v2_d.so") \
.compute_cost(10) \
.kernel_name("reverse_v2_d") \
.partial_flag(True) \
.op_pattern("dynamicFormat") \
.attr("axis", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()
@op_info_register(reverse_v2_d_op_info)
def _reverse_v2_d_tbe():
"""ReverseV2D TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Rint op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
rint_op_info = TBERegOp("Rint") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("rint.so") \
.compute_cost(10) \
.kernel_name("rint") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(rint_op_info)
def _rint_tbe():
"""Rint TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterNdAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_nd_add_op_info = TBERegOp("ScatterNdAdd") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_nd_add.so") \
.compute_cost(10) \
.kernel_name("scatter_nd_add") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(scatter_nd_add_op_info)
def _scatter_nd_add_tbe():
"""ScatterNdAdd TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterNdSub op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_nd_sub_op_info = TBERegOp("ScatterNdSub") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_nd_sub.so") \
.compute_cost(10) \
.kernel_name("scatter_nd_sub") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(scatter_nd_sub_op_info)
def _scatter_nd_sub_tbe():
"""ScatterNdSub TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterNonAliasingAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_non_aliasing_add_op_info = TBERegOp("ScatterNonAliasingAdd") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_non_aliasing_add.so") \
.compute_cost(10) \
.kernel_name("scatter_non_aliasing_add") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(scatter_non_aliasing_add_op_info)
def _scatter_non_aliasing_add_tbe():
"""ScatterNonAliasingAdd TBE register"""
return
......@@ -28,6 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, TransShape, ParallelConcat,
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
......@@ -233,6 +234,11 @@ __all__ = [
'ScatterNd',
'ScatterMax',
'ScatterMin',
'ScatterNdAdd',
'ScatterNdSub',
'ScatterNonAliasingAdd',
'ReverseV2',
'Rint',
'ResizeNearestNeighbor',
'HistogramFixedWidth',
'Pad',
......
......@@ -47,8 +47,8 @@ class _ScatterOp(PrimitiveWithInfer):
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@staticmethod
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
......@@ -61,7 +61,7 @@ class _ScatterOp(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, updates_shape):
_ScatterOp._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
......@@ -71,6 +71,19 @@ class _ScatterOp(PrimitiveWithInfer):
return x_dtype
class _ScatterNdOp(_ScatterOp):
"""
Define _ScatterNd operators
"""
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE)
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape:
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or updates_shape = "
f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
validator.check_value_type('axis', axis, [int, tuple], prim_name)
......@@ -1759,6 +1772,75 @@ class Slice(PrimitiveWithInfer):
'value': None}
class ReverseV2(PrimitiveWithInfer):
"""
Reverse specific dimensions of a tensor.
Args:
axis (Union[tuple(int), list(int)): The indices of the dimensions to reverse.
Inputs:
- **input_x** (Tensor) - The target tensor.
Outputs:
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.int32)
>>> op = P.ReverseV2(axis=[1])
>>> output = op(input_x)
[[4, 3, 2, 1], [8, 7, 6, 5]]
"""
@prim_attr_register
def __init__(self, axis):
validator.check_value_type('axis', axis, [list, tuple], self.name)
for i, each in enumerate(axis):
validator.check_value_type(f'axis[{i}]', each, [int], self.name)
self.axis = axis
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
dim = len(x_shape)
for i, each in enumerate(self.axis):
validator.check_int_range(f'axis[{i}]', each, -dim, dim, Rel.INC_LEFT, self.name)
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
class Rint(PrimitiveWithInfer):
"""
Return element-wise integer closest to x.
Inputs:
- **input_x** (Tensor) - The target tensor, which must be one of the following types:
float16, float32.
Outputs:
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32)
>>> op = P.Rint()
>>> output = op(input_x)
[-2., 0., 2., 2.]
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
return x_dtype
class Select(PrimitiveWithInfer):
r"""
......@@ -2404,7 +2486,7 @@ class ScatterUpdate(_ScatterOp):
return x_dtype
class ScatterNdUpdate(PrimitiveWithInfer):
class ScatterNdUpdate(_ScatterNdOp):
"""
Update tensor value by using input indices and value.
......@@ -2429,11 +2511,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
>>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update)
"""
__mindspore_signature__ = (
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register
def __init__(self, use_locking=True):
......@@ -2441,13 +2519,6 @@ class ScatterNdUpdate(PrimitiveWithInfer):
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE)
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.")
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
......@@ -2635,6 +2706,101 @@ class ScatterDiv(_ScatterOp):
"""
class ScatterNdAdd(_ScatterNdOp):
"""
Applies sparse addition to individual values or slices in a Tensor.
Using given values to update tensor value through the add operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the add operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
Outputs:
Parameter, the updated `input_x`.
Examples:
>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
>>> scatter_nd_add = P.ScatterNdAdd()
>>> output = scatter_nd_add(input_x, indices, updates)
[1, 10, 9, 4, 12, 6, 7, 17]
"""
class ScatterNdSub(_ScatterNdOp):
"""
Applies sparse subtraction to individual values or slices in a Tensor.
Using given values to update tensor value through the sub operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the sub operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
Outputs:
Parameter, the updated `input_x`.
Examples:
>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
>>> scatter_nd_sub = P.ScatterNdSub()
>>> output = scatter_nd_sub(input_x, indices, updates)
[1, -6, -3, 4, -2, 6, 7, -1]
"""
class ScatterNonAliasingAdd(_ScatterNdOp):
"""
Applies sparse addition to input using individual values or slices.
Using given values to update tensor value through the add operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the add operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
Outputs:
Parameter, the updated `input_x`.
Examples:
>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
>>> scatter_non_aliasing_add = P.ScatterNonAliasingAdd()
>>> output = scatter_non_aliasing_add(input_x, indices, updates)
[1, 10, 9, 4, 12, 6, 7, 17]
"""
@prim_attr_register
def __init__(self):
"""Init ScatterNonAliasingAdd"""
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
return x_dtype
class SpaceToDepth(PrimitiveWithInfer):
r"""
Rearrange blocks of spatial data into depth.
......
......@@ -237,6 +237,44 @@ class ScatterAdd(nn.Cell):
return out
class ScatterNonAliasingAdd(nn.Cell):
"""ScatterNonAliasingAdd net definition"""
def __init__(self, ref_shape, dtype=np.float32):
super(ScatterNonAliasingAdd, self).__init__()
self.scatter_no_aliasing_add = P.ScatterNonAliasingAdd()
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
def construct(self, indices, updates):
out = self.scatter_no_aliasing_add(self.ref, indices, updates)
return out
class ScatterNdSub(nn.Cell):
"""ScatterNdSub net definition"""
def __init__(self, ref_shape, dtype=np.float32):
super(ScatterNdSub, self).__init__()
self.scatter_nd_sub = P.ScatterNdSub()
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
def construct(self, indices, updates):
out = self.scatter_nd_sub(self.ref, indices, updates)
return out
class ScatterNdAdd(nn.Cell):
"""ScatterNdAdd net definition"""
def __init__(self, ref_shape, dtype=np.float32):
super(ScatterNdAdd, self).__init__()
self.scatter_nd_add = P.ScatterNdAdd()
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
def construct(self, indices, updates):
out = self.scatter_nd_add(self.ref, indices, updates)
return out
class ScatterSub(nn.Cell):
"""ScatterSub net definition"""
......@@ -1811,6 +1849,14 @@ test_case_array_ops = [
'desc_const': [(2, 1, 1, 2)],
'desc_inputs': [[2, 2, 2]],
'desc_bprop': [[2, 2, 2, 4]]}),
('ReverseV2', {
'block': P.ReverseV2(axis=[1]),
'desc_inputs': [(Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)))],
'desc_bprop': [(Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)))]}),
('Rint', {
'block': P.Rint(),
'desc_inputs': [(Tensor(np.array([-1.6, -0.1, 1.5, 2.0]).astype(np.float32)))],
'skip': ['backward']}),
('ConcatV2_0', {
'block': P.Concat(),
'desc_inputs': [
......@@ -2074,6 +2120,21 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
'skip': ['backward']}),
('ScatterNonAliasingAdd_1d', {
'block': ScatterNonAliasingAdd((8,)),
'desc_inputs': (Tensor(np.array([[2], [3], [4], [5]], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0, 8.0], np.float32))),
'skip': ['backward']}),
('ScatterNdAdd', {
'block': ScatterNdAdd((8,)),
'desc_inputs': (Tensor(np.array([[2], [3], [4], [5]], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0, 8.0], np.float32))),
'skip': ['backward']}),
('ScatterNdSub', {
'block': ScatterNdAdd((8,)),
'desc_inputs': (Tensor(np.array([[2], [3], [4], [5]], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0, 8.0], np.float32))),
'skip': ['backward']}),
('ScatterAdd', {
'block': ScatterAdd((6,)),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册