test_scatter_op.py 24.7 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
Z
zchen0211 已提交
16
import unittest
17

Q
qijun 已提交
18
import numpy as np
W
wanghuancoder 已提交
19
from eager_op_test import OpTest, convert_float_to_uint16
20

S
ShenLiang 已提交
21
import paddle
22 23
from paddle import fluid
from paddle.fluid import core
Z
Zeng Jinle 已提交
24
from paddle.fluid.dygraph.base import switch_to_static_graph
Z
zchen0211 已提交
25 26


Q
qijun 已提交
27
class TestScatterOp(OpTest):
Z
zchen0211 已提交
28
    def setUp(self):
Q
qijun 已提交
29
        self.op_type = "scatter"
H
hong 已提交
30
        self.python_api = paddle.scatter
Z
zxcd 已提交
31 32
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
33
        self._set_dtype()
34
        self.if_enable_cinn()
35 36
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 50)).astype(target_dtype)
Q
qijun 已提交
37
        index_np = np.array([1, 2]).astype("int32")
38
        updates_np = np.random.random((2, 50)).astype(target_dtype)
Q
qijun 已提交
39
        output_np = np.copy(ref_np)
Z
zchen0211 已提交
40
        output_np[index_np] = updates_np
41 42 43 44
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
D
dzhwinter 已提交
45
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
Z
zchen0211 已提交
46 47
        self.outputs = {'Out': output_np}

48 49 50
    def if_enable_cinn(self):
        pass

51 52 53
    def _set_dtype(self):
        self.dtype = np.float32

Q
qijun 已提交
54
    def test_check_output(self):
55
        self.check_output()
Z
zchen0211 已提交
56

Q
qijun 已提交
57
    def test_check_grad(self):
Z
zxcd 已提交
58
        self.check_grad(["X", "Updates"], "Out", check_prim=True)
Z
zchen0211 已提交
59 60


61 62 63 64 65 66 67 68 69 70 71 72 73
class TestScatterFP16Op(TestScatterOp):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op(TestScatterOp):
    def _set_dtype(self):
        self.dtype = np.uint16
74 75

    def if_enable_cinn(self):
Z
zxcd 已提交
76
        self.enable_cinn = False
77 78 79 80

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
81
            self.check_output_with_place(place)
82 83 84 85

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
86 87 88 89 90 91
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
92 93


94 95 96
class TestScatterOp0(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
97
        self.python_api = paddle.scatter
Z
zxcd 已提交
98 99
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
100
        self.if_enable_cinn()
101 102 103
        self._set_dtype()
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
104
        index_np = np.array([1, 2]).astype("int32")
105
        updates_np = np.random.random((2, 3)).astype(target_dtype)
106 107
        output_np = np.copy(ref_np)
        output_np[index_np] = updates_np
108 109 110 111
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
112 113 114 115
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.attrs = {'overwrite': True}
        self.outputs = {'Out': output_np}

116 117 118
    def if_enable_cinn(self):
        pass

119 120 121
    def _set_dtype(self):
        self.dtype = np.float32

122
    def test_check_output(self):
123
        self.check_output()
124 125

    def test_check_grad(self):
Z
zxcd 已提交
126
        self.check_grad(["X", "Updates"], "Out", check_prim=True)
127 128


129 130 131 132 133 134 135 136 137 138 139 140 141
class TestScatterFP16Op0(TestScatterOp0):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op0(TestScatterOp0):
    def _set_dtype(self):
        self.dtype = np.uint16
142 143

    def if_enable_cinn(self):
Z
zxcd 已提交
144
        self.enable_cinn = False
145 146 147 148

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
149
            self.check_output_with_place(place)
150 151 152 153

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
154 155 156 157 158 159
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
160 161


162 163 164
class TestScatterOp1(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
165
        self.python_api = paddle.scatter
Z
zxcd 已提交
166 167
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
168
        self._set_dtype()
169
        self.if_enable_cinn()
170 171 172
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
        zeros_np = np.zeros([2, 3]).astype(target_dtype)
173
        index_np = np.array([1, 1]).astype("int32")
174
        updates_np = np.random.random((2, 3)).astype(target_dtype)
175 176 177 178
        output_np = np.copy(ref_np)
        output_np[index_np] = zeros_np
        for i in range(0, len(index_np)):
            output_np[index_np[i]] += updates_np[i]
179 180 181 182
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
183 184 185 186
        self.attrs = {'overwrite': False}
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

187 188 189
    def if_enable_cinn(self):
        pass

190 191 192
    def _set_dtype(self):
        self.dtype = np.float32

193
    def test_check_output(self):
194
        self.check_output()
195 196

    def test_check_grad(self):
Z
zxcd 已提交
197
        self.check_grad(["X", "Updates"], "Out", check_prim=True)
198 199


200 201 202 203 204 205 206 207 208 209 210 211 212
class TestScatterFP16Op1(TestScatterOp1):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op1(TestScatterOp1):
    def _set_dtype(self):
        self.dtype = np.uint16
213 214

    def if_enable_cinn(self):
Z
zxcd 已提交
215
        self.enable_cinn = False
216 217 218 219

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
220
            self.check_output_with_place(place)
221 222 223 224

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
225 226 227 228 229 230
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
231 232


233 234 235
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
236 237 238
class TestScatterOp2(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
239
        self.python_api = paddle.scatter
Z
zxcd 已提交
240 241
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
242
        self._set_dtype()
243
        self.if_enable_cinn()
244 245
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
246
        index_np = np.array([1, 2]).astype("int32")
247
        updates_np = np.random.random((2, 3)).astype(target_dtype)
248 249
        output_np = np.copy(ref_np)
        output_np[index_np] = updates_np
250 251 252 253
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
254 255 256
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

257 258 259
    def if_enable_cinn(self):
        pass

260 261 262
    def _set_dtype(self):
        self.dtype = np.float32

263 264 265
    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
266
            self.check_output_with_place(place, atol=1e-3)
267 268 269 270

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
271 272 273 274 275 276
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
277 278


279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestScatterFP16Op2(TestScatterOp2):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op2(TestScatterOp2):
    def _set_dtype(self):
        self.dtype = np.uint16
295 296

    def if_enable_cinn(self):
Z
zxcd 已提交
297
        self.enable_cinn = False
298 299


300 301 302
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
303 304 305
class TestScatterOp3(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
306
        self.python_api = paddle.scatter
Z
zxcd 已提交
307 308
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
309
        self._set_dtype()
310
        self.if_enable_cinn()
311 312 313
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
        zeros_np = np.zeros([2, 3]).astype(target_dtype)
314
        index_np = np.array([1, 1]).astype("int32")
315
        updates_np = np.random.random((2, 3)).astype(target_dtype)
316 317 318 319
        output_np = np.copy(ref_np)
        output_np[index_np] = zeros_np
        for i in range(0, len(index_np)):
            output_np[index_np[i]] += updates_np[i]
320 321 322 323
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
324 325 326 327
        self.attrs = {'overwrite': False}
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

328 329 330
    def if_enable_cinn(self):
        pass

331 332 333
    def _set_dtype(self):
        self.dtype = np.float32

334 335 336
    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
337
            self.check_output_with_place(place, atol=1e-3)
338 339 340 341

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
342 343 344 345 346 347
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
348 349


350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestScatterFP16Op3(TestScatterOp3):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op3(TestScatterOp3):
    def _set_dtype(self):
        self.dtype = np.uint16
366 367

    def if_enable_cinn(self):
Z
zxcd 已提交
368
        self.enable_cinn = False
369 370


371 372 373
class TestScatterOp4(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
374
        self.python_api = paddle.scatter
Z
zxcd 已提交
375 376
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
377
        self._set_dtype()
378
        self.if_enable_cinn()
379 380
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
381
        index_np = np.array([1, 2]).astype("int64")
382
        updates_np = np.random.random((2, 3)).astype(target_dtype)
383 384
        output_np = np.copy(ref_np)
        output_np[index_np] = updates_np
385 386 387 388
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
389 390 391
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

392 393 394
    def if_enable_cinn(self):
        pass

395 396 397
    def _set_dtype(self):
        self.dtype = np.float32

398
    def test_check_output(self):
399
        self.check_output()
400 401

    def test_check_grad(self):
Z
zxcd 已提交
402
        self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
403 404


405 406 407 408 409 410 411 412 413 414 415 416 417
class TestScatterFP16Op4(TestScatterOp4):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op4(TestScatterOp4):
    def _set_dtype(self):
        self.dtype = np.uint16
418 419

    def if_enable_cinn(self):
Z
zxcd 已提交
420
        self.enable_cinn = False
421 422 423 424

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
425
            self.check_output_with_place(place)
426 427 428 429

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
430 431 432 433 434 435
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
436 437


438 439 440
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
441 442 443
class TestScatterOp5(OpTest):
    def setUp(self):
        self.op_type = "scatter"
H
hong 已提交
444
        self.python_api = paddle.scatter
Z
zxcd 已提交
445 446
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
447
        self._set_dtype()
448
        self.if_enable_cinn()
449 450
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 3)).astype(target_dtype)
451
        index_np = np.array([1, 2]).astype("int64")
452
        updates_np = np.random.random((2, 3)).astype(target_dtype)
453 454
        output_np = np.copy(ref_np)
        output_np[index_np] = updates_np
455 456 457 458
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
459 460 461
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

462 463 464
    def if_enable_cinn(self):
        pass

465 466 467
    def _set_dtype(self):
        self.dtype = np.float32

468 469 470
    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
471
            self.check_output_with_place(place, atol=1e-3)
472 473 474 475

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
476 477 478 479 480 481
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
482 483


484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestScatterFP16Op5(TestScatterOp5):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op5(TestScatterOp5):
    def _set_dtype(self):
        self.dtype = np.uint16
500 501

    def if_enable_cinn(self):
Z
zxcd 已提交
502
        self.enable_cinn = False
503 504


505 506 507 508
class TestScatterOp6(OpTest):
    def setUp(self):
        self.op_type = "scatter"
        self.python_api = paddle.scatter
Z
zxcd 已提交
509 510
        self.public_python_api = paddle.scatter
        self.prim_op_type = "prim"
511
        self.if_enable_cinn()
512 513 514
        self._set_dtype()
        target_dtype = "float16" if self.dtype == np.float16 else "float32"
        ref_np = np.ones((3, 50)).astype(target_dtype)
515
        index_np = np.array([[1], [2]]).astype("int32")
516
        updates_np = np.random.random((2, 50)).astype(target_dtype)
517 518
        output_np = np.copy(ref_np)
        output_np[np.array([1, 2]).astype("int32")] = updates_np
519 520 521 522
        if self.dtype == np.uint16:
            ref_np = convert_float_to_uint16(ref_np)
            updates_np = convert_float_to_uint16(updates_np)
            output_np = convert_float_to_uint16(output_np)
523 524 525
        self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
        self.outputs = {'Out': output_np}

526 527 528
    def if_enable_cinn(self):
        pass

529 530 531
    def _set_dtype(self):
        self.dtype = np.float32

532
    def test_check_output(self):
533
        self.check_output()
534 535

    def test_check_grad(self):
Z
zxcd 已提交
536
        self.check_grad(["X", "Updates"], "Out", check_prim=True)
537 538


539 540 541 542 543 544 545 546 547 548 549
class TestScatterFP16Op6(TestScatterOp6):
    def _set_dtype(self):
        self.dtype = np.float16


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op6(TestScatterOp6):
550 551 552
    def if_enable_cinn(self):
        self.enable_cinn = False

553 554 555 556 557 558
    def _set_dtype(self):
        self.dtype = np.uint16

    def test_check_output(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
559
            self.check_output_with_place(place)
560 561 562 563

    def test_check_grad(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
Z
zxcd 已提交
564 565 566 567 568 569
            self.check_grad_with_place(
                place,
                ['X', 'Updates'],
                'Out',
                check_prim=True,
            )
570 571


S
ShenLiang 已提交
572 573 574 575 576
class TestScatterAPI(unittest.TestCase):
    def setUp(self):
        self.places = [fluid.CPUPlace()]
        if core.is_compiled_with_cuda():
            self.places.append(fluid.CUDAPlace(0))
577 578 579 580
        self.executed_api()

    def executed_api(self):
        self.scatter = paddle.scatter
S
ShenLiang 已提交
581 582 583

    def check_static_result(self, place):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
584 585 586 587 588 589 590
            input = paddle.static.data(
                name="input", shape=[3, 2], dtype="float64"
            )
            index = paddle.static.data(name="index", shape=[4], dtype="int64")
            updates = paddle.static.data(
                name="updates", shape=[4, 2], dtype="float64"
            )
591
            result = self.scatter(input, index, updates, False)
S
ShenLiang 已提交
592 593 594

            input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64)
            index_data = np.array([2, 1, 0, 1]).astype(np.int64)
595 596 597
            updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype(
                np.float64
            )
S
ShenLiang 已提交
598 599

            exe = fluid.Executor(place)
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
            fetches = exe.run(
                fluid.default_main_program(),
                feed={
                    "input": input_data,
                    "index": index_data,
                    "updates": updates_data,
                },
                fetch_list=[result],
            )
            self.assertEqual(
                (
                    fetches[0] == np.array([[3.0, 3.0], [6.0, 6.0], [1.0, 1.0]])
                ).all(),
                True,
            )
S
ShenLiang 已提交
615 616 617 618 619 620 621 622 623 624

    def test_static(self):
        for place in self.places:
            self.check_static_result(place=place)

    def test_dygraph(self):
        for place in self.places:
            with fluid.dygraph.guard(place):
                x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64)
                index_data = np.array([2, 1, 0, 1]).astype(np.int64)
625 626 627
                updates_data = np.array(
                    [[1, 1], [2, 2], [3, 3], [4, 4]]
                ).astype(np.float64)
S
ShenLiang 已提交
628 629 630 631 632

                x = fluid.dygraph.to_variable(x_data)
                index = fluid.dygraph.to_variable(index_data)
                updates = fluid.dygraph.to_variable(updates_data)

633
                output1 = self.scatter(x, index, updates, overwrite=False)
634 635 636 637 638 639 640
                self.assertEqual(
                    (
                        output1.numpy()
                        == np.array([[3.0, 3.0], [6.0, 6.0], [1.0, 1.0]])
                    ).all(),
                    True,
                )
S
ShenLiang 已提交
641

Z
Zeng Jinle 已提交
642 643 644 645 646 647 648 649 650 651
    def test_large_data(self):
        if os.name == "nt" or not paddle.is_compiled_with_cuda():
            return

        x = np.random.rand(183826, 256).astype("float32")
        index = np.ones(10759233, dtype="int64")
        updates = np.ones(shape=[10759233, 256], dtype="float32")

        def test_dygraph():
            with fluid.dygraph.guard():
652 653 654 655 656
                gpu_out = paddle.scatter(
                    paddle.to_tensor(x),
                    paddle.to_tensor(index),
                    paddle.to_tensor(updates),
                )
Z
Zeng Jinle 已提交
657 658 659 660
                return gpu_out.numpy()

        @switch_to_static_graph
        def test_static_graph():
661 662 663
            with paddle.static.program_guard(
                paddle.static.Program(), paddle.static.Program()
            ):
Z
Zeng Jinle 已提交
664
                x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape)
665 666 667 668 669 670
                index_t = paddle.static.data(
                    name="index", dtype=index.dtype, shape=index.shape
                )
                updates_t = paddle.static.data(
                    name="updates", dtype=updates.dtype, shape=updates.shape
                )
Z
Zeng Jinle 已提交
671 672 673 674
                out_t = paddle.scatter(x_t, index_t, updates_t)
                feed = {
                    x_t.name: x,
                    index_t.name: index,
675
                    updates_t.name: updates,
Z
Zeng Jinle 已提交
676 677 678 679 680 681 682
                }
                fetch = [out_t]

                gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
                gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
                return gpu_value

683
        np.testing.assert_array_equal(test_dygraph(), test_static_graph())
Z
Zeng Jinle 已提交
684

S
ShenLiang 已提交
685

686 687 688
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
L
Li Min 已提交
689 690 691
class TestScatterOpFp16(OpTest):
    def setUp(self):
        self.__class__.op_type = "scatter"
H
hong 已提交
692
        self.python_api = paddle.scatter
L
Li Min 已提交
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708
        # compute grad in the following code handly.
        self.__class__.no_need_check_grad = True
        self.x_type = 'float16'
        self.x_np = np.ones((3, 3)).astype(self.x_type)
        self.index_np = np.array([1, 2]).astype("int32")
        self.updates_np = np.random.random((2, 3)).astype(self.x_type)
        self.output_np = np.copy(self.x_np)
        self.output_np[self.index_np] = self.updates_np
        self.dout_np = np.random.random((3, 3)).astype(self.x_type)

        # compute ref_dx
        self.ref_dx = np.copy(self.dout_np)
        zero_np = np.zeros((2, 3)).astype(self.x_type)
        self.ref_dx[self.index_np] = zero_np

    def compute_ref_grad_updates(self):
709 710 711
        ref_grad_updates = paddle.gather(
            paddle.to_tensor(self.dout_np), paddle.to_tensor(self.index_np)
        )
L
Li Min 已提交
712 713 714 715 716 717 718 719
        return ref_grad_updates

    def test_scatter_fp16(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        x_tensor = paddle.to_tensor(self.x_np, stop_gradient=False)
        index_tensor = paddle.to_tensor(self.index_np)
        updates_tensor = paddle.to_tensor(self.updates_np, stop_gradient=False)
        out_tensor = paddle.scatter(x_tensor, index_tensor, updates_tensor)
720 721 722
        paddle.autograd.backward(
            [out_tensor], [paddle.to_tensor(self.dout_np)], retain_graph=True
        )
L
Li Min 已提交
723
        ref_grad_updates = self.compute_ref_grad_updates()
724
        np.testing.assert_allclose(
725 726
            ref_grad_updates.numpy(False),
            updates_tensor.grad.numpy(False),
727 728 729 730
            rtol=1e-5,
            atol=1e-5,
        )
        np.testing.assert_allclose(
731
            self.ref_dx, x_tensor.grad.numpy(False), rtol=1e-5, atol=1e-5
732
        )
L
Li Min 已提交
733 734


735 736 737 738 739
class TestScatterInplaceAPI(TestScatterAPI):
    def executed_api(self):
        self.scatter = paddle.scatter_


740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765
@unittest.skipIf(core.is_compiled_with_cuda(), "CUDA will not throw exception")
class TestScatterError(unittest.TestCase):
    def test_scatter_index(self):
        paddle.disable_static()
        x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32')

        def test_neg_index():
            index = paddle.to_tensor([2, 1, -1, 1], dtype='int64')
            updates = paddle.to_tensor(
                [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32'
            )
            out = paddle.scatter(x, index, updates)

        self.assertRaises(IndexError, test_neg_index)

        def test_too_big_index():
            index = paddle.to_tensor([2, 1, 5, 1], dtype='int64')
            updates = paddle.to_tensor(
                [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32'
            )
            out = paddle.scatter(x, index, updates)

        self.assertRaises(IndexError, test_too_big_index)
        paddle.enable_static()


Z
zchen0211 已提交
766
if __name__ == "__main__":
767
    paddle.enable_static()
Z
zchen0211 已提交
768
    unittest.main()