test_tensor_shape.py 17.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import unittest

17
import numpy as np
18

19
import paddle
20
from paddle import fluid
21 22 23 24


def dyfunc_tensor_shape_1(x):
    x = fluid.dygraph.to_variable(x)
25
    res = paddle.reshape(x, shape=x.shape)
26 27 28 29
    return res


def dyfunc_tensor_shape_2(x):
30
    x = paddle.to_tensor(x)
31 32
    shape = x.shape
    shape2 = shape
33
    res = paddle.reshape(x, shape2)
34 35 36 37
    return res


def dyfunc_tensor_shape_3(x):
38
    # Transform y.shape but run y.shape actually because y is not Tensor
39
    x = fluid.dygraph.to_variable(x)
2
201716010711 已提交
40
    y = paddle.ones([1, 5])
41
    res = paddle.reshape(x, shape=y.shape)
42 43 44 45 46
    return res


def dyfunc_tensor_shape_4(x):
    x = fluid.dygraph.to_variable(x)
47
    res = paddle.reshape(x, shape=(-1, x.shape[0], len(x.shape)))
48 49 50 51 52
    return res


def dyfunc_tensor_shape_5(x):
    # `res = fluid.layers.reshape(x, shape=(-1, s))` to
53
    # `res = fluid.layers.reshape(x, shape=(-1,
54
    #           paddle.jit.dy2static.convert_var_shape(x)[0]))`
55 56
    x = fluid.dygraph.to_variable(x)
    s = x.shape[0]
57
    res = paddle.reshape(x, shape=(-1, s))
58 59 60
    return res


61 62 63 64 65 66
def dyfunc_tensor_shape_6(x):
    # `res = fluid.layers.reshape(x, shape=(-1, s))` to
    # `res = fluid.layers.reshape(x, shape=(-1,
    #           paddle.jit.dy2static.convert_var_shape(x)[0:]))`
    x = fluid.dygraph.to_variable(x)
    s = x.shape[0:]
67
    res = paddle.reshape(x, shape=s)
68 69 70
    return res


71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
def dyfunc_tuple_shape_1(x):
    x = paddle.to_tensor(x)
    a, b = x.shape
    res = paddle.reshape(x, shape=(b, a))
    return res


def dyfunc_tuple_shape_2(x):
    x = paddle.to_tensor(x)
    shape = x.shape
    a, b = shape
    res = paddle.reshape(x, shape=(b, a))
    return res


86 87 88 89 90 91 92
def dyfunc_tuple_shape_3(x):
    x = paddle.to_tensor(x)
    a, b = paddle.shape(x)
    res = paddle.reshape(x, shape=(b, a))
    return res


93 94 95 96 97 98
def dyfunc_paddle_shape_api(x):
    x = paddle.to_tensor(x)
    # paddle.shape will not be converted.
    a = paddle.shape(x)[0]
    # alias api will also not be converted.
    alias_old_api = paddle.fluid.layers
2
201716010711 已提交
99
    b = paddle.shape(x)[1]
100 101 102 103
    res = paddle.reshape(x, shape=(b, a))
    return res


104 105
def dyfunc_with_if_1(x):
    x = fluid.dygraph.to_variable(x)
106
    res = paddle.reshape(x, [-1, 1])
107 108
    x_shape_0 = x.shape[0]
    if x_shape_0 < 1:
109
        # `res.shape[0]` is transformed into
110
        #   `paddle.jit.dy2static.convert_var_shape(res)[0]`
111
        if res.shape[0] > 1:
112
            res = paddle.tensor.fill_constant(
113 114
                value=2, shape=x.shape, dtype="int32"
            )
115
        else:
116
            res = paddle.tensor.fill_constant(
117 118
                value=3, shape=x.shape, dtype="int32"
            )
119 120 121 122 123
    return res


def dyfunc_with_if_2(x):
    x = fluid.dygraph.to_variable(x)
124
    # `len(x.shape)` will not be transformed because x.shape is not used by Paddle api.
125 126 127
    if len(x.shape) < 1:
        res = x
    else:
128
        res = paddle.tensor.fill_constant(value=8, shape=x.shape, dtype="int32")
129 130 131 132 133 134

    return res


def dyfunc_with_for_1(x):
    x = fluid.dygraph.to_variable(x)
135
    res = paddle.tensor.fill_constant(value=0, shape=[1], dtype="int32")
136
    # `x.shape[0]` is transformed into `paddle.jit.dy2static.convert_var_shape(x)[0]`
137 138 139 140 141 142 143 144
    for i in range(x.shape[0]):
        res += 1
    return res


def dyfunc_with_for_2(x):
    x = fluid.dygraph.to_variable(x)
    x_shape_0 = x.shape[0]
145
    res = paddle.tensor.fill_constant(value=0, shape=[1], dtype="int32")
146

147
    # `x_shape_0` is transformed into `paddle.jit.dy2static.convert_var_shape(x)[0]`
148 149 150 151 152 153 154
    for i in range(x_shape_0):
        res += 1
    return res


def dyfunc_with_for_3(x):
    x = fluid.dygraph.to_variable(x)
155
    res = paddle.tensor.fill_constant(value=0, shape=[1], dtype="int32")
156 157 158 159 160 161 162 163 164
    # `len(x.shape)` is not transformed.
    for i in range(len(x.shape)):
        res += 1

    return res


def dyfunc_with_while_1(x):
    x = fluid.dygraph.to_variable(x)
165
    res = paddle.tensor.fill_constant(value=0, shape=[1], dtype="int32")
166
    # `x.shape[0]` is transformed into `paddle.jit.dy2static.convert_var_shape(x)[0]`
167 168 169 170 171 172 173 174 175 176
    i = 1
    while i < x.shape[0]:
        res += 1
        i = i + 2
    return res


def dyfunc_with_while_2(x):
    x = fluid.dygraph.to_variable(x)
    x_shape_0 = x.shape[0]
177
    res = paddle.tensor.fill_constant(value=0, shape=[1], dtype="int32")
178
    i = 1
179
    # `x_shape_0` is transformed into `paddle.jit.dy2static.convert_var_shape(x)[0]`
180
    while i < x_shape_0:
181 182 183
        res += 1
        i = i + 2
    return res
184 185


186 187 188
def dyfunc_with_while_3(x):
    x = fluid.dygraph.to_variable(x)
    x_shape = x.shape
189
    res = paddle.tensor.fill_constant(value=0, shape=[1], dtype="int32")
190 191 192 193 194 195 196 197 198
    i = 1

    # `len(x.shape)` is not transformed.
    while len(x_shape) > i:
        res += 1
        i += 1
    return res


199
def dyfunc_with_while_4(x):
200
    x = paddle.to_tensor(x)
2
201716010711 已提交
201
    y = paddle.ones([1, 5])
202 203 204 205 206 207 208 209 210 211
    y_shape_0 = y.shape[0]
    i = 1

    # Transform y_shape_0 but run y.shape[0] actually because y is not Tensor
    while y_shape_0 > i:
        x += 1
        i += 1
    return x


212 213 214 215 216 217 218 219
def dyfunc_change_shape_after_assign(x):
    x = paddle.to_tensor(x)
    a, b = x.shape
    x = paddle.reshape(x, shape=(-1, 1))
    res = paddle.reshape(x, shape=(b, a))
    return res


220 221 222 223 224 225
def dyfunc_len_paddle_shape():
    x = paddle.to_tensor([1, 2, 3])
    if len(paddle.shape(x)) > 0:
        print(x)


226 227 228 229 230 231
def dyfunc_dict_assign_shape():
    x = paddle.to_tensor([1, 2])
    a = {}
    a['shape'] = x.shape[0]


232 233
# 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase):
234
    def setUp(self):
235
        self.input = np.ones(5).astype("int32")
236 237 238 239 240
        self.place = (
            fluid.CUDAPlace(0)
            if fluid.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
241 242
        self._set_input_spec()
        self._set_expected_op_num()
243 244 245 246
        self.init_test_func()

    def init_test_func(self):
        self.dygraph_func = dyfunc_tensor_shape_1
247

248 249 250
    def _set_input_spec(self):
        self.input_spec = [paddle.static.InputSpec(shape=[5], dtype="int32")]

251
    def _run(self, to_static):
252
        with fluid.dygraph.guard():
253
            if to_static:
H
hjyp 已提交
254 255 256
                res = paddle.jit.to_static(self.dygraph_func)(
                    self.input
                ).numpy()
257 258
            else:
                res = self.dygraph_func(self.input).numpy()
259 260
            return res

261 262
    def get_dygraph_output(self):
        return self._run(to_static=False)
263

264
    def get_static_output(self):
265
        return self._run(to_static=True)
266 267

    def test_transformed_static_result(self):
268 269
        static_res = self.get_static_output()
        dygraph_res = self.get_dygraph_output()
270
        np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
271

272 273 274 275 276 277 278 279 280 281 282 283
    def _set_expected_op_num(self):
        self.expected_op_num = 2
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0

    def _compute_op_num(self, program):
        self.op_num = sum([len(block.ops) for block in program.blocks])
        self.shape_op_num = 0
        self.slice_op_num = 0

        for block in program.blocks:
            self.shape_op_num += len(
284 285
                [op for op in block.ops if op.type == "shape"]
            )
286
            self.slice_op_num += len(
287 288
                [op for op in block.ops if op.type == "slice"]
            )
289 290 291 292 293 294 295 296 297

    def test_op_num(self):
        static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec)
        program = static_layer.main_program
        self._compute_op_num(program)
        self.assertEqual(self.op_num, self.expected_op_num)
        self.assertEqual(self.shape_op_num, self.expected_shape_op_num)
        self.assertEqual(self.slice_op_num, self.expected_slice_op_num)

298 299 300 301 302

class TestTensorShapeBasic2(TestTensorShapeBasic):
    def init_test_func(self):
        self.dygraph_func = dyfunc_tensor_shape_2

303
    def _set_expected_op_num(self):
304
        self.expected_op_num = 1
305
        self.expected_shape_op_num = 0
306 307
        self.expected_slice_op_num = 0

308 309 310 311 312

class TestTensorShapeBasic3(TestTensorShapeBasic):
    def init_test_func(self):
        self.dygraph_func = dyfunc_tensor_shape_3

2
201716010711 已提交
313 314 315 316 317
    def _set_expected_op_num(self):
        self.expected_op_num = 3
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0

318 319 320 321 322 323 324 325 326 327

class TestTensorShapeBasic4(TestTensorShapeBasic):
    def init_test_func(self):
        self.dygraph_func = dyfunc_tensor_shape_4


class TestTensorShapeBasic5(TestTensorShapeBasic):
    def init_test_func(self):
        self.dygraph_func = dyfunc_tensor_shape_5

328
    def _set_expected_op_num(self):
329 330 331
        self.expected_op_num = 2
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0
332

333

334 335 336 337
class TestTensorShapeBasic6(TestTensorShapeBasic):
    def init_test_func(self):
        self.dygraph_func = dyfunc_tensor_shape_6

338
    def _set_expected_op_num(self):
339 340 341
        self.expected_op_num = 2
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0
342

343

344 345
class TestTupleShape1(TestTensorShapeBasic):
    def init_test_func(self):
346
        self.input = np.ones((5, 7)).astype("int32")
347 348 349
        self.input_spec = [
            paddle.static.InputSpec(shape=[-1, -1], dtype="int32")
        ]
350 351
        self.dygraph_func = dyfunc_tuple_shape_1

352
    def _set_expected_op_num(self):
353
        self.expected_op_num = 4
354
        self.expected_shape_op_num = 1
355 356
        self.expected_slice_op_num = 2

357 358 359

class TestTupleShape2(TestTensorShapeBasic):
    def init_test_func(self):
360
        self.input = np.ones((5, 7)).astype("int32")
361 362 363
        self.input_spec = [
            paddle.static.InputSpec(shape=[-1, 7], dtype="int32")
        ]
364
        self.dygraph_func = dyfunc_tuple_shape_2
365 366

    def _set_expected_op_num(self):
367
        self.expected_op_num = 4
368
        self.expected_shape_op_num = 1
369
        self.expected_slice_op_num = 1
370 371 372 373


class TestTupleShape3(TestTensorShapeBasic):
    def init_test_func(self):
374
        self.input = np.ones((5, 7)).astype("int32")
375 376
        self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")]
        self.dygraph_func = dyfunc_tuple_shape_3
377

378
    def _set_expected_op_num(self):
379
        self.expected_op_num = 4
380 381 382
        self.expected_shape_op_num = 1
        self.expected_slice_op_num = 2

383

384 385
class TestPaddleShapeApi(TestTensorShapeBasic):
    def init_test_func(self):
386
        self.input = np.ones((5, 7)).astype("int32")
387 388 389 390
        self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")]
        self.dygraph_func = dyfunc_paddle_shape_api

    def _set_expected_op_num(self):
391
        self.expected_op_num = 5
392 393 394 395
        self.expected_shape_op_num = 2
        self.expected_slice_op_num = 2


396 397 398 399 400
# 2. Tests with control flow if
class TestTensorShapeInIf1(TestTensorShapeBasic):
    def init_test_func(self):
        self.dygraph_func = dyfunc_with_if_1

401
    def _set_expected_op_num(self):
402 403 404
        self.expected_op_num = 2
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0
405

406 407 408 409 410

class TestTensorShapeInIf2(TestTensorShapeBasic):
    def init_test_func(self):
        self.dygraph_func = dyfunc_with_if_2

411
    def _set_expected_op_num(self):
412 413 414
        self.expected_op_num = 2
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0
415

416 417 418 419 420 421

# 3. Tests with control flow for loop
class TestTensorShapeInFor1(TestTensorShapeBasic):
    def init_test_func(self):
        self.dygraph_func = dyfunc_with_for_1

422
    def _set_expected_op_num(self):
423 424 425
        self.expected_op_num = 7
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0
426

427

428
class TestTensorShapeInFor2(TestTensorShapeInFor1):
429 430 431
    def init_test_func(self):
        self.dygraph_func = dyfunc_with_for_2

432
    def _set_expected_op_num(self):
433 434 435
        self.expected_op_num = 7
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0
436

437

438 439 440 441 442
class TestTensorShapeInFor3(TestTensorShapeInFor1):
    def init_test_func(self):
        self.dygraph_func = dyfunc_with_for_3

    def _set_expected_op_num(self):
443 444 445
        self.expected_op_num = 3
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0
446 447


448
# 4. Tests with control flow while loop
449
class TestTensorShapeInWhile1(TestTensorShapeInFor1):
450 451 452
    def init_test_func(self):
        self.dygraph_func = dyfunc_with_while_1

453 454 455 456 457
    def _set_expected_op_num(self):
        self.expected_op_num = 4
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0

458

459
class TestTensorShapeInWhile2(TestTensorShapeInFor1):
460 461
    def init_test_func(self):
        self.dygraph_func = dyfunc_with_while_2
462

463
    def _set_expected_op_num(self):
464 465 466
        self.expected_op_num = 4
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0
467

468

469 470 471 472
class TestTensorShapeInWhile3(TestTensorShapeBasic):
    def init_test_func(self):
        self.dygraph_func = dyfunc_with_while_3

473
    def _set_expected_op_num(self):
474 475
        self.expected_op_num = 2
        self.expected_shape_op_num = 0
476
        self.expected_slice_op_num = 0
477

478 479 480 481 482

class TestTensorShapeInWhile4(TestTensorShapeBasic):
    def init_test_func(self):
        self.dygraph_func = dyfunc_with_while_4

483
    def _set_expected_op_num(self):
2
201716010711 已提交
484
        self.expected_op_num = 1
485 486 487 488 489 490 491 492 493 494 495 496 497
        self.expected_shape_op_num = 0
        self.expected_slice_op_num = 0


# 5. Test op num for negetive dim
class TestOpNumBasicWithTensorShape(unittest.TestCase):
    def setUp(self):
        self._set_input_spec()
        self._set_test_func()
        self._set_expected_op_num()

    def _set_input_spec(self):
        self.input_spec = [
498
            paddle.static.InputSpec(shape=[-1, 5], dtype="int32")
499 500 501 502 503 504
        ]

    def _set_test_func(self):
        self.dygraph_func = dyfunc_tensor_shape_1

    def _set_expected_op_num(self):
505
        self.expected_op_num = 5
506
        self.expected_shape_op_num = 1
507
        self.expected_slice_op_num = 1
508 509 510 511 512 513 514 515

    def _compute_op_num(self, program):
        self.op_num = sum([len(block.ops) for block in program.blocks])
        self.shape_op_num = 0
        self.slice_op_num = 0

        for block in program.blocks:
            self.shape_op_num += len(
516 517
                [op for op in block.ops if op.type == "shape"]
            )
518
            self.slice_op_num += len(
519 520
                [op for op in block.ops if op.type == "slice"]
            )
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536

    def test_op_num(self):
        static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec)
        program = static_layer.main_program

        self._compute_op_num(program)
        self.assertEqual(self.op_num, self.expected_op_num)
        self.assertEqual(self.shape_op_num, self.expected_shape_op_num)
        self.assertEqual(self.slice_op_num, self.expected_slice_op_num)


class TestOpNumBasicWithTensorShape4(TestOpNumBasicWithTensorShape):
    def _set_test_func(self):
        self.dygraph_func = dyfunc_tensor_shape_4

    def _set_expected_op_num(self):
537 538 539
        self.expected_op_num = 8
        self.expected_shape_op_num = 2
        self.expected_slice_op_num = 2
540 541 542 543 544 545 546


class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
    def _set_test_func(self):
        self.dygraph_func = dyfunc_tuple_shape_1

    def _set_expected_op_num(self):
547
        self.expected_op_num = 4
548 549
        self.expected_shape_op_num = 1
        self.expected_slice_op_num = 1
550 551 552 553 554 555 556


class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
    def _set_test_func(self):
        self.dygraph_func = dyfunc_with_if_1

    def _set_expected_op_num(self):
557
        self.expected_op_num = 32
558
        self.expected_shape_op_num = 4
559
        self.expected_slice_op_num = 4
560 561 562 563 564 565 566


class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape):
    def _set_test_func(self):
        self.dygraph_func = dyfunc_with_for_1

    def _set_expected_op_num(self):
567 568
        self.expected_op_num = 29
        self.expected_shape_op_num = 2
569 570 571 572 573 574 575 576
        self.expected_slice_op_num = 3


class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape):
    def _set_test_func(self):
        self.dygraph_func = dyfunc_with_while_1

    def _set_expected_op_num(self):
577
        self.expected_op_num = 21
578 579 580
        self.expected_shape_op_num = 3
        self.expected_slice_op_num = 3

581

582 583
class TestChangeShapeAfterAssign(TestTensorShapeBasic):
    def init_test_func(self):
584
        self.input = np.ones((2, 3)).astype("int32")
585 586 587
        self.input_spec = [
            paddle.static.InputSpec(shape=[-1, 3], dtype="int32")
        ]
588 589 590
        self.dygraph_func = dyfunc_change_shape_after_assign

    def _set_expected_op_num(self):
591
        self.expected_op_num = 5
592 593
        self.expected_shape_op_num = 1
        self.expected_slice_op_num = 1
594 595


596 597 598 599 600 601 602 603
def dyfunc_with_static_convert_var_shape(x):
    # Note: this will create `batch_size__static_convert_var_shape_suffix_0` firstly.
    batch_size = x.shape[0]
    if len(x.shape) < 1:
        res = x
    else:
        # Test for correctly to find `batch_size__static_convert_var_shape_suffix_0` in
        # deeply nested scope.
604
        res = paddle.tensor.fill_constant(
605 606
            value=8, shape=[batch_size], dtype="int32"
        )
607 608 609 610 611 612 613 614 615 616 617 618

    return res


class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
    def test(self):
        x_spec = paddle.static.InputSpec(shape=[None, 10])
        func = paddle.jit.to_static(dyfunc_with_if_2, input_spec=[x_spec])
        # Call this function to trigger program translation.
        func.concrete_program


619 620
if __name__ == '__main__':
    unittest.main()