test_scatter_op.py 24.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
Z
zchen0211 已提交
16
import unittest
17

Q
qijun 已提交
18
import numpy as np
W
wanghuancoder 已提交
19
from eager_op_test import OpTest, convert_float_to_uint16
20

S
ShenLiang 已提交
21
import paddle
22 23
from paddle import fluid
from paddle.fluid import core
Z
Zeng Jinle 已提交
24
from paddle.fluid.dygraph.base import switch_to_static_graph
Z
zchen0211 已提交
25 26


Q
qijun 已提交
27
class TestScatterOp(OpTest):
Z
zchen0211 已提交
28
    def setUp(self):
Q
qijun 已提交
29
        self.op_type = "scatter"
H
hong 已提交
30
        self.python_api = paddle.scatter
Z
zxcd 已提交
31 32
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
33 34 35
        self._set_dtype()
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 50)).astype(target_dtype)
Q
qijun 已提交
36
        index_np = np.array([1, 2]).astype("int32")
37
        updates_np = np.random.random((2, 50)).astype(target_dtype)
Q
qijun 已提交
38
        output_np = np.copy(ref_np)
Z
zchen0211 已提交
39
        output_np[index_np] = updates_np
40 41 42 43
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
D
dzhwinter 已提交
44
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
Z
zchen0211 已提交
45 46
        self.outputs = {'Out': output_np}

47 48 49
    def _set_dtype(self):
        self.dtype = np.float32

Q
qijun 已提交
50
    def test_check_output(self):
Z
zxcd 已提交
51
        self.check_output(check_prim=True)
Z
zchen0211 已提交
52

Q
qijun 已提交
53
    def test_check_grad(self):
Z
zxcd 已提交
54
        self.check_grad(["X", "Updates"], "Out", check_prim=True)
Z
zchen0211 已提交
55 56


57 58 59 60 61 62 63 64 65 66 67 68 69
class TestScatterFP16Op(TestScatterOp):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op(TestScatterOp):
    def _set_dtype(self):
        self.dtype = np.uint16
Z
zxcd 已提交
70
        self.enable_cinn = False
71 72 73 74

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
75
            self.check_output_with_place(place, check_prim=True)
76 77 78 79

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
80 81 82 83 84 85
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
86 87


88 89 90
class TestScatterOp0(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
91
        self.python_api = paddle.scatter
Z
zxcd 已提交
92 93
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
94 95 96
        self._set_dtype()
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
97
        index_np = np.array([1, 2]).astype("int32")
98
        updates_np = np.random.random((2, 3)).astype(target_dtype)
99 100
        output_np = np.copy(ref_np)
        output_np[index_np] = updates_np
101 102 103 104
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
105 106 107 108
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.attrs = {'overwrite': True}
        self.outputs = {'Out': output_np}

109 110 111
    def _set_dtype(self):
        self.dtype = np.float32

112
    def test_check_output(self):
Z
zxcd 已提交
113
        self.check_output(check_prim=True)
114 115

    def test_check_grad(self):
Z
zxcd 已提交
116
        self.check_grad(["X", "Updates"], "Out", check_prim=True)
117 118


119 120 121 122 123 124 125 126 127 128 129 130 131
class TestScatterFP16Op0(TestScatterOp0):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op0(TestScatterOp0):
    def _set_dtype(self):
        self.dtype = np.uint16
Z
zxcd 已提交
132
        self.enable_cinn = False
133 134 135 136

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
137
            self.check_output_with_place(place, check_prim=True)
138 139 140 141

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
142 143 144 145 146 147
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
148 149


150 151 152
class TestScatterOp1(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
153
        self.python_api = paddle.scatter
Z
zxcd 已提交
154 155
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
156 157 158 159
        self._set_dtype()
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
        zeros_np = np.zeros([2, 3]).astype(target_dtype)
160
        index_np = np.array([1, 1]).astype("int32")
161
        updates_np = np.random.random((2, 3)).astype(target_dtype)
162 163 164 165
        output_np = np.copy(ref_np)
        output_np[index_np] = zeros_np
        for i in range(0, len(index_np)):
            output_np[index_np[i]] += updates_np[i]
166 167 168 169
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
170 171 172 173
        self.attrs = {'overwrite': False}
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

174 175 176
    def _set_dtype(self):
        self.dtype = np.float32

177
    def test_check_output(self):
Z
zxcd 已提交
178
        self.check_output(check_prim=True)
179 180

    def test_check_grad(self):
Z
zxcd 已提交
181
        self.check_grad(["X", "Updates"], "Out", check_prim=True)
182 183


184 185 186 187 188 189 190 191 192 193 194 195 196
class TestScatterFP16Op1(TestScatterOp1):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op1(TestScatterOp1):
    def _set_dtype(self):
        self.dtype = np.uint16
Z
zxcd 已提交
197
        self.enable_cinn = False
198 199 200 201

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
202
            self.check_output_with_place(place, check_prim=True)
203 204 205 206

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
207 208 209 210 211 212
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
213 214


215 216 217
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
218 219 220
class TestScatterOp2(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
221
        self.python_api = paddle.scatter
Z
zxcd 已提交
222 223
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
224 225 226
        self._set_dtype()
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
227
        index_np = np.array([1, 2]).astype("int32")
228
        updates_np = np.random.random((2, 3)).astype(target_dtype)
229 230
        output_np = np.copy(ref_np)
        output_np[index_np] = updates_np
231 232 233 234
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
235 236 237
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

238 239 240
    def _set_dtype(self):
        self.dtype = np.float32

241 242 243
    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
244
            self.check_output_with_place(place, atol=1e-3, check_prim=True)
245 246 247 248

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
249 250 251 252 253 254
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
255 256


257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestScatterFP16Op2(TestScatterOp2):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op2(TestScatterOp2):
    def _set_dtype(self):
        self.dtype = np.uint16
Z
zxcd 已提交
273
        self.enable_cinn = False
274 275


276 277 278
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
279 280 281
class TestScatterOp3(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
282
        self.python_api = paddle.scatter
Z
zxcd 已提交
283 284
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
285 286 287 288
        self._set_dtype()
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
        zeros_np = np.zeros([2, 3]).astype(target_dtype)
289
        index_np = np.array([1, 1]).astype("int32")
290
        updates_np = np.random.random((2, 3)).astype(target_dtype)
291 292 293 294
        output_np = np.copy(ref_np)
        output_np[index_np] = zeros_np
        for i in range(0, len(index_np)):
            output_np[index_np[i]] += updates_np[i]
295 296 297 298
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
299 300 301 302
        self.attrs = {'overwrite': False}
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

303 304 305
    def _set_dtype(self):
        self.dtype = np.float32

306 307 308
    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
309
            self.check_output_with_place(place, atol=1e-3, check_prim=True)
310 311 312 313

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
314 315 316 317 318 319
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
320 321


322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestScatterFP16Op3(TestScatterOp3):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op3(TestScatterOp3):
    def _set_dtype(self):
        self.dtype = np.uint16
Z
zxcd 已提交
338
        self.enable_cinn = False
339 340


341 342 343
class TestScatterOp4(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
344
        self.python_api = paddle.scatter
Z
zxcd 已提交
345 346
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
347 348 349
        self._set_dtype()
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
350
        index_np = np.array([1, 2]).astype("int64")
351
        updates_np = np.random.random((2, 3)).astype(target_dtype)
352 353
        output_np = np.copy(ref_np)
        output_np[index_np] = updates_np
354 355 356 357
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
358 359 360
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

361 362 363
    def _set_dtype(self):
        self.dtype = np.float32

364
    def test_check_output(self):
Z
zxcd 已提交
365
        self.check_output(check_prim=True)
366 367

    def test_check_grad(self):
Z
zxcd 已提交
368
        self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
369 370


371 372 373 374 375 376 377 378 379 380 381 382 383
class TestScatterFP16Op4(TestScatterOp4):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op4(TestScatterOp4):
    def _set_dtype(self):
        self.dtype = np.uint16
Z
zxcd 已提交
384
        self.enable_cinn = False
385 386 387 388

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
389
            self.check_output_with_place(place, check_prim=True)
390 391 392 393

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
394 395 396 397 398 399
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
400 401


402 403 404
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
405 406 407
class TestScatterOp5(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
408
        self.python_api = paddle.scatter
Z
zxcd 已提交
409 410
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
411 412 413
        self._set_dtype()
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
414
        index_np = np.array([1, 2]).astype("int64")
415
        updates_np = np.random.random((2, 3)).astype(target_dtype)
416 417
        output_np = np.copy(ref_np)
        output_np[index_np] = updates_np
418 419 420 421
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
422 423 424
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

425 426 427
    def _set_dtype(self):
        self.dtype = np.float32

428 429 430
    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
431
            self.check_output_with_place(place, atol=1e-3, check_prim=True)
432 433 434 435

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
436 437 438 439 440 441
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
442 443


444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestScatterFP16Op5(TestScatterOp5):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op5(TestScatterOp5):
    def _set_dtype(self):
        self.dtype = np.uint16
Z
zxcd 已提交
460
        self.enable_cinn = False
461 462


463 464 465 466
class TestScatterOp6(OpTest):
    def setUp(self):
        self.op_type = "scatter"
        self.python_api = paddle.scatter
Z
zxcd 已提交
467 468 469
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
        self.enable_cinn = False
470 471 472
        self._set_dtype()
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 50)).astype(target_dtype)
473
        index_np = np.array([[1], [2]]).astype("int32")
474
        updates_np = np.random.random((2, 50)).astype(target_dtype)
475 476
        output_np = np.copy(ref_np)
        output_np[np.array([1, 2]).astype("int32")] = updates_np
477 478 479 480
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
481 482 483
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

484 485 486
    def _set_dtype(self):
        self.dtype = np.float32

487
    def test_check_output(self):
Z
zxcd 已提交
488
        self.check_output(check_prim=True)
489 490

    def test_check_grad(self):
Z
zxcd 已提交
491
        self.check_grad(["X", "Updates"], "Out", check_prim=True)
492 493


494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
class TestScatterFP16Op6(TestScatterOp6):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op6(TestScatterOp6):
    def _set_dtype(self):
        self.dtype = np.uint16

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
511
            self.check_output_with_place(place, check_prim=True)
512 513 514 515

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
516 517 518 519 520 521
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
522 523


S
ShenLiang 已提交
524 525 526 527 528
class TestScatterAPI(unittest.TestCase):
    def setUp(self):
        self.places = [fluid.CPUPlace()]
        if core.is_compiled_with_cuda():
            self.places.append(fluid.CUDAPlace(0))
529 530 531 532
        self.executed_api()

    def executed_api(self):
        self.scatter = paddle.scatter
S
ShenLiang 已提交
533 534 535

    def check_static_result(self, place):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
536 537 538 539 540 541 542
            input = paddle.static.data(
                name="input", shape=[3, 2], dtype="float64"
            )
            index = paddle.static.data(name="index", shape=[4], dtype="int64")
            updates = paddle.static.data(
                name="updates", shape=[4, 2], dtype="float64"
            )
543
            result = self.scatter(input, index, updates, False)
S
ShenLiang 已提交
544 545 546

            input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64)
            index_data = np.array([2, 1, 0, 1]).astype(np.int64)
547 548 549
            updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype(
                np.float64
            )
S
ShenLiang 已提交
550 551

            exe = fluid.Executor(place)
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
            fetches = exe.run(
                fluid.default_main_program(),
                feed={
                    "input": input_data,
                    "index": index_data,
                    "updates": updates_data,
                },
                fetch_list=[result],
            )
            self.assertEqual(
                (
                    fetches[0] == np.array([[3.0, 3.0], [6.0, 6.0], [1.0, 1.0]])
                ).all(),
                True,
            )
S
ShenLiang 已提交
567 568 569 570 571 572 573 574 575 576

    def test_static(self):
        for place in self.places:
            self.check_static_result(place=place)

    def test_dygraph(self):
        for place in self.places:
            with fluid.dygraph.guard(place):
                x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64)
                index_data = np.array([2, 1, 0, 1]).astype(np.int64)
577 578 579
                updates_data = np.array(
                    [[1, 1], [2, 2], [3, 3], [4, 4]]
                ).astype(np.float64)
S
ShenLiang 已提交
580 581 582 583 584

                x = fluid.dygraph.to_variable(x_data)
                index = fluid.dygraph.to_variable(index_data)
                updates = fluid.dygraph.to_variable(updates_data)

585
                output1 = self.scatter(x, index, updates, overwrite=False)
586 587 588 589 590 591 592
                self.assertEqual(
                    (
                        output1.numpy()
                        == np.array([[3.0, 3.0], [6.0, 6.0], [1.0, 1.0]])
                    ).all(),
                    True,
                )
S
ShenLiang 已提交
593

Z
Zeng Jinle 已提交
594 595 596 597 598 599 600 601 602 603
    def test_large_data(self):
        if os.name == "nt" or not paddle.is_compiled_with_cuda():
            return

        x = np.random.rand(183826, 256).astype("float32")
        index = np.ones(10759233, dtype="int64")
        updates = np.ones(shape=[10759233, 256], dtype="float32")

        def test_dygraph():
            with fluid.dygraph.guard():
604 605 606 607 608
                gpu_out = paddle.scatter(
                    paddle.to_tensor(x),
                    paddle.to_tensor(index),
                    paddle.to_tensor(updates),
                )
Z
Zeng Jinle 已提交
609 610 611 612
                return gpu_out.numpy()

        @switch_to_static_graph
        def test_static_graph():
613 614 615
            with paddle.static.program_guard(
                paddle.static.Program(), paddle.static.Program()
            ):
Z
Zeng Jinle 已提交
616
                x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape)
617 618 619 620 621 622
                index_t = paddle.static.data(
                    name="index", dtype=index.dtype, shape=index.shape
                )
                updates_t = paddle.static.data(
                    name="updates", dtype=updates.dtype, shape=updates.shape
                )
Z
Zeng Jinle 已提交
623 624 625 626
                out_t = paddle.scatter(x_t, index_t, updates_t)
                feed = {
                    x_t.name: x,
                    index_t.name: index,
627
                    updates_t.name: updates,
Z
Zeng Jinle 已提交
628 629 630 631 632 633 634
                }
                fetch = [out_t]

                gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
                gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
                return gpu_value

635
        np.testing.assert_array_equal(test_dygraph(), test_static_graph())
Z
Zeng Jinle 已提交
636

S
ShenLiang 已提交
637

638 639 640
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
L
Li Min 已提交
641 642 643
class TestScatterOpFp16(OpTest):
    def setUp(self):
        self.__class__.op_type = "scatter"
H
hong 已提交
644
        self.python_api = paddle.scatter
L
Li Min 已提交
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
        # compute grad in the following code handly.
        self.__class__.no_need_check_grad = True
        self.x_type = 'float16'
        self.x_np = np.ones((3, 3)).astype(self.x_type)
        self.index_np = np.array([1, 2]).astype("int32")
        self.updates_np = np.random.random((2, 3)).astype(self.x_type)
        self.output_np = np.copy(self.x_np)
        self.output_np[self.index_np] = self.updates_np
        self.dout_np = np.random.random((3, 3)).astype(self.x_type)

        # compute ref_dx
        self.ref_dx = np.copy(self.dout_np)
        zero_np = np.zeros((2, 3)).astype(self.x_type)
        self.ref_dx[self.index_np] = zero_np

    def compute_ref_grad_updates(self):
661 662 663
        ref_grad_updates = paddle.gather(
            paddle.to_tensor(self.dout_np), paddle.to_tensor(self.index_np)
        )
L
Li Min 已提交
664 665 666 667 668 669 670 671
        return ref_grad_updates

    def test_scatter_fp16(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        x_tensor = paddle.to_tensor(self.x_np, stop_gradient=False)
        index_tensor = paddle.to_tensor(self.index_np)
        updates_tensor = paddle.to_tensor(self.updates_np, stop_gradient=False)
        out_tensor = paddle.scatter(x_tensor, index_tensor, updates_tensor)
672 673 674
        paddle.autograd.backward(
            [out_tensor], [paddle.to_tensor(self.dout_np)], retain_graph=True
        )
L
Li Min 已提交
675
        ref_grad_updates = self.compute_ref_grad_updates()
676
        np.testing.assert_allclose(
677 678
            ref_grad_updates.numpy(False),
            updates_tensor.grad.numpy(False),
679 680 681 682
            rtol=1e-5,
            atol=1e-5,
        )
        np.testing.assert_allclose(
683
            self.ref_dx, x_tensor.grad.numpy(False), rtol=1e-5, atol=1e-5
684
        )
L
Li Min 已提交
685 686


687 688 689 690 691
class TestScatterInplaceAPI(TestScatterAPI):
    def executed_api(self):
        self.scatter = paddle.scatter_


692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
@unittest.skipIf(core.is_compiled_with_cuda(), "CUDA will not throw exception")
class TestScatterError(unittest.TestCase):
    def test_scatter_index(self):
        paddle.disable_static()
        x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32')

        def test_neg_index():
            index = paddle.to_tensor([2, 1, -1, 1], dtype='int64')
            updates = paddle.to_tensor(
                [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32'
            )
            out = paddle.scatter(x, index, updates)

        self.assertRaises(IndexError, test_neg_index)

        def test_too_big_index():
            index = paddle.to_tensor([2, 1, 5, 1], dtype='int64')
            updates = paddle.to_tensor(
                [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32'
            )
            out = paddle.scatter(x, index, updates)

        self.assertRaises(IndexError, test_too_big_index)
        paddle.enable_static()


Z
zchen0211 已提交
718
if __name__ == "__main__":
719
    paddle.enable_static()
Z
zchen0211 已提交
720
    unittest.main()