test_unsqueeze2_op.py 9.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
18 19

import paddle
20
import paddle.fluid as fluid
21
from op_test import OpTest
22

23
paddle.enable_static()
24 25 26 27


# Correct: General.
class TestUnsqueezeOp(OpTest):
28

29 30 31
    def setUp(self):
        self.init_test_case()
        self.op_type = "unsqueeze2"
32 33
        self.python_api = paddle.unsqueeze
        self.python_out_sig = ["Out"]
34
        self.inputs = {"X": np.random.random(self.ori_shape).astype("float64")}
35 36 37
        self.init_attrs()
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.new_shape),
38
            "XShape": np.random.random(self.ori_shape).astype("float64")
39 40 41
        }

    def test_check_output(self):
42
        self.check_output(no_check_set=["XShape"], check_eager=True)
43 44

    def test_check_grad(self):
45
        self.check_grad(["X"], "Out", check_eager=True)
46 47

    def init_test_case(self):
Z
zhupengyang 已提交
48
        self.ori_shape = (3, 40)
49
        self.axes = (1, 2)
Z
zhupengyang 已提交
50
        self.new_shape = (3, 1, 1, 40)
51 52 53 54 55 56 57

    def init_attrs(self):
        self.attrs = {"axes": self.axes}


# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
58

59
    def init_test_case(self):
Z
zhupengyang 已提交
60
        self.ori_shape = (20, 5)
61
        self.axes = (-1, )
Z
zhupengyang 已提交
62
        self.new_shape = (20, 5, 1)
63 64 65 66


# Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp):
67

68
    def init_test_case(self):
Z
zhupengyang 已提交
69
        self.ori_shape = (20, 5)
70
        self.axes = (0, -1)
Z
zhupengyang 已提交
71
        self.new_shape = (1, 20, 5, 1)
72 73 74 75


# Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp):
76

77
    def init_test_case(self):
Z
zhupengyang 已提交
78
        self.ori_shape = (10, 2, 5)
79
        self.axes = (0, 3, 3)
Z
zhupengyang 已提交
80
        self.new_shape = (1, 10, 2, 1, 1, 5)
81 82 83 84


# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
85

86
    def init_test_case(self):
Z
zhupengyang 已提交
87
        self.ori_shape = (10, 2, 5)
88
        self.axes = (3, 1, 1)
Z
zhupengyang 已提交
89
        self.new_shape = (10, 1, 1, 2, 5, 1)
90 91


92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp):

    def init_test_case(self):
        self.ori_shape = ()
        self.axes = (-1, )
        self.new_shape = (1)


class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp):

    def init_test_case(self):
        self.ori_shape = ()
        self.axes = (-1, 1)
        self.new_shape = (1, 1)


class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp):

    def init_test_case(self):
        self.ori_shape = ()
        self.axes = (0, 1, 2)
        self.new_shape = (1, 1, 1)


116 117
# axes is a list(with tensor)
class TestUnsqueezeOp_AxesTensorList(OpTest):
118

119 120 121
    def setUp(self):
        self.init_test_case()
        self.op_type = "unsqueeze2"
122 123
        self.python_out_sig = ["Out"]
        self.python_api = paddle.unsqueeze
124 125 126 127 128 129 130

        axes_tensor_list = []
        for index, ele in enumerate(self.axes):
            axes_tensor_list.append(("axes" + str(index), np.ones(
                (1)).astype('int32') * ele))

        self.inputs = {
131
            "X": np.random.random(self.ori_shape).astype("float64"),
132 133 134 135 136
            "AxesTensorList": axes_tensor_list
        }
        self.init_attrs()
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.new_shape),
137
            "XShape": np.random.random(self.ori_shape).astype("float64")
138 139 140
        }

    def test_check_output(self):
141
        self.check_output(no_check_set=["XShape"], check_eager=True)
142 143

    def test_check_grad(self):
144
        self.check_grad(["X"], "Out", check_eager=True)
145 146

    def init_test_case(self):
Z
zhupengyang 已提交
147
        self.ori_shape = (20, 5)
148
        self.axes = (1, 2)
Z
zhupengyang 已提交
149
        self.new_shape = (20, 1, 1, 5)
150 151 152 153 154 155

    def init_attrs(self):
        self.attrs = {}


class TestUnsqueezeOp1_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
156

157
    def init_test_case(self):
Z
zhupengyang 已提交
158
        self.ori_shape = (20, 5)
159
        self.axes = (-1, )
Z
zhupengyang 已提交
160
        self.new_shape = (20, 5, 1)
161 162 163


class TestUnsqueezeOp2_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
164

165
    def init_test_case(self):
Z
zhupengyang 已提交
166
        self.ori_shape = (20, 5)
167
        self.axes = (0, -1)
Z
zhupengyang 已提交
168
        self.new_shape = (1, 20, 5, 1)
169 170 171


class TestUnsqueezeOp3_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
172

173
    def init_test_case(self):
Z
zhupengyang 已提交
174
        self.ori_shape = (10, 2, 5)
175
        self.axes = (0, 3, 3)
Z
zhupengyang 已提交
176
        self.new_shape = (1, 10, 2, 1, 1, 5)
177 178 179


class TestUnsqueezeOp4_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
180

181
    def init_test_case(self):
Z
zhupengyang 已提交
182
        self.ori_shape = (10, 2, 5)
183
        self.axes = (3, 1, 1)
Z
zhupengyang 已提交
184
        self.new_shape = (10, 1, 1, 2, 5, 1)
185 186 187 188


# axes is a Tensor
class TestUnsqueezeOp_AxesTensor(OpTest):
189

190 191 192
    def setUp(self):
        self.init_test_case()
        self.op_type = "unsqueeze2"
193 194
        self.python_out_sig = ["Out"]
        self.python_api = paddle.unsqueeze
195 196

        self.inputs = {
197
            "X": np.random.random(self.ori_shape).astype("float64"),
198 199 200 201 202
            "AxesTensor": np.array(self.axes).astype("int32")
        }
        self.init_attrs()
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.new_shape),
203
            "XShape": np.random.random(self.ori_shape).astype("float64")
204 205 206
        }

    def test_check_output(self):
207
        self.check_output(no_check_set=["XShape"], check_eager=True)
208 209

    def test_check_grad(self):
210
        self.check_grad(["X"], "Out", check_eager=True)
211 212

    def init_test_case(self):
Z
zhupengyang 已提交
213
        self.ori_shape = (20, 5)
214
        self.axes = (1, 2)
Z
zhupengyang 已提交
215
        self.new_shape = (20, 1, 1, 5)
216 217 218 219 220 221

    def init_attrs(self):
        self.attrs = {}


class TestUnsqueezeOp1_AxesTensor(TestUnsqueezeOp_AxesTensor):
222

223
    def init_test_case(self):
Z
zhupengyang 已提交
224
        self.ori_shape = (20, 5)
225
        self.axes = (-1, )
Z
zhupengyang 已提交
226
        self.new_shape = (20, 5, 1)
227 228 229


class TestUnsqueezeOp2_AxesTensor(TestUnsqueezeOp_AxesTensor):
230

231
    def init_test_case(self):
Z
zhupengyang 已提交
232
        self.ori_shape = (20, 5)
233
        self.axes = (0, -1)
Z
zhupengyang 已提交
234
        self.new_shape = (1, 20, 5, 1)
235 236 237


class TestUnsqueezeOp3_AxesTensor(TestUnsqueezeOp_AxesTensor):
238

239
    def init_test_case(self):
Z
zhupengyang 已提交
240
        self.ori_shape = (10, 2, 5)
241
        self.axes = (0, 3, 3)
Z
zhupengyang 已提交
242
        self.new_shape = (1, 10, 2, 1, 1, 5)
243 244 245


class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor):
246

247
    def init_test_case(self):
Z
zhupengyang 已提交
248
        self.ori_shape = (10, 2, 5)
249
        self.axes = (3, 1, 1)
Z
zhupengyang 已提交
250
        self.new_shape = (10, 1, 1, 2, 5, 1)
251 252 253


# test api
254
class TestUnsqueezeAPI(unittest.TestCase):
255

256 257 258 259 260 261
    def setUp(self):
        self.executed_api()

    def executed_api(self):
        self.unsqueeze = paddle.unsqueeze

262
    def test_api(self):
263
        input = np.random.random([3, 2, 5]).astype("float64")
264
        x = paddle.static.data(name='x', shape=[3, 2, 5], dtype="float64")
265 266
        positive_3_int32 = fluid.layers.fill_constant([1], "int32", 3)
        positive_1_int64 = fluid.layers.fill_constant([1], "int64", 1)
267 268 269 270 271 272
        axes_tensor_int32 = paddle.static.data(name='axes_tensor_int32',
                                               shape=[3],
                                               dtype="int32")
        axes_tensor_int64 = paddle.static.data(name='axes_tensor_int64',
                                               shape=[3],
                                               dtype="int64")
273

274 275 276 277 278
        out_1 = self.unsqueeze(x, axis=[3, 1, 1])
        out_2 = self.unsqueeze(x, axis=[positive_3_int32, positive_1_int64, 1])
        out_3 = self.unsqueeze(x, axis=axes_tensor_int32)
        out_4 = self.unsqueeze(x, axis=3)
        out_5 = self.unsqueeze(x, axis=axes_tensor_int64)
279

280
        exe = paddle.static.Executor(place=paddle.CPUPlace())
281
        res_1, res_2, res_3, res_4, res_5 = exe.run(
282
            paddle.static.default_main_program(),
283 284
            feed={
                "x": input,
285 286
                "axes_tensor_int32": np.array([3, 1, 1]).astype("int32"),
                "axes_tensor_int64": np.array([3, 1, 1]).astype("int64")
287
            },
288
            fetch_list=[out_1, out_2, out_3, out_4, out_5])
289 290 291 292 293

        assert np.array_equal(res_1, input.reshape([3, 1, 1, 2, 5, 1]))
        assert np.array_equal(res_2, input.reshape([3, 1, 1, 2, 5, 1]))
        assert np.array_equal(res_3, input.reshape([3, 1, 1, 2, 5, 1]))
        assert np.array_equal(res_4, input.reshape([3, 2, 5, 1]))
294
        assert np.array_equal(res_5, input.reshape([3, 1, 1, 2, 5, 1]))
295 296

    def test_error(self):
297

298
        def test_axes_type():
299
            x2 = paddle.static.data(name="x2", shape=[2, 25], dtype="int32")
300
            self.unsqueeze(x2, axis=2.1)
301 302 303 304

        self.assertRaises(TypeError, test_axes_type)


305
class TestUnsqueezeInplaceAPI(TestUnsqueezeAPI):
306

307 308 309 310
    def executed_api(self):
        self.unsqueeze = paddle.unsqueeze_


311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
class TestUnsqueezeAPI_ZeroDim(unittest.TestCase):

    def test_dygraph(self):
        paddle.disable_static()
        fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})

        x = paddle.rand([])
        x.stop_gradient = False

        out = paddle.unsqueeze(x, [-1])
        out.backward()
        self.assertEqual(out.shape, [1])
        self.assertEqual(x.grad.shape, [])
        self.assertEqual(out.grad.shape, [1])

        out = paddle.unsqueeze(x, [-1, 1])
        out.backward()
        self.assertEqual(out.shape, [1, 1])
        self.assertEqual(x.grad.shape, [])
        self.assertEqual(out.grad.shape, [1, 1])

        out = paddle.unsqueeze(x, [0, 1, 2])
        out.backward()
        self.assertEqual(out.shape, [1, 1, 1])
        self.assertEqual(x.grad.shape, [])
        self.assertEqual(out.grad.shape, [1, 1, 1])

        paddle.enable_static()


341 342
if __name__ == "__main__":
    unittest.main()