test_gather_nd_op.py 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
19
import paddle
20 21 22


class TestGatherNdOpWithEmptyIndex(OpTest):
H
hong 已提交
23
    # Index has empty element, which means copy entire tensor
24 25 26

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
27
        self.python_api = paddle.gather_nd
28
        xnp = np.random.random((5, 20)).astype("float64")
29 30 31 32 33 34
        self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")}
        self.outputs = {
            'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
        }

    def test_check_output(self):
35
        self.check_output(check_eager=False)
36 37

    def test_check_grad(self):
38
        self.check_grad(['X'], 'Out', check_eager=False)
39 40


41
class TestGatherNdOpWithIndex1(OpTest):
42

43 44
    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
45
        self.python_api = paddle.gather_nd
46 47 48 49 50
        xnp = np.random.random((5, 20)).astype("float64")
        self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")}
        self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}

    def test_check_output(self):
51
        self.check_output(check_eager=False)
52 53

    def test_check_grad(self):
54
        self.check_grad(['X'], 'Out', check_eager=False)
55 56


57
class TestGatherNdOpWithLowIndex(OpTest):
58
    #Index has low rank, X has high rank
59 60 61

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
62
        self.python_api = paddle.gather_nd
Z
zhupengyang 已提交
63
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
64 65 66 67 68 69 70
        index = np.array([[1], [2]]).astype("int64")

        self.inputs = {'X': xnp, 'Index': index}

        self.outputs = {'Out': xnp[tuple(index.T)]}  #[[14, 25, 1], [76, 22, 3]]

    def test_check_output(self):
71
        self.check_output(check_eager=False)
72 73

    def test_check_grad(self):
74
        self.check_grad(['X'], 'Out', check_eager=False)
75 76


77 78 79 80 81
class TestGatherNdOpIndex1(OpTest):
    #Index has low rank, X has high rank

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
82
        self.python_api = paddle.gather_nd
83
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
H
hong 已提交
84
        index = np.array([1, 2]).astype("int32")
85 86 87 88 89 90

        self.inputs = {'X': xnp, 'Index': index}

        self.outputs = {'Out': xnp[tuple(index.T)]}

    def test_check_output(self):
91
        self.check_output(check_eager=False)
92 93

    def test_check_grad(self):
94
        self.check_grad(['X'], 'Out', check_eager=False)
95 96


97
class TestGatherNdOpWithSameIndexAsX(OpTest):
98
    #Index has same rank as X's rank
99 100 101

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
102
        self.python_api = paddle.gather_nd
Z
zhupengyang 已提交
103
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
104 105 106 107 108 109
        index = np.array([[1, 1], [2, 1]]).astype("int64")

        self.inputs = {'X': xnp, 'Index': index}
        self.outputs = {'Out': xnp[tuple(index.T)]}  #[25, 22]

    def test_check_output(self):
110
        self.check_output(check_eager=False)
111 112

    def test_check_grad(self):
113
        self.check_grad(['X'], 'Out', check_eager=False)
114 115 116


class TestGatherNdOpWithHighRankSame(OpTest):
117
    #Both Index and X have high rank, and Rank(Index) = Rank(X)
118 119 120

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
121
        self.python_api = paddle.gather_nd
S
ShenLiang 已提交
122
        shape = (5, 2, 3, 1, 10)
123
        xnp = np.random.rand(*shape).astype("float64")
S
ShenLiang 已提交
124
        index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T
125 126 127 128 129

        self.inputs = {'X': xnp, 'Index': index.astype("int32")}
        self.outputs = {'Out': xnp[tuple(index.T)]}

    def test_check_output(self):
130
        self.check_output(check_eager=False)
131 132

    def test_check_grad(self):
133
        self.check_grad(['X'], 'Out', check_eager=False)
134 135 136


class TestGatherNdOpWithHighRankDiff(OpTest):
137
    #Both Index and X have high rank, and Rank(Index) < Rank(X)
138 139 140

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
141
        self.python_api = paddle.gather_nd
S
ShenLiang 已提交
142
        shape = (2, 3, 4, 1, 10)
143
        xnp = np.random.rand(*shape).astype("float64")
S
ShenLiang 已提交
144 145
        index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T
        index_re = index.reshape([20, 5, 2, 5])
146 147

        self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
S
ShenLiang 已提交
148
        self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
149 150

    def test_check_output(self):
151
        self.check_output(check_eager=False)
152 153

    def test_check_grad(self):
154
        self.check_grad(['X'], 'Out', check_eager=False)
155 156 157


#Test Python API
158
class TestGatherNdOpAPI(unittest.TestCase):
159

160
    def test_case1(self):
161 162 163
        x1 = fluid.layers.data(name='x1',
                               shape=[30, 40, 50, 60],
                               dtype='float32')
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        index1 = fluid.layers.data(name='index1', shape=[2, 4], dtype='int32')
        output1 = fluid.layers.gather_nd(x1, index1)

    def test_case2(self):
        x2 = fluid.layers.data(name='x2', shape=[30, 40, 50], dtype='float32')
        index2 = fluid.layers.data(name='index2', shape=[2, 2], dtype='int64')
        output2 = fluid.layers.gather_nd(x2, index2)

    def test_case3(self):
        x3 = fluid.layers.data(name='x3', shape=[3, 4, 5], dtype='float32')
        index3 = fluid.layers.data(name='index3', shape=[2, 1], dtype='int32')
        output3 = fluid.layers.gather_nd(x3, index3, name="gather_nd_layer")


#Test Raise Index Error
179
class TestGatherNdOpRaise(unittest.TestCase):
180

181
    def test_check_raise(self):
182

183 184
        def check_raise_is_test():
            try:
185 186 187 188 189 190
                x = fluid.layers.data(name='x',
                                      shape=[3, 4, 5],
                                      dtype='float32')
                index = fluid.layers.data(name='index',
                                          shape=[2, 10],
                                          dtype='int32')
191 192 193
                output = fluid.layers.gather_nd(x, index)
            except Exception as e:
                t = \
194
                "Input(Index).shape[-1] should be no greater than Input(X).rank"
195 196 197 198 199 200
                if t in str(e):
                    raise IndexError

        self.assertRaises(IndexError, check_raise_is_test)


201
class TestGatherNdError(unittest.TestCase):
202

203 204 205 206 207
    def test_error(self):
        with paddle.static.program_guard(paddle.static.Program(),
                                         paddle.static.Program()):

            shape = [8, 9, 6]
208 209
            x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
            index = paddle.fluid.data(shape=shape, dtype='bool', name='index')
210 211 212
            index_float = paddle.fluid.data(shape=shape,
                                            dtype='float32',
                                            name='index_float')
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
            np_x = np.random.random(shape).astype('float32')
            np_index = np.array(np.random.randint(2, size=shape, dtype=bool))

            def test_x_type():
                paddle.gather_nd(np_x, index)

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
                paddle.gather_nd(x, np_index)

            self.assertRaises(TypeError, test_index_type)

            def test_index_dtype():
                paddle.gather_nd(x, index_float)

            self.assertRaises(TypeError, test_index_dtype)


class TestGatherNdAPI2(unittest.TestCase):
233

234 235 236 237 238 239 240 241 242
    def test_static(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
            data1 = fluid.layers.data('data1', shape=[-1, 2], dtype='float64')
            index = fluid.layers.data('index', shape=[-1, 1], dtype='int32')
            out = paddle.gather_nd(data1, index)
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            input = np.array([[1, 2], [3, 4], [5, 6]])
            index_1 = np.array([[1]])
243 244 245 246
            result, = exe.run(feed={
                "data1": input,
                "index": index_1
            },
247 248
                              fetch_list=[out])
            expected_output = np.array([[3, 4]])
249
        np.testing.assert_allclose(result, expected_output, rtol=1e-05)
250 251 252 253 254 255 256 257 258

    def test_imperative(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([[1]])
        input = fluid.dygraph.to_variable(input_1)
        index = fluid.dygraph.to_variable(index_1)
        output = paddle.fluid.layers.gather(input, index)
        output_np = output.numpy()
259 260
        expected_output = np.array([[3, 4]])
        np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
261 262 263
        paddle.enable_static()


264
if __name__ == "__main__":
H
hong 已提交
265
    paddle.enable_static()
266
    unittest.main()