test_gather_nd_op.py 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
19
import paddle
20 21 22


class TestGatherNdOpWithEmptyIndex(OpTest):
H
hong 已提交
23
    # Index has empty element, which means copy entire tensor
24 25 26

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
27
        self.python_api = paddle.gather_nd
28
        xnp = np.random.random((5, 20)).astype("float64")
29 30 31 32 33 34
        self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")}
        self.outputs = {
            'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
        }

    def test_check_output(self):
35
        self.check_output(check_eager=False)
36 37

    def test_check_grad(self):
38
        self.check_grad(['X'], 'Out', check_eager=False)
39 40


41 42 43
class TestGatherNdOpWithIndex1(OpTest):
    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
44
        self.python_api = paddle.gather_nd
45 46 47 48 49
        xnp = np.random.random((5, 20)).astype("float64")
        self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")}
        self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}

    def test_check_output(self):
50
        self.check_output(check_eager=False)
51 52

    def test_check_grad(self):
53
        self.check_grad(['X'], 'Out', check_eager=False)
54 55


56
class TestGatherNdOpWithLowIndex(OpTest):
57
    # Index has low rank, X has high rank
58 59 60

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
61
        self.python_api = paddle.gather_nd
Z
zhupengyang 已提交
62
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
63 64 65 66
        index = np.array([[1], [2]]).astype("int64")

        self.inputs = {'X': xnp, 'Index': index}

67 68 69
        self.outputs = {
            'Out': xnp[tuple(index.T)]
        }  # [[14, 25, 1], [76, 22, 3]]
70 71

    def test_check_output(self):
72
        self.check_output(check_eager=False)
73 74

    def test_check_grad(self):
75
        self.check_grad(['X'], 'Out', check_eager=False)
76 77


78
class TestGatherNdOpIndex1(OpTest):
79
    # Index has low rank, X has high rank
80 81 82

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
83
        self.python_api = paddle.gather_nd
84
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
H
hong 已提交
85
        index = np.array([1, 2]).astype("int32")
86 87 88 89 90 91

        self.inputs = {'X': xnp, 'Index': index}

        self.outputs = {'Out': xnp[tuple(index.T)]}

    def test_check_output(self):
92
        self.check_output(check_eager=False)
93 94

    def test_check_grad(self):
95
        self.check_grad(['X'], 'Out', check_eager=False)
96 97


98
class TestGatherNdOpWithSameIndexAsX(OpTest):
99
    # Index has same rank as X's rank
100 101 102

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
103
        self.python_api = paddle.gather_nd
Z
zhupengyang 已提交
104
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
105 106 107
        index = np.array([[1, 1], [2, 1]]).astype("int64")

        self.inputs = {'X': xnp, 'Index': index}
108
        self.outputs = {'Out': xnp[tuple(index.T)]}  # [25, 22]
109 110

    def test_check_output(self):
111
        self.check_output(check_eager=False)
112 113

    def test_check_grad(self):
114
        self.check_grad(['X'], 'Out', check_eager=False)
115 116 117


class TestGatherNdOpWithHighRankSame(OpTest):
118
    # Both Index and X have high rank, and Rank(Index) = Rank(X)
119 120 121

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
122
        self.python_api = paddle.gather_nd
S
ShenLiang 已提交
123
        shape = (5, 2, 3, 1, 10)
124
        xnp = np.random.rand(*shape).astype("float64")
S
ShenLiang 已提交
125
        index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T
126 127 128 129 130

        self.inputs = {'X': xnp, 'Index': index.astype("int32")}
        self.outputs = {'Out': xnp[tuple(index.T)]}

    def test_check_output(self):
131
        self.check_output(check_eager=False)
132 133

    def test_check_grad(self):
134
        self.check_grad(['X'], 'Out', check_eager=False)
135 136 137


class TestGatherNdOpWithHighRankDiff(OpTest):
138
    # Both Index and X have high rank, and Rank(Index) < Rank(X)
139 140 141

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
142
        self.python_api = paddle.gather_nd
S
ShenLiang 已提交
143
        shape = (2, 3, 4, 1, 10)
144
        xnp = np.random.rand(*shape).astype("float64")
S
ShenLiang 已提交
145 146
        index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T
        index_re = index.reshape([20, 5, 2, 5])
147 148

        self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
S
ShenLiang 已提交
149
        self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
150 151

    def test_check_output(self):
152
        self.check_output(check_eager=False)
153 154

    def test_check_grad(self):
155
        self.check_grad(['X'], 'Out', check_eager=False)
156 157


158
# Test Python API
159
class TestGatherNdOpAPI(unittest.TestCase):
160
    def test_case1(self):
161 162 163
        x1 = fluid.layers.data(
            name='x1', shape=[30, 40, 50, 60], dtype='float32'
        )
164
        index1 = fluid.layers.data(name='index1', shape=[2, 4], dtype='int32')
165
        output1 = paddle.gather_nd(x1, index1)
166 167 168 169

    def test_case2(self):
        x2 = fluid.layers.data(name='x2', shape=[30, 40, 50], dtype='float32')
        index2 = fluid.layers.data(name='index2', shape=[2, 2], dtype='int64')
170
        output2 = paddle.gather_nd(x2, index2)
171 172 173 174

    def test_case3(self):
        x3 = fluid.layers.data(name='x3', shape=[3, 4, 5], dtype='float32')
        index3 = fluid.layers.data(name='index3', shape=[2, 1], dtype='int32')
175
        output3 = paddle.gather_nd(x3, index3, name="gather_nd_layer")
176 177


178
# Test Raise Index Error
179
class TestGatherNdOpRaise(unittest.TestCase):
180 181 182
    def test_check_raise(self):
        def check_raise_is_test():
            try:
183 184 185 186 187 188
                x = fluid.layers.data(
                    name='x', shape=[3, 4, 5], dtype='float32'
                )
                index = fluid.layers.data(
                    name='index', shape=[2, 10], dtype='int32'
                )
189
                output = paddle.gather_nd(x, index)
190
            except Exception as e:
191
                t = "Input(Index).shape[-1] should be no greater than Input(X).rank"
192 193 194 195 196 197
                if t in str(e):
                    raise IndexError

        self.assertRaises(IndexError, check_raise_is_test)


198 199
class TestGatherNdError(unittest.TestCase):
    def test_error(self):
200 201 202
        with paddle.static.program_guard(
            paddle.static.Program(), paddle.static.Program()
        ):
203 204

            shape = [8, 9, 6]
205 206
            x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
            index = paddle.fluid.data(shape=shape, dtype='bool', name='index')
207 208 209
            index_float = paddle.fluid.data(
                shape=shape, dtype='float32', name='index_float'
            )
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
            np_x = np.random.random(shape).astype('float32')
            np_index = np.array(np.random.randint(2, size=shape, dtype=bool))

            def test_x_type():
                paddle.gather_nd(np_x, index)

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
                paddle.gather_nd(x, np_index)

            self.assertRaises(TypeError, test_index_type)

            def test_index_dtype():
                paddle.gather_nd(x, index_float)

            self.assertRaises(TypeError, test_index_dtype)


class TestGatherNdAPI2(unittest.TestCase):
    def test_static(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
            data1 = fluid.layers.data('data1', shape=[-1, 2], dtype='float64')
            index = fluid.layers.data('index', shape=[-1, 1], dtype='int32')
            out = paddle.gather_nd(data1, index)
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            input = np.array([[1, 2], [3, 4], [5, 6]])
            index_1 = np.array([[1]])
239 240 241
            (result,) = exe.run(
                feed={"data1": input, "index": index_1}, fetch_list=[out]
            )
242
            expected_output = np.array([[3, 4]])
243
        np.testing.assert_allclose(result, expected_output, rtol=1e-05)
244 245 246 247 248 249 250

    def test_imperative(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([[1]])
        input = fluid.dygraph.to_variable(input_1)
        index = fluid.dygraph.to_variable(index_1)
251
        output = paddle.gather(input, index)
252
        output_np = output.numpy()
253 254
        expected_output = np.array([[3, 4]])
        np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
255 256 257
        paddle.enable_static()


258
if __name__ == "__main__":
H
hong 已提交
259
    paddle.enable_static()
260
    unittest.main()