test_reshape_op.py 12.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Y
Yibing Liu 已提交
17 18 19
import unittest
import numpy as np

Y
ying 已提交
20
from op_test import OpTest
21
import paddle
22
import paddle.fluid as fluid
23
from paddle.fluid import compiler, Program, program_guard
Y
Yibing Liu 已提交
24

C
caoying03 已提交
25

26
# situation 1: have shape( list, no tensor), no actual shape(Tensor)
C
caoying03 已提交
27 28
class TestReshapeOp(OpTest):
    def setUp(self):
29 30 31 32 33 34 35 36
        self.init_data()
        self.op_type = "reshape2"
        self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
        self.attrs = {"shape": self.new_shape}
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.infered_shape),
            'XShape': np.random.random(self.ori_shape).astype("float32")
        }
Y
ying 已提交
37

38
    def init_data(self):
Z
zhupengyang 已提交
39 40 41
        self.ori_shape = (2, 60)
        self.new_shape = (12, 10)
        self.infered_shape = (12, 10)
42 43

    def test_check_output(self):
44
        self.check_output(no_check_set=['XShape'])
45 46 47 48 49

    def test_check_grad(self):
        self.check_grad(["X"], "Out")


50 51
class TestReshapeOpDimInfer1(TestReshapeOp):
    def init_data(self):
Z
zhupengyang 已提交
52
        self.ori_shape = (5, 25)
53 54
        self.new_shape = (5, -1, 5)
        self.infered_shape = (5, -1, 5)
C
caoying03 已提交
55 56


57 58
class TestReshapeOpDimInfer2(TestReshapeOp):
    def init_data(self):
Z
zhupengyang 已提交
59 60 61
        self.ori_shape = (10, 2, 6)
        self.new_shape = (10, 0, 3, -1)
        self.infered_shape = (10, 2, 3, -1)
C
caoying03 已提交
62

C
caoying03 已提交
63

64
# situation 2: have shape(list, no tensor), have actual shape(Tensor)
65 66
class TestReshapeOpWithInputShape(OpTest):
    def setUp(self):
67
        self.init_data()
68
        self.op_type = "reshape2"
69

70
        self.inputs = {
71
            "X": np.random.random(self.ori_shape).astype("float32"),
72
            "Shape": np.array(
73
                self.actual_shape, dtype="int32")
74
        }
75
        self.attrs = {"shape": self.new_shape}
76
        self.outputs = {
77 78
            "Out": self.inputs["X"].reshape(self.actual_shape),
            'XShape': np.random.random(self.ori_shape).astype("float32")
79
        }
80

81
    def init_data(self):
Z
zhupengyang 已提交
82 83 84
        self.ori_shape = (6, 20)
        self.new_shape = (0, -1, 20)
        self.actual_shape = (2, 3, 20)
85

86
    def test_check_output(self):
87
        self.check_output(no_check_set=['XShape'])
88

G
guosheng 已提交
89
    def test_check_grad(self):
C
chengduo 已提交
90
        self.check_grad(["X"], "Out")
91 92


93 94
# Situation 3: have shape(list, have tensor), no actual shape(Tensor)
class TestReshapeOp_attr_ShapeTensor(OpTest):
95 96 97 98 99 100 101 102 103 104 105 106 107
    def setUp(self):
        self.init_data()
        self.op_type = "reshape2"

        shape_tensor = []
        for index, ele in enumerate(self.new_shape):
            shape_tensor.append(("x" + str(index), np.ones(
                (1)).astype('int32') * ele))

        self.inputs = {
            "X": np.random.random(self.ori_shape).astype("float32"),
            'ShapeTensor': shape_tensor
        }
108 109 110 111 112 113 114
        self.attrs = {'shape': self.shape}
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.infered_shape),
            'XShape': np.random.random(self.ori_shape).astype("float32")
        }

    def init_data(self):
Z
zhupengyang 已提交
115 116 117
        self.ori_shape = (4, 25)
        self.new_shape = (10, 10)
        self.infered_shape = (10, 10)
118 119 120 121 122 123 124 125 126 127 128
        self.shape = (-1, -1)

    def test_check_output(self):
        self.check_output(no_check_set=['XShape'])

    def test_check_grad(self):
        self.check_grad(["X"], "Out")


class TestReshapeOpDimInfer1_attr_ShapeTensor(TestReshapeOp_attr_ShapeTensor):
    def init_data(self):
Z
zhupengyang 已提交
129 130 131
        self.ori_shape = (5, 20)
        self.new_shape = (5, -1, 20)
        self.infered_shape = (5, -1, 20)
132 133 134 135 136
        self.shape = (5, -1, -1)


class TestReshapeOpDimInfer2_attr_ShapeTensor(TestReshapeOp_attr_ShapeTensor):
    def init_data(self):
Z
zhupengyang 已提交
137 138 139 140
        self.ori_shape = (10, 2, 6)
        self.new_shape = (10, 0, 3, -1)
        self.infered_shape = (10, 2, 3, -1)
        self.shape = (10, 0, 3, -1)
141 142 143 144 145 146 147 148 149 150 151 152 153


# Situation 4: have shape(Tensor), no actual shape(Tensor)
class TestReshapeOp_attr_OnlyShape(OpTest):
    def setUp(self):
        self.init_data()
        self.op_type = "reshape2"

        self.inputs = {
            "X": np.random.random(self.ori_shape).astype("float32"),
            "Shape": np.array(
                self.new_shape, dtype="int32")
        }
154 155 156 157 158 159 160
        self.attrs = {}
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.infered_shape),
            'XShape': np.random.random(self.ori_shape).astype("float32")
        }

    def init_data(self):
Z
zhupengyang 已提交
161 162 163
        self.ori_shape = (4, 25)
        self.new_shape = (10, 10)
        self.infered_shape = (10, 10)
164 165 166 167 168 169 170 171

    def test_check_output(self):
        self.check_output(no_check_set=['XShape'])

    def test_check_grad(self):
        self.check_grad(["X"], "Out")


172
class TestReshapeOpDimInfer1_attr_OnlyShape(TestReshapeOp_attr_OnlyShape):
173
    def init_data(self):
Z
zhupengyang 已提交
174 175 176
        self.ori_shape = (5, 20)
        self.new_shape = (5, -1, 10)
        self.infered_shape = (5, -1, 10)
177
        self.shape = (5, -1, -1)
178 179


180
class TestReshapeOpDimInfer2_attr_OnlyShape(TestReshapeOp_attr_OnlyShape):
181
    def init_data(self):
Z
zhupengyang 已提交
182 183 184 185
        self.ori_shape = (10, 2, 6)
        self.new_shape = (10, 0, 3, -1)
        self.infered_shape = (10, 2, 3, -1)
        self.shape = (10, 0, 3, -1)
186 187


188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
# test int8 data type on CPU
class TestReshapeInt8Op(OpTest):
    def setUp(self):
        self.init_dtype()
        self.init_data()
        self.use_mkldnn = True
        self._cpu_only = True
        self.op_type = "reshape2"
        input = np.random.randint(0, 127, self.ori_shape).astype(self.dtype)
        self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
        self.attrs = {
            'shape': self.new_shape,
            'use_mkldnn': self.use_mkldnn,
        }
        self.outputs = {
            "Out": self.inputs["X"].reshape(self.infered_shape),
            'XShape': np.random.random(self.ori_shape).astype(np.float32)
        }

    def init_dtype(self):
        self.dtype = np.int8

    def init_data(self):
Z
zhupengyang 已提交
211 212 213
        self.ori_shape = (10, 2, 6)
        self.new_shape = (10, 0, 3, -1)
        self.infered_shape = (10, 2, 3, -1)
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228

    def test_check_output(self):
        self.check_output_with_place(
            fluid.core.CPUPlace(), atol=1e-5, no_check_set=['XShape'])

    def test_check_grad(self):
        pass


# test unt8 data type on CPU
class TestReshapeUint8Op(TestReshapeInt8Op):
    def init_dtype(self):
        self.dtype = np.uint8


229
# Test python API
230
class TestReshapeAPI(unittest.TestCase):
231 232 233 234 235 236 237 238 239 240 241 242
    def _set_paddle_api(self):
        self.fill_constant = paddle.fill_constant
        self.data = paddle.data
        self.reshape = paddle.reshape
        self.to_tensor = paddle.to_tensor

    def _set_fluid_api(self):
        self.fill_constant = fluid.layers.fill_constant
        self.data = fluid.data
        self.reshape = fluid.layers.reshape

    def _test_api(self):
243 244
        input = np.random.random([2, 25]).astype("float32")
        shape = [2, 5, 5]
245 246 247 248
        main_prog = Program()
        with program_guard(main_prog, Program()):
            positive_five = self.fill_constant([1], "int32", 5)
            x = self.data(name="x", shape=[2, 25], dtype="float32")
249

250
            actual_shape = self.data(name="shape", shape=[3], dtype="int32")
251

252 253
            # situation 1: have shape( list, no tensor), no actual shape(Tensor)
            out_1 = self.reshape(x, shape)
254

255 256 257
            # situation 2: have shape(list, no tensor), have actual shape(Tensor)
            out_2 = fluid.layers.reshape(
                x, shape=shape, actual_shape=actual_shape)
258

259 260
            # Situation 3: have shape(list, have tensor), no actual shape(Tensor)
            out_3 = self.reshape(x, shape=[positive_five, 10])
261

262 263
            # Situation 4: have shape(Tensor), no actual shape(Tensor)
            out_4 = self.reshape(x, shape=actual_shape)
264 265 266

        exe = fluid.Executor(place=fluid.CPUPlace())
        res_1, res_2, res_3, res_4 = exe.run(
267
            main_prog,
268 269 270 271 272 273 274 275
            feed={"x": input,
                  "shape": np.array([2, 5, 5]).astype("int32")},
            fetch_list=[out_1, out_2, out_3, out_4])

        assert np.array_equal(res_1, input.reshape(shape))
        assert np.array_equal(res_2, input.reshape(shape))
        assert np.array_equal(res_3, input.reshape([5, 10]))
        assert np.array_equal(res_4, input.reshape(shape))
276

277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
    def test_paddle_api(self):
        self._set_paddle_api()
        self._test_api()

    def test_fluid_api(self):
        self._set_fluid_api()
        self._test_api()

    def test_imperative(self):
        self._set_paddle_api()
        input = np.random.random([2, 25]).astype("float32")
        shape = [2, 5, 5]
        with fluid.dygraph.guard():
            x = self.to_tensor(input)
            positive_five = self.fill_constant([1], "int32", 5)

            out_1 = self.reshape(x, shape)

            out_2 = self.reshape(x, shape=[positive_five, 10])

            shape_tensor = self.to_tensor(np.array([2, 5, 5]).astype("int32"))
            out_3 = self.reshape(x, shape=shape_tensor)

        assert np.array_equal(out_1.numpy(), input.reshape(shape))
        assert np.array_equal(out_2.numpy(), input.reshape([5, 10]))
        assert np.array_equal(out_3.numpy(), input.reshape(shape))

304

305
# Test Input Error
306
class TestReshapeOpError(unittest.TestCase):
307 308 309 310 311 312 313 314 315
    def _set_paddle_api(self):
        self.data = paddle.data
        self.reshape = paddle.reshape

    def _set_fluid_api(self):
        self.data = fluid.data
        self.reshape = fluid.layers.reshape

    def _test_errors(self):
316 317 318 319 320
        with program_guard(Program(), Program()):
            # The x type of reshape_op must be Variable.
            def test_x_type():
                x1 = fluid.create_lod_tensor(
                    np.array([[-1]]), [[1]], fluid.CPUPlace())
321
                self.reshape(x1, shape=[1])
322 323 324

            self.assertRaises(TypeError, test_x_type)

325
            # The x dtype of reshape_op must be float16, float32, float64, int32 or int64.
326
            def test_x_dtype():
327 328
                x2 = self.data(name="x2", shape=[2, 25], dtype="bool")
                self.reshape(x2, shape=[2, 5, 5])
329 330 331

            self.assertRaises(TypeError, test_x_dtype)

332
            def test_x_dtype_float16():
333 334 335
                x_float16 = self.data(
                    name="x_float16", shape=[2, 25], dtype="float16")
                self.reshape(x_float16, shape=[2, 5, 5])
336 337 338

            test_x_dtype_float16()

339
            x3 = self.data(name="x3", shape=[2, 25], dtype="float32")
340 341 342

            # The argument shape's type of reshape_op must be list, tuple or Variable.
            def test_shape_type():
343
                self.reshape(x3, shape=1)
344 345 346 347 348

            self.assertRaises(TypeError, test_shape_type)

            # The argument actual_shape's type of reshape_op must be Variable or None.
            def test_actual_shape_type():
349
                self.reshape(x3, shape=[25, 2], actual_shape=1)
350 351 352 353 354

            self.assertRaises(TypeError, test_actual_shape_type)

            # The argument shape have more than one -1.
            def test_shape_1():
355
                self.reshape(x3, shape=[-1, -1, 5])
356 357 358 359 360

            self.assertRaises(AssertionError, test_shape_1)

            # The argument shape have element 0 whose index exceed the input dimension.
            def test_shape_2():
361
                self.reshape(x3, [2, 5, 5, 0])
362 363 364

            self.assertRaises(AssertionError, test_shape_2)

T
tianshuo78520a 已提交
365
            # The argument shape have more than one negative value.
366
            def test_shape_3():
367
                self.reshape(x3, [-1, -2, 5])
368 369 370

            self.assertRaises(AssertionError, test_shape_3)

371 372 373 374 375 376 377 378
    def test_paddle_api_error(self):
        self._set_paddle_api()
        self._test_errors()

    def test_fluid_api_error(self):
        self._set_fluid_api()
        self._test_errors()

379

Y
ying 已提交
380
if __name__ == "__main__":
Y
Yibing Liu 已提交
381
    unittest.main()