From d55c0f66799473e2140e1963994ab7124038e133 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Sun, 28 Jun 2020 15:30:06 +0800 Subject: [PATCH] complement scatter add support dtype --- mindspore/ops/_op_impl/tbe/scatter_add.py | 2 ++ tests/ut/python/ops/test_ops.py | 29 +++++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mindspore/ops/_op_impl/tbe/scatter_add.py b/mindspore/ops/_op_impl/tbe/scatter_add.py index ea54719d4..0202cdf45 100644 --- a/mindspore/ops/_op_impl/tbe/scatter_add.py +++ b/mindspore/ops/_op_impl/tbe/scatter_add.py @@ -31,6 +31,8 @@ scatter_add_op_info = TBERegOp("ScatterAdd") \ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ .get_op_info() diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index cf6a6705a..def84be4b 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -220,10 +220,10 @@ class ScatterMax(nn.Cell): class ScatterAdd(nn.Cell): """ScatterAdd net definition""" - def __init__(self, ref_shape): + def __init__(self, ref_shape, dtype=np.float32): super(ScatterAdd, self).__init__() self.scatter_add = P.ScatterAdd() - self.ref = Parameter(Tensor(np.ones(ref_shape, np.float32)), name="ref") + self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref") def construct(self, indices, updates): out = self.scatter_add(self.ref, indices, updates) @@ -1677,12 +1677,37 @@ test_case_other_ops = [ 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), Tensor(np.array([2.0, 3.0, 4.0], np.float32))), 'skip': ['backward']}), + ('ScatterAddScalar', { + 'block': ScatterAdd((6,)), + 'desc_inputs': (Tensor(np.array([2], np.int32)), + Tensor(np.array([2.0], np.float32))), + 'skip': ['backward']}), ('ScatterAdd2d', { 'block': ScatterAdd((3, 4)), 'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)), Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]], [[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))), 'skip': ['backward']}), + ('ScatterAddF16', { + 'block': ScatterAdd((6,), np.float16), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2.0, 3.0, 4.0], np.float16))), + 'skip': ['backward']}), + ('ScatterAddI8', { + 'block': ScatterAdd((6,), np.int8), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2, 3, 4], np.int8))), + 'skip': ['backward']}), + ('ScatterAddI32', { + 'block': ScatterAdd((6,), np.int32), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2, 3, 4], np.int32))), + 'skip': ['backward']}), + ('ScatterAddU8', { + 'block': ScatterAdd((6,), np.uint8), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2, 3, 4], np.uint8))), + 'skip': ['backward']}), ('SmoothL1Loss', { 'block': P.SmoothL1Loss(), 'desc_inputs': [[256, 4], [256, 4]], -- GitLab