test_unsqueeze2_op.py 9.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
W
wanghuancoder 已提交
18
from eager_op_test import OpTest
19 20 21

import paddle

22
paddle.enable_static()
23 24 25 26 27 28 29


# Correct: General.
class TestUnsqueezeOp(OpTest):
    def setUp(self):
        self.init_test_case()
        self.op_type = "unsqueeze2"
30
        self.python_api = paddle.unsqueeze
31
        self.public_python_api = paddle.unsqueeze
32
        self.python_out_sig = ["Out"]
33
        self.inputs = {"X": np.random.random(self.ori_shape).astype("float64")}
34 35 36
        self.init_attrs()
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.new_shape),
37
            "XShape": np.random.random(self.ori_shape).astype("float64"),
38
        }
39
        self.prim_op_type = "comp"
40 41

    def test_check_output(self):
W
wanghuancoder 已提交
42
        self.check_output(no_check_set=["XShape"], check_prim=True)
43 44

    def test_check_grad(self):
W
wanghuancoder 已提交
45
        self.check_grad(["X"], "Out")
46 47

    def init_test_case(self):
Z
zhupengyang 已提交
48
        self.ori_shape = (3, 40)
49
        self.axes = (1, 2)
Z
zhupengyang 已提交
50
        self.new_shape = (3, 1, 1, 40)
51 52 53 54 55 56 57 58

    def init_attrs(self):
        self.attrs = {"axes": self.axes}


# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
    def init_test_case(self):
Z
zhupengyang 已提交
59
        self.ori_shape = (20, 5)
60
        self.axes = (-1,)
Z
zhupengyang 已提交
61
        self.new_shape = (20, 5, 1)
62 63 64 65 66


# Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp):
    def init_test_case(self):
Z
zhupengyang 已提交
67
        self.ori_shape = (20, 5)
68
        self.axes = (0, -1)
Z
zhupengyang 已提交
69
        self.new_shape = (1, 20, 5, 1)
70 71 72 73 74


# Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp):
    def init_test_case(self):
Z
zhupengyang 已提交
75
        self.ori_shape = (10, 2, 5)
76
        self.axes = (0, 3, 3)
Z
zhupengyang 已提交
77
        self.new_shape = (1, 10, 2, 1, 1, 5)
78 79 80 81 82


# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
    def init_test_case(self):
Z
zhupengyang 已提交
83
        self.ori_shape = (10, 2, 5)
84
        self.axes = (3, 1, 1)
Z
zhupengyang 已提交
85
        self.new_shape = (10, 1, 1, 2, 5, 1)
86 87


88 89 90
class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp):
    def init_test_case(self):
        self.ori_shape = ()
91 92
        self.axes = (-1,)
        self.new_shape = 1
93
        self.enable_cinn = False
94 95 96 97 98 99 100


class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp):
    def init_test_case(self):
        self.ori_shape = ()
        self.axes = (-1, 1)
        self.new_shape = (1, 1)
101
        self.enable_cinn = False
102 103 104 105 106 107 108


class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp):
    def init_test_case(self):
        self.ori_shape = ()
        self.axes = (0, 1, 2)
        self.new_shape = (1, 1, 1)
109
        self.enable_cinn = False
110 111


112 113 114 115 116
# axes is a list(with tensor)
class TestUnsqueezeOp_AxesTensorList(OpTest):
    def setUp(self):
        self.init_test_case()
        self.op_type = "unsqueeze2"
117 118
        self.python_out_sig = ["Out"]
        self.python_api = paddle.unsqueeze
119 120 121

        axes_tensor_list = []
        for index, ele in enumerate(self.axes):
122
            axes_tensor_list.append(
123
                ("axes" + str(index), np.ones(1).astype('int32') * ele)
124
            )
125 126

        self.inputs = {
127
            "X": np.random.random(self.ori_shape).astype("float64"),
128
            "AxesTensorList": axes_tensor_list,
129 130 131 132
        }
        self.init_attrs()
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.new_shape),
133
            "XShape": np.random.random(self.ori_shape).astype("float64"),
134 135 136
        }

    def test_check_output(self):
W
wanghuancoder 已提交
137
        self.check_output(no_check_set=["XShape"])
138 139

    def test_check_grad(self):
W
wanghuancoder 已提交
140
        self.check_grad(["X"], "Out")
141 142

    def init_test_case(self):
Z
zhupengyang 已提交
143
        self.ori_shape = (20, 5)
144
        self.axes = (1, 2)
Z
zhupengyang 已提交
145
        self.new_shape = (20, 1, 1, 5)
146 147 148 149 150 151 152

    def init_attrs(self):
        self.attrs = {}


class TestUnsqueezeOp1_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
    def init_test_case(self):
Z
zhupengyang 已提交
153
        self.ori_shape = (20, 5)
154
        self.axes = (-1,)
Z
zhupengyang 已提交
155
        self.new_shape = (20, 5, 1)
156 157 158 159


class TestUnsqueezeOp2_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
    def init_test_case(self):
Z
zhupengyang 已提交
160
        self.ori_shape = (20, 5)
161
        self.axes = (0, -1)
Z
zhupengyang 已提交
162
        self.new_shape = (1, 20, 5, 1)
163 164 165 166


class TestUnsqueezeOp3_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
    def init_test_case(self):
Z
zhupengyang 已提交
167
        self.ori_shape = (10, 2, 5)
168
        self.axes = (0, 3, 3)
Z
zhupengyang 已提交
169
        self.new_shape = (1, 10, 2, 1, 1, 5)
170 171 172 173


class TestUnsqueezeOp4_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
    def init_test_case(self):
Z
zhupengyang 已提交
174
        self.ori_shape = (10, 2, 5)
175
        self.axes = (3, 1, 1)
Z
zhupengyang 已提交
176
        self.new_shape = (10, 1, 1, 2, 5, 1)
177 178 179 180 181 182 183


# axes is a Tensor
class TestUnsqueezeOp_AxesTensor(OpTest):
    def setUp(self):
        self.init_test_case()
        self.op_type = "unsqueeze2"
184 185
        self.python_out_sig = ["Out"]
        self.python_api = paddle.unsqueeze
186 187

        self.inputs = {
188
            "X": np.random.random(self.ori_shape).astype("float64"),
189
            "AxesTensor": np.array(self.axes).astype("int32"),
190 191 192 193
        }
        self.init_attrs()
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.new_shape),
194
            "XShape": np.random.random(self.ori_shape).astype("float64"),
195 196 197
        }

    def test_check_output(self):
W
wanghuancoder 已提交
198
        self.check_output(no_check_set=["XShape"])
199 200

    def test_check_grad(self):
W
wanghuancoder 已提交
201
        self.check_grad(["X"], "Out")
202 203

    def init_test_case(self):
Z
zhupengyang 已提交
204
        self.ori_shape = (20, 5)
205
        self.axes = (1, 2)
Z
zhupengyang 已提交
206
        self.new_shape = (20, 1, 1, 5)
207 208 209 210 211 212 213

    def init_attrs(self):
        self.attrs = {}


class TestUnsqueezeOp1_AxesTensor(TestUnsqueezeOp_AxesTensor):
    def init_test_case(self):
Z
zhupengyang 已提交
214
        self.ori_shape = (20, 5)
215
        self.axes = (-1,)
Z
zhupengyang 已提交
216
        self.new_shape = (20, 5, 1)
217 218 219 220


class TestUnsqueezeOp2_AxesTensor(TestUnsqueezeOp_AxesTensor):
    def init_test_case(self):
Z
zhupengyang 已提交
221
        self.ori_shape = (20, 5)
222
        self.axes = (0, -1)
Z
zhupengyang 已提交
223
        self.new_shape = (1, 20, 5, 1)
224 225 226 227


class TestUnsqueezeOp3_AxesTensor(TestUnsqueezeOp_AxesTensor):
    def init_test_case(self):
Z
zhupengyang 已提交
228
        self.ori_shape = (10, 2, 5)
229
        self.axes = (0, 3, 3)
Z
zhupengyang 已提交
230
        self.new_shape = (1, 10, 2, 1, 1, 5)
231 232 233 234


class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor):
    def init_test_case(self):
Z
zhupengyang 已提交
235
        self.ori_shape = (10, 2, 5)
236
        self.axes = (3, 1, 1)
Z
zhupengyang 已提交
237
        self.new_shape = (10, 1, 1, 2, 5, 1)
238 239 240


# test api
241
class TestUnsqueezeAPI(unittest.TestCase):
242 243 244 245 246 247
    def setUp(self):
        self.executed_api()

    def executed_api(self):
        self.unsqueeze = paddle.unsqueeze

248
    def test_api(self):
249
        input = np.random.random([3, 2, 5]).astype("float64")
250
        x = paddle.static.data(name='x', shape=[3, 2, 5], dtype="float64")
251 252
        positive_3_int32 = paddle.tensor.fill_constant([1], "int32", 3)
        positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1)
253 254 255 256 257 258
        axes_tensor_int32 = paddle.static.data(
            name='axes_tensor_int32', shape=[3], dtype="int32"
        )
        axes_tensor_int64 = paddle.static.data(
            name='axes_tensor_int64', shape=[3], dtype="int64"
        )
259

260 261 262 263 264
        out_1 = self.unsqueeze(x, axis=[3, 1, 1])
        out_2 = self.unsqueeze(x, axis=[positive_3_int32, positive_1_int64, 1])
        out_3 = self.unsqueeze(x, axis=axes_tensor_int32)
        out_4 = self.unsqueeze(x, axis=3)
        out_5 = self.unsqueeze(x, axis=axes_tensor_int64)
265

266
        exe = paddle.static.Executor(place=paddle.CPUPlace())
267
        res_1, res_2, res_3, res_4, res_5 = exe.run(
268
            paddle.static.default_main_program(),
269 270
            feed={
                "x": input,
271
                "axes_tensor_int32": np.array([3, 1, 1]).astype("int32"),
272
                "axes_tensor_int64": np.array([3, 1, 1]).astype("int64"),
273
            },
274 275
            fetch_list=[out_1, out_2, out_3, out_4, out_5],
        )
276 277 278 279 280

        assert np.array_equal(res_1, input.reshape([3, 1, 1, 2, 5, 1]))
        assert np.array_equal(res_2, input.reshape([3, 1, 1, 2, 5, 1]))
        assert np.array_equal(res_3, input.reshape([3, 1, 1, 2, 5, 1]))
        assert np.array_equal(res_4, input.reshape([3, 2, 5, 1]))
281
        assert np.array_equal(res_5, input.reshape([3, 1, 1, 2, 5, 1]))
282 283 284

    def test_error(self):
        def test_axes_type():
285
            x2 = paddle.static.data(name="x2", shape=[2, 25], dtype="int32")
286
            self.unsqueeze(x2, axis=2.1)
287 288 289 290

        self.assertRaises(TypeError, test_axes_type)


291 292 293 294 295
class TestUnsqueezeInplaceAPI(TestUnsqueezeAPI):
    def executed_api(self):
        self.unsqueeze = paddle.unsqueeze_


296 297 298 299 300 301 302 303
class TestUnsqueezeAPI_ZeroDim(unittest.TestCase):
    def test_dygraph(self):
        paddle.disable_static()

        x = paddle.rand([])
        x.stop_gradient = False

        out = paddle.unsqueeze(x, [-1])
304
        out.retain_grads()
305 306 307 308 309 310
        out.backward()
        self.assertEqual(out.shape, [1])
        self.assertEqual(x.grad.shape, [])
        self.assertEqual(out.grad.shape, [1])

        out = paddle.unsqueeze(x, [-1, 1])
311
        out.retain_grads()
312 313 314 315 316 317
        out.backward()
        self.assertEqual(out.shape, [1, 1])
        self.assertEqual(x.grad.shape, [])
        self.assertEqual(out.grad.shape, [1, 1])

        out = paddle.unsqueeze(x, [0, 1, 2])
318
        out.retain_grads()
319 320 321 322 323 324 325 326
        out.backward()
        self.assertEqual(out.shape, [1, 1, 1])
        self.assertEqual(x.grad.shape, [])
        self.assertEqual(out.grad.shape, [1, 1, 1])

        paddle.enable_static()


327 328
if __name__ == "__main__":
    unittest.main()