test_gather_op.py 12.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Z
zchen0211 已提交
17
import unittest
Q
qijun 已提交
18
import numpy as np
19
from op_test import OpTest, convert_float_to_uint16
20 21
import paddle
import paddle.fluid as fluid
22
from paddle.framework import core
23
from paddle.fluid.dygraph.base import switch_to_static_graph
Z
zchen0211 已提交
24 25


26 27 28 29 30 31 32
def gather_numpy(x, index, axis):
    x_transpose = np.swapaxes(x, 0, axis)
    tmp_gather = x_transpose[index, ...]
    gather = np.swapaxes(tmp_gather, 0, axis)
    return gather


Q
qijun 已提交
33
class TestGatherOp(OpTest):
34

Z
zchen0211 已提交
35
    def setUp(self):
Q
qijun 已提交
36
        self.op_type = "gather"
37
        self.python_api = paddle.gather
W
whs 已提交
38
        self.config()
39 40 41 42 43
        xnp = np.random.random(self.x_shape).astype(self.x_type)
        self.inputs = {
            'X': xnp,
            'Index': np.array(self.index).astype(self.index_type)
        }
Q
qijun 已提交
44
        self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
Z
zchen0211 已提交
45

Q
qijun 已提交
46
    def test_check_output(self):
47
        self.check_output(check_eager=True)
Z
zchen0211 已提交
48

Q
qijun 已提交
49
    def test_check_grad(self):
50
        self.check_grad(['X'], 'Out', check_eager=True)
Z
zchen0211 已提交
51

W
whs 已提交
52
    def config(self):
53 54 55
        """
        For multi-dimension input
        """
W
whs 已提交
56
        self.x_shape = (10, 20)
57
        self.x_type = "float64"
W
whs 已提交
58
        self.index = [1, 3, 5]
59
        self.index_type = "int32"
W
whs 已提交
60 61 62


class TestCase1(TestGatherOp):
63

W
whs 已提交
64
    def config(self):
65 66 67
        """
        For one dimension input
        """
Z
zhupengyang 已提交
68
        self.x_shape = (100)
69
        self.x_type = "float64"
W
whs 已提交
70
        self.index = [1, 3, 5]
71 72 73 74
        self.index_type = "int32"


class TestCase2(TestGatherOp):
75

76 77 78 79
    def config(self):
        """
        For int64_t index type
        """
Z
zhupengyang 已提交
80
        self.x_shape = (100)
81
        self.x_type = "float64"
82 83 84 85 86
        self.index = [1, 3, 5]
        self.index_type = "int64"


class TestCase3(TestGatherOp):
87

88 89 90 91 92
    def config(self):
        """
        For other input type
        """
        self.x_shape = (10, 20)
93
        self.x_type = "float64"
94 95
        self.index = [1, 3, 5]
        self.index_type = "int64"
W
whs 已提交
96

Z
zchen0211 已提交
97

98
class TestCase4(TestGatherOp):
99

100 101 102 103 104 105 106 107 108
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': False}
        self.x_type = "double"
        self.index = [1, 1]
        self.index_type = "int32"


class TestCase5(TestGatherOp):
109

110 111 112
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': False}
113
        self.x_type = "float64"
114 115 116 117 118
        self.index = [1, 1, 3]
        self.index_type = "int32"


class TestCase6(TestGatherOp):
119

120 121 122
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': True}
123
        self.x_type = "float64"
124 125 126 127
        self.index = [1, 3]
        self.index_type = "int32"


128
class TestGatherBF16Op(OpTest):
129

130 131
    def setUp(self):
        self.op_type = "gather"
132
        self.python_api = paddle.gather
133 134 135 136 137 138 139 140 141 142 143 144 145 146
        self.dtype = np.uint16
        self.config()
        xnp = np.random.random(self.x_shape).astype(np.float32)
        axis_np = np.array(self.axis).astype(self.axis_type)
        index_np = np.array(self.index).astype(self.index_type)
        self.inputs = {
            'X': convert_float_to_uint16(xnp),
            'Index': index_np,
            'Axis': axis_np
        }
        out = gather_numpy(self.inputs['X'], index_np, axis_np[0])
        self.outputs = {'Out': out}

    def test_check_output(self):
147
        self.check_output(check_eager=True)
148 149

    def test_check_grad(self):
150
        self.check_grad(['X'], 'Out', numeric_grad_delta=0.5, check_eager=True)
151 152 153 154 155 156 157 158 159 160 161 162

    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (3, 88, 3)
        self.index = [1, 3, 5]
        self.index_type = "int32"
        self.axis = [1]
        self.axis_type = "int32"


163
class TestGatherOp1(OpTest):
164

165 166
    def setUp(self):
        self.op_type = "gather"
167
        self.python_api = paddle.gather
168 169 170 171 172 173 174 175 176
        self.config()
        xnp = np.random.random(self.x_shape).astype(self.x_type)
        axis_np = np.array(self.axis).astype(self.index_type)
        index_np = np.array(self.index).astype(self.index_type)
        out = gather_numpy(xnp, index_np, axis_np[0])
        self.inputs = {'X': xnp, 'Index': index_np, 'Axis': axis_np}
        self.outputs = {'Out': out}

    def test_check_output(self):
177
        self.check_output(check_eager=True)
178 179

    def test_check_grad(self):
180
        self.check_grad(['X'], 'Out', check_eager=True)
181 182 183 184 185 186 187 188 189 190 191 192 193 194

    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (3, 88, 3)
        self.x_type = "float64"
        self.index = [1, 3, 5]
        self.index_type = "int32"
        self.axis = [1]
        self.axis_type = "int32"


class TestGatherOp2(TestGatherOp1):
195

196 197 198 199 200 201 202 203 204 205 206 207 208
    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (10, 88, 10)
        self.x_type = "float64"
        self.index = [1, 3, 5]
        self.index_type = "int64"
        self.axis = [0]
        self.axis_type = "int32"


class TestGatherOp3(TestGatherOp1):
209

210 211 212 213 214 215 216 217 218 219 220 221 222
    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (10, 88, 10)
        self.x_type = "float64"
        self.index = [1, 3, 5]
        self.index_type = "int64"
        self.axis = [2]
        self.axis_type = "int32"


class TestGatherOp4(TestGatherOp1):
223

224 225 226 227 228 229 230 231 232 233
    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (3, 100, 10)
        self.x_type = "float64"
        self.index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
        self.index_type = "int64"
        self.axis = [0]
        self.axis_type = "int32"
234
        self.attrs = {'overwrite': False}
235 236


237
class API_TestGather(unittest.TestCase):
238

239
    def test_out1(self):
240 241
        with fluid.program_guard(fluid.Program(), fluid.Program()):
            data1 = fluid.layers.data('data1', shape=[-1, 2], dtype='float64')
242 243
            index = fluid.layers.data('index', shape=[-1, 1], dtype='int32')
            out = paddle.fluid.layers.gather(data1, index)
244 245 246 247
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            input = np.array([[1, 2], [3, 4], [5, 6]])
            index_1 = np.array([1, 2])
248 249 250 251
            result, = exe.run(feed={
                "data1": input,
                "index": index_1
            },
252 253 254 255
                              fetch_list=[out])
            expected_output = np.array([[3, 4], [5, 6]])
        self.assertTrue(np.allclose(result, expected_output))

256 257 258
    def test_out2(self):
        with paddle.static.program_guard(paddle.static.Program(),
                                         paddle.static.Program()):
259 260 261
            x = paddle.fluid.data('x', shape=[-1, 2], dtype='float64')
            index = paddle.fluid.data('index', shape=[-1, 1], dtype='int32')
            axis = paddle.fluid.data('axis', shape=[1], dtype='int32')
262 263 264 265 266 267
            out = paddle.gather(x, index, axis)
            place = paddle.CPUPlace()
            exe = paddle.static.Executor(place)
            x_np = np.array([[1, 2], [3, 4], [5, 6]]).astype('float64')
            index_np = np.array([1, 1]).astype('int32')
            axis_np = np.array([1]).astype('int32')
268 269 270 271 272 273
            result, = exe.run(feed={
                "x": x_np,
                "index": index_np,
                'axis': axis_np
            },
                              fetch_list=[out])
274
            expected_output = gather_numpy(x_np, index_np, axis_np[0])
275 276
        self.assertTrue(np.allclose(result, expected_output))

277 278

class API_TestDygraphGather(unittest.TestCase):
279

280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
    def test_out1(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([1, 2])
        input = paddle.to_tensor(input_1)
        index = paddle.to_tensor(index_1)
        output = paddle.fluid.layers.gather(input, index)
        output_np = output.numpy()
        expected_output = np.array([[3, 4], [5, 6]])
        self.assertTrue(np.allclose(output_np, expected_output))
        paddle.enable_static()

    def test_out12(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([1, 2])
        x = paddle.to_tensor(input_1)
        index = paddle.to_tensor(index_1)
        output = paddle.gather(x, index, axis=0)
        output_np = output.numpy()
        expected_output = gather_numpy(input_1, index_1, axis=0)
301
        self.assertTrue(np.allclose(output_np, expected_output))
302 303
        paddle.enable_static()

Z
Zeng Jinle 已提交
304 305 306 307 308 309 310 311 312 313 314
    def test_zero_index(self):
        paddle.disable_static()
        x = paddle.to_tensor([[1, 2], [3, 4]])
        index = paddle.to_tensor(np.array([]).astype('int64'))
        for axis in range(len(x.shape)):
            out = paddle.gather(x, index, axis)
            expected_shape = list(x.shape)
            expected_shape[axis] = 0
            self.assertEqual(list(out.shape), expected_shape)
        paddle.enable_static()

315 316 317 318 319
    def test_large_data(self):
        if not paddle.is_compiled_with_cuda():
            return

        x = np.random.rand(226862, 256).astype("float32")
320
        index = np.random.randint(0, 22682, size=(8859027))
321 322 323

        def test_dygraph():
            with fluid.dygraph.guard():
324 325
                gpu_out = paddle.gather(paddle.to_tensor(x),
                                        paddle.to_tensor(index))
326 327 328 329 330 331 332
                return gpu_out.numpy()

        @switch_to_static_graph
        def test_static_graph():
            with paddle.static.program_guard(paddle.static.Program(),
                                             paddle.static.Program()):
                x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape)
333 334 335
                index_t = paddle.static.data(name="index",
                                             dtype=index.dtype,
                                             shape=index.shape)
336 337 338 339 340 341 342 343 344 345
                out_t = paddle.gather(x_t, index_t)
                feed = {x_t.name: x, index_t.name: index}
                fetch = [out_t]

                gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
                gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
                return gpu_value

        self.assertTrue(np.array_equal(test_dygraph(), test_static_graph()))

346 347

class TestGathertError(unittest.TestCase):
348

349 350 351 352 353
    def test_error1(self):
        with paddle.static.program_guard(paddle.static.Program(),
                                         paddle.static.Program()):

            shape = [8, 9, 6]
354 355 356
            x = paddle.fluid.data(shape=shape, dtype='int8', name='x')
            axis = paddle.fluid.data(shape=[1], dtype='float32', name='axis')
            index = paddle.fluid.data(shape=shape, dtype='int32', name='index')
357 358 359
            index_float = paddle.fluid.data(shape=shape,
                                            dtype='float32',
                                            name='index_float')
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375

            def test_x_type():
                paddle.gather(x, index)

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
                paddle.gather(x, index_float)

            self.assertRaises(TypeError, test_index_type)

            def test_axis_dtype():
                paddle.gather(x, index, axis=1.11)

            self.assertRaises(TypeError, test_axis_dtype)

Z
zhangchunle 已提交
376
            def test_axis_dtype1():
377 378
                paddle.gather(x, index, axis=axis)

Z
zhangchunle 已提交
379
            self.assertRaises(TypeError, test_axis_dtype1)
380 381 382 383 384 385 386

    def test_error2(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):

            shape = [8, 9, 6]
            x = fluid.data(shape=shape, dtype='int8', name='x')
            index = fluid.data(shape=shape, dtype='int32', name='mask')
387 388 389
            index_float = fluid.data(shape=shape,
                                     dtype='float32',
                                     name='index_float')
390 391 392 393 394 395 396 397 398 399

            def test_x_type():
                paddle.fluid.layers.gather(x, index)

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
                paddle.fluid.layers.gather(x, index_float)

            self.assertRaises(TypeError, test_index_type)
400 401


402
class TestCheckOutType(unittest.TestCase):
403

404 405 406 407 408 409 410
    def test_out_type(self):
        data = paddle.static.data(shape=[16, 10], dtype='int64', name='x')
        index = paddle.static.data(shape=[4], dtype='int64', name='index')
        out = paddle.gather(data, index)
        self.assertTrue(out.dtype == core.VarDesc.VarType.INT64)


Z
zchen0211 已提交
411
if __name__ == "__main__":
Z
Zeng Jinle 已提交
412
    paddle.enable_static()
Z
zchen0211 已提交
413
    unittest.main()