test_cond.py 27.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16 17 18 19

import numpy as np
from simple_nets import batchnorm_fc_with_inputs, simple_fc_net_with_inputs

20
import paddle
21 22 23
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.framework as framework
24
import paddle.fluid.layers as layers
25
from paddle.fluid.backward import append_backward
26
from paddle.fluid.framework import Program, program_guard
27 28

np.random.seed(123)
29 30


31
class TestCondInputOutput(unittest.TestCase):
32 33 34 35 36 37 38 39 40 41
    def test_return_single_var(self):
        """
        pseudocode:

        if 0.23 < 0.1:
            return 2
        else:
            return -1
        """

42 43
        paddle.enable_static()

44 45 46 47 48 49 50 51 52 53 54
        def true_func():
            return layers.fill_constant(shape=[2, 3], dtype='int32', value=2)

        def false_func():
            return layers.fill_constant(shape=[3, 2], dtype='int32', value=-1)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
            y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
L
LiYuRio 已提交
55
            pred = paddle.less_than(y, x)
56
            out = paddle.static.nn.cond(pred, true_func, false_func)
57 58
            # out is one tensor

59 60 61 62 63
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
64
        exe = fluid.Executor(place)
65 66 67 68
        (ret,) = exe.run(main_program, fetch_list=[out.name])
        np.testing.assert_allclose(
            np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05
        )
69

70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    def test_return_0d_tensor(self):
        """
        pseudocode:

        if 0.23 >= 0.1:
            return 2
        else:
            return -1
        """

        paddle.enable_static()

        def true_func():
            return paddle.full(shape=[], dtype='int32', fill_value=2)

        def false_func():
            return paddle.full(shape=[], dtype='int32', fill_value=-1)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
            y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
            pred = paddle.greater_equal(y, x)
            out = paddle.static.nn.cond(pred, true_func, false_func)
            # out is one tensor

        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
        exe = fluid.Executor(place)
        (ret,) = exe.run(main_program, fetch_list=[out.name])
        np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05)
105
        self.assertEqual(ret.shape, ())
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131

    def test_0d_tensor_as_cond(self):
        """
        pseudocode:

        if 0.23 >= 0.1:
            return 2
        else:
            return -1
        """

        paddle.enable_static()

        def true_func():
            return paddle.full(shape=[3, 3], dtype='int32', fill_value=2)

        def false_func():
            return paddle.full(shape=[3, 3], dtype='int32', fill_value=-1)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            x = paddle.full(shape=[], dtype='float32', fill_value=0.1)
            y = paddle.full(shape=[], dtype='float32', fill_value=0.23)
            pred = paddle.greater_equal(y, x)
            out = paddle.static.nn.cond(pred, true_func, false_func)
132
            # out is a tensor
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170

        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
        exe = fluid.Executor(place)
        (ret,) = exe.run(main_program, fetch_list=[out.name])
        np.testing.assert_allclose(
            np.asarray(ret), np.full((3, 3), 2, np.int32), rtol=1e-05
        )

    def test_0d_tensor_backward(self):
        """
        pseudocode:

        a = -2.0
        if a >= 0:
            return a
        else:
            return -a
        """

        paddle.enable_static()

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            a = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
            a.stop_gradient = False
            out = paddle.static.nn.cond(a >= 0, lambda: a, lambda: -a)
            append_backward(out)

        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
171

172 173 174 175 176
        exe = fluid.Executor(place)
        ret = exe.run(main_program, fetch_list=[out.name, a.grad_name])
        np.testing.assert_allclose(
            np.asarray(ret[0]), np.array(2.0), rtol=1e-05
        )
177
        self.assertEqual(ret[0].shape, ())
178 179 180
        np.testing.assert_allclose(
            np.asarray(ret[1]), np.array(-1.0), rtol=1e-05
        )
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
        self.assertEqual(ret[1].shape, ())

    def test_0d_tensor_dygraph(self):
        """
        pseudocode:

        a = -2.0
        if a >= 0:
            return a
        else:
            return -a
        """
        paddle.disable_static()
        a = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
        a.stop_gradient = False
        out = paddle.static.nn.cond(a >= 0, lambda: a, lambda: -a)
        out.backward()

        np.testing.assert_allclose(np.asarray(out), np.array(2.0), rtol=1e-05)
        self.assertEqual(out.shape, [])

        np.testing.assert_allclose(
            np.asarray(a.grad), np.array(-1.0), rtol=1e-05
        )
        self.assertEqual(a.grad.shape, [])
206

207 208 209 210 211 212 213 214 215 216
    def test_return_var_tuple(self):
        """
        pseudocode:

        if True:
            return 1, True
        else:
            return 3, 2
        """

217 218
        paddle.enable_static()

219
        def true_func():
220 221 222
            return layers.fill_constant(
                shape=[1, 2], dtype='int32', value=1
            ), layers.fill_constant(shape=[2, 3], dtype='bool', value=True)
223 224

        def false_func():
225 226 227
            return layers.fill_constant(
                shape=[3, 4], dtype='float32', value=3
            ), layers.fill_constant(shape=[4, 5], dtype='int64', value=2)
228 229 230 231 232

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            pred = layers.fill_constant(shape=[1], dtype='bool', value=True)
233
            out = paddle.static.nn.cond(pred, true_func, false_func)
234 235
            # out is a tuple containing 2 tensors

236 237 238 239 240
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
241 242
        exe = fluid.Executor(place)
        ret = exe.run(main_program, fetch_list=out)
243 244 245 246 247 248
        np.testing.assert_allclose(
            np.asarray(ret[0]), np.full((1, 2), 1, np.int32), rtol=1e-05
        )
        np.testing.assert_allclose(
            np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05
        )
249 250 251 252 253 254 255 256 257 258 259 260

    def test_pass_and_modify_var(self):
        """
        pseudocode:
        for i in range(5):
            a = 7
            if i % 2 == 0:
                a = a * (i + 1)
            else:
                a = a - (i - 1)
        """

261 262
        paddle.enable_static()

263 264 265 266 267 268 269 270 271 272 273 274 275
        def true_func(a, i):
            a = a * (i + 1)
            return a

        def false_func(a, i):
            a = a - (i - 1)
            return a

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            a = layers.fill_constant(shape=[3, 2, 1], dtype='int32', value=7)
            i = fluid.data(name="i", shape=[1], dtype='int32')
276
            pred = (i % 2) == 0
277
            a = paddle.static.nn.cond(
278 279 280 281 282 283 284
                pred, lambda: true_func(a, i), lambda: false_func(a, i)
            )
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
285 286 287
        exe = fluid.Executor(place)
        for feed_i in range(5):
            expected_a = 7 * (feed_i + 1) if feed_i % 2 == 0 else 8 - feed_i
288 289 290 291 292 293 294 295 296 297
            (ret,) = exe.run(
                main_program,
                feed={'i': np.full((1), feed_i, np.int32)},
                fetch_list=[a],
            )
            np.testing.assert_allclose(
                np.asarray(ret),
                np.full((3, 2, 1), expected_a, np.int32),
                rtol=1e-05,
            )
298 299 300 301 302 303 304 305 306 307 308

    def test_return_none(self):
        """
        pseudocode: test doing nothing in branches
        for i in range(5):
            if i % 2 == 0:
                pass
            else:
                pass
        """

309 310
        paddle.enable_static()

311 312 313 314 315 316 317 318 319 320
        def true_func():
            pass

        def false_func():
            return None

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='int32')
321
            pred = (i % 2) == 0
322 323 324
            out1 = paddle.static.nn.cond(pred, true_func, false_func)
            out2 = paddle.static.nn.cond(pred, None, false_func)
            out3 = paddle.static.nn.cond(pred, true_func, None)
325 326 327 328 329
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
330 331 332 333 334 335 336 337 338 339 340 341 342
        exe = fluid.Executor(place)
        for feed_i in range(5):
            # Test that output is None is runnable
            exe.run(main_program, feed={'i': np.full((1), feed_i, np.int32)})
            self.assertIsNone(out1)
            self.assertIsNone(out2)
            self.assertIsNone(out3)

    def test_wrong_structure_exception(self):
        """
        test returning different number of tensors cannot merge into output
        """

343 344
        paddle.enable_static()

345 346 347 348 349 350 351
        def func_return_none():
            return None

        def func_return_one_tensor():
            return layers.fill_constant(shape=[2, 7], dtype='int32', value=3)

        def func_return_two_tensors():
352 353 354
            return layers.fill_constant(
                shape=[3, 1], dtype='int32', value=7
            ), layers.fill_constant(shape=[3, 1], dtype='int32', value=8)
355 356 357 358 359

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='int32')
360
            pred = (i % 2) == 0
361
            with self.assertRaises(TypeError):
362
                out = paddle.static.nn.cond(pred, i, func_return_one_tensor)
363

364
            with self.assertRaises(TypeError):
365 366 367
                out = paddle.static.nn.cond(
                    pred, func_return_one_tensor, np.asarray([3])
                )
368 369

            with self.assertRaises(Exception) as e:
370
                out = paddle.static.nn.cond(
371 372
                    pred, func_return_none, func_return_one_tensor
                )
373
            self.assertTrue(
374 375 376
                "Incompatible return values of true_fn and false_fn in cond"
                in str(e.exception)
            )
377 378

            with self.assertRaises(Exception) as e:
379
                out = paddle.static.nn.cond(
380 381
                    pred, func_return_two_tensors, func_return_none
                )
382
            self.assertTrue(
383 384 385
                "Incompatible return values of true_fn and false_fn in cond"
                in str(e.exception)
            )
386 387

            with self.assertRaises(Exception) as e:
388
                out = paddle.static.nn.cond(
389 390
                    pred, func_return_one_tensor, func_return_two_tensors
                )
391
            self.assertTrue(
392
                "true fn returns 1 vars, but false fn returns 2 vars, which is not equals"
393 394
                in str(e.exception)
            )
395

396
    def test_extremely_simple_net_with_op_in_condition(self):
397
        paddle.enable_static()
398 399 400
        main_program = fluid.Program()
        startup_program = fluid.Program()
        with fluid.program_guard(main_program, startup_program):
401 402 403
            a = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.23
            )
404
            a.stop_gradient = False
405 406 407
            b = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.25
            )
408
            b.stop_gradient = False
409
            out = paddle.static.nn.cond(a - b < -1.0, lambda: a, lambda: b)
410 411
        append_backward(out)

412 413 414 415 416
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
417
        exe = fluid.Executor(place)
418 419 420
        ret = exe.run(
            main_program, fetch_list=[out, b, a.grad_name, b.grad_name]
        )
421 422
        # Note: fill_constant has loss of precision, you have to assertEqual
        # with values doens't lose precision in float-point number.
423 424 425
        self.assertEqual(ret[0][0], ret[1][0])
        self.assertEqual(ret[2][0], 0.0)
        self.assertEqual(ret[3][0], 1.0)
426

427

428 429 430 431 432 433 434 435
class TestCondNestedControlFlow(unittest.TestCase):
    def test_cond_inside_cond(self):
        """
        pseudocode:
        for i in range(1, 10):
            a = 2 * i
            if i < 5:
                if i >= 3:
436
                    return a + a
437 438 439 440 441 442 443 444 445
                else:
                    return a - a
            else:
                if i < 8:
                    return a * a
                else:
                    return a / a
        """

446 447
        paddle.enable_static()

448
        def less_than_branch(i, a):
449
            return paddle.static.nn.cond(
450
                i >= 3.0,
451 452
                lambda: paddle.add(a, a),
                lambda: paddle.subtract(a, a),
453
            )
454 455

        def greater_equal_branch(i, a):
456
            return paddle.static.nn.cond(
457
                i < 8.0,
458 459
                lambda: paddle.multiply(a, a),
                lambda: paddle.divide(a, a),
460
            )
461 462 463 464 465

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='float32')
466
            i.stop_gradient = False
467
            a = 2.0 * i
468
            out = paddle.static.nn.cond(
469 470 471 472
                i < 5.0,
                lambda: less_than_branch(i, a),
                lambda: greater_equal_branch(i, a),
            )
473
            mean = paddle.mean(out)
474 475
            append_backward(mean)

476 477 478 479 480
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
481 482 483 484 485 486 487 488 489
        exe = fluid.Executor(place)
        for feed_i in range(0, 10):
            expected_a = 2.0 * feed_i
            if feed_i < 5:
                expected_ret = expected_a + expected_a if feed_i >= 3 else 0.0
                expected_a_grad = 2.0 if feed_i >= 3 else 0.0
            else:
                expected_ret = expected_a * expected_a if feed_i < 8 else 1.0
                expected_a_grad = 2.0 * expected_a if feed_i < 8 else 0.0
490 491 492 493 494
            ret = exe.run(
                main_program,
                feed={'i': np.full((1), feed_i, np.float32)},
                fetch_list=[out.name, a.grad_name],
            )
495 496 497
            self.assertEqual(ret[0][0], expected_ret)
            self.assertEqual(ret[1][0], expected_a_grad)

498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
    def test_cond_inside_cond_0d_tensor(self):
        """
        pseudocode:
            i = 3.0
            a = 2 * i
            if i < 5:
                if i >= 3:
                    return a + 1
                else:
                    return 1 - a
            else:
                if i < 8:
                    return a * 2
                else:
                    return a / 2
        """

        paddle.enable_static()

        def less_than_branch(i, a):
            return paddle.static.nn.cond(
                i >= 3.0,
                lambda: a + 1,
                lambda: 1 - a,
            )

        def greater_equal_branch(i, a):
            return paddle.static.nn.cond(
                i < 8.0,
                lambda: a * 2,
                lambda: a / 2,
            )

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = paddle.full(fill_value=3.0, shape=[], dtype='float32')
            i.stop_gradient = False
            a = 2.0 * i
            out = paddle.static.nn.cond(
                i < 5.0,
                lambda: less_than_branch(i, a),
                lambda: greater_equal_branch(i, a),
            )
            mean = paddle.mean(out)
            append_backward(out)

        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
        exe = fluid.Executor(place)
        ret = exe.run(
            main_program,
            fetch_list=[out.name, i.grad_name],
        )
        np.testing.assert_allclose(
            np.asarray(ret[0]), np.array(7.0), rtol=1e-05
        )
558
        self.assertEqual(ret[0].shape, ())
559 560 561
        np.testing.assert_allclose(
            np.asarray(ret[1]), np.array(2.0), rtol=1e-05
        )
562
        self.assertEqual(ret[1].shape, ())
563

564
    def test_cond_op_in_condition(self):
565
        paddle.enable_static()
566 567 568 569
        main_program = fluid.Program()
        startup_program = fluid.Program()

        with fluid.program_guard(main_program, startup_program):
570 571 572
            a = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.23
            )
573
            a.stop_gradient = False
574 575 576
            b = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.24
            )
577
            b.stop_gradient = False
578
            out = paddle.static.nn.cond(
579
                a < b,
580
                lambda: paddle.static.nn.cond(
581
                    a - b < -1.0,
582 583
                    lambda: paddle.add(a, b),
                    lambda: paddle.multiply(a, b),
584
                ),
585
                lambda: paddle.static.nn.cond(
586
                    a == b,
587
                    lambda: paddle.subtract(a, b),
588
                    lambda: paddle.pow(a, b),
589 590
                ),
            )
591 592
            append_backward(out)

593 594 595 596 597
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
598 599
        exe = fluid.Executor(place)
        ret = exe.run(main_program, fetch_list=[out, a.grad_name, b.grad_name])
600
        # Note: fill_constant has loss of precision, so we assertAlmostEqual.
601 602 603 604
        self.assertAlmostEqual(ret[0][0], 1.5252)
        self.assertAlmostEqual(ret[1][0], 1.24)
        self.assertAlmostEqual(ret[2][0], 1.23)

605

606
class TestCondBackward(unittest.TestCase):
607
    def backward_value_helper(self, cond_func, use_cuda):
608 609 610
        """
        Helper function that compares calculated backward value is close to dy/dx
        """
611
        paddle.enable_static()
612 613 614 615 616 617 618 619 620 621 622
        main_program = Program()
        main_program.random_seed = 123
        startup_program = Program()
        startup_program.random_seed = 123
        with program_guard(main_program, startup_program):
            img = fluid.data(name='image', shape=[-1, 9], dtype='float32')
            img.stop_gradient = False
            label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
            i = fluid.data(name="i", shape=[1], dtype='int32')
            loss = cond_func(i, img, label)
            append_backward(loss)
623
        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
624 625 626
        exe = fluid.Executor(place)
        exe.run(startup_program)

627 628
        num_devices = 1

629 630 631
        delta = 0.005
        for feed_i in range(0, 10):
            feed_img = np.random.random(size=[1, 9]).astype(np.float32)
632 633 634
            feed_label = np.random.randint(
                low=0, high=10, size=[1, 1], dtype=np.int64
            )
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650

            img_grad, loss_value = exe.run(
                main_program,
                feed={
                    'i': np.full((1), feed_i, np.int32),
                    'image': feed_img,
                    'label': feed_label,
                },
                fetch_list=[img.grad_name, loss.name],
            )

            numerical_grad = np.zeros(shape=[num_devices, 9], dtype=np.float32)
            feed_img_delta = np.copy(feed_img)
            for j in range(9):
                feed_img_delta[0][j] = feed_img[0][j] + delta
                loss_delta = exe.run(
651 652 653
                    main_program,
                    feed={
                        'i': np.full((1), feed_i, np.int32),
654
                        'image': feed_img_delta,
655
                        'label': feed_label,
656
                    },
657
                    fetch_list=[loss.name],
658
                )
659
                numerical_grad[0][j] = (loss_delta[0] - loss_value[0]) / delta
660
                feed_img_delta[0][j] = feed_img[0][j]
661 662 663
            np.testing.assert_allclose(
                img_grad, numerical_grad, rtol=0.05, atol=0.05
            )
664

665
    def add_optimizer_helper(self, cond_func, use_cuda):
666 667 668 669 670 671 672 673 674 675 676 677 678
        """
        Test that program is runnable when add optimizer
        """
        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            img = fluid.data(name='image', shape=[-1, 784], dtype='float32')
            label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
            i = fluid.data(name="i", shape=[1], dtype='int32')
            loss = cond_func(i, img, label)
            optimizer = fluid.optimizer.SGD(learning_rate=0.1)
            optimizer.minimize(loss)

679
        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
680 681 682 683 684
        exe = fluid.Executor(place)
        exe.run(startup_program)

        for feed_i in range(0, 10):
            feed_img = np.random.random(size=[16, 784]).astype(np.float32)
685 686 687
            feed_label = np.random.randint(
                low=0, high=10, size=[16, 1], dtype=np.int64
            )
688 689 690 691 692 693 694 695 696
            exe.run(
                main_program,
                feed={
                    'i': np.full((1), feed_i, np.int32),
                    'image': feed_img,
                    'label': feed_label,
                },
                fetch_list=[loss],
            )
697 698

    def test_cond_backward(self):
699

700 701
        paddle.enable_static()

702
        def cond_func(i, img, label):
703
            predicate = (i % 2) == 0
704
            return paddle.static.nn.cond(
705 706
                predicate,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
707 708
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
709

710 711
        self.backward_value_helper(cond_func, core.is_compiled_with_cuda())
        self.add_optimizer_helper(cond_func, core.is_compiled_with_cuda())
712 713

    def test_half_nested_cond_backward(self):
714
        paddle.enable_static()
715

716
        def branch(i, img, label):
717
            return paddle.static.nn.cond(
718 719
                (i % 2) == 0,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
720 721
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
722 723

        def cond_func_simple_net_at_true(i, img, label):
724
            return paddle.static.nn.cond(
725 726
                i < 5, lambda: branch(i, img, label), lambda: paddle.mean(img)
            )
727 728

        def cond_func_simple_net_at_false(i, img, label):
729
            return paddle.static.nn.cond(
730 731
                i < 5, lambda: paddle.mean(img), lambda: branch(i, img, label)
            )
732

733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748
        self.backward_value_helper(
            cond_func_simple_net_at_true,
            core.is_compiled_with_cuda(),
        )
        self.add_optimizer_helper(
            cond_func_simple_net_at_true,
            core.is_compiled_with_cuda(),
        )
        self.backward_value_helper(
            cond_func_simple_net_at_false,
            core.is_compiled_with_cuda(),
        )
        self.add_optimizer_helper(
            cond_func_simple_net_at_false,
            core.is_compiled_with_cuda(),
        )
749 750

    def test_nested_cond_backward(self):
751
        paddle.enable_static()
752

753 754
        def branch(i, img, label, mod_two):
            if mod_two:
755
                predicate = (i % 2) == 0
756
            else:
757
                predicate = (i % 2) != 0
758
            return paddle.static.nn.cond(
759 760
                predicate,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
761 762
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
763 764

        def cond_func(i, img, label):
765
            return paddle.static.nn.cond(
766 767 768 769
                i < 5,
                lambda: branch(i, img, label, True),
                lambda: branch(i, img, label, False),
            )
770

771 772
        self.backward_value_helper(cond_func, core.is_compiled_with_cuda())
        self.add_optimizer_helper(cond_func, core.is_compiled_with_cuda())
773 774


775 776
class TestCondWithError(unittest.TestCase):
    def test_input_type_error(self):
777
        paddle.enable_static()
778 779 780 781 782 783 784 785 786
        main_program = framework.Program()
        startup_program = framework.Program()
        with framework.program_guard(main_program, startup_program):
            pred = fluid.data(name='y', shape=[1], dtype='bool')

            def func():
                return pred

            with self.assertRaises(TypeError):
787
                paddle.static.nn.cond(None, func, func)
788 789

            with self.assertRaises(TypeError):
790
                paddle.static.nn.cond(pred, func, set())
791 792

            with self.assertRaises(TypeError):
793
                paddle.static.nn.cond(pred, set(), func)
794 795

            with self.assertRaises(TypeError):
796
                paddle.static.nn.cond(pred, func, func, set())
797 798


799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837
class TestCondWithDict(unittest.TestCase):
    def test_input_with_dict(self):
        paddle.enable_static()
        main_program = framework.Program()
        startup_program = framework.Program()
        with framework.program_guard(main_program, startup_program):

            def true_func():
                return {
                    '1': paddle.full(shape=[3, 2], dtype='int32', fill_value=1),
                    '2': paddle.full(
                        shape=[2, 3], dtype='bool', fill_value=True
                    ),
                }

            def false_func():
                return {
                    '1': paddle.full(
                        shape=[3, 4], dtype='float32', fill_value=3
                    ),
                    '2': paddle.full(shape=[4, 5], dtype='int64', fill_value=2),
                }

            x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
            y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
            pred = paddle.less_than(x=x, y=y, name=None)
            ret = paddle.static.nn.cond(pred, true_func, false_func)
            self.assertEqual(
                ret['1'].shape,
                (3, -1),
                f"The shape is not correct, expects (3, -1) but gets {ret['1'].shape}.",
            )
            self.assertEqual(
                ret['2'].shape,
                (-1, -1),
                f"The shape is not correct, expects (-1, -1) but gets {ret['2'].shape}.",
            )


838 839
if __name__ == '__main__':
    unittest.main()