test_cond.py 25.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16
import unittest
17 18 19 20

import numpy as np
from simple_nets import batchnorm_fc_with_inputs, simple_fc_net_with_inputs

21
import paddle
22 23 24
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.framework as framework
25
import paddle.fluid.layers as layers
26
from paddle.fluid.backward import append_backward
27
from paddle.fluid.framework import Program, program_guard
28 29

np.random.seed(123)
30 31


32
class TestCondInputOutput(unittest.TestCase):
33 34 35 36 37 38 39 40 41 42
    def test_return_single_var(self):
        """
        pseudocode:

        if 0.23 < 0.1:
            return 2
        else:
            return -1
        """

43 44
        paddle.enable_static()

45 46 47 48 49 50 51 52 53 54 55
        def true_func():
            return layers.fill_constant(shape=[2, 3], dtype='int32', value=2)

        def false_func():
            return layers.fill_constant(shape=[3, 2], dtype='int32', value=-1)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
            y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
L
LiYuRio 已提交
56
            pred = paddle.less_than(y, x)
57
            out = paddle.static.nn.cond(pred, true_func, false_func)
58 59
            # out is one tensor

60 61 62 63 64
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
65
        exe = fluid.Executor(place)
66 67 68 69
        (ret,) = exe.run(main_program, fetch_list=[out.name])
        np.testing.assert_allclose(
            np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05
        )
70 71 72 73 74 75 76 77 78 79 80

    def test_return_var_tuple(self):
        """
        pseudocode:

        if True:
            return 1, True
        else:
            return 3, 2
        """

81 82
        paddle.enable_static()

83
        def true_func():
84 85 86
            return layers.fill_constant(
                shape=[1, 2], dtype='int32', value=1
            ), layers.fill_constant(shape=[2, 3], dtype='bool', value=True)
87 88

        def false_func():
89 90 91
            return layers.fill_constant(
                shape=[3, 4], dtype='float32', value=3
            ), layers.fill_constant(shape=[4, 5], dtype='int64', value=2)
92 93 94 95 96

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            pred = layers.fill_constant(shape=[1], dtype='bool', value=True)
97
            out = paddle.static.nn.cond(pred, true_func, false_func)
98 99
            # out is a tuple containing 2 tensors

100 101 102 103 104
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
105 106
        exe = fluid.Executor(place)
        ret = exe.run(main_program, fetch_list=out)
107 108 109 110 111 112
        np.testing.assert_allclose(
            np.asarray(ret[0]), np.full((1, 2), 1, np.int32), rtol=1e-05
        )
        np.testing.assert_allclose(
            np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05
        )
113 114 115 116 117 118 119 120 121 122 123 124

    def test_pass_and_modify_var(self):
        """
        pseudocode:
        for i in range(5):
            a = 7
            if i % 2 == 0:
                a = a * (i + 1)
            else:
                a = a - (i - 1)
        """

125 126
        paddle.enable_static()

127 128 129 130 131 132 133 134 135 136 137 138 139
        def true_func(a, i):
            a = a * (i + 1)
            return a

        def false_func(a, i):
            a = a - (i - 1)
            return a

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            a = layers.fill_constant(shape=[3, 2, 1], dtype='int32', value=7)
            i = fluid.data(name="i", shape=[1], dtype='int32')
140
            pred = (i % 2) == 0
141
            a = paddle.static.nn.cond(
142 143 144 145 146 147 148
                pred, lambda: true_func(a, i), lambda: false_func(a, i)
            )
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
149 150 151
        exe = fluid.Executor(place)
        for feed_i in range(5):
            expected_a = 7 * (feed_i + 1) if feed_i % 2 == 0 else 8 - feed_i
152 153 154 155 156 157 158 159 160 161
            (ret,) = exe.run(
                main_program,
                feed={'i': np.full((1), feed_i, np.int32)},
                fetch_list=[a],
            )
            np.testing.assert_allclose(
                np.asarray(ret),
                np.full((3, 2, 1), expected_a, np.int32),
                rtol=1e-05,
            )
162 163 164 165 166 167 168 169 170 171 172

    def test_return_none(self):
        """
        pseudocode: test doing nothing in branches
        for i in range(5):
            if i % 2 == 0:
                pass
            else:
                pass
        """

173 174
        paddle.enable_static()

175 176 177 178 179 180 181 182 183 184
        def true_func():
            pass

        def false_func():
            return None

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='int32')
185
            pred = (i % 2) == 0
186 187 188
            out1 = paddle.static.nn.cond(pred, true_func, false_func)
            out2 = paddle.static.nn.cond(pred, None, false_func)
            out3 = paddle.static.nn.cond(pred, true_func, None)
189 190 191 192 193
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
194 195 196 197 198 199 200 201 202 203 204 205 206
        exe = fluid.Executor(place)
        for feed_i in range(5):
            # Test that output is None is runnable
            exe.run(main_program, feed={'i': np.full((1), feed_i, np.int32)})
            self.assertIsNone(out1)
            self.assertIsNone(out2)
            self.assertIsNone(out3)

    def test_wrong_structure_exception(self):
        """
        test returning different number of tensors cannot merge into output
        """

207 208
        paddle.enable_static()

209 210 211 212 213 214 215
        def func_return_none():
            return None

        def func_return_one_tensor():
            return layers.fill_constant(shape=[2, 7], dtype='int32', value=3)

        def func_return_two_tensors():
216 217 218
            return layers.fill_constant(
                shape=[3, 1], dtype='int32', value=7
            ), layers.fill_constant(shape=[3, 1], dtype='int32', value=8)
219 220 221 222 223

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='int32')
224
            pred = (i % 2) == 0
225
            with self.assertRaises(TypeError):
226
                out = paddle.static.nn.cond(pred, i, func_return_one_tensor)
227

228
            with self.assertRaises(TypeError):
229 230 231
                out = paddle.static.nn.cond(
                    pred, func_return_one_tensor, np.asarray([3])
                )
232 233

            with self.assertRaises(Exception) as e:
234
                out = paddle.static.nn.cond(
235 236
                    pred, func_return_none, func_return_one_tensor
                )
237
            self.assertTrue(
238 239 240
                "Incompatible return values of true_fn and false_fn in cond"
                in str(e.exception)
            )
241 242

            with self.assertRaises(Exception) as e:
243
                out = paddle.static.nn.cond(
244 245
                    pred, func_return_two_tensors, func_return_none
                )
246
            self.assertTrue(
247 248 249
                "Incompatible return values of true_fn and false_fn in cond"
                in str(e.exception)
            )
250 251

            with self.assertRaises(Exception) as e:
252
                out = paddle.static.nn.cond(
253 254
                    pred, func_return_one_tensor, func_return_two_tensors
                )
255
            self.assertTrue(
256
                "true fn returns 1 vars, but false fn returns 2 vars, which is not equals"
257 258
                in str(e.exception)
            )
259

260
    def test_extremely_simple_net_with_op_in_condition(self):
261
        paddle.enable_static()
262 263 264
        main_program = fluid.Program()
        startup_program = fluid.Program()
        with fluid.program_guard(main_program, startup_program):
265 266 267
            a = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.23
            )
268
            a.stop_gradient = False
269 270 271
            b = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.25
            )
272
            b.stop_gradient = False
273
            out = paddle.static.nn.cond(a - b < -1.0, lambda: a, lambda: b)
274 275
        append_backward(out)

276 277 278 279 280
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
281
        exe = fluid.Executor(place)
282 283 284
        ret = exe.run(
            main_program, fetch_list=[out, b, a.grad_name, b.grad_name]
        )
285 286
        # Note: fill_constant has loss of precision, you have to assertEqual
        # with values doens't lose precision in float-point number.
287 288 289
        self.assertEqual(ret[0][0], ret[1][0])
        self.assertEqual(ret[2][0], 0.0)
        self.assertEqual(ret[3][0], 1.0)
290

291

292 293 294 295 296 297 298 299
class TestCondNestedControlFlow(unittest.TestCase):
    def test_cond_inside_cond(self):
        """
        pseudocode:
        for i in range(1, 10):
            a = 2 * i
            if i < 5:
                if i >= 3:
300
                    return a + a
301 302 303 304 305 306 307 308 309
                else:
                    return a - a
            else:
                if i < 8:
                    return a * a
                else:
                    return a / a
        """

310 311
        paddle.enable_static()

312
        def less_than_branch(i, a):
313
            return paddle.static.nn.cond(
314
                i >= 3.0,
315 316
                lambda: paddle.add(a, a),
                lambda: paddle.subtract(a, a),
317
            )
318 319

        def greater_equal_branch(i, a):
320
            return paddle.static.nn.cond(
321
                i < 8.0,
322 323
                lambda: paddle.multiply(a, a),
                lambda: paddle.divide(a, a),
324
            )
325 326 327 328 329 330

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='float32')
            a = 2.0 * i
331
            out = paddle.static.nn.cond(
332 333 334 335
                i < 5.0,
                lambda: less_than_branch(i, a),
                lambda: greater_equal_branch(i, a),
            )
336
            mean = paddle.mean(out)
337 338
            append_backward(mean)

339 340 341 342 343
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
344 345 346 347 348 349 350 351 352
        exe = fluid.Executor(place)
        for feed_i in range(0, 10):
            expected_a = 2.0 * feed_i
            if feed_i < 5:
                expected_ret = expected_a + expected_a if feed_i >= 3 else 0.0
                expected_a_grad = 2.0 if feed_i >= 3 else 0.0
            else:
                expected_ret = expected_a * expected_a if feed_i < 8 else 1.0
                expected_a_grad = 2.0 * expected_a if feed_i < 8 else 0.0
353 354 355 356 357
            ret = exe.run(
                main_program,
                feed={'i': np.full((1), feed_i, np.float32)},
                fetch_list=[out.name, a.grad_name],
            )
358 359 360
            self.assertEqual(ret[0][0], expected_ret)
            self.assertEqual(ret[1][0], expected_a_grad)

361
    def test_cond_op_in_condition(self):
362
        paddle.enable_static()
363 364 365 366
        main_program = fluid.Program()
        startup_program = fluid.Program()

        with fluid.program_guard(main_program, startup_program):
367 368 369
            a = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.23
            )
370
            a.stop_gradient = False
371 372 373
            b = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.24
            )
374
            b.stop_gradient = False
375
            out = paddle.static.nn.cond(
376
                a < b,
377
                lambda: paddle.static.nn.cond(
378
                    a - b < -1.0,
379 380
                    lambda: paddle.add(a, b),
                    lambda: paddle.multiply(a, b),
381
                ),
382
                lambda: paddle.static.nn.cond(
383
                    a == b,
384
                    lambda: paddle.subtract(a, b),
385
                    lambda: paddle.pow(a, b),
386 387
                ),
            )
388 389
            append_backward(out)

390 391 392 393 394
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
395 396
        exe = fluid.Executor(place)
        ret = exe.run(main_program, fetch_list=[out, a.grad_name, b.grad_name])
397
        # Note: fill_constant has loss of precision, so we assertAlmostEqual.
398 399 400 401
        self.assertAlmostEqual(ret[0][0], 1.5252)
        self.assertAlmostEqual(ret[1][0], 1.24)
        self.assertAlmostEqual(ret[2][0], 1.23)

402

403
class TestCondBackward(unittest.TestCase):
404
    def backward_value_helper(self, cond_func, use_cuda, use_parallel_exe):
405 406 407
        """
        Helper function that compares calculated backward value is close to dy/dx
        """
408
        paddle.enable_static()
409 410 411 412 413 414 415 416 417 418 419
        main_program = Program()
        main_program.random_seed = 123
        startup_program = Program()
        startup_program.random_seed = 123
        with program_guard(main_program, startup_program):
            img = fluid.data(name='image', shape=[-1, 9], dtype='float32')
            img.stop_gradient = False
            label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
            i = fluid.data(name="i", shape=[1], dtype='int32')
            loss = cond_func(i, img, label)
            append_backward(loss)
420
        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
421 422 423
        exe = fluid.Executor(place)
        exe.run(startup_program)

424 425 426
        num_devices = 1
        if use_parallel_exe:
            os.environ['CPU_NUM'] = str(2)
427 428 429 430 431
            exe = fluid.ParallelExecutor(
                use_cuda=use_cuda,
                main_program=main_program,
                loss_name=loss.name,
            )
432 433
            num_devices = exe.device_count

434 435 436
        delta = 0.005
        for feed_i in range(0, 10):
            feed_img = np.random.random(size=[1, 9]).astype(np.float32)
437 438 439
            feed_label = np.random.randint(
                low=0, high=10, size=[1, 1], dtype=np.int64
            )
440 441 442 443
            if use_parallel_exe:
                img_grad, loss_value = exe.run(
                    feed={
                        'i': np.full((num_devices), feed_i, np.int32),
444
                        'image': np.repeat(feed_img, num_devices, axis=0),
445
                        'label': np.repeat(feed_label, num_devices, axis=0),
446
                    },
447 448
                    fetch_list=[img.grad_name, loss.name],
                )
449 450 451 452 453 454
            else:
                img_grad, loss_value = exe.run(
                    main_program,
                    feed={
                        'i': np.full((1), feed_i, np.int32),
                        'image': feed_img,
455
                        'label': feed_label,
456
                    },
457 458
                    fetch_list=[img.grad_name, loss.name],
                )
459

460
            numerical_grad = np.zeros(shape=[num_devices, 9], dtype=np.float32)
461 462 463
            feed_img_delta = np.copy(feed_img)
            for j in range(9):
                feed_img_delta[0][j] = feed_img[0][j] + delta
464
                if use_parallel_exe:
465 466 467 468 469 470 471 472 473 474 475 476 477
                    loss_delta = exe.run(
                        feed={
                            'i': np.full((num_devices), feed_i, np.int32),
                            'image': np.repeat(
                                feed_img_delta, num_devices, axis=0
                            ),
                            'label': np.repeat(feed_label, num_devices, axis=0),
                        },
                        fetch_list=[loss.name],
                    )
                    multi_device_grad = (
                        (loss_delta[0] - loss_value[0]) / delta / num_devices
                    )
478 479 480
                    for d in range(num_devices):
                        numerical_grad[d][j] = multi_device_grad[d]
                else:
481 482 483 484 485 486 487 488 489 490 491 492
                    loss_delta = exe.run(
                        main_program,
                        feed={
                            'i': np.full((1), feed_i, np.int32),
                            'image': feed_img_delta,
                            'label': feed_label,
                        },
                        fetch_list=[loss.name],
                    )
                    numerical_grad[0][j] = (
                        loss_delta[0] - loss_value[0]
                    ) / delta
493
                feed_img_delta[0][j] = feed_img[0][j]
494 495 496
            np.testing.assert_allclose(
                img_grad, numerical_grad, rtol=0.05, atol=0.05
            )
497

498
    def add_optimizer_helper(self, cond_func, use_cuda, use_parallel_exe):
499 500 501 502 503 504 505 506 507 508 509 510 511
        """
        Test that program is runnable when add optimizer
        """
        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            img = fluid.data(name='image', shape=[-1, 784], dtype='float32')
            label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
            i = fluid.data(name="i", shape=[1], dtype='int32')
            loss = cond_func(i, img, label)
            optimizer = fluid.optimizer.SGD(learning_rate=0.1)
            optimizer.minimize(loss)

512
        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
513 514
        exe = fluid.Executor(place)
        exe.run(startup_program)
515 516
        if use_parallel_exe:
            os.environ['CPU_NUM'] = str(2)
517 518 519 520 521
            exe = fluid.ParallelExecutor(
                use_cuda=use_cuda,
                main_program=main_program,
                loss_name=loss.name,
            )
522
            num_devices = exe.device_count
523 524 525

        for feed_i in range(0, 10):
            feed_img = np.random.random(size=[16, 784]).astype(np.float32)
526 527 528
            feed_label = np.random.randint(
                low=0, high=10, size=[16, 1], dtype=np.int64
            )
529
            if use_parallel_exe:
530 531 532 533 534 535 536 537
                exe.run(
                    feed={
                        'i': np.full((num_devices), feed_i, np.int32),
                        'image': np.repeat(feed_img, num_devices, axis=0),
                        'label': np.repeat(feed_label, num_devices, axis=0),
                    },
                    fetch_list=[loss.name],
                )
538
            else:
539 540 541 542 543 544 545 546 547
                exe.run(
                    main_program,
                    feed={
                        'i': np.full((1), feed_i, np.int32),
                        'image': feed_img,
                        'label': feed_label,
                    },
                    fetch_list=[loss],
                )
548 549

    def test_cond_backward(self):
550

551 552
        paddle.enable_static()

553
        def cond_func(i, img, label):
554
            predicate = (i % 2) == 0
555
            return paddle.static.nn.cond(
556 557
                predicate,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
558 559
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
560

561
        for use_parallel_exe in [False, True]:
562 563 564 565 566 567
            if use_parallel_exe and os.name == "nt":
                print(
                    "Skip use_parallel_exe=True in Windows because of flaky test when using PE under old Windows machine"
                )
                continue

568 569 570 571 572 573
            self.backward_value_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
            self.add_optimizer_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
574 575

    def test_half_nested_cond_backward(self):
576
        paddle.enable_static()
577

578
        def branch(i, img, label):
579
            return paddle.static.nn.cond(
580 581
                (i % 2) == 0,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
582 583
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
584 585

        def cond_func_simple_net_at_true(i, img, label):
586
            return paddle.static.nn.cond(
587 588
                i < 5, lambda: branch(i, img, label), lambda: paddle.mean(img)
            )
589 590

        def cond_func_simple_net_at_false(i, img, label):
591
            return paddle.static.nn.cond(
592 593
                i < 5, lambda: paddle.mean(img), lambda: branch(i, img, label)
            )
594

595
        for use_parallel_exe in [False, True]:
596 597 598 599 600 601
            if use_parallel_exe and os.name == "nt":
                print(
                    "Skip use_parallel_exe=True in Windows because of flaky test when using PE under old Windows machine"
                )
                continue

602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
            self.backward_value_helper(
                cond_func_simple_net_at_true,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
            self.add_optimizer_helper(
                cond_func_simple_net_at_true,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
            self.backward_value_helper(
                cond_func_simple_net_at_false,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
            self.add_optimizer_helper(
                cond_func_simple_net_at_false,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
622 623

    def test_nested_cond_backward(self):
624
        paddle.enable_static()
625

626 627
        def branch(i, img, label, mod_two):
            if mod_two:
628
                predicate = (i % 2) == 0
629
            else:
630
                predicate = (i % 2) != 0
631
            return paddle.static.nn.cond(
632 633
                predicate,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
634 635
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
636 637

        def cond_func(i, img, label):
638
            return paddle.static.nn.cond(
639 640 641 642
                i < 5,
                lambda: branch(i, img, label, True),
                lambda: branch(i, img, label, False),
            )
643

644
        for use_parallel_exe in [False, True]:
645 646 647 648 649
            if use_parallel_exe and os.name == "nt":
                print(
                    "Skip use_parallel_exe=True in Windows because of flaky test when using PE under old Windows machine"
                )
                continue
650 651 652 653 654 655
            self.backward_value_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
            self.add_optimizer_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
656 657


658 659
class TestCondWithError(unittest.TestCase):
    def test_input_type_error(self):
660
        paddle.enable_static()
661 662 663 664 665 666 667 668 669
        main_program = framework.Program()
        startup_program = framework.Program()
        with framework.program_guard(main_program, startup_program):
            pred = fluid.data(name='y', shape=[1], dtype='bool')

            def func():
                return pred

            with self.assertRaises(TypeError):
670
                paddle.static.nn.cond(None, func, func)
671 672

            with self.assertRaises(TypeError):
673
                paddle.static.nn.cond(pred, func, set())
674 675

            with self.assertRaises(TypeError):
676
                paddle.static.nn.cond(pred, set(), func)
677 678

            with self.assertRaises(TypeError):
679
                paddle.static.nn.cond(pred, func, func, set())
680 681


682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720
class TestCondWithDict(unittest.TestCase):
    def test_input_with_dict(self):
        paddle.enable_static()
        main_program = framework.Program()
        startup_program = framework.Program()
        with framework.program_guard(main_program, startup_program):

            def true_func():
                return {
                    '1': paddle.full(shape=[3, 2], dtype='int32', fill_value=1),
                    '2': paddle.full(
                        shape=[2, 3], dtype='bool', fill_value=True
                    ),
                }

            def false_func():
                return {
                    '1': paddle.full(
                        shape=[3, 4], dtype='float32', fill_value=3
                    ),
                    '2': paddle.full(shape=[4, 5], dtype='int64', fill_value=2),
                }

            x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
            y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
            pred = paddle.less_than(x=x, y=y, name=None)
            ret = paddle.static.nn.cond(pred, true_func, false_func)
            self.assertEqual(
                ret['1'].shape,
                (3, -1),
                f"The shape is not correct, expects (3, -1) but gets {ret['1'].shape}.",
            )
            self.assertEqual(
                ret['2'].shape,
                (-1, -1),
                f"The shape is not correct, expects (-1, -1) but gets {ret['2'].shape}.",
            )


721 722
if __name__ == '__main__':
    unittest.main()