test_gather_op.py 17.7 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Z
zchen0211 已提交
15
import unittest
16

Q
qijun 已提交
17
import numpy as np
W
wanghuancoder 已提交
18
from eager_op_test import OpTest, convert_float_to_uint16
19

20
import paddle
21
from paddle import fluid
22
from paddle.fluid.dygraph.base import switch_to_static_graph
23
from paddle.framework import core
Z
zchen0211 已提交
24 25


26 27 28 29 30 31 32
def gather_numpy(x, index, axis):
    x_transpose = np.swapaxes(x, 0, axis)
    tmp_gather = x_transpose[index, ...]
    gather = np.swapaxes(tmp_gather, 0, axis)
    return gather


Q
qijun 已提交
33
class TestGatherOp(OpTest):
Z
zchen0211 已提交
34
    def setUp(self):
Q
qijun 已提交
35
        self.op_type = "gather"
36
        self.python_api = paddle.gather
37
        self.public_python_api = paddle.gather
W
whs 已提交
38
        self.config()
39
        self.prim_op_type = "prim"
40 41
        self.init_inputs_and_outputs()
        self.if_enable_cinn()
Z
zchen0211 已提交
42

Q
qijun 已提交
43
    def test_check_output(self):
W
wanghuancoder 已提交
44
        self.check_output()
Z
zchen0211 已提交
45

Q
qijun 已提交
46
    def test_check_grad(self):
W
wanghuancoder 已提交
47
        self.check_grad(['X'], 'Out', check_prim=True)
Z
zchen0211 已提交
48

W
whs 已提交
49
    def config(self):
50 51 52
        """
        For multi-dimension input
        """
W
whs 已提交
53
        self.x_shape = (10, 20)
54
        self.config_dtype()
W
whs 已提交
55
        self.index = [1, 3, 5]
56
        self.index_type = "int32"
W
whs 已提交
57

58 59 60
    def config_dtype(self):
        self.x_type = "float64"

61 62 63 64 65 66 67 68 69 70 71
    def init_inputs_and_outputs(self):
        xnp = np.random.random(self.x_shape).astype(self.x_type)
        self.inputs = {
            'X': xnp,
            'Index': np.array(self.index).astype(self.index_type),
        }
        self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}

    def if_enable_cinn(self):
        pass

72

73 74 75 76 77 78 79 80 81 82 83 84 85 86
class TestGatherOp_ZeroDim(TestGatherOp):
    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = 100
        self.config_dtype()
        self.index = 2
        self.index_type = "int32"

    def if_enable_cinn(self):
        self.enable_cinn = False


87 88 89 90
class TestGatherOpFP16(TestGatherOp):
    def config_dtype(self):
        self.x_type = "float16"

W
whs 已提交
91

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or core.cudnn_version() < 8100
    or paddle.device.cuda.get_device_capability()[0] < 8,
    "only support compiled with CUDA and cudnn version need larger than 8.1.0 and device's compute capability is at least 8.0",
)
class TestGatherOpBFP16(TestGatherOp):
    def config_dtype(self):
        self.x_type = "float32"
        self.dtype = np.uint16

    def init_inputs_and_outputs(self):
        xnp = np.random.random(self.x_shape).astype(self.x_type)
        self.inputs = {
            'X': convert_float_to_uint16(xnp),
            'Index': np.array(self.index).astype(self.index_type),
        }
        self.outputs = {
            'Out': convert_float_to_uint16(xnp[self.inputs["Index"]])
        }

    def if_enable_cinn(self):
        self.enable_cinn = False

    def test_check_output(self):
        self.check_output_with_place(place=paddle.CUDAPlace(0))

    def test_check_grad(self):
        self.check_grad_with_place(
            paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True
        )


W
whs 已提交
125 126
class TestCase1(TestGatherOp):
    def config(self):
127 128 129
        """
        For one dimension input
        """
130
        self.x_shape = 100
131
        self.config_dtype()
W
whs 已提交
132
        self.index = [1, 3, 5]
133 134
        self.index_type = "int32"

135 136 137 138 139 140 141 142
    def config_dtype(self):
        self.x_type = "float64"


class TestCase1FP16(TestCase1):
    def config_dtype(self):
        self.x_type = "float16"

143

144 145 146 147 148 149 150 151
class TestCase1BFP16(TestGatherOpBFP16):
    def config(self):
        self.x_shape = 100
        self.config_dtype()
        self.index = [1, 3, 5]
        self.index_type = "int32"


152 153 154 155 156
class TestCase2(TestGatherOp):
    def config(self):
        """
        For int64_t index type
        """
157
        self.x_shape = 100
158
        self.config_dtype()
159 160 161
        self.index = [1, 3, 5]
        self.index_type = "int64"

162 163 164 165 166 167 168 169
    def config_dtype(self):
        self.x_type = "float64"


class TestCase2FP16(TestCase2):
    def config_dtype(self):
        self.x_type = "float16"

170

171 172 173 174 175 176 177 178
class TestCase2BFP16(TestGatherOpBFP16):
    def config(self):
        self.x_shape = 100
        self.config_dtype()
        self.index = [1, 3, 5]
        self.index_type = "int64"


179 180 181 182 183 184
class TestCase3(TestGatherOp):
    def config(self):
        """
        For other input type
        """
        self.x_shape = (10, 20)
185
        self.config_dtype()
186 187
        self.index = [1, 3, 5]
        self.index_type = "int64"
W
whs 已提交
188

189 190 191 192 193 194 195 196
    def config_dtype(self):
        self.x_type = "float64"


class TestCase3Fp16(TestCase3):
    def config_dtype(self):
        self.x_type = "float16"

Z
zchen0211 已提交
197

198 199 200 201 202 203 204 205
class TestCase3BFP16(TestGatherOpBFP16):
    def config(self):
        self.x_shape = (10, 20)
        self.config_dtype()
        self.index = [1, 3, 5]
        self.index_type = "int64"


206 207 208 209
class TestCase4(TestGatherOp):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': False}
210
        self.config_dtype()
211 212 213
        self.index = [1, 1]
        self.index_type = "int32"

214 215 216 217 218 219 220 221
    def config_dtype(self):
        self.x_type = "float64"


class TestCase4FP16(TestCase4):
    def config_dtype(self):
        self.x_type = "float16"

222

223 224 225 226 227 228 229 230 231
class TestCase4BFP16(TestGatherOpBFP16):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': False}
        self.config_dtype()
        self.index = [1, 1]
        self.index_type = "int32"


232 233 234 235
class TestCase5(TestGatherOp):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': False}
236
        self.config_dtype()
237 238 239
        self.index = [1, 1, 3]
        self.index_type = "int32"

240 241 242 243
    def config_dtype(self):
        self.x_type = "float64"


244 245 246 247 248 249 250 251 252
class TestCase5BFP16(TestGatherOpBFP16):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': False}
        self.config_dtype()
        self.index = [1, 1]
        self.index_type = "int32"


253 254 255 256
class TestCase5FP16(TestCase5):
    def config_dtype(self):
        self.x_type = "float16"

257 258 259 260 261

class TestCase6(TestGatherOp):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': True}
262
        self.config_dtype()
263 264 265
        self.index = [1, 3]
        self.index_type = "int32"

266 267 268 269 270 271 272 273
    def config_dtype(self):
        self.x_type = "float64"


class TestCase6FP16(TestCase6):
    def config_dtype(self):
        self.x_type = "float16"

274

275 276 277 278 279 280 281 282 283
class TestCase6BFP16(TestGatherOpBFP16):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': True}
        self.config_dtype()
        self.index = [1, 3]
        self.index_type = "int32"


284 285 286
class TestGatherBF16Op(OpTest):
    def setUp(self):
        self.op_type = "gather"
287
        self.python_api = paddle.gather
288 289 290 291 292 293 294 295
        self.dtype = np.uint16
        self.config()
        xnp = np.random.random(self.x_shape).astype(np.float32)
        axis_np = np.array(self.axis).astype(self.axis_type)
        index_np = np.array(self.index).astype(self.index_type)
        self.inputs = {
            'X': convert_float_to_uint16(xnp),
            'Index': index_np,
296
            'Axis': axis_np,
297 298 299 300 301
        }
        out = gather_numpy(self.inputs['X'], index_np, axis_np[0])
        self.outputs = {'Out': out}

    def test_check_output(self):
W
wanghuancoder 已提交
302
        self.check_output()
303 304

    def test_check_grad(self):
W
wanghuancoder 已提交
305
        self.check_grad(['X'], 'Out', numeric_grad_delta=0.5)
306 307 308 309 310 311 312 313 314 315 316 317

    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (3, 88, 3)
        self.index = [1, 3, 5]
        self.index_type = "int32"
        self.axis = [1]
        self.axis_type = "int32"


318 319 320
class TestGatherOp1(OpTest):
    def setUp(self):
        self.op_type = "gather"
321
        self.python_api = paddle.gather
322 323 324 325 326 327 328 329 330
        self.config()
        xnp = np.random.random(self.x_shape).astype(self.x_type)
        axis_np = np.array(self.axis).astype(self.index_type)
        index_np = np.array(self.index).astype(self.index_type)
        out = gather_numpy(xnp, index_np, axis_np[0])
        self.inputs = {'X': xnp, 'Index': index_np, 'Axis': axis_np}
        self.outputs = {'Out': out}

    def test_check_output(self):
W
wanghuancoder 已提交
331
        self.check_output()
332 333

    def test_check_grad(self):
W
wanghuancoder 已提交
334
        self.check_grad(['X'], 'Out')
335 336 337 338 339 340

    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (3, 88, 3)
341
        self.config_dtype()
342 343 344 345 346
        self.index = [1, 3, 5]
        self.index_type = "int32"
        self.axis = [1]
        self.axis_type = "int32"

347 348 349 350 351 352 353 354
    def config_dtype(self):
        self.x_type = "float64"


class TestGatherOp1FP16(TestGatherOp1):
    def config_dtype(self):
        self.x_type = "float16"

355 356 357 358 359 360 361

class TestGatherOp2(TestGatherOp1):
    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (10, 88, 10)
362
        self.config_dtype()
363 364 365 366 367
        self.index = [1, 3, 5]
        self.index_type = "int64"
        self.axis = [0]
        self.axis_type = "int32"

368 369 370 371 372 373 374 375
    def config_dtype(self):
        self.x_type = "float64"


class TestGatherOp2FP16(TestGatherOp2):
    def config_dtype(self):
        self.x_type = "float16"

376 377 378 379 380 381 382

class TestGatherOp3(TestGatherOp1):
    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (10, 88, 10)
383
        self.config_dtype()
384 385 386 387 388
        self.index = [1, 3, 5]
        self.index_type = "int64"
        self.axis = [2]
        self.axis_type = "int32"

389 390 391 392 393 394 395 396
    def config_dtype(self):
        self.x_type = "float64"


class TestGatherOp3FP16(TestGatherOp3):
    def config_dtype(self):
        self.x_type = "float16"

397 398 399 400 401 402 403

class TestGatherOp4(TestGatherOp1):
    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (3, 100, 10)
404
        self.config_dtype()
405 406 407 408
        self.index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
        self.index_type = "int64"
        self.axis = [0]
        self.axis_type = "int32"
409
        self.attrs = {'overwrite': False}
410

411 412 413 414 415 416 417 418
    def config_dtype(self):
        self.x_type = "float64"


class TestGatherOp4FP16(TestGatherOp4):
    def config_dtype(self):
        self.x_type = "float16"

419

420
class API_TestGather(unittest.TestCase):
421
    def test_out1(self):
422
        with fluid.program_guard(fluid.Program(), fluid.Program()):
G
GGBond8488 已提交
423 424 425 426
            data1 = paddle.static.data('data1', shape=[-1, 2], dtype='float64')
            data1.desc.set_need_check_feed(False)
            index = paddle.static.data('index', shape=[-1, 1], dtype='int32')
            index.desc.set_need_check_feed(False)
427
            out = paddle.gather(data1, index)
428 429 430 431
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            input = np.array([[1, 2], [3, 4], [5, 6]])
            index_1 = np.array([1, 2])
432 433 434
            (result,) = exe.run(
                feed={"data1": input, "index": index_1}, fetch_list=[out]
            )
435
            expected_output = np.array([[3, 4], [5, 6]])
436
        np.testing.assert_allclose(result, expected_output, rtol=1e-05)
437

438
    def test_out2(self):
439 440 441
        with paddle.static.program_guard(
            paddle.static.Program(), paddle.static.Program()
        ):
442 443 444
            x = paddle.static.data('x', shape=[-1, 2], dtype='float64')
            index = paddle.static.data('index', shape=[-1, 1], dtype='int32')
            axis = paddle.static.data('axis', shape=[1], dtype='int32')
445 446 447 448 449 450
            out = paddle.gather(x, index, axis)
            place = paddle.CPUPlace()
            exe = paddle.static.Executor(place)
            x_np = np.array([[1, 2], [3, 4], [5, 6]]).astype('float64')
            index_np = np.array([1, 1]).astype('int32')
            axis_np = np.array([1]).astype('int32')
451 452 453 454
            (result,) = exe.run(
                feed={"x": x_np, "index": index_np, 'axis': axis_np},
                fetch_list=[out],
            )
455
            expected_output = gather_numpy(x_np, index_np, axis_np[0])
456
        np.testing.assert_allclose(result, expected_output, rtol=1e-05)
457

458 459

class API_TestDygraphGather(unittest.TestCase):
460 461 462 463 464 465
    def test_out1(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([1, 2])
        input = paddle.to_tensor(input_1)
        index = paddle.to_tensor(index_1)
466
        output = paddle.gather(input, index)
467 468
        output_np = output.numpy()
        expected_output = np.array([[3, 4], [5, 6]])
469
        np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
470 471 472 473 474 475 476 477 478 479 480
        paddle.enable_static()

    def test_out12(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([1, 2])
        x = paddle.to_tensor(input_1)
        index = paddle.to_tensor(index_1)
        output = paddle.gather(x, index, axis=0)
        output_np = output.numpy()
        expected_output = gather_numpy(input_1, index_1, axis=0)
481
        np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
482 483
        paddle.enable_static()

Z
Zeng Jinle 已提交
484 485 486 487 488 489 490 491 492 493 494
    def test_zero_index(self):
        paddle.disable_static()
        x = paddle.to_tensor([[1, 2], [3, 4]])
        index = paddle.to_tensor(np.array([]).astype('int64'))
        for axis in range(len(x.shape)):
            out = paddle.gather(x, index, axis)
            expected_shape = list(x.shape)
            expected_shape[axis] = 0
            self.assertEqual(list(out.shape), expected_shape)
        paddle.enable_static()

495 496 497 498 499
    def test_large_data(self):
        if not paddle.is_compiled_with_cuda():
            return

        x = np.random.rand(226862, 256).astype("float32")
500
        index = np.random.randint(0, 22682, size=(8859027))
501 502 503

        def test_dygraph():
            with fluid.dygraph.guard():
504 505 506
                gpu_out = paddle.gather(
                    paddle.to_tensor(x), paddle.to_tensor(index)
                )
507 508 509 510
                return gpu_out.numpy()

        @switch_to_static_graph
        def test_static_graph():
511 512 513
            with paddle.static.program_guard(
                paddle.static.Program(), paddle.static.Program()
            ):
514
                x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape)
515 516 517
                index_t = paddle.static.data(
                    name="index", dtype=index.dtype, shape=index.shape
                )
518 519 520 521 522 523 524 525
                out_t = paddle.gather(x_t, index_t)
                feed = {x_t.name: x, index_t.name: index}
                fetch = [out_t]

                gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
                gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
                return gpu_value

526
        np.testing.assert_array_equal(test_dygraph(), test_static_graph())
527

528 529 530

class TestGathertError(unittest.TestCase):
    def test_error1(self):
531 532 533
        with paddle.static.program_guard(
            paddle.static.Program(), paddle.static.Program()
        ):
534
            shape = [8, 9, 6]
535 536 537 538
            x = paddle.static.data(shape=shape, dtype='int8', name='x')
            axis = paddle.static.data(shape=[1], dtype='float32', name='axis')
            index = paddle.static.data(shape=shape, dtype='int32', name='index')
            index_float = paddle.static.data(
539 540
                shape=shape, dtype='float32', name='index_float'
            )
541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556

            def test_x_type():
                paddle.gather(x, index)

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
                paddle.gather(x, index_float)

            self.assertRaises(TypeError, test_index_type)

            def test_axis_dtype():
                paddle.gather(x, index, axis=1.11)

            self.assertRaises(TypeError, test_axis_dtype)

Z
zhangchunle 已提交
557
            def test_axis_dtype1():
558 559
                paddle.gather(x, index, axis=axis)

Z
zhangchunle 已提交
560
            self.assertRaises(TypeError, test_axis_dtype1)
561 562 563 564

    def test_error2(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
            shape = [8, 9, 6]
565 566 567
            x = paddle.static.data(shape=shape, dtype='int8', name='x')
            index = paddle.static.data(shape=shape, dtype='int32', name='mask')
            index_float = paddle.static.data(
568 569
                shape=shape, dtype='float32', name='index_float'
            )
570 571

            def test_x_type():
572
                paddle.gather(x, index)
573 574 575 576

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
577
                paddle.gather(x, index_float)
578 579

            self.assertRaises(TypeError, test_index_type)
580

581 582 583 584 585
    def test_error3(self):
        with paddle.static.program_guard(
            paddle.static.Program(), paddle.static.Program()
        ):
            shape = [8, 9, 6]
586 587 588 589
            x = paddle.static.data(shape=shape, dtype='int32', name='x')
            axis = paddle.static.data(shape=[1], dtype='int32', name='axis')
            index = paddle.static.data(shape=shape, dtype='int32', name='index')
            index_float = paddle.static.data(
590 591 592 593 594 595 596 597 598 599 600 601 602
                shape=shape, dtype='float32', name='index_float'
            )

            def test_axis_minsize():
                paddle.gather(x, index, axis=-1)

            self.assertRaises(ValueError, test_axis_minsize)

            def test_axis_maxsize():
                paddle.gather(x, index, axis=512)

            self.assertRaises(ValueError, test_axis_maxsize)

603

604 605 606 607 608 609 610 611
class TestCheckOutType(unittest.TestCase):
    def test_out_type(self):
        data = paddle.static.data(shape=[16, 10], dtype='int64', name='x')
        index = paddle.static.data(shape=[4], dtype='int64', name='index')
        out = paddle.gather(data, index)
        self.assertTrue(out.dtype == core.VarDesc.VarType.INT64)


Z
zchen0211 已提交
612
if __name__ == "__main__":
Z
Zeng Jinle 已提交
613
    paddle.enable_static()
Z
zchen0211 已提交
614
    unittest.main()