From 55d1927534dc653bee05c7fdb68a838abae0a69e Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Mon, 25 May 2020 17:29:44 +0800 Subject: [PATCH] add op scatter add vm --- mindspore/ops/_op_impl/tbe/__init__.py | 1 + mindspore/ops/_op_impl/tbe/scatter_add.py | 40 ++++++++++++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/array_ops.py | 58 ++++++++++++++++++++--- tests/ut/python/ops/test_ops.py | 24 ++++++++++ 5 files changed, 119 insertions(+), 7 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/scatter_add.py diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index ccbd30158..5a91e866d 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -198,3 +198,4 @@ from .apply_rms_prop import _apply_rms_prop_tbe from .cumprod import _cumprop_tbe from .reduce_prod import _reduce_prod_tbe from .flatten_grad import _flatten_grad_tbe +from .scatter_add import _scatter_add_tbe diff --git a/mindspore/ops/_op_impl/tbe/scatter_add.py b/mindspore/ops/_op_impl/tbe/scatter_add.py new file mode 100644 index 000000000..ea54719d4 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_add.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterAdd op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_add_op_info = TBERegOp("ScatterAdd") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_add.so") \ + .compute_cost(10) \ + .kernel_name("scatter_add") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(scatter_add_op_info) +def _scatter_add_tbe(): + """ScatterAdd TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index e933fa970..fca4f57b7 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Fill, GatherNd, GatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, - SameTypeShape, ScatterMax, ScatterUpdate, + SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, Squeeze, StridedSlice, Tile, @@ -190,6 +190,7 @@ __all__ = [ 'BoundingBoxEncode', 'BoundingBoxDecode', 'L2Normalize', + 'ScatterAdd', 'ScatterNd', 'ScatterMax', 'ResizeNearestNeighbor', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 21847abd0..280e24f7c 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2145,6 +2145,12 @@ class ScatterNdUpdate(PrimitiveWithInfer): validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) return x_dtype +def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name): + if updates_shape and updates_shape != indices_shape + x_shape[1:]: + raise ValueError(f"For '{prim_name}', the shape of updates should be [] or " + f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " + f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") + class ScatterMax(PrimitiveWithInfer): """ @@ -2158,8 +2164,8 @@ class ScatterMax(PrimitiveWithInfer): Inputs: - **input_x** (Parameter) - The target parameter. - **indices** (Tensor) - The index to do max operation whose data type should be int. - - **updates** (Tensor) - The tensor doing the maximum operation with 'input_x', - the data type is same as 'input_x', the shape is 'indices_shape + x_shape[1:]'. + - **updates** (Tensor) - The tensor doing the maximum operation with `input_x`, + the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. Outputs: Tensor, has the same shape and data type as `input_x`. @@ -2180,10 +2186,7 @@ class ScatterMax(PrimitiveWithInfer): validator.check_value_type('use_locking', use_locking, (bool,), self.name) def infer_shape(self, x_shape, indices_shape, updates_shape): - if updates_shape and updates_shape != indices_shape + x_shape[1:]: - raise ValueError(f"For '{self.name}', the shape of update should be [] or " - f"update_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " - f"indices_shape: {indices_shape}, update_shape: {updates_shape}.") + _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) return x_shape def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): @@ -2193,6 +2196,49 @@ class ScatterMax(PrimitiveWithInfer): return x_dtype +class ScatterAdd(PrimitiveWithInfer): + """ + Update the value of the input tensor through the add operation. + + Using given values to update tensor value through the add operation, along with the input indices. + + Args: + use_locking (bool): Whether protect the assignment by a lock. Default: False. + + Inputs: + - **input_x** (Parameter) - The target parameter. + - **indices** (Tensor) - The index to do add operation whose data type should be int. + - **updates** (Tensor) - The tensor doing the add operation with `input_x`, + the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. + + Outputs: + Tensor, has the same shape and data type as `input_x`. + + Examples: + >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x") + >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32) + >>> updates = Tensor(np.ones([2, 2, 3]), mindspore.float32) + >>> scatter_add = P.ScatterAdd() + >>> output = scatter_add(input_x, indices, updates) + [[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]] + """ + + @prim_attr_register + def __init__(self, use_locking=False): + """Init ScatterAdd""" + validator.check_value_type('use_locking', use_locking, (bool,), self.name) + + def infer_shape(self, x_shape, indices_shape, updates_shape): + _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) + return x_shape + + def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): + validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) + args = {'x': x_dtype, 'updates': updates_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + return x_dtype + + class SpaceToDepth(PrimitiveWithInfer): r""" Rearrange blocks of spatial data into depth. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 25c76033f..b08480858 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -196,6 +196,19 @@ class ScatterMax(nn.Cell): return out +class ScatterAdd(nn.Cell): + """ScatterAdd net definition""" + + def __init__(self, ref_shape): + super(ScatterAdd, self).__init__() + self.scatter_add = P.ScatterAdd() + self.ref = Parameter(Tensor(np.ones(ref_shape, np.float32)), name="ref") + + def construct(self, indices, updates): + out = self.scatter_add(self.ref, indices, updates) + return out + + class ApplyFtrlNet(nn.Cell): def __init__(self): super(ApplyFtrlNet, self).__init__() @@ -1257,6 +1270,17 @@ test_case_other_ops = [ 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), Tensor(np.ones([2, 2, 3], np.float32) * 99)), 'skip': ['backward']}), + ('ScatterAdd', { + 'block': ScatterAdd((6,)), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2.0, 3.0, 4.0], np.float32))), + 'skip': ['backward']}), + ('ScatterAdd2d', { + 'block': ScatterAdd((3, 4)), + 'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)), + Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]], + [[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))), + 'skip': ['backward']}), ('SmoothL1Loss', { 'block': P.SmoothL1Loss(), 'desc_inputs': [[256, 4], [256, 4]], -- GitLab