test_unsqueeze_op.py 10.4 KB
Newer Older
1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from __future__ import print_function
16
import unittest
17

18
import numpy as np
19

20 21
import paddle
import paddle.fluid as fluid
22 23
from op_test import OpTest, convert_float_to_uint16
import paddle.fluid.core as core
24

25
paddle.enable_static()
26 27 28


# Correct: General.
29
class TestUnsqueezeOp(OpTest):
30

31
    def setUp(self):
32
        self.init_test_case()
33
        self.op_type = "unsqueeze"
34
        self.inputs = {"X": np.random.random(self.ori_shape).astype("float64")}
C
chenweihang 已提交
35
        self.init_attrs()
36
        self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
37 38

    def test_check_output(self):
39
        self.check_output()
40 41 42 43

    def test_check_grad(self):
        self.check_grad(["X"], "Out")

44
    def init_test_case(self):
Z
zhupengyang 已提交
45
        self.ori_shape = (3, 40)
46
        self.axes = (1, 2)
Z
zhupengyang 已提交
47
        self.new_shape = (3, 1, 1, 40)
48

C
chenweihang 已提交
49
    def init_attrs(self):
50
        self.attrs = {"axes": self.axes}
C
chenweihang 已提交
51

52

53
class TestUnsqueezeBF16Op(OpTest):
54

55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    def setUp(self):
        self.init_test_case()
        self.op_type = "unsqueeze"
        self.dtype = np.uint16
        x = np.random.random(self.ori_shape).astype("float32")
        out = x.reshape(self.new_shape)
        self.inputs = {"X": convert_float_to_uint16(x)}
        self.init_attrs()
        self.outputs = {"Out": convert_float_to_uint16(out)}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(["X"], "Out")

    def init_test_case(self):
        self.ori_shape = (3, 40)
        self.axes = (1, 2)
        self.new_shape = (3, 1, 1, 40)

    def init_attrs(self):
        self.attrs = {"axes": self.axes}


80 81
# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
82

83
    def init_test_case(self):
Z
zhupengyang 已提交
84
        self.ori_shape = (20, 5)
85
        self.axes = (-1, )
Z
zhupengyang 已提交
86
        self.new_shape = (20, 5, 1)
87 88 89


# Correct: Mixed input axis.
90
class TestUnsqueezeOp2(TestUnsqueezeOp):
91

92
    def init_test_case(self):
Z
zhupengyang 已提交
93
        self.ori_shape = (20, 5)
94
        self.axes = (0, -1)
Z
zhupengyang 已提交
95
        self.new_shape = (1, 20, 5, 1)
96 97


98
# Correct: There is duplicated axis.
99
class TestUnsqueezeOp3(TestUnsqueezeOp):
100

101
    def init_test_case(self):
Z
zhupengyang 已提交
102
        self.ori_shape = (10, 2, 5)
103
        self.axes = (0, 3, 3)
Z
zhupengyang 已提交
104
        self.new_shape = (1, 10, 2, 1, 1, 5)
105 106


107 108
# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
109

110
    def init_test_case(self):
Z
zhupengyang 已提交
111
        self.ori_shape = (10, 2, 5)
112
        self.axes = (3, 1, 1)
Z
zhupengyang 已提交
113
        self.new_shape = (10, 1, 1, 2, 5, 1)
114 115


116
class API_TestUnsqueeze(unittest.TestCase):
117

118
    def test_out(self):
119 120 121 122
        paddle.enable_static()
        with paddle.static.program_guard(paddle.static.Program(),
                                         paddle.static.Program()):
            data1 = paddle.static.data('data1', shape=[-1, 10], dtype='float64')
123
            result_squeeze = paddle.unsqueeze(data1, axis=[1])
124 125
            place = paddle.CPUPlace()
            exe = paddle.static.Executor(place)
126 127 128 129
            input1 = np.random.random([5, 1, 10]).astype('float64')
            input = np.squeeze(input1, axis=1)
            result, = exe.run(feed={"data1": input},
                              fetch_list=[result_squeeze])
130
            np.testing.assert_allclose(input1, result, rtol=1e-05)
131 132 133


class TestUnsqueezeOpError(unittest.TestCase):
134

135
    def test_errors(self):
136 137 138
        paddle.enable_static()
        with paddle.static.program_guard(paddle.static.Program(),
                                         paddle.static.Program()):
139 140
            # The type of axis in split_op should be int or Variable.
            def test_axes_type():
141 142 143
                x6 = paddle.static.data(shape=[-1, 10],
                                        dtype='float16',
                                        name='x3')
144
                paddle.unsqueeze(x6, axis=3.2)
145 146 147 148 149

            self.assertRaises(TypeError, test_axes_type)


class API_TestUnsqueeze2(unittest.TestCase):
150

151
    def test_out(self):
152 153 154 155 156
        paddle.enable_static()
        with paddle.static.program_guard(paddle.static.Program(),
                                         paddle.static.Program()):
            data1 = paddle.static.data('data1', shape=[-1, 10], dtype='float64')
            data2 = paddle.static.data('data2', shape=[1], dtype='int32')
157
            result_squeeze = paddle.unsqueeze(data1, axis=data2)
158 159
            place = paddle.CPUPlace()
            exe = paddle.static.Executor(place)
160 161 162
            input1 = np.random.random([5, 1, 10]).astype('float64')
            input2 = np.array([1]).astype('int32')
            input = np.squeeze(input1, axis=1)
163 164 165 166
            result1, = exe.run(feed={
                "data1": input,
                "data2": input2
            },
167
                               fetch_list=[result_squeeze])
168
            np.testing.assert_allclose(input1, result1, rtol=1e-05)
169 170 171


class API_TestUnsqueeze3(unittest.TestCase):
172

173
    def test_out(self):
174 175 176 177 178
        paddle.enable_static()
        with paddle.static.program_guard(paddle.static.Program(),
                                         paddle.static.Program()):
            data1 = paddle.static.data('data1', shape=[-1, 10], dtype='float64')
            data2 = paddle.static.data('data2', shape=[1], dtype='int32')
179
            result_squeeze = paddle.unsqueeze(data1, axis=[data2, 3])
180 181
            place = paddle.CPUPlace()
            exe = paddle.static.Executor(place)
182 183 184
            input1 = np.random.random([5, 1, 10, 1]).astype('float64')
            input2 = np.array([1]).astype('int32')
            input = np.squeeze(input1)
185 186 187 188
            result1, = exe.run(feed={
                "data1": input,
                "data2": input2
            },
189
                               fetch_list=[result_squeeze])
190
            np.testing.assert_array_equal(input1, result1)
L
Leo Chen 已提交
191
            self.assertEqual(input1.shape, result1.shape)
192 193 194


class API_TestDyUnsqueeze(unittest.TestCase):
195

196
    def test_out(self):
197 198 199 200 201 202
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("int32")
        input1 = np.expand_dims(input_1, axis=1)
        input = paddle.to_tensor(input_1)
        output = paddle.unsqueeze(input, axis=[1])
        out_np = output.numpy()
203
        np.testing.assert_array_equal(input1, out_np)
204
        self.assertEqual(input1.shape, out_np.shape)
205 206 207


class API_TestDyUnsqueeze2(unittest.TestCase):
208

209
    def test_out(self):
210 211 212 213 214 215
        paddle.disable_static()
        input1 = np.random.random([5, 10]).astype("int32")
        out1 = np.expand_dims(input1, axis=1)
        input = paddle.to_tensor(input1)
        output = paddle.unsqueeze(input, axis=1)
        out_np = output.numpy()
216
        np.testing.assert_array_equal(out1, out_np)
217
        self.assertEqual(out1.shape, out_np.shape)
L
Leo Chen 已提交
218 219 220


class API_TestDyUnsqueezeAxisTensor(unittest.TestCase):
221

L
Leo Chen 已提交
222
    def test_out(self):
223 224 225 226 227 228 229
        paddle.disable_static()
        input1 = np.random.random([5, 10]).astype("int32")
        out1 = np.expand_dims(input1, axis=1)
        out1 = np.expand_dims(out1, axis=2)
        input = paddle.to_tensor(input1)
        output = paddle.unsqueeze(input, axis=paddle.to_tensor([1, 2]))
        out_np = output.numpy()
230
        np.testing.assert_array_equal(out1, out_np)
231
        self.assertEqual(out1.shape, out_np.shape)
L
Leo Chen 已提交
232 233 234


class API_TestDyUnsqueezeAxisTensorList(unittest.TestCase):
235

L
Leo Chen 已提交
236
    def test_out(self):
237 238 239 240 241 242 243 244
        paddle.disable_static()
        input1 = np.random.random([5, 10]).astype("int32")
        # Actually, expand_dims supports tuple since version 1.18.0
        out1 = np.expand_dims(input1, axis=1)
        out1 = np.expand_dims(out1, axis=2)
        input = paddle.to_tensor(input1)
        output = paddle.unsqueeze(
            paddle.to_tensor(input1),
245 246
            axis=[paddle.to_tensor([1]),
                  paddle.to_tensor([2])])
247
        out_np = output.numpy()
248
        np.testing.assert_array_equal(out1, out_np)
249 250 251 252
        self.assertEqual(out1.shape, out_np.shape)


class API_TestDygraphUnSqueeze(unittest.TestCase):
253

254 255 256 257 258 259
    def setUp(self):
        self.executed_api()

    def executed_api(self):
        self.unsqueeze = paddle.unsqueeze

260 261 262 263
    def test_out(self):
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("int32")
        input = paddle.to_tensor(input_1)
264
        output = self.unsqueeze(input, axis=[1])
265 266
        out_np = output.numpy()
        expected_out = np.expand_dims(input_1, axis=1)
267
        np.testing.assert_allclose(expected_out, out_np, rtol=1e-05)
268 269 270 271 272

    def test_out_int8(self):
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("int8")
        input = paddle.to_tensor(input_1)
273
        output = self.unsqueeze(input, axis=[1])
274 275
        out_np = output.numpy()
        expected_out = np.expand_dims(input_1, axis=1)
276
        np.testing.assert_allclose(expected_out, out_np, rtol=1e-05)
277 278 279 280 281

    def test_out_uint8(self):
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("uint8")
        input = paddle.to_tensor(input_1)
282
        output = self.unsqueeze(input, axis=1)
283 284
        out_np = output.numpy()
        expected_out = np.expand_dims(input_1, axis=1)
285
        np.testing.assert_allclose(expected_out, out_np, rtol=1e-05)
286 287 288 289 290

    def test_axis_not_list(self):
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("int32")
        input = paddle.to_tensor(input_1)
291
        output = self.unsqueeze(input, axis=1)
292 293
        out_np = output.numpy()
        expected_out = np.expand_dims(input_1, axis=1)
294
        np.testing.assert_allclose(expected_out, out_np, rtol=1e-05)
295 296 297 298 299

    def test_dimension_not_1(self):
        paddle.disable_static()
        input_1 = np.random.random([5, 1, 10]).astype("int32")
        input = paddle.to_tensor(input_1)
300
        output = self.unsqueeze(input, axis=(1, 2))
301
        out_np = output.numpy()
302 303
        expected_out = np.expand_dims(input_1, axis=(1, 2))
        np.testing.assert_allclose(expected_out, out_np, rtol=1e-05)
304 305


306
class API_TestDygraphUnSqueezeInplace(API_TestDygraphUnSqueeze):
307

308 309 310 311
    def executed_api(self):
        self.unsqueeze = paddle.unsqueeze_


312 313
if __name__ == "__main__":
    unittest.main()