test_cond.py 24.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
16
import os
17
import unittest
18
import paddle
19 20 21 22
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
import paddle.fluid.framework as framework
23
from paddle.fluid.backward import append_backward
24
from paddle.fluid.framework import Program, program_guard
25
from simple_nets import simple_fc_net_with_inputs, batchnorm_fc_with_inputs
26
import paddle
27 28

np.random.seed(123)
29 30


31
class TestCondInputOutput(unittest.TestCase):
32 33 34 35 36 37 38 39 40 41
    def test_return_single_var(self):
        """
        pseudocode:

        if 0.23 < 0.1:
            return 2
        else:
            return -1
        """

42 43
        paddle.enable_static()

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
        def true_func():
            return layers.fill_constant(shape=[2, 3], dtype='int32', value=2)

        def false_func():
            return layers.fill_constant(shape=[3, 2], dtype='int32', value=-1)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
            y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
            pred = layers.less_than(y, x)
            out = layers.cond(pred, true_func, false_func)
            # out is one tensor

59 60 61 62 63
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
64
        exe = fluid.Executor(place)
65 66 67 68
        (ret,) = exe.run(main_program, fetch_list=[out.name])
        np.testing.assert_allclose(
            np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05
        )
69 70 71 72 73 74 75 76 77 78 79

    def test_return_var_tuple(self):
        """
        pseudocode:

        if True:
            return 1, True
        else:
            return 3, 2
        """

80 81
        paddle.enable_static()

82
        def true_func():
83 84 85
            return layers.fill_constant(
                shape=[1, 2], dtype='int32', value=1
            ), layers.fill_constant(shape=[2, 3], dtype='bool', value=True)
86 87

        def false_func():
88 89 90
            return layers.fill_constant(
                shape=[3, 4], dtype='float32', value=3
            ), layers.fill_constant(shape=[4, 5], dtype='int64', value=2)
91 92 93 94 95 96 97 98

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            pred = layers.fill_constant(shape=[1], dtype='bool', value=True)
            out = layers.cond(pred, true_func, false_func)
            # out is a tuple containing 2 tensors

99 100 101 102 103
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
104 105
        exe = fluid.Executor(place)
        ret = exe.run(main_program, fetch_list=out)
106 107 108 109 110 111
        np.testing.assert_allclose(
            np.asarray(ret[0]), np.full((1, 2), 1, np.int32), rtol=1e-05
        )
        np.testing.assert_allclose(
            np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05
        )
112 113 114 115 116 117 118 119 120 121 122 123

    def test_pass_and_modify_var(self):
        """
        pseudocode:
        for i in range(5):
            a = 7
            if i % 2 == 0:
                a = a * (i + 1)
            else:
                a = a - (i - 1)
        """

124 125
        paddle.enable_static()

126 127 128 129 130 131 132 133 134 135 136 137 138
        def true_func(a, i):
            a = a * (i + 1)
            return a

        def false_func(a, i):
            a = a - (i - 1)
            return a

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            a = layers.fill_constant(shape=[3, 2, 1], dtype='int32', value=7)
            i = fluid.data(name="i", shape=[1], dtype='int32')
139 140 141 142 143 144 145 146 147
            pred = (i % 2) == 0
            a = layers.cond(
                pred, lambda: true_func(a, i), lambda: false_func(a, i)
            )
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
148 149 150
        exe = fluid.Executor(place)
        for feed_i in range(5):
            expected_a = 7 * (feed_i + 1) if feed_i % 2 == 0 else 8 - feed_i
151 152 153 154 155 156 157 158 159 160
            (ret,) = exe.run(
                main_program,
                feed={'i': np.full((1), feed_i, np.int32)},
                fetch_list=[a],
            )
            np.testing.assert_allclose(
                np.asarray(ret),
                np.full((3, 2, 1), expected_a, np.int32),
                rtol=1e-05,
            )
161 162 163 164 165 166 167 168 169 170 171

    def test_return_none(self):
        """
        pseudocode: test doing nothing in branches
        for i in range(5):
            if i % 2 == 0:
                pass
            else:
                pass
        """

172 173
        paddle.enable_static()

174 175 176 177 178 179 180 181 182 183
        def true_func():
            pass

        def false_func():
            return None

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='int32')
184
            pred = (i % 2) == 0
185 186 187
            out1 = layers.cond(pred, true_func, false_func)
            out2 = layers.cond(pred, None, false_func)
            out3 = layers.cond(pred, true_func, None)
188 189 190 191 192
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
193 194 195 196 197 198 199 200 201 202 203 204 205
        exe = fluid.Executor(place)
        for feed_i in range(5):
            # Test that output is None is runnable
            exe.run(main_program, feed={'i': np.full((1), feed_i, np.int32)})
            self.assertIsNone(out1)
            self.assertIsNone(out2)
            self.assertIsNone(out3)

    def test_wrong_structure_exception(self):
        """
        test returning different number of tensors cannot merge into output
        """

206 207
        paddle.enable_static()

208 209 210 211 212 213 214
        def func_return_none():
            return None

        def func_return_one_tensor():
            return layers.fill_constant(shape=[2, 7], dtype='int32', value=3)

        def func_return_two_tensors():
215 216 217
            return layers.fill_constant(
                shape=[3, 1], dtype='int32', value=7
            ), layers.fill_constant(shape=[3, 1], dtype='int32', value=8)
218 219 220 221 222

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='int32')
223
            pred = (i % 2) == 0
224
            with self.assertRaises(TypeError):
225 226
                out = layers.cond(pred, i, func_return_one_tensor)

227
            with self.assertRaises(TypeError):
228 229 230
                out = layers.cond(pred, func_return_one_tensor, np.asarray([3]))

            with self.assertRaises(Exception) as e:
231 232 233
                out = layers.cond(
                    pred, func_return_none, func_return_one_tensor
                )
234
            self.assertTrue(
235 236 237
                "Incompatible return values of true_fn and false_fn in cond"
                in str(e.exception)
            )
238 239

            with self.assertRaises(Exception) as e:
240 241 242
                out = layers.cond(
                    pred, func_return_two_tensors, func_return_none
                )
243
            self.assertTrue(
244 245 246
                "Incompatible return values of true_fn and false_fn in cond"
                in str(e.exception)
            )
247 248

            with self.assertRaises(Exception) as e:
249 250 251
                out = layers.cond(
                    pred, func_return_one_tensor, func_return_two_tensors
                )
252
            self.assertTrue(
253
                "true fn returns 1 vars, but false fn returns 2 vars, which is not equals"
254 255
                in str(e.exception)
            )
256

257
    def test_extremely_simple_net_with_op_in_condition(self):
258
        paddle.enable_static()
259 260 261
        main_program = fluid.Program()
        startup_program = fluid.Program()
        with fluid.program_guard(main_program, startup_program):
262 263 264
            a = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.23
            )
265
            a.stop_gradient = False
266 267 268
            b = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.25
            )
269 270 271 272
            b.stop_gradient = False
            out = layers.cond(a - b < -1.0, lambda: a, lambda: b)
        append_backward(out)

273 274 275 276 277
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
278
        exe = fluid.Executor(place)
279 280 281
        ret = exe.run(
            main_program, fetch_list=[out, b, a.grad_name, b.grad_name]
        )
282 283
        # Note: fill_constant has loss of precision, you have to assertEqual
        # with values doens't lose precision in float-point number.
284 285 286
        self.assertEqual(ret[0][0], ret[1][0])
        self.assertEqual(ret[2][0], 0.0)
        self.assertEqual(ret[3][0], 1.0)
287

288

289 290 291 292 293 294 295 296
class TestCondNestedControlFlow(unittest.TestCase):
    def test_cond_inside_cond(self):
        """
        pseudocode:
        for i in range(1, 10):
            a = 2 * i
            if i < 5:
                if i >= 3:
297
                    return a + a
298 299 300 301 302 303 304 305 306
                else:
                    return a - a
            else:
                if i < 8:
                    return a * a
                else:
                    return a / a
        """

307 308
        paddle.enable_static()

309
        def less_than_branch(i, a):
310 311 312 313 314
            return layers.cond(
                i >= 3.0,
                lambda: layers.elementwise_add(a, a),
                lambda: layers.elementwise_sub(a, a),
            )
315 316

        def greater_equal_branch(i, a):
317 318 319 320 321
            return layers.cond(
                i < 8.0,
                lambda: layers.elementwise_mul(a, a),
                lambda: layers.elementwise_div(a, a),
            )
322 323 324 325 326 327

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='float32')
            a = 2.0 * i
328 329 330 331 332
            out = layers.cond(
                i < 5.0,
                lambda: less_than_branch(i, a),
                lambda: greater_equal_branch(i, a),
            )
333
            mean = paddle.mean(out)
334 335
            append_backward(mean)

336 337 338 339 340
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
341 342 343 344 345 346 347 348 349
        exe = fluid.Executor(place)
        for feed_i in range(0, 10):
            expected_a = 2.0 * feed_i
            if feed_i < 5:
                expected_ret = expected_a + expected_a if feed_i >= 3 else 0.0
                expected_a_grad = 2.0 if feed_i >= 3 else 0.0
            else:
                expected_ret = expected_a * expected_a if feed_i < 8 else 1.0
                expected_a_grad = 2.0 * expected_a if feed_i < 8 else 0.0
350 351 352 353 354
            ret = exe.run(
                main_program,
                feed={'i': np.full((1), feed_i, np.float32)},
                fetch_list=[out.name, a.grad_name],
            )
355 356 357
            self.assertEqual(ret[0][0], expected_ret)
            self.assertEqual(ret[1][0], expected_a_grad)

358
    def test_cond_op_in_condition(self):
359
        paddle.enable_static()
360 361 362 363
        main_program = fluid.Program()
        startup_program = fluid.Program()

        with fluid.program_guard(main_program, startup_program):
364 365 366
            a = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.23
            )
367
            a.stop_gradient = False
368 369 370
            b = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.24
            )
371 372
            b.stop_gradient = False
            out = fluid.layers.cond(
373 374 375 376 377 378 379 380 381 382 383 384
                a < b,
                lambda: fluid.layers.cond(
                    a - b < -1.0,
                    lambda: fluid.layers.elementwise_add(a, b),
                    lambda: fluid.layers.elementwise_mul(a, b),
                ),
                lambda: fluid.layers.cond(
                    a == b,
                    lambda: fluid.layers.elementwise_sub(a, b),
                    lambda: fluid.layers.elementwise_pow(a, b),
                ),
            )
385 386
            append_backward(out)

387 388 389 390 391
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
392 393
        exe = fluid.Executor(place)
        ret = exe.run(main_program, fetch_list=[out, a.grad_name, b.grad_name])
394
        # Note: fill_constant has loss of precision, so we assertAlmostEqual.
395 396 397 398
        self.assertAlmostEqual(ret[0][0], 1.5252)
        self.assertAlmostEqual(ret[1][0], 1.24)
        self.assertAlmostEqual(ret[2][0], 1.23)

399

400
class TestCondBackward(unittest.TestCase):
401
    def backward_value_helper(self, cond_func, use_cuda, use_parallel_exe):
402 403 404
        """
        Helper function that compares calculated backward value is close to dy/dx
        """
405
        paddle.enable_static()
406 407 408 409 410 411 412 413 414 415 416
        main_program = Program()
        main_program.random_seed = 123
        startup_program = Program()
        startup_program.random_seed = 123
        with program_guard(main_program, startup_program):
            img = fluid.data(name='image', shape=[-1, 9], dtype='float32')
            img.stop_gradient = False
            label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
            i = fluid.data(name="i", shape=[1], dtype='int32')
            loss = cond_func(i, img, label)
            append_backward(loss)
417
        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
418 419 420
        exe = fluid.Executor(place)
        exe.run(startup_program)

421 422 423
        num_devices = 1
        if use_parallel_exe:
            os.environ['CPU_NUM'] = str(2)
424 425 426 427 428
            exe = fluid.ParallelExecutor(
                use_cuda=use_cuda,
                main_program=main_program,
                loss_name=loss.name,
            )
429 430
            num_devices = exe.device_count

431 432 433
        delta = 0.005
        for feed_i in range(0, 10):
            feed_img = np.random.random(size=[1, 9]).astype(np.float32)
434 435 436
            feed_label = np.random.randint(
                low=0, high=10, size=[1, 1], dtype=np.int64
            )
437 438 439 440
            if use_parallel_exe:
                img_grad, loss_value = exe.run(
                    feed={
                        'i': np.full((num_devices), feed_i, np.int32),
441
                        'image': np.repeat(feed_img, num_devices, axis=0),
442
                        'label': np.repeat(feed_label, num_devices, axis=0),
443
                    },
444 445
                    fetch_list=[img.grad_name, loss.name],
                )
446 447 448 449 450 451
            else:
                img_grad, loss_value = exe.run(
                    main_program,
                    feed={
                        'i': np.full((1), feed_i, np.int32),
                        'image': feed_img,
452
                        'label': feed_label,
453
                    },
454 455
                    fetch_list=[img.grad_name, loss.name],
                )
456

457
            numerical_grad = np.zeros(shape=[num_devices, 9], dtype=np.float32)
458 459 460
            feed_img_delta = np.copy(feed_img)
            for j in range(9):
                feed_img_delta[0][j] = feed_img[0][j] + delta
461
                if use_parallel_exe:
462 463 464 465 466 467 468 469 470 471 472 473 474
                    loss_delta = exe.run(
                        feed={
                            'i': np.full((num_devices), feed_i, np.int32),
                            'image': np.repeat(
                                feed_img_delta, num_devices, axis=0
                            ),
                            'label': np.repeat(feed_label, num_devices, axis=0),
                        },
                        fetch_list=[loss.name],
                    )
                    multi_device_grad = (
                        (loss_delta[0] - loss_value[0]) / delta / num_devices
                    )
475 476 477
                    for d in range(num_devices):
                        numerical_grad[d][j] = multi_device_grad[d]
                else:
478 479 480 481 482 483 484 485 486 487 488 489
                    loss_delta = exe.run(
                        main_program,
                        feed={
                            'i': np.full((1), feed_i, np.int32),
                            'image': feed_img_delta,
                            'label': feed_label,
                        },
                        fetch_list=[loss.name],
                    )
                    numerical_grad[0][j] = (
                        loss_delta[0] - loss_value[0]
                    ) / delta
490
                feed_img_delta[0][j] = feed_img[0][j]
491 492 493
            np.testing.assert_allclose(
                img_grad, numerical_grad, rtol=0.05, atol=0.05
            )
494

495
    def add_optimizer_helper(self, cond_func, use_cuda, use_parallel_exe):
496 497 498 499 500 501 502 503 504 505 506 507 508
        """
        Test that program is runnable when add optimizer
        """
        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            img = fluid.data(name='image', shape=[-1, 784], dtype='float32')
            label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
            i = fluid.data(name="i", shape=[1], dtype='int32')
            loss = cond_func(i, img, label)
            optimizer = fluid.optimizer.SGD(learning_rate=0.1)
            optimizer.minimize(loss)

509
        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
510 511
        exe = fluid.Executor(place)
        exe.run(startup_program)
512 513
        if use_parallel_exe:
            os.environ['CPU_NUM'] = str(2)
514 515 516 517 518
            exe = fluid.ParallelExecutor(
                use_cuda=use_cuda,
                main_program=main_program,
                loss_name=loss.name,
            )
519
            num_devices = exe.device_count
520 521 522

        for feed_i in range(0, 10):
            feed_img = np.random.random(size=[16, 784]).astype(np.float32)
523 524 525
            feed_label = np.random.randint(
                low=0, high=10, size=[16, 1], dtype=np.int64
            )
526
            if use_parallel_exe:
527 528 529 530 531 532 533 534
                exe.run(
                    feed={
                        'i': np.full((num_devices), feed_i, np.int32),
                        'image': np.repeat(feed_img, num_devices, axis=0),
                        'label': np.repeat(feed_label, num_devices, axis=0),
                    },
                    fetch_list=[loss.name],
                )
535
            else:
536 537 538 539 540 541 542 543 544
                exe.run(
                    main_program,
                    feed={
                        'i': np.full((1), feed_i, np.int32),
                        'image': feed_img,
                        'label': feed_label,
                    },
                    fetch_list=[loss],
                )
545 546

    def test_cond_backward(self):
547

548 549
        paddle.enable_static()

550
        def cond_func(i, img, label):
551
            predicate = (i % 2) == 0
552 553 554
            return layers.cond(
                predicate,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
555 556
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
557

558
        for use_parallel_exe in [False, True]:
559 560 561 562 563 564
            if use_parallel_exe and os.name == "nt":
                print(
                    "Skip use_parallel_exe=True in Windows because of flaky test when using PE under old Windows machine"
                )
                continue

565 566 567 568 569 570
            self.backward_value_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
            self.add_optimizer_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
571 572

    def test_half_nested_cond_backward(self):
573
        paddle.enable_static()
574

575
        def branch(i, img, label):
576 577 578
            return layers.cond(
                (i % 2) == 0,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
579 580
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
581 582

        def cond_func_simple_net_at_true(i, img, label):
583 584 585
            return layers.cond(
                i < 5, lambda: branch(i, img, label), lambda: paddle.mean(img)
            )
586 587

        def cond_func_simple_net_at_false(i, img, label):
588 589 590
            return layers.cond(
                i < 5, lambda: paddle.mean(img), lambda: branch(i, img, label)
            )
591

592
        for use_parallel_exe in [False, True]:
593 594 595 596 597 598
            if use_parallel_exe and os.name == "nt":
                print(
                    "Skip use_parallel_exe=True in Windows because of flaky test when using PE under old Windows machine"
                )
                continue

599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
            self.backward_value_helper(
                cond_func_simple_net_at_true,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
            self.add_optimizer_helper(
                cond_func_simple_net_at_true,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
            self.backward_value_helper(
                cond_func_simple_net_at_false,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
            self.add_optimizer_helper(
                cond_func_simple_net_at_false,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
619 620

    def test_nested_cond_backward(self):
621
        paddle.enable_static()
622

623 624
        def branch(i, img, label, mod_two):
            if mod_two:
625
                predicate = (i % 2) == 0
626
            else:
627
                predicate = (i % 2) != 0
628 629 630
            return layers.cond(
                predicate,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
631 632
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
633 634

        def cond_func(i, img, label):
635 636 637 638 639
            return layers.cond(
                i < 5,
                lambda: branch(i, img, label, True),
                lambda: branch(i, img, label, False),
            )
640

641
        for use_parallel_exe in [False, True]:
642 643 644 645 646
            if use_parallel_exe and os.name == "nt":
                print(
                    "Skip use_parallel_exe=True in Windows because of flaky test when using PE under old Windows machine"
                )
                continue
647 648 649 650 651 652
            self.backward_value_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
            self.add_optimizer_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
653 654


655 656
class TestCondWithError(unittest.TestCase):
    def test_input_type_error(self):
657
        paddle.enable_static()
658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678
        main_program = framework.Program()
        startup_program = framework.Program()
        with framework.program_guard(main_program, startup_program):
            pred = fluid.data(name='y', shape=[1], dtype='bool')

            def func():
                return pred

            with self.assertRaises(TypeError):
                layers.cond(None, func, func)

            with self.assertRaises(TypeError):
                layers.cond(pred, func, set())

            with self.assertRaises(TypeError):
                layers.cond(pred, set(), func)

            with self.assertRaises(TypeError):
                layers.cond(pred, func, func, set())


679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
class TestCondWithDict(unittest.TestCase):
    def test_input_with_dict(self):
        paddle.enable_static()
        main_program = framework.Program()
        startup_program = framework.Program()
        with framework.program_guard(main_program, startup_program):

            def true_func():
                return {
                    '1': paddle.full(shape=[3, 2], dtype='int32', fill_value=1),
                    '2': paddle.full(
                        shape=[2, 3], dtype='bool', fill_value=True
                    ),
                }

            def false_func():
                return {
                    '1': paddle.full(
                        shape=[3, 4], dtype='float32', fill_value=3
                    ),
                    '2': paddle.full(shape=[4, 5], dtype='int64', fill_value=2),
                }

            x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
            y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
            pred = paddle.less_than(x=x, y=y, name=None)
            ret = paddle.static.nn.cond(pred, true_func, false_func)
            self.assertEqual(
                ret['1'].shape,
                (3, -1),
                f"The shape is not correct, expects (3, -1) but gets {ret['1'].shape}.",
            )
            self.assertEqual(
                ret['2'].shape,
                (-1, -1),
                f"The shape is not correct, expects (-1, -1) but gets {ret['2'].shape}.",
            )


718 719
if __name__ == '__main__':
    unittest.main()