test_squeeze_op.py 7.0 KB
Newer Older
1
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from __future__ import print_function
16
import unittest
17

18
import numpy as np
19 20

import paddle
21 22
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
23 24
from op_test import OpTest, convert_float_to_uint16
import paddle.fluid.core as core
25

26
paddle.enable_static()
27 28 29


# Correct: General.
C
chenweihang 已提交
30
class TestSqueezeOp(OpTest):
31

32
    def setUp(self):
33
        self.op_type = "squeeze"
C
chenweihang 已提交
34
        self.init_test_case()
35
        self.inputs = {"X": np.random.random(self.ori_shape).astype("float64")}
C
chenweihang 已提交
36
        self.init_attrs()
37 38 39
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.new_shape),
        }
40 41

    def test_check_output(self):
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
        self.check_output()

    def test_check_grad(self):
        self.check_grad(["X"], "Out")

    def init_test_case(self):
        self.ori_shape = (1, 3, 1, 40)
        self.axes = (0, 2)
        self.new_shape = (3, 40)

    def init_attrs(self):
        self.attrs = {"axes": self.axes}


class TestSqueezeBF16Op(OpTest):
57

58 59 60 61 62 63 64 65 66 67 68
    def setUp(self):
        self.op_type = "squeeze"
        self.dtype = np.uint16
        self.init_test_case()
        x = np.random.random(self.ori_shape).astype("float32")
        out = x.reshape(self.new_shape)
        self.inputs = {"X": convert_float_to_uint16(x)}
        self.init_attrs()
        self.outputs = {"Out": convert_float_to_uint16(out)}

    def test_check_output(self):
69
        self.check_output()
70 71 72 73

    def test_check_grad(self):
        self.check_grad(["X"], "Out")

C
chenweihang 已提交
74
    def init_test_case(self):
Z
zhupengyang 已提交
75
        self.ori_shape = (1, 3, 1, 40)
C
chenweihang 已提交
76
        self.axes = (0, 2)
Z
zhupengyang 已提交
77
        self.new_shape = (3, 40)
78

C
chenweihang 已提交
79
    def init_attrs(self):
80
        self.attrs = {"axes": self.axes}
81 82


C
chenweihang 已提交
83 84
# Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp):
85

C
chenweihang 已提交
86
    def init_test_case(self):
Z
zhupengyang 已提交
87
        self.ori_shape = (1, 3, 1, 40)
C
chenweihang 已提交
88
        self.axes = (0, -2)
Z
zhupengyang 已提交
89
        self.new_shape = (3, 40)
90 91 92


# Correct: No axes input.
C
chenweihang 已提交
93
class TestSqueezeOp2(TestSqueezeOp):
94

C
chenweihang 已提交
95
    def init_test_case(self):
Z
zhupengyang 已提交
96
        self.ori_shape = (1, 20, 1, 5)
C
chenweihang 已提交
97
        self.axes = ()
Z
zhupengyang 已提交
98
        self.new_shape = (20, 5)
99 100


101
# Correct: Just part of axes be squeezed.
C
chenweihang 已提交
102
class TestSqueezeOp3(TestSqueezeOp):
103

C
chenweihang 已提交
104
    def init_test_case(self):
Z
zhupengyang 已提交
105
        self.ori_shape = (6, 1, 5, 1, 4, 1)
C
chenweihang 已提交
106
        self.axes = (1, -1)
Z
zhupengyang 已提交
107
        self.new_shape = (6, 5, 1, 4)
108 109


L
Leo Chen 已提交
110 111
# Correct: The demension of axis is not of size 1 remains unchanged.
class TestSqueezeOp4(TestSqueezeOp):
112

L
Leo Chen 已提交
113 114 115 116 117 118
    def init_test_case(self):
        self.ori_shape = (6, 1, 5, 1, 4, 1)
        self.axes = (1, 2)
        self.new_shape = (6, 5, 1, 4, 1)


119
class TestSqueezeOpError(unittest.TestCase):
120

121
    def test_errors(self):
122
        paddle.enable_static()
123 124
        with program_guard(Program(), Program()):
            # The input type of softmax_op must be Variable.
125 126
            x1 = fluid.create_lod_tensor(np.array([[-1]]), [[1]],
                                         paddle.CPUPlace())
127
            self.assertRaises(TypeError, paddle.squeeze, x1)
128
            # The input axes of squeeze must be list.
129 130
            x2 = paddle.static.data(name='x2', shape=[4], dtype="int32")
            self.assertRaises(TypeError, paddle.squeeze, x2, axes=0)
131
            # The input dtype of squeeze not support float16.
132 133
            x3 = paddle.static.data(name='x3', shape=[4], dtype="float16")
            self.assertRaises(TypeError, paddle.squeeze, x3, axes=0)
134 135


136
class API_TestSqueeze(unittest.TestCase):
137

138 139 140 141 142 143
    def setUp(self):
        self.executed_api()

    def executed_api(self):
        self.squeeze = paddle.squeeze

144
    def test_out(self):
145 146 147
        paddle.enable_static()
        with paddle.static.program_guard(paddle.static.Program(),
                                         paddle.static.Program()):
148 149 150
            data1 = paddle.static.data('data1',
                                       shape=[-1, 1, 10],
                                       dtype='float64')
151
            result_squeeze = self.squeeze(data1, axis=[1])
152 153
            place = paddle.CPUPlace()
            exe = paddle.static.Executor(place)
154 155 156 157 158 159 160
            input1 = np.random.random([5, 1, 10]).astype('float64')
            result, = exe.run(feed={"data1": input1},
                              fetch_list=[result_squeeze])
            expected_result = np.squeeze(input1, axis=1)
            self.assertTrue(np.allclose(expected_result, result))


161
class API_TestStaticSqueeze_(API_TestSqueeze):
162

163 164 165 166
    def executed_api(self):
        self.squeeze = paddle.squeeze_


167
class API_TestDygraphSqueeze(unittest.TestCase):
168

169 170 171 172 173 174
    def setUp(self):
        self.executed_api()

    def executed_api(self):
        self.squeeze = paddle.squeeze

175
    def test_out(self):
176 177 178
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("int32")
        input = paddle.to_tensor(input_1)
179
        output = self.squeeze(input, axis=[1])
180 181 182 183 184 185 186 187
        out_np = output.numpy()
        expected_out = np.squeeze(input_1, axis=1)
        self.assertTrue(np.allclose(expected_out, out_np))

    def test_out_int8(self):
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("int8")
        input = paddle.to_tensor(input_1)
188
        output = self.squeeze(input, axis=[1])
189 190 191 192 193 194 195 196
        out_np = output.numpy()
        expected_out = np.squeeze(input_1, axis=1)
        self.assertTrue(np.allclose(expected_out, out_np))

    def test_out_uint8(self):
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("uint8")
        input = paddle.to_tensor(input_1)
197
        output = self.squeeze(input, axis=[1])
198 199 200
        out_np = output.numpy()
        expected_out = np.squeeze(input_1, axis=1)
        self.assertTrue(np.allclose(expected_out, out_np))
L
Leo Chen 已提交
201 202

    def test_axis_not_list(self):
203 204 205
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("int32")
        input = paddle.to_tensor(input_1)
206
        output = self.squeeze(input, axis=1)
207 208 209
        out_np = output.numpy()
        expected_out = np.squeeze(input_1, axis=1)
        self.assertTrue(np.allclose(expected_out, out_np))
L
Leo Chen 已提交
210 211

    def test_dimension_not_1(self):
212 213 214
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("int32")
        input = paddle.to_tensor(input_1)
215
        output = self.squeeze(input, axis=(1, 0))
216 217 218
        out_np = output.numpy()
        expected_out = np.squeeze(input_1, axis=1)
        self.assertTrue(np.allclose(expected_out, out_np))
219 220


221
class API_TestDygraphSqueezeInplace(API_TestDygraphSqueeze):
222

223 224 225 226
    def executed_api(self):
        self.squeeze = paddle.squeeze_


227 228
if __name__ == "__main__":
    unittest.main()