test_gather_nd_op.py 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17 18
import numpy as np
from op_test import OpTest
19

20
import paddle
21
import paddle.fluid as fluid
22 23 24


class TestGatherNdOpWithEmptyIndex(OpTest):
H
hong 已提交
25
    # Index has empty element, which means copy entire tensor
26 27 28

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
29
        self.python_api = paddle.gather_nd
30
        xnp = np.random.random((5, 20)).astype("float64")
31 32 33 34 35 36
        self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")}
        self.outputs = {
            'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
        }

    def test_check_output(self):
37
        self.check_output(check_eager=False)
38 39

    def test_check_grad(self):
40
        self.check_grad(['X'], 'Out', check_eager=False)
41 42


43 44 45
class TestGatherNdOpWithIndex1(OpTest):
    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
46
        self.python_api = paddle.gather_nd
47 48 49 50 51
        xnp = np.random.random((5, 20)).astype("float64")
        self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")}
        self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}

    def test_check_output(self):
52
        self.check_output(check_eager=False)
53 54

    def test_check_grad(self):
55
        self.check_grad(['X'], 'Out', check_eager=False)
56 57


58
class TestGatherNdOpWithLowIndex(OpTest):
59
    # Index has low rank, X has high rank
60 61 62

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
63
        self.python_api = paddle.gather_nd
Z
zhupengyang 已提交
64
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
65 66 67 68
        index = np.array([[1], [2]]).astype("int64")

        self.inputs = {'X': xnp, 'Index': index}

69 70 71
        self.outputs = {
            'Out': xnp[tuple(index.T)]
        }  # [[14, 25, 1], [76, 22, 3]]
72 73

    def test_check_output(self):
74
        self.check_output(check_eager=False)
75 76

    def test_check_grad(self):
77
        self.check_grad(['X'], 'Out', check_eager=False)
78 79


80
class TestGatherNdOpIndex1(OpTest):
81
    # Index has low rank, X has high rank
82 83 84

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
85
        self.python_api = paddle.gather_nd
86
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
H
hong 已提交
87
        index = np.array([1, 2]).astype("int32")
88 89 90 91 92 93

        self.inputs = {'X': xnp, 'Index': index}

        self.outputs = {'Out': xnp[tuple(index.T)]}

    def test_check_output(self):
94
        self.check_output(check_eager=False)
95 96

    def test_check_grad(self):
97
        self.check_grad(['X'], 'Out', check_eager=False)
98 99


100
class TestGatherNdOpWithSameIndexAsX(OpTest):
101
    # Index has same rank as X's rank
102 103 104

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
105
        self.python_api = paddle.gather_nd
Z
zhupengyang 已提交
106
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
107 108 109
        index = np.array([[1, 1], [2, 1]]).astype("int64")

        self.inputs = {'X': xnp, 'Index': index}
110
        self.outputs = {'Out': xnp[tuple(index.T)]}  # [25, 22]
111 112

    def test_check_output(self):
113
        self.check_output(check_eager=False)
114 115

    def test_check_grad(self):
116
        self.check_grad(['X'], 'Out', check_eager=False)
117 118 119


class TestGatherNdOpWithHighRankSame(OpTest):
120
    # Both Index and X have high rank, and Rank(Index) = Rank(X)
121 122 123

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
124
        self.python_api = paddle.gather_nd
S
ShenLiang 已提交
125
        shape = (5, 2, 3, 1, 10)
126
        xnp = np.random.rand(*shape).astype("float64")
S
ShenLiang 已提交
127
        index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T
128 129 130 131 132

        self.inputs = {'X': xnp, 'Index': index.astype("int32")}
        self.outputs = {'Out': xnp[tuple(index.T)]}

    def test_check_output(self):
133
        self.check_output(check_eager=False)
134 135

    def test_check_grad(self):
136
        self.check_grad(['X'], 'Out', check_eager=False)
137 138 139


class TestGatherNdOpWithHighRankDiff(OpTest):
140
    # Both Index and X have high rank, and Rank(Index) < Rank(X)
141 142 143

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
144
        self.python_api = paddle.gather_nd
S
ShenLiang 已提交
145
        shape = (2, 3, 4, 1, 10)
146
        xnp = np.random.rand(*shape).astype("float64")
S
ShenLiang 已提交
147 148
        index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T
        index_re = index.reshape([20, 5, 2, 5])
149 150

        self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
S
ShenLiang 已提交
151
        self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
152 153

    def test_check_output(self):
154
        self.check_output(check_eager=False)
155 156

    def test_check_grad(self):
157
        self.check_grad(['X'], 'Out', check_eager=False)
158 159


160
# Test Python API
161
class TestGatherNdOpAPI(unittest.TestCase):
162
    def test_case1(self):
G
GGBond8488 已提交
163 164 165 166 167
        x1 = paddle.static.data(
            name='x1', shape=[-1, 30, 40, 50, 60], dtype='float32'
        )
        index1 = paddle.static.data(
            name='index1', shape=[-1, 2, 4], dtype='int32'
168
        )
169
        output1 = paddle.gather_nd(x1, index1)
170 171

    def test_case2(self):
G
GGBond8488 已提交
172 173 174 175 176 177
        x2 = paddle.static.data(
            name='x2', shape=[-1, 30, 40, 50], dtype='float32'
        )
        index2 = paddle.static.data(
            name='index2', shape=[-1, 2, 2], dtype='int64'
        )
178
        output2 = paddle.gather_nd(x2, index2)
179 180

    def test_case3(self):
G
GGBond8488 已提交
181 182 183 184
        x3 = paddle.static.data(name='x3', shape=[-1, 3, 4, 5], dtype='float32')
        index3 = paddle.static.data(
            name='index3', shape=[-1, 2, 1], dtype='int32'
        )
185
        output3 = paddle.gather_nd(x3, index3, name="gather_nd_layer")
186 187


188
# Test Raise Index Error
189
class TestGatherNdOpRaise(unittest.TestCase):
190 191 192
    def test_check_raise(self):
        def check_raise_is_test():
            try:
G
GGBond8488 已提交
193 194
                x = paddle.static.data(
                    name='x', shape=[-1, 3, 4, 5], dtype='float32'
195
                )
G
GGBond8488 已提交
196 197
                index = paddle.static.data(
                    name='index', shape=[-1, 2, 10], dtype='int32'
198
                )
199
                output = paddle.gather_nd(x, index)
200
            except Exception as e:
201
                t = "Input(Index).shape[-1] should be no greater than Input(X).rank"
202 203 204 205 206 207
                if t in str(e):
                    raise IndexError

        self.assertRaises(IndexError, check_raise_is_test)


208 209
class TestGatherNdError(unittest.TestCase):
    def test_error(self):
210 211 212
        with paddle.static.program_guard(
            paddle.static.Program(), paddle.static.Program()
        ):
213 214

            shape = [8, 9, 6]
215 216
            x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
            index = paddle.fluid.data(shape=shape, dtype='bool', name='index')
217 218 219
            index_float = paddle.fluid.data(
                shape=shape, dtype='float32', name='index_float'
            )
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
            np_x = np.random.random(shape).astype('float32')
            np_index = np.array(np.random.randint(2, size=shape, dtype=bool))

            def test_x_type():
                paddle.gather_nd(np_x, index)

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
                paddle.gather_nd(x, np_index)

            self.assertRaises(TypeError, test_index_type)

            def test_index_dtype():
                paddle.gather_nd(x, index_float)

            self.assertRaises(TypeError, test_index_dtype)


class TestGatherNdAPI2(unittest.TestCase):
    def test_static(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
G
GGBond8488 已提交
242 243 244 245
            data1 = paddle.static.data('data1', shape=[-1, 2], dtype='float64')
            data1.desc.set_need_check_feed(False)
            index = paddle.static.data('index', shape=[-1, 1], dtype='int32')
            index.desc.set_need_check_feed(False)
246 247 248 249
            out = paddle.gather_nd(data1, index)
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            input = np.array([[1, 2], [3, 4], [5, 6]])
G
GGBond8488 已提交
250
            index_1 = np.array([[1]]).astype('int32')
251 252 253
            (result,) = exe.run(
                feed={"data1": input, "index": index_1}, fetch_list=[out]
            )
254
            expected_output = np.array([[3, 4]])
255
        np.testing.assert_allclose(result, expected_output, rtol=1e-05)
256

张春乔 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
    def test_static_fp16_with_gpu(self):
        if paddle.fluid.core.is_compiled_with_cuda():
            place = paddle.CUDAPlace(0)
            with paddle.static.program_guard(
                paddle.static.Program(), paddle.static.Program()
            ):
                input = np.array(
                    [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]],
                    dtype='float16',
                )
                index = np.array([[0, 1]], dtype='int32')
                res_np = np.array([[3, 4]], dtype='float16')

                x = paddle.static.data(
                    name="x", shape=[2, 3, 2], dtype="float16"
                )
                x.desc.set_need_check_feed(False)
                idx = paddle.static.data(
                    name="index", shape=[1, 2], dtype="int32"
                )
                idx.desc.set_need_check_feed(False)

                y = paddle.gather_nd(x, idx)

                exe = paddle.static.Executor(place)
                res = exe.run(
                    paddle.static.default_main_program(),
                    feed={"x": input, "index": index},
                    fetch_list=[y],
                )

                np.testing.assert_allclose(res[0], res_np, rtol=1e-05)

290 291 292 293 294 295
    def test_imperative(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([[1]])
        input = fluid.dygraph.to_variable(input_1)
        index = fluid.dygraph.to_variable(index_1)
296
        output = paddle.gather(input, index)
297
        output_np = output.numpy()
298 299
        expected_output = np.array([[3, 4]])
        np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
300 301 302
        paddle.enable_static()


303
if __name__ == "__main__":
H
hong 已提交
304
    paddle.enable_static()
305
    unittest.main()