test_gather_nd_op.py 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17 18
import numpy as np
from op_test import OpTest
19

20
import paddle
21
import paddle.fluid as fluid
22 23 24


class TestGatherNdOpWithEmptyIndex(OpTest):
H
hong 已提交
25
    # Index has empty element, which means copy entire tensor
26 27 28

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
29
        self.python_api = paddle.gather_nd
30
        xnp = np.random.random((5, 20)).astype("float64")
31 32 33 34 35 36
        self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")}
        self.outputs = {
            'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
        }

    def test_check_output(self):
37
        self.check_output(check_eager=False)
38 39

    def test_check_grad(self):
40
        self.check_grad(['X'], 'Out', check_eager=False)
41 42


43 44 45
class TestGatherNdOpWithIndex1(OpTest):
    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
46
        self.python_api = paddle.gather_nd
47 48 49 50 51
        xnp = np.random.random((5, 20)).astype("float64")
        self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")}
        self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}

    def test_check_output(self):
52
        self.check_output(check_eager=False)
53 54

    def test_check_grad(self):
55
        self.check_grad(['X'], 'Out', check_eager=False)
56 57


58
class TestGatherNdOpWithLowIndex(OpTest):
59
    # Index has low rank, X has high rank
60 61 62

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
63
        self.python_api = paddle.gather_nd
Z
zhupengyang 已提交
64
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
65 66 67 68
        index = np.array([[1], [2]]).astype("int64")

        self.inputs = {'X': xnp, 'Index': index}

69 70 71
        self.outputs = {
            'Out': xnp[tuple(index.T)]
        }  # [[14, 25, 1], [76, 22, 3]]
72 73

    def test_check_output(self):
74
        self.check_output(check_eager=False)
75 76

    def test_check_grad(self):
77
        self.check_grad(['X'], 'Out', check_eager=False)
78 79


80
class TestGatherNdOpIndex1(OpTest):
81
    # Index has low rank, X has high rank
82 83 84

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
85
        self.python_api = paddle.gather_nd
86
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
H
hong 已提交
87
        index = np.array([1, 2]).astype("int32")
88 89 90 91 92 93

        self.inputs = {'X': xnp, 'Index': index}

        self.outputs = {'Out': xnp[tuple(index.T)]}

    def test_check_output(self):
94
        self.check_output(check_eager=False)
95 96

    def test_check_grad(self):
97
        self.check_grad(['X'], 'Out', check_eager=False)
98 99


100
class TestGatherNdOpWithSameIndexAsX(OpTest):
101
    # Index has same rank as X's rank
102 103 104

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
105
        self.python_api = paddle.gather_nd
Z
zhupengyang 已提交
106
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
107 108 109
        index = np.array([[1, 1], [2, 1]]).astype("int64")

        self.inputs = {'X': xnp, 'Index': index}
110
        self.outputs = {'Out': xnp[tuple(index.T)]}  # [25, 22]
111 112

    def test_check_output(self):
113
        self.check_output(check_eager=False)
114 115

    def test_check_grad(self):
116
        self.check_grad(['X'], 'Out', check_eager=False)
117 118 119


class TestGatherNdOpWithHighRankSame(OpTest):
120
    # Both Index and X have high rank, and Rank(Index) = Rank(X)
121 122 123

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
124
        self.python_api = paddle.gather_nd
S
ShenLiang 已提交
125
        shape = (5, 2, 3, 1, 10)
126
        xnp = np.random.rand(*shape).astype("float64")
S
ShenLiang 已提交
127
        index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T
128 129 130 131 132

        self.inputs = {'X': xnp, 'Index': index.astype("int32")}
        self.outputs = {'Out': xnp[tuple(index.T)]}

    def test_check_output(self):
133
        self.check_output(check_eager=False)
134 135

    def test_check_grad(self):
136
        self.check_grad(['X'], 'Out', check_eager=False)
137 138 139


class TestGatherNdOpWithHighRankDiff(OpTest):
140
    # Both Index and X have high rank, and Rank(Index) < Rank(X)
141 142 143

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
144
        self.python_api = paddle.gather_nd
S
ShenLiang 已提交
145
        shape = (2, 3, 4, 1, 10)
146
        xnp = np.random.rand(*shape).astype("float64")
S
ShenLiang 已提交
147 148
        index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T
        index_re = index.reshape([20, 5, 2, 5])
149 150

        self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
S
ShenLiang 已提交
151
        self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
152 153

    def test_check_output(self):
154
        self.check_output(check_eager=False)
155 156

    def test_check_grad(self):
157
        self.check_grad(['X'], 'Out', check_eager=False)
158 159


160
# Test Python API
161
class TestGatherNdOpAPI(unittest.TestCase):
162
    def test_case1(self):
G
GGBond8488 已提交
163 164 165 166 167
        x1 = paddle.static.data(
            name='x1', shape=[-1, 30, 40, 50, 60], dtype='float32'
        )
        index1 = paddle.static.data(
            name='index1', shape=[-1, 2, 4], dtype='int32'
168
        )
169
        output1 = paddle.gather_nd(x1, index1)
170 171

    def test_case2(self):
G
GGBond8488 已提交
172 173 174 175 176 177
        x2 = paddle.static.data(
            name='x2', shape=[-1, 30, 40, 50], dtype='float32'
        )
        index2 = paddle.static.data(
            name='index2', shape=[-1, 2, 2], dtype='int64'
        )
178
        output2 = paddle.gather_nd(x2, index2)
179 180

    def test_case3(self):
G
GGBond8488 已提交
181 182 183 184
        x3 = paddle.static.data(name='x3', shape=[-1, 3, 4, 5], dtype='float32')
        index3 = paddle.static.data(
            name='index3', shape=[-1, 2, 1], dtype='int32'
        )
185
        output3 = paddle.gather_nd(x3, index3, name="gather_nd_layer")
186 187


188
# Test Raise Index Error
189
class TestGatherNdOpRaise(unittest.TestCase):
190 191 192
    def test_check_raise(self):
        def check_raise_is_test():
            try:
G
GGBond8488 已提交
193 194
                x = paddle.static.data(
                    name='x', shape=[-1, 3, 4, 5], dtype='float32'
195
                )
G
GGBond8488 已提交
196 197
                index = paddle.static.data(
                    name='index', shape=[-1, 2, 10], dtype='int32'
198
                )
199
                output = paddle.gather_nd(x, index)
200
            except Exception as e:
201
                t = "Input(Index).shape[-1] should be no greater than Input(X).rank"
202 203 204 205 206 207
                if t in str(e):
                    raise IndexError

        self.assertRaises(IndexError, check_raise_is_test)


208 209
class TestGatherNdError(unittest.TestCase):
    def test_error(self):
210 211 212
        with paddle.static.program_guard(
            paddle.static.Program(), paddle.static.Program()
        ):
213 214

            shape = [8, 9, 6]
215 216
            x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
            index = paddle.fluid.data(shape=shape, dtype='bool', name='index')
217 218 219
            index_float = paddle.fluid.data(
                shape=shape, dtype='float32', name='index_float'
            )
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
            np_x = np.random.random(shape).astype('float32')
            np_index = np.array(np.random.randint(2, size=shape, dtype=bool))

            def test_x_type():
                paddle.gather_nd(np_x, index)

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
                paddle.gather_nd(x, np_index)

            self.assertRaises(TypeError, test_index_type)

            def test_index_dtype():
                paddle.gather_nd(x, index_float)

            self.assertRaises(TypeError, test_index_dtype)


class TestGatherNdAPI2(unittest.TestCase):
    def test_static(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
G
GGBond8488 已提交
242 243 244 245
            data1 = paddle.static.data('data1', shape=[-1, 2], dtype='float64')
            data1.desc.set_need_check_feed(False)
            index = paddle.static.data('index', shape=[-1, 1], dtype='int32')
            index.desc.set_need_check_feed(False)
246 247 248 249
            out = paddle.gather_nd(data1, index)
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            input = np.array([[1, 2], [3, 4], [5, 6]])
G
GGBond8488 已提交
250
            index_1 = np.array([[1]]).astype('int32')
251 252 253
            (result,) = exe.run(
                feed={"data1": input, "index": index_1}, fetch_list=[out]
            )
254
            expected_output = np.array([[3, 4]])
255
        np.testing.assert_allclose(result, expected_output, rtol=1e-05)
256 257 258 259 260 261 262

    def test_imperative(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([[1]])
        input = fluid.dygraph.to_variable(input_1)
        index = fluid.dygraph.to_variable(index_1)
263
        output = paddle.gather(input, index)
264
        output_np = output.numpy()
265 266
        expected_output = np.array([[3, 4]])
        np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
267 268 269
        paddle.enable_static()


270
if __name__ == "__main__":
H
hong 已提交
271
    paddle.enable_static()
272
    unittest.main()