test_gather_op.py 4.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Z
zchen0211 已提交
17
import unittest
Q
qijun 已提交
18
import numpy as np
19
from op_test import OpTest
20 21
import paddle
import paddle.fluid as fluid
Z
zchen0211 已提交
22 23


Q
qijun 已提交
24
class TestGatherOp(OpTest):
Z
zchen0211 已提交
25
    def setUp(self):
Q
qijun 已提交
26
        self.op_type = "gather"
W
whs 已提交
27
        self.config()
28 29 30 31 32
        xnp = np.random.random(self.x_shape).astype(self.x_type)
        self.inputs = {
            'X': xnp,
            'Index': np.array(self.index).astype(self.index_type)
        }
Q
qijun 已提交
33
        self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
Z
zchen0211 已提交
34

Q
qijun 已提交
35 36
    def test_check_output(self):
        self.check_output()
Z
zchen0211 已提交
37

Q
qijun 已提交
38 39
    def test_check_grad(self):
        self.check_grad(['X'], 'Out')
Z
zchen0211 已提交
40

W
whs 已提交
41
    def config(self):
42 43 44
        """
        For multi-dimension input
        """
W
whs 已提交
45
        self.x_shape = (10, 20)
46
        self.x_type = "float64"
W
whs 已提交
47
        self.index = [1, 3, 5]
48
        self.index_type = "int32"
W
whs 已提交
49 50 51 52


class TestCase1(TestGatherOp):
    def config(self):
53 54 55
        """
        For one dimension input
        """
Z
zhupengyang 已提交
56
        self.x_shape = (100)
57
        self.x_type = "float64"
W
whs 已提交
58
        self.index = [1, 3, 5]
59 60 61 62 63 64 65 66
        self.index_type = "int32"


class TestCase2(TestGatherOp):
    def config(self):
        """
        For int64_t index type
        """
Z
zhupengyang 已提交
67
        self.x_shape = (100)
68
        self.x_type = "float64"
69 70 71 72 73 74 75 76 77 78
        self.index = [1, 3, 5]
        self.index_type = "int64"


class TestCase3(TestGatherOp):
    def config(self):
        """
        For other input type
        """
        self.x_shape = (10, 20)
79
        self.x_type = "float64"
80 81
        self.index = [1, 3, 5]
        self.index_type = "int64"
W
whs 已提交
82

Z
zchen0211 已提交
83

84 85 86 87 88 89 90 91 92 93 94 95 96
class TestCase4(TestGatherOp):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': False}
        self.x_type = "double"
        self.index = [1, 1]
        self.index_type = "int32"


class TestCase5(TestGatherOp):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': False}
97
        self.x_type = "float64"
98 99 100 101 102 103 104 105
        self.index = [1, 1, 3]
        self.index_type = "int32"


class TestCase6(TestGatherOp):
    def config(self):
        self.x_shape = (10, 20)
        self.attrs = {'overwrite': True}
106
        self.x_type = "float64"
107 108 109 110
        self.index = [1, 3]
        self.index_type = "int32"


111 112 113 114 115
class API_TestGather(unittest.TestCase):
    def test_out(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
            data1 = fluid.layers.data('data1', shape=[-1, 2], dtype='float64')
            index = fluid.layers.data('index', shape=[-1, 1], dtype='float64')
116
            out = fluid.layers.gather(data1, index)
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            input = np.array([[1, 2], [3, 4], [5, 6]])
            index_1 = np.array([1, 2])
            result, = exe.run(feed={"data1": input,
                                    "index": index_1},
                              fetch_list=[out])
            expected_output = np.array([[3, 4], [5, 6]])
        self.assertTrue(np.allclose(result, expected_output))


class API_TestDygraphGather(unittest.TestCase):
    def test_out(self):
        with fluid.dygraph.guard():
            input_1 = np.array([[1, 2], [3, 4], [5, 6]])
            index_1 = np.array([1, 2])
            input = fluid.dygraph.to_variable(input_1)
            index = fluid.dygraph.to_variable(index_1)
            output = paddle.fluid.layers.gather(input, index)
            output_np = output.numpy()
            expected_output = np.array([[3, 4], [5, 6]])
        self.assertTrue(np.allclose(output_np, expected_output))


Z
zchen0211 已提交
141 142
if __name__ == "__main__":
    unittest.main()