test_unsqueeze2_op.py 9.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
W
wanghuancoder 已提交
18
from eager_op_test import OpTest
19 20 21

import paddle

22
paddle.enable_static()
23 24 25 26 27 28 29


# Correct: General.
class TestUnsqueezeOp(OpTest):
    def setUp(self):
        self.init_test_case()
        self.op_type = "unsqueeze2"
30
        self.python_api = paddle.unsqueeze
31
        self.public_python_api = paddle.unsqueeze
32
        self.python_out_sig = ["Out"]
33
        self.inputs = {"X": np.random.random(self.ori_shape).astype("float64")}
34 35 36
        self.init_attrs()
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.new_shape),
37
            "XShape": np.random.random(self.ori_shape).astype("float64"),
38
        }
39
        self.prim_op_type = "comp"
40 41 42 43
        self.if_enable_cinn()

    def if_enable_cinn(self):
        pass
44 45

    def test_check_output(self):
W
wanghuancoder 已提交
46
        self.check_output(no_check_set=["XShape"], check_prim=True)
47 48

    def test_check_grad(self):
49
        self.check_grad(["X"], "Out", check_prim=True)
50 51

    def init_test_case(self):
Z
zhupengyang 已提交
52
        self.ori_shape = (3, 40)
53
        self.axes = (1, 2)
Z
zhupengyang 已提交
54
        self.new_shape = (3, 1, 1, 40)
55 56 57 58 59 60 61 62

    def init_attrs(self):
        self.attrs = {"axes": self.axes}


# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
    def init_test_case(self):
Z
zhupengyang 已提交
63
        self.ori_shape = (20, 5)
64
        self.axes = (-1,)
Z
zhupengyang 已提交
65
        self.new_shape = (20, 5, 1)
66 67 68 69 70


# Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp):
    def init_test_case(self):
Z
zhupengyang 已提交
71
        self.ori_shape = (20, 5)
72
        self.axes = (0, -1)
Z
zhupengyang 已提交
73
        self.new_shape = (1, 20, 5, 1)
74 75 76 77 78


# Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp):
    def init_test_case(self):
Z
zhupengyang 已提交
79
        self.ori_shape = (10, 2, 5)
80
        self.axes = (0, 3, 3)
Z
zhupengyang 已提交
81
        self.new_shape = (1, 10, 2, 1, 1, 5)
82 83 84 85 86


# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
    def init_test_case(self):
Z
zhupengyang 已提交
87
        self.ori_shape = (10, 2, 5)
88
        self.axes = (3, 1, 1)
Z
zhupengyang 已提交
89
        self.new_shape = (10, 1, 1, 2, 5, 1)
90 91


92 93 94
class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp):
    def init_test_case(self):
        self.ori_shape = ()
95 96
        self.axes = (-1,)
        self.new_shape = 1
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112


class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp):
    def init_test_case(self):
        self.ori_shape = ()
        self.axes = (-1, 1)
        self.new_shape = (1, 1)


class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp):
    def init_test_case(self):
        self.ori_shape = ()
        self.axes = (0, 1, 2)
        self.new_shape = (1, 1, 1)


113 114 115 116 117
# axes is a list(with tensor)
class TestUnsqueezeOp_AxesTensorList(OpTest):
    def setUp(self):
        self.init_test_case()
        self.op_type = "unsqueeze2"
118 119
        self.python_out_sig = ["Out"]
        self.python_api = paddle.unsqueeze
120 121 122

        axes_tensor_list = []
        for index, ele in enumerate(self.axes):
123
            axes_tensor_list.append(
124
                ("axes" + str(index), np.ones(1).astype('int32') * ele)
125
            )
126 127

        self.inputs = {
128
            "X": np.random.random(self.ori_shape).astype("float64"),
129
            "AxesTensorList": axes_tensor_list,
130 131 132 133
        }
        self.init_attrs()
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.new_shape),
134
            "XShape": np.random.random(self.ori_shape).astype("float64"),
135 136 137
        }

    def test_check_output(self):
W
wanghuancoder 已提交
138
        self.check_output(no_check_set=["XShape"])
139 140

    def test_check_grad(self):
W
wanghuancoder 已提交
141
        self.check_grad(["X"], "Out")
142 143

    def init_test_case(self):
Z
zhupengyang 已提交
144
        self.ori_shape = (20, 5)
145
        self.axes = (1, 2)
Z
zhupengyang 已提交
146
        self.new_shape = (20, 1, 1, 5)
147 148 149 150 151 152 153

    def init_attrs(self):
        self.attrs = {}


class TestUnsqueezeOp1_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
    def init_test_case(self):
Z
zhupengyang 已提交
154
        self.ori_shape = (20, 5)
155
        self.axes = (-1,)
Z
zhupengyang 已提交
156
        self.new_shape = (20, 5, 1)
157 158 159 160


class TestUnsqueezeOp2_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
    def init_test_case(self):
Z
zhupengyang 已提交
161
        self.ori_shape = (20, 5)
162
        self.axes = (0, -1)
Z
zhupengyang 已提交
163
        self.new_shape = (1, 20, 5, 1)
164 165 166 167


class TestUnsqueezeOp3_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
    def init_test_case(self):
Z
zhupengyang 已提交
168
        self.ori_shape = (10, 2, 5)
169
        self.axes = (0, 3, 3)
Z
zhupengyang 已提交
170
        self.new_shape = (1, 10, 2, 1, 1, 5)
171 172 173 174


class TestUnsqueezeOp4_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
    def init_test_case(self):
Z
zhupengyang 已提交
175
        self.ori_shape = (10, 2, 5)
176
        self.axes = (3, 1, 1)
Z
zhupengyang 已提交
177
        self.new_shape = (10, 1, 1, 2, 5, 1)
178 179 180 181 182 183 184


# axes is a Tensor
class TestUnsqueezeOp_AxesTensor(OpTest):
    def setUp(self):
        self.init_test_case()
        self.op_type = "unsqueeze2"
185 186
        self.python_out_sig = ["Out"]
        self.python_api = paddle.unsqueeze
187 188

        self.inputs = {
189
            "X": np.random.random(self.ori_shape).astype("float64"),
190
            "AxesTensor": np.array(self.axes).astype("int32"),
191 192 193 194
        }
        self.init_attrs()
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.new_shape),
195
            "XShape": np.random.random(self.ori_shape).astype("float64"),
196 197 198
        }

    def test_check_output(self):
W
wanghuancoder 已提交
199
        self.check_output(no_check_set=["XShape"])
200 201

    def test_check_grad(self):
W
wanghuancoder 已提交
202
        self.check_grad(["X"], "Out")
203 204

    def init_test_case(self):
Z
zhupengyang 已提交
205
        self.ori_shape = (20, 5)
206
        self.axes = (1, 2)
Z
zhupengyang 已提交
207
        self.new_shape = (20, 1, 1, 5)
208 209 210 211 212 213 214

    def init_attrs(self):
        self.attrs = {}


class TestUnsqueezeOp1_AxesTensor(TestUnsqueezeOp_AxesTensor):
    def init_test_case(self):
Z
zhupengyang 已提交
215
        self.ori_shape = (20, 5)
216
        self.axes = (-1,)
Z
zhupengyang 已提交
217
        self.new_shape = (20, 5, 1)
218 219 220 221


class TestUnsqueezeOp2_AxesTensor(TestUnsqueezeOp_AxesTensor):
    def init_test_case(self):
Z
zhupengyang 已提交
222
        self.ori_shape = (20, 5)
223
        self.axes = (0, -1)
Z
zhupengyang 已提交
224
        self.new_shape = (1, 20, 5, 1)
225 226 227 228


class TestUnsqueezeOp3_AxesTensor(TestUnsqueezeOp_AxesTensor):
    def init_test_case(self):
Z
zhupengyang 已提交
229
        self.ori_shape = (10, 2, 5)
230
        self.axes = (0, 3, 3)
Z
zhupengyang 已提交
231
        self.new_shape = (1, 10, 2, 1, 1, 5)
232 233 234 235


class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor):
    def init_test_case(self):
Z
zhupengyang 已提交
236
        self.ori_shape = (10, 2, 5)
237
        self.axes = (3, 1, 1)
Z
zhupengyang 已提交
238
        self.new_shape = (10, 1, 1, 2, 5, 1)
239 240 241


# test api
242
class TestUnsqueezeAPI(unittest.TestCase):
243 244 245 246 247 248
    def setUp(self):
        self.executed_api()

    def executed_api(self):
        self.unsqueeze = paddle.unsqueeze

249
    def test_api(self):
250
        input = np.random.random([3, 2, 5]).astype("float64")
251
        x = paddle.static.data(name='x', shape=[3, 2, 5], dtype="float64")
252 253
        positive_3_int32 = paddle.tensor.fill_constant([1], "int32", 3)
        positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1)
254 255 256 257 258 259
        axes_tensor_int32 = paddle.static.data(
            name='axes_tensor_int32', shape=[3], dtype="int32"
        )
        axes_tensor_int64 = paddle.static.data(
            name='axes_tensor_int64', shape=[3], dtype="int64"
        )
260

261 262 263 264 265
        out_1 = self.unsqueeze(x, axis=[3, 1, 1])
        out_2 = self.unsqueeze(x, axis=[positive_3_int32, positive_1_int64, 1])
        out_3 = self.unsqueeze(x, axis=axes_tensor_int32)
        out_4 = self.unsqueeze(x, axis=3)
        out_5 = self.unsqueeze(x, axis=axes_tensor_int64)
266

267
        exe = paddle.static.Executor(place=paddle.CPUPlace())
268
        res_1, res_2, res_3, res_4, res_5 = exe.run(
269
            paddle.static.default_main_program(),
270 271
            feed={
                "x": input,
272
                "axes_tensor_int32": np.array([3, 1, 1]).astype("int32"),
273
                "axes_tensor_int64": np.array([3, 1, 1]).astype("int64"),
274
            },
275 276
            fetch_list=[out_1, out_2, out_3, out_4, out_5],
        )
277 278 279 280 281

        assert np.array_equal(res_1, input.reshape([3, 1, 1, 2, 5, 1]))
        assert np.array_equal(res_2, input.reshape([3, 1, 1, 2, 5, 1]))
        assert np.array_equal(res_3, input.reshape([3, 1, 1, 2, 5, 1]))
        assert np.array_equal(res_4, input.reshape([3, 2, 5, 1]))
282
        assert np.array_equal(res_5, input.reshape([3, 1, 1, 2, 5, 1]))
283 284 285

    def test_error(self):
        def test_axes_type():
286
            x2 = paddle.static.data(name="x2", shape=[2, 25], dtype="int32")
287
            self.unsqueeze(x2, axis=2.1)
288 289 290 291

        self.assertRaises(TypeError, test_axes_type)


292 293 294 295 296
class TestUnsqueezeInplaceAPI(TestUnsqueezeAPI):
    def executed_api(self):
        self.unsqueeze = paddle.unsqueeze_


297 298 299 300 301 302 303 304
class TestUnsqueezeAPI_ZeroDim(unittest.TestCase):
    def test_dygraph(self):
        paddle.disable_static()

        x = paddle.rand([])
        x.stop_gradient = False

        out = paddle.unsqueeze(x, [-1])
305
        out.retain_grads()
306 307 308 309 310 311
        out.backward()
        self.assertEqual(out.shape, [1])
        self.assertEqual(x.grad.shape, [])
        self.assertEqual(out.grad.shape, [1])

        out = paddle.unsqueeze(x, [-1, 1])
312
        out.retain_grads()
313 314 315 316 317 318
        out.backward()
        self.assertEqual(out.shape, [1, 1])
        self.assertEqual(x.grad.shape, [])
        self.assertEqual(out.grad.shape, [1, 1])

        out = paddle.unsqueeze(x, [0, 1, 2])
319
        out.retain_grads()
320 321 322 323 324 325 326 327
        out.backward()
        self.assertEqual(out.shape, [1, 1, 1])
        self.assertEqual(x.grad.shape, [])
        self.assertEqual(out.grad.shape, [1, 1, 1])

        paddle.enable_static()


328 329
if __name__ == "__main__":
    unittest.main()