提交 55d19275 编写于 作者: Z zhaozhenlong

add op scatter add vm

上级 10076ffe
......@@ -198,3 +198,4 @@ from .apply_rms_prop import _apply_rms_prop_tbe
from .cumprod import _cumprop_tbe
from .reduce_prod import _reduce_prod_tbe
from .flatten_grad import _flatten_grad_tbe
from .scatter_add import _scatter_add_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_add_op_info = TBERegOp("ScatterAdd") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_add.so") \
.compute_cost(10) \
.kernel_name("scatter_add") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(scatter_add_op_info)
def _scatter_add_tbe():
"""ScatterAdd TBE register"""
return
......@@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range,
SameTypeShape, ScatterMax, ScatterUpdate,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile,
......@@ -190,6 +190,7 @@ __all__ = [
'BoundingBoxEncode',
'BoundingBoxDecode',
'L2Normalize',
'ScatterAdd',
'ScatterNd',
'ScatterMax',
'ResizeNearestNeighbor',
......
......@@ -2145,6 +2145,12 @@ class ScatterNdUpdate(PrimitiveWithInfer):
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
class ScatterMax(PrimitiveWithInfer):
"""
......@@ -2158,8 +2164,8 @@ class ScatterMax(PrimitiveWithInfer):
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do max operation whose data type should be int.
- **updates** (Tensor) - The tensor doing the maximum operation with 'input_x',
the data type is same as 'input_x', the shape is 'indices_shape + x_shape[1:]'.
- **updates** (Tensor) - The tensor doing the maximum operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
Outputs:
Tensor, has the same shape and data type as `input_x`.
......@@ -2180,10 +2186,7 @@ class ScatterMax(PrimitiveWithInfer):
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{self.name}', the shape of update should be [] or "
f"update_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, update_shape: {updates_shape}.")
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
......@@ -2193,6 +2196,49 @@ class ScatterMax(PrimitiveWithInfer):
return x_dtype
class ScatterAdd(PrimitiveWithInfer):
"""
Update the value of the input tensor through the add operation.
Using given values to update tensor value through the add operation, along with the input indices.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do add operation whose data type should be int.
- **updates** (Tensor) - The tensor doing the add operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
Outputs:
Tensor, has the same shape and data type as `input_x`.
Examples:
>>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
>>> updates = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> scatter_add = P.ScatterAdd()
>>> output = scatter_add(input_x, indices, updates)
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
"""
@prim_attr_register
def __init__(self, use_locking=False):
"""Init ScatterAdd"""
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
args = {'x': x_dtype, 'updates': updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
class SpaceToDepth(PrimitiveWithInfer):
r"""
Rearrange blocks of spatial data into depth.
......
......@@ -196,6 +196,19 @@ class ScatterMax(nn.Cell):
return out
class ScatterAdd(nn.Cell):
"""ScatterAdd net definition"""
def __init__(self, ref_shape):
super(ScatterAdd, self).__init__()
self.scatter_add = P.ScatterAdd()
self.ref = Parameter(Tensor(np.ones(ref_shape, np.float32)), name="ref")
def construct(self, indices, updates):
out = self.scatter_add(self.ref, indices, updates)
return out
class ApplyFtrlNet(nn.Cell):
def __init__(self):
super(ApplyFtrlNet, self).__init__()
......@@ -1257,6 +1270,17 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
Tensor(np.ones([2, 2, 3], np.float32) * 99)),
'skip': ['backward']}),
('ScatterAdd', {
'block': ScatterAdd((6,)),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
'skip': ['backward']}),
('ScatterAdd2d', {
'block': ScatterAdd((3, 4)),
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
'skip': ['backward']}),
('SmoothL1Loss', {
'block': P.SmoothL1Loss(),
'desc_inputs': [[256, 4], [256, 4]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册