test_gather_op.py 12.5 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Z
zchen0211 已提交
15
import unittest
16

Q
qijun 已提交
17
import numpy as np
18
from op_test import OpTest, convert_float_to_uint16
19

20 21
import paddle
import paddle.fluid as fluid
22
from paddle.fluid.dygraph.base import switch_to_static_graph
23
from paddle.framework import core
Z
zchen0211 已提交
24 25


26 27 28 29 30 31 32
def gather_numpy(x, index, axis):
    x_transpose = np.swapaxes(x, 0, axis)
    tmp_gather = x_transpose[index, ...]
    gather = np.swapaxes(tmp_gather, 0, axis)
    return gather


Q
qijun 已提交
33
class TestGatherOp(OpTest):
Z
zchen0211 已提交
34
    def setUp(self):
Q
qijun 已提交
35
        self.op_type = "gather"
36
        self.python_api = paddle.gather
W
whs 已提交
37
        self.config()
38 39 40
        xnp = np.random.random(self.x_shape).astype(self.x_type)
        self.inputs = {
            'X': xnp,
41
            'Index': np.array(self.index).astype(self.index_type),
42
        }
Q
qijun 已提交
43
        self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
Z
zchen0211 已提交
44

Q
qijun 已提交
45
    def test_check_output(self):
46
        self.check_output(check_eager=True)
Z
zchen0211 已提交
47

Q
qijun 已提交
48
    def test_check_grad(self):
49
        self.check_grad(['X'], 'Out', check_eager=True)
Z
zchen0211 已提交
50

W
whs 已提交
51
    def config(self):
52 53 54
        """
        For multi-dimension input
        """
W
whs 已提交
55
        self.x_shape = (10, 20)
56
        self.x_type = "float64"
W
whs 已提交
57
        self.index = [1, 3, 5]
58
        self.index_type = "int32"
W
whs 已提交
59 60 61 62


class TestCase1(TestGatherOp):
    def config(self):
63 64 65
        """
        For one dimension input
        """
66
        self.x_shape = 100
67
        self.x_type = "float64"
W
whs 已提交
68
        self.index = [1, 3, 5]
69 70 71 72 73 74 75 76
        self.index_type = "int32"


class TestCase2(TestGatherOp):
    def config(self):
        """
        For int64_t index type
        """
77
        self.x_shape = 100
78
        self.x_type = "float64"
79 80 81 82 83 84 85 86 87 88
        self.index = [1, 3, 5]
        self.index_type = "int64"


class TestCase3(TestGatherOp):
    def config(self):
        """
        For other input type
        """
        self.x_shape = (10, 20)
89
        self.x_type = "float64"
90 91
        self.index = [1, 3, 5]
        self.index_type = "int64"
W
whs 已提交
92

Z
zchen0211 已提交
93

94 95 96 97 98 99 100 101 102 103 104 105 106
class TestCase4(TestGatherOp):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': False}
        self.x_type = "double"
        self.index = [1, 1]
        self.index_type = "int32"


class TestCase5(TestGatherOp):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': False}
107
        self.x_type = "float64"
108 109 110 111 112 113 114 115
        self.index = [1, 1, 3]
        self.index_type = "int32"


class TestCase6(TestGatherOp):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': True}
116
        self.x_type = "float64"
117 118 119 120
        self.index = [1, 3]
        self.index_type = "int32"


121 122 123
class TestGatherBF16Op(OpTest):
    def setUp(self):
        self.op_type = "gather"
124
        self.python_api = paddle.gather
125 126 127 128 129 130 131 132
        self.dtype = np.uint16
        self.config()
        xnp = np.random.random(self.x_shape).astype(np.float32)
        axis_np = np.array(self.axis).astype(self.axis_type)
        index_np = np.array(self.index).astype(self.index_type)
        self.inputs = {
            'X': convert_float_to_uint16(xnp),
            'Index': index_np,
133
            'Axis': axis_np,
134 135 136 137 138
        }
        out = gather_numpy(self.inputs['X'], index_np, axis_np[0])
        self.outputs = {'Out': out}

    def test_check_output(self):
139
        self.check_output(check_eager=True)
140 141

    def test_check_grad(self):
142
        self.check_grad(['X'], 'Out', numeric_grad_delta=0.5, check_eager=True)
143 144 145 146 147 148 149 150 151 152 153 154

    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (3, 88, 3)
        self.index = [1, 3, 5]
        self.index_type = "int32"
        self.axis = [1]
        self.axis_type = "int32"


155 156 157
class TestGatherOp1(OpTest):
    def setUp(self):
        self.op_type = "gather"
158
        self.python_api = paddle.gather
159 160 161 162 163 164 165 166 167
        self.config()
        xnp = np.random.random(self.x_shape).astype(self.x_type)
        axis_np = np.array(self.axis).astype(self.index_type)
        index_np = np.array(self.index).astype(self.index_type)
        out = gather_numpy(xnp, index_np, axis_np[0])
        self.inputs = {'X': xnp, 'Index': index_np, 'Axis': axis_np}
        self.outputs = {'Out': out}

    def test_check_output(self):
168
        self.check_output(check_eager=True)
169 170

    def test_check_grad(self):
171
        self.check_grad(['X'], 'Out', check_eager=True)
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221

    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (3, 88, 3)
        self.x_type = "float64"
        self.index = [1, 3, 5]
        self.index_type = "int32"
        self.axis = [1]
        self.axis_type = "int32"


class TestGatherOp2(TestGatherOp1):
    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (10, 88, 10)
        self.x_type = "float64"
        self.index = [1, 3, 5]
        self.index_type = "int64"
        self.axis = [0]
        self.axis_type = "int32"


class TestGatherOp3(TestGatherOp1):
    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (10, 88, 10)
        self.x_type = "float64"
        self.index = [1, 3, 5]
        self.index_type = "int64"
        self.axis = [2]
        self.axis_type = "int32"


class TestGatherOp4(TestGatherOp1):
    def config(self):
        """
        For multi-dimension input
        """
        self.x_shape = (3, 100, 10)
        self.x_type = "float64"
        self.index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
        self.index_type = "int64"
        self.axis = [0]
        self.axis_type = "int32"
222
        self.attrs = {'overwrite': False}
223 224


225
class API_TestGather(unittest.TestCase):
226
    def test_out1(self):
227
        with fluid.program_guard(fluid.Program(), fluid.Program()):
G
GGBond8488 已提交
228 229 230 231
            data1 = paddle.static.data('data1', shape=[-1, 2], dtype='float64')
            data1.desc.set_need_check_feed(False)
            index = paddle.static.data('index', shape=[-1, 1], dtype='int32')
            index.desc.set_need_check_feed(False)
232
            out = paddle.gather(data1, index)
233 234 235 236
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            input = np.array([[1, 2], [3, 4], [5, 6]])
            index_1 = np.array([1, 2])
237 238 239
            (result,) = exe.run(
                feed={"data1": input, "index": index_1}, fetch_list=[out]
            )
240
            expected_output = np.array([[3, 4], [5, 6]])
241
        np.testing.assert_allclose(result, expected_output, rtol=1e-05)
242

243
    def test_out2(self):
244 245 246
        with paddle.static.program_guard(
            paddle.static.Program(), paddle.static.Program()
        ):
247 248 249
            x = paddle.fluid.data('x', shape=[-1, 2], dtype='float64')
            index = paddle.fluid.data('index', shape=[-1, 1], dtype='int32')
            axis = paddle.fluid.data('axis', shape=[1], dtype='int32')
250 251 252 253 254 255
            out = paddle.gather(x, index, axis)
            place = paddle.CPUPlace()
            exe = paddle.static.Executor(place)
            x_np = np.array([[1, 2], [3, 4], [5, 6]]).astype('float64')
            index_np = np.array([1, 1]).astype('int32')
            axis_np = np.array([1]).astype('int32')
256 257 258 259
            (result,) = exe.run(
                feed={"x": x_np, "index": index_np, 'axis': axis_np},
                fetch_list=[out],
            )
260
            expected_output = gather_numpy(x_np, index_np, axis_np[0])
261
        np.testing.assert_allclose(result, expected_output, rtol=1e-05)
262

263 264

class API_TestDygraphGather(unittest.TestCase):
265 266 267 268 269 270
    def test_out1(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([1, 2])
        input = paddle.to_tensor(input_1)
        index = paddle.to_tensor(index_1)
271
        output = paddle.gather(input, index)
272 273
        output_np = output.numpy()
        expected_output = np.array([[3, 4], [5, 6]])
274
        np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
275 276 277 278 279 280 281 282 283 284 285
        paddle.enable_static()

    def test_out12(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([1, 2])
        x = paddle.to_tensor(input_1)
        index = paddle.to_tensor(index_1)
        output = paddle.gather(x, index, axis=0)
        output_np = output.numpy()
        expected_output = gather_numpy(input_1, index_1, axis=0)
286
        np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
287 288
        paddle.enable_static()

Z
Zeng Jinle 已提交
289 290 291 292 293 294 295 296 297 298 299
    def test_zero_index(self):
        paddle.disable_static()
        x = paddle.to_tensor([[1, 2], [3, 4]])
        index = paddle.to_tensor(np.array([]).astype('int64'))
        for axis in range(len(x.shape)):
            out = paddle.gather(x, index, axis)
            expected_shape = list(x.shape)
            expected_shape[axis] = 0
            self.assertEqual(list(out.shape), expected_shape)
        paddle.enable_static()

300 301 302 303 304
    def test_large_data(self):
        if not paddle.is_compiled_with_cuda():
            return

        x = np.random.rand(226862, 256).astype("float32")
305
        index = np.random.randint(0, 22682, size=(8859027))
306 307 308

        def test_dygraph():
            with fluid.dygraph.guard():
309 310 311
                gpu_out = paddle.gather(
                    paddle.to_tensor(x), paddle.to_tensor(index)
                )
312 313 314 315
                return gpu_out.numpy()

        @switch_to_static_graph
        def test_static_graph():
316 317 318
            with paddle.static.program_guard(
                paddle.static.Program(), paddle.static.Program()
            ):
319
                x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape)
320 321 322
                index_t = paddle.static.data(
                    name="index", dtype=index.dtype, shape=index.shape
                )
323 324 325 326 327 328 329 330
                out_t = paddle.gather(x_t, index_t)
                feed = {x_t.name: x, index_t.name: index}
                fetch = [out_t]

                gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
                gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
                return gpu_value

331
        np.testing.assert_array_equal(test_dygraph(), test_static_graph())
332

333 334 335

class TestGathertError(unittest.TestCase):
    def test_error1(self):
336 337 338
        with paddle.static.program_guard(
            paddle.static.Program(), paddle.static.Program()
        ):
339 340

            shape = [8, 9, 6]
341 342 343
            x = paddle.fluid.data(shape=shape, dtype='int8', name='x')
            axis = paddle.fluid.data(shape=[1], dtype='float32', name='axis')
            index = paddle.fluid.data(shape=shape, dtype='int32', name='index')
344 345 346
            index_float = paddle.fluid.data(
                shape=shape, dtype='float32', name='index_float'
            )
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362

            def test_x_type():
                paddle.gather(x, index)

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
                paddle.gather(x, index_float)

            self.assertRaises(TypeError, test_index_type)

            def test_axis_dtype():
                paddle.gather(x, index, axis=1.11)

            self.assertRaises(TypeError, test_axis_dtype)

Z
zhangchunle 已提交
363
            def test_axis_dtype1():
364 365
                paddle.gather(x, index, axis=axis)

Z
zhangchunle 已提交
366
            self.assertRaises(TypeError, test_axis_dtype1)
367 368 369 370 371 372 373

    def test_error2(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):

            shape = [8, 9, 6]
            x = fluid.data(shape=shape, dtype='int8', name='x')
            index = fluid.data(shape=shape, dtype='int32', name='mask')
374 375 376
            index_float = fluid.data(
                shape=shape, dtype='float32', name='index_float'
            )
377 378

            def test_x_type():
379
                paddle.gather(x, index)
380 381 382 383

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
384
                paddle.gather(x, index_float)
385 386

            self.assertRaises(TypeError, test_index_type)
387 388


389 390 391 392 393 394 395 396
class TestCheckOutType(unittest.TestCase):
    def test_out_type(self):
        data = paddle.static.data(shape=[16, 10], dtype='int64', name='x')
        index = paddle.static.data(shape=[4], dtype='int64', name='index')
        out = paddle.gather(data, index)
        self.assertTrue(out.dtype == core.VarDesc.VarType.INT64)


Z
zchen0211 已提交
397
if __name__ == "__main__":
Z
Zeng Jinle 已提交
398
    paddle.enable_static()
Z
zchen0211 已提交
399
    unittest.main()