test_gather_nd_op.py 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
21
import paddle
22 23 24


class TestGatherNdOpWithEmptyIndex(OpTest):
H
hong 已提交
25
    # Index has empty element, which means copy entire tensor
26 27 28

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
29
        self.python_api = paddle.gather_nd
30
        xnp = np.random.random((5, 20)).astype("float64")
31 32 33 34 35 36
        self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")}
        self.outputs = {
            'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
        }

    def test_check_output(self):
37
        self.check_output(check_eager=False)
38 39

    def test_check_grad(self):
40
        self.check_grad(['X'], 'Out', check_eager=False)
41 42


43
class TestGatherNdOpWithIndex1(OpTest):
44

45 46
    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
47
        self.python_api = paddle.gather_nd
48 49 50 51 52
        xnp = np.random.random((5, 20)).astype("float64")
        self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")}
        self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}

    def test_check_output(self):
53
        self.check_output(check_eager=False)
54 55

    def test_check_grad(self):
56
        self.check_grad(['X'], 'Out', check_eager=False)
57 58


59
class TestGatherNdOpWithLowIndex(OpTest):
60
    #Index has low rank, X has high rank
61 62 63

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
64
        self.python_api = paddle.gather_nd
Z
zhupengyang 已提交
65
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
66 67 68 69 70 71 72
        index = np.array([[1], [2]]).astype("int64")

        self.inputs = {'X': xnp, 'Index': index}

        self.outputs = {'Out': xnp[tuple(index.T)]}  #[[14, 25, 1], [76, 22, 3]]

    def test_check_output(self):
73
        self.check_output(check_eager=False)
74 75

    def test_check_grad(self):
76
        self.check_grad(['X'], 'Out', check_eager=False)
77 78


79 80 81 82 83
class TestGatherNdOpIndex1(OpTest):
    #Index has low rank, X has high rank

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
84
        self.python_api = paddle.gather_nd
85
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
H
hong 已提交
86
        index = np.array([1, 2]).astype("int32")
87 88 89 90 91 92

        self.inputs = {'X': xnp, 'Index': index}

        self.outputs = {'Out': xnp[tuple(index.T)]}

    def test_check_output(self):
93
        self.check_output(check_eager=False)
94 95

    def test_check_grad(self):
96
        self.check_grad(['X'], 'Out', check_eager=False)
97 98


99
class TestGatherNdOpWithSameIndexAsX(OpTest):
100
    #Index has same rank as X's rank
101 102 103

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
104
        self.python_api = paddle.gather_nd
Z
zhupengyang 已提交
105
        xnp = np.random.uniform(0, 100, (10, 10)).astype("float64")
106 107 108 109 110 111
        index = np.array([[1, 1], [2, 1]]).astype("int64")

        self.inputs = {'X': xnp, 'Index': index}
        self.outputs = {'Out': xnp[tuple(index.T)]}  #[25, 22]

    def test_check_output(self):
112
        self.check_output(check_eager=False)
113 114

    def test_check_grad(self):
115
        self.check_grad(['X'], 'Out', check_eager=False)
116 117 118


class TestGatherNdOpWithHighRankSame(OpTest):
119
    #Both Index and X have high rank, and Rank(Index) = Rank(X)
120 121 122

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
123
        self.python_api = paddle.gather_nd
S
ShenLiang 已提交
124
        shape = (5, 2, 3, 1, 10)
125
        xnp = np.random.rand(*shape).astype("float64")
S
ShenLiang 已提交
126
        index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T
127 128 129 130 131

        self.inputs = {'X': xnp, 'Index': index.astype("int32")}
        self.outputs = {'Out': xnp[tuple(index.T)]}

    def test_check_output(self):
132
        self.check_output(check_eager=False)
133 134

    def test_check_grad(self):
135
        self.check_grad(['X'], 'Out', check_eager=False)
136 137 138


class TestGatherNdOpWithHighRankDiff(OpTest):
139
    #Both Index and X have high rank, and Rank(Index) < Rank(X)
140 141 142

    def setUp(self):
        self.op_type = "gather_nd"
H
hong 已提交
143
        self.python_api = paddle.gather_nd
S
ShenLiang 已提交
144
        shape = (2, 3, 4, 1, 10)
145
        xnp = np.random.rand(*shape).astype("float64")
S
ShenLiang 已提交
146 147
        index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T
        index_re = index.reshape([20, 5, 2, 5])
148 149

        self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
S
ShenLiang 已提交
150
        self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
151 152

    def test_check_output(self):
153
        self.check_output(check_eager=False)
154 155

    def test_check_grad(self):
156
        self.check_grad(['X'], 'Out', check_eager=False)
157 158 159


#Test Python API
160
class TestGatherNdOpAPI(unittest.TestCase):
161

162
    def test_case1(self):
163 164 165
        x1 = fluid.layers.data(name='x1',
                               shape=[30, 40, 50, 60],
                               dtype='float32')
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
        index1 = fluid.layers.data(name='index1', shape=[2, 4], dtype='int32')
        output1 = fluid.layers.gather_nd(x1, index1)

    def test_case2(self):
        x2 = fluid.layers.data(name='x2', shape=[30, 40, 50], dtype='float32')
        index2 = fluid.layers.data(name='index2', shape=[2, 2], dtype='int64')
        output2 = fluid.layers.gather_nd(x2, index2)

    def test_case3(self):
        x3 = fluid.layers.data(name='x3', shape=[3, 4, 5], dtype='float32')
        index3 = fluid.layers.data(name='index3', shape=[2, 1], dtype='int32')
        output3 = fluid.layers.gather_nd(x3, index3, name="gather_nd_layer")


#Test Raise Index Error
181
class TestGatherNdOpRaise(unittest.TestCase):
182

183
    def test_check_raise(self):
184

185 186
        def check_raise_is_test():
            try:
187 188 189 190 191 192
                x = fluid.layers.data(name='x',
                                      shape=[3, 4, 5],
                                      dtype='float32')
                index = fluid.layers.data(name='index',
                                          shape=[2, 10],
                                          dtype='int32')
193 194 195
                output = fluid.layers.gather_nd(x, index)
            except Exception as e:
                t = \
196
                "Input(Index).shape[-1] should be no greater than Input(X).rank"
197 198 199 200 201 202
                if t in str(e):
                    raise IndexError

        self.assertRaises(IndexError, check_raise_is_test)


203
class TestGatherNdError(unittest.TestCase):
204

205 206 207 208 209
    def test_error(self):
        with paddle.static.program_guard(paddle.static.Program(),
                                         paddle.static.Program()):

            shape = [8, 9, 6]
210 211
            x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
            index = paddle.fluid.data(shape=shape, dtype='bool', name='index')
212 213 214
            index_float = paddle.fluid.data(shape=shape,
                                            dtype='float32',
                                            name='index_float')
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
            np_x = np.random.random(shape).astype('float32')
            np_index = np.array(np.random.randint(2, size=shape, dtype=bool))

            def test_x_type():
                paddle.gather_nd(np_x, index)

            self.assertRaises(TypeError, test_x_type)

            def test_index_type():
                paddle.gather_nd(x, np_index)

            self.assertRaises(TypeError, test_index_type)

            def test_index_dtype():
                paddle.gather_nd(x, index_float)

            self.assertRaises(TypeError, test_index_dtype)


class TestGatherNdAPI2(unittest.TestCase):
235

236 237 238 239 240 241 242 243 244
    def test_static(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
            data1 = fluid.layers.data('data1', shape=[-1, 2], dtype='float64')
            index = fluid.layers.data('index', shape=[-1, 1], dtype='int32')
            out = paddle.gather_nd(data1, index)
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            input = np.array([[1, 2], [3, 4], [5, 6]])
            index_1 = np.array([[1]])
245 246 247 248
            result, = exe.run(feed={
                "data1": input,
                "index": index_1
            },
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
                              fetch_list=[out])
            expected_output = np.array([[3, 4]])
        self.assertTrue(np.allclose(result, expected_output))

    def test_imperative(self):
        paddle.disable_static()
        input_1 = np.array([[1, 2], [3, 4], [5, 6]])
        index_1 = np.array([[1]])
        input = fluid.dygraph.to_variable(input_1)
        index = fluid.dygraph.to_variable(index_1)
        output = paddle.fluid.layers.gather(input, index)
        output_np = output.numpy()
        expected_output = np.array([3, 4])
        self.assertTrue(np.allclose(output_np, expected_output))
        paddle.enable_static()


266
if __name__ == "__main__":
H
hong 已提交
267
    paddle.enable_static()
268
    unittest.main()