test_scatter_nd_op.py 16.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
W
wanghuancoder 已提交
18
from eager_op_test import OpTest, convert_float_to_uint16
19

20
import paddle
21 22
from paddle import fluid
from paddle.fluid import core
23
from paddle.fluid.dygraph.base import switch_to_static_graph
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70


def numpy_scatter_nd(ref, index, updates, fun):
    ref_shape = ref.shape
    index_shape = index.shape

    end_size = index_shape[-1]
    remain_numl = 1
    for i in range(len(index_shape) - 1):
        remain_numl *= index_shape[i]

    slice_size = 1
    for i in range(end_size, len(ref_shape)):
        slice_size *= ref_shape[i]

    flat_index = index.reshape([remain_numl] + list(index_shape[-1:]))
    flat_updates = updates.reshape((remain_numl, slice_size))
    flat_output = ref.reshape(list(ref_shape[:end_size]) + [slice_size])

    for i_up, i_out in enumerate(flat_index):
        i_out = tuple(i_out)
        flat_output[i_out] = fun(flat_output[i_out], flat_updates[i_up])
    return flat_output.reshape(ref.shape)


def numpy_scatter_nd_add(ref, index, updates):
    return numpy_scatter_nd(ref, index, updates, lambda x, y: x + y)


def judge_update_shape(ref, index):
    ref_shape = ref.shape
    index_shape = index.shape
    update_shape = []
    for i in range(len(index_shape) - 1):
        update_shape.append(index_shape[i])
    for i in range(index_shape[-1], len(ref_shape), 1):
        update_shape.append(ref_shape[i])
    return update_shape


class TestScatterNdAddSimpleOp(OpTest):
    """
    A simple example
    """

    def setUp(self):
        self.op_type = "scatter_nd_add"
H
hong 已提交
71
        self.python_api = paddle.scatter_nd_add
72 73 74 75 76 77 78 79
        self._set_dtype()
        if self.dtype == np.float64:
            target_dtype = "float64"
        elif self.dtype == np.float16:
            target_dtype = "float16"
        else:
            target_dtype = "float32"
        ref_np = np.random.random([100]).astype(target_dtype)
80
        index_np = np.random.randint(0, 100, [100, 1]).astype("int32")
81
        updates_np = np.random.random([100]).astype(target_dtype)
82
        expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
83 84 85 86
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            expect_np = convert_float_to_uint16(expect_np)
87 88 89
        self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
        self.outputs = {'Out': expect_np}

90 91 92
    def _set_dtype(self):
        self.dtype = np.float64

93
    def test_check_output(self):
W
wanghuancoder 已提交
94
        self.check_output()
95 96

    def test_check_grad(self):
W
wanghuancoder 已提交
97
        self.check_grad(['X', 'Updates'], 'Out')
98 99


100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
class TestScatterNdAddSimpleFP16Op(TestScatterNdAddSimpleOp):
    """
    A simple example
    """

    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterNdAddSimpleBF16Op(TestScatterNdAddSimpleOp):
    """
    A simple example
    """

    def _set_dtype(self):
        self.dtype = np.uint16

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
W
wanghuancoder 已提交
125
            self.check_output_with_place(place)
126 127 128 129

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
W
wanghuancoder 已提交
130
            self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
131 132


133 134 135 136 137 138 139
class TestScatterNdAddWithEmptyIndex(OpTest):
    """
    Index has empty element
    """

    def setUp(self):
        self.op_type = "scatter_nd_add"
H
hong 已提交
140
        self.python_api = paddle.scatter_nd_add
141 142 143 144 145 146 147 148
        self._set_dtype()
        if self.dtype == np.float64:
            target_dtype = "float64"
        elif self.dtype == np.float16:
            target_dtype = "float16"
        else:
            target_dtype = "float32"
        ref_np = np.random.random((10, 10)).astype(target_dtype)
149
        index_np = np.array([[], []]).astype("int32")
150
        updates_np = np.random.random((2, 10, 10)).astype(target_dtype)
151 152 153

        expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)

154 155 156 157 158
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            expect_np = convert_float_to_uint16(expect_np)

159 160 161
        self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
        self.outputs = {'Out': expect_np}

162 163 164
    def _set_dtype(self):
        self.dtype = np.float64

165
    def test_check_output(self):
W
wanghuancoder 已提交
166
        self.check_output()
167 168

    def test_check_grad(self):
W
wanghuancoder 已提交
169
        self.check_grad(['X', 'Updates'], 'Out')
170 171


172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
class TestScatterNdAddWithEmptyIndexFP16(TestScatterNdAddWithEmptyIndex):
    """
    Index has empty element
    """

    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterNdAddWithEmptyIndexBF16(TestScatterNdAddWithEmptyIndex):
    """
    Index has empty element
    """

    def _set_dtype(self):
        self.dtype = np.uint16

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
W
wanghuancoder 已提交
197
            self.check_output_with_place(place)
198 199 200 201

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
W
wanghuancoder 已提交
202
            self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
203 204


205 206 207 208 209 210 211
class TestScatterNdAddWithHighRankSame(OpTest):
    """
    Both Index and X have high rank, and Rank(Index) = Rank(X)
    """

    def setUp(self):
        self.op_type = "scatter_nd_add"
H
hong 已提交
212
        self.python_api = paddle.scatter_nd_add
213 214 215 216 217 218 219
        self._set_dtype()
        if self.dtype == np.float64:
            target_dtype = "float64"
        elif self.dtype == np.float16:
            target_dtype = "float16"
        else:
            target_dtype = "float32"
S
ShenLiang 已提交
220
        shape = (3, 2, 2, 1, 10)
221
        ref_np = np.random.rand(*shape).astype(target_dtype)
222 223 224
        index_np = np.vstack(
            [np.random.randint(0, s, size=100) for s in shape]
        ).T.astype("int32")
225
        update_shape = judge_update_shape(ref_np, index_np)
226
        updates_np = np.random.rand(*update_shape).astype(target_dtype)
227 228
        expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)

229 230 231 232 233
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            expect_np = convert_float_to_uint16(expect_np)

234 235 236
        self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
        self.outputs = {'Out': expect_np}

237 238 239
    def _set_dtype(self):
        self.dtype = np.float64

240
    def test_check_output(self):
W
wanghuancoder 已提交
241
        self.check_output()
242 243

    def test_check_grad(self):
W
wanghuancoder 已提交
244
        self.check_grad(['X', 'Updates'], 'Out')
245 246


247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
class TestScatterNdAddWithHighRankSameFP16(TestScatterNdAddWithHighRankSame):
    """
    Both Index and X have high rank, and Rank(Index) = Rank(X)
    """

    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterNdAddWithHighRankSameBF16(TestScatterNdAddWithHighRankSame):
    """
    Both Index and X have high rank, and Rank(Index) = Rank(X)
    """

    def _set_dtype(self):
        self.dtype = np.uint16

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
W
wanghuancoder 已提交
272
            self.check_output_with_place(place)
273 274 275 276

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
W
wanghuancoder 已提交
277
            self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
278 279


280 281 282 283 284 285 286
class TestScatterNdAddWithHighRankDiff(OpTest):
    """
    Both Index and X have high rank, and Rank(Index) < Rank(X)
    """

    def setUp(self):
        self.op_type = "scatter_nd_add"
H
hong 已提交
287
        self.python_api = paddle.scatter_nd_add
S
ShenLiang 已提交
288
        shape = (8, 2, 2, 1, 10)
289 290 291 292 293 294 295 296 297 298 299
        ref_np = np.random.rand(*shape).astype("double")
        index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T
        index_np = index.reshape([10, 5, 10, 5]).astype("int64")
        update_shape = judge_update_shape(ref_np, index_np)
        updates_np = np.random.rand(*update_shape).astype("double")
        expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)

        self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
        self.outputs = {'Out': expect_np}

    def test_check_output(self):
W
wanghuancoder 已提交
300
        self.check_output()
301 302

    def test_check_grad(self):
W
wanghuancoder 已提交
303
        self.check_grad(['X', 'Updates'], 'Out')
304 305


306
# Test Python API
307
class TestScatterNdOpAPI(unittest.TestCase):
308 309 310 311 312
    """
    test scatter_nd_add api and scatter_nd api
    """

    def testcase1(self):
G
GGBond8488 已提交
313
        ref1 = paddle.static.data(
314 315 316 317
            name='ref1',
            shape=[10, 9, 8, 1, 3],
            dtype='float32',
        )
G
GGBond8488 已提交
318
        index1 = paddle.static.data(
319 320 321 322
            name='index1',
            shape=[5, 5, 8, 5],
            dtype='int32',
        )
G
GGBond8488 已提交
323
        updates1 = paddle.static.data(
324 325 326 327
            name='update1',
            shape=[5, 5, 8],
            dtype='float32',
        )
328
        output1 = paddle.scatter_nd_add(ref1, index1, updates1)
329 330

    def testcase2(self):
G
GGBond8488 已提交
331
        ref2 = paddle.static.data(
332 333 334 335
            name='ref2',
            shape=[10, 9, 8, 1, 3],
            dtype='double',
        )
G
GGBond8488 已提交
336
        index2 = paddle.static.data(
337 338 339 340
            name='index2',
            shape=[5, 8, 5],
            dtype='int32',
        )
G
GGBond8488 已提交
341
        updates2 = paddle.static.data(
342 343 344 345
            name='update2',
            shape=[5, 8],
            dtype='double',
        )
346
        output2 = paddle.scatter_nd_add(
347 348
            ref2, index2, updates2, name="scatter_nd_add"
        )
349 350 351

    def testcase3(self):
        shape3 = [10, 9, 8, 1, 3]
G
GGBond8488 已提交
352
        index3 = paddle.static.data(
353 354 355 356
            name='index3',
            shape=[5, 5, 8, 5],
            dtype='int32',
        )
G
GGBond8488 已提交
357
        updates3 = paddle.static.data(
358 359 360 361
            name='update3',
            shape=[5, 5, 8],
            dtype='float32',
        )
362
        output3 = paddle.scatter_nd(index3, updates3, shape3)
363 364 365

    def testcase4(self):
        shape4 = [10, 9, 8, 1, 3]
G
GGBond8488 已提交
366
        index4 = paddle.static.data(
367 368 369 370
            name='index4',
            shape=[5, 5, 8, 5],
            dtype='int32',
        )
G
GGBond8488 已提交
371
        updates4 = paddle.static.data(
372 373 374 375
            name='update4',
            shape=[5, 5, 8],
            dtype='double',
        )
376
        output4 = paddle.scatter_nd(index4, updates4, shape4, name='scatter_nd')
377

378 379 380 381 382 383 384 385 386 387 388 389
    def testcase5(self):
        if not fluid.core.is_compiled_with_cuda():
            return

        shape = [2, 3, 4]
        x = np.arange(int(np.prod(shape))).reshape(shape)
        index = np.array([[0, 0, 2], [0, 1, 2]])
        val = np.array([-1, -3])

        with fluid.dygraph.guard():
            device = paddle.get_device()
            paddle.set_device('gpu')
390 391 392 393 394
            gpu_value = paddle.scatter_nd_add(
                paddle.to_tensor(x),
                paddle.to_tensor(index),
                paddle.to_tensor(val),
            )
395
            paddle.set_device('cpu')
396 397 398 399 400
            cpu_value = paddle.scatter_nd_add(
                paddle.to_tensor(x),
                paddle.to_tensor(index),
                paddle.to_tensor(val),
            )
401
            np.testing.assert_array_equal(gpu_value.numpy(), cpu_value.numpy())
402 403 404 405
            paddle.set_device(device)

        @switch_to_static_graph
        def test_static_graph():
406 407 408
            with paddle.static.program_guard(
                paddle.static.Program(), paddle.static.Program()
            ):
409
                x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape)
410 411 412 413 414 415
                index_t = paddle.static.data(
                    name="index", dtype=index.dtype, shape=index.shape
                )
                val_t = paddle.static.data(
                    name="val", dtype=val.dtype, shape=val.shape
                )
416 417 418 419 420 421 422 423
                out_t = paddle.scatter_nd_add(x_t, index_t, val_t)
                feed = {x_t.name: x, index_t.name: index, val_t.name: val}
                fetch = [out_t]

                gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
                gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
                cpu_exe = paddle.static.Executor(paddle.CPUPlace())
                cpu_value = cpu_exe.run(feed=feed, fetch_list=fetch)[0]
424
                np.testing.assert_array_equal(gpu_value, cpu_value)
425 426 427

        test_static_graph()

428

429
# Test Raise Error
430
class TestScatterNdOpRaise(unittest.TestCase):
431 432 433
    def test_check_raise(self):
        def check_raise_is_test():
            try:
G
GGBond8488 已提交
434 435
                ref5 = paddle.static.data(
                    name='ref5', shape=[-1, 3, 4, 5], dtype='float32'
436
                )
G
GGBond8488 已提交
437 438
                index5 = paddle.static.data(
                    name='index5', shape=[-1, 2, 10], dtype='int32'
439
                )
G
GGBond8488 已提交
440 441
                updates5 = paddle.static.data(
                    name='updates5', shape=[-1, 2, 10], dtype='float32'
442
                )
443
                output5 = paddle.scatter_nd_add(ref5, index5, updates5)
444
            except Exception as e:
445
                t = "The last dimension of Input(Index)'s shape should be no greater "
446 447 448 449 450 451 452
                if t in str(e):
                    raise IndexError

        self.assertRaises(IndexError, check_raise_is_test)

    def test_check_raise2(self):
        with self.assertRaises(ValueError):
G
GGBond8488 已提交
453
            ref6 = paddle.static.data(
454 455 456 457
                name='ref6',
                shape=[10, 9, 8, 1, 3],
                dtype='double',
            )
G
GGBond8488 已提交
458
            index6 = paddle.static.data(
459 460 461 462
                name='index6',
                shape=[5, 8, 5],
                dtype='int32',
            )
G
GGBond8488 已提交
463
            updates6 = paddle.static.data(
464 465 466 467
                name='update6',
                shape=[5, 8],
                dtype='float32',
            )
468
            output6 = paddle.scatter_nd_add(ref6, index6, updates6)
469 470 471 472 473

    def test_check_raise3(self):
        def check_raise_is_test():
            try:
                shape = [3, 4, 5]
G
GGBond8488 已提交
474 475
                index7 = paddle.static.data(
                    name='index7', shape=[-1, 2, 1], dtype='int32'
476
                )
G
GGBond8488 已提交
477 478
                updates7 = paddle.static.data(
                    name='updates7', shape=[-1, 2, 4, 5, 20], dtype='float32'
479
                )
480
                output7 = paddle.scatter_nd(index7, updates7, shape)
481
            except Exception as e:
482
                t = "Updates has wrong shape"
483 484 485 486 487 488
                if t in str(e):
                    raise ValueError

        self.assertRaises(ValueError, check_raise_is_test)


489 490 491 492 493 494 495 496 497
class TestDygraph(unittest.TestCase):
    def test_dygraph(self):
        with fluid.dygraph.guard(fluid.CPUPlace()):
            index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64)
            index = fluid.dygraph.to_variable(index_data)
            updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
            shape = [3, 5, 9, 10]
            output = paddle.scatter_nd(index, updates, shape)

Z
zhangchunle 已提交
498
    def test_dygraph_1(self):
499 500 501 502 503 504 505 506
        with fluid.dygraph.guard(fluid.CPUPlace()):
            x = paddle.rand(shape=[3, 5, 9, 10], dtype='float32')
            updates = paddle.rand(shape=[3, 9, 10], dtype='float32')
            index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64)
            index = fluid.dygraph.to_variable(index_data)
            output = paddle.scatter_nd_add(x, index, updates)


507
if __name__ == "__main__":
508
    paddle.enable_static()
509
    unittest.main()