test_cond.py 30.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16
import unittest
17 18 19 20

import numpy as np
from simple_nets import batchnorm_fc_with_inputs, simple_fc_net_with_inputs

21
import paddle
22 23 24
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.framework as framework
25
import paddle.fluid.layers as layers
26
from paddle.fluid.backward import append_backward
27
from paddle.fluid.framework import Program, program_guard
28 29

np.random.seed(123)
30 31


32
class TestCondInputOutput(unittest.TestCase):
33 34 35 36 37 38 39 40 41 42
    def test_return_single_var(self):
        """
        pseudocode:

        if 0.23 < 0.1:
            return 2
        else:
            return -1
        """

43 44
        paddle.enable_static()

45 46 47 48 49 50 51 52 53 54 55
        def true_func():
            return layers.fill_constant(shape=[2, 3], dtype='int32', value=2)

        def false_func():
            return layers.fill_constant(shape=[3, 2], dtype='int32', value=-1)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
            y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
L
LiYuRio 已提交
56
            pred = paddle.less_than(y, x)
57
            out = paddle.static.nn.cond(pred, true_func, false_func)
58 59
            # out is one tensor

60 61 62 63 64
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
65
        exe = fluid.Executor(place)
66 67 68 69
        (ret,) = exe.run(main_program, fetch_list=[out.name])
        np.testing.assert_allclose(
            np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05
        )
70

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    def test_return_0d_tensor(self):
        """
        pseudocode:

        if 0.23 >= 0.1:
            return 2
        else:
            return -1
        """

        paddle.enable_static()

        def true_func():
            return paddle.full(shape=[], dtype='int32', fill_value=2)

        def false_func():
            return paddle.full(shape=[], dtype='int32', fill_value=-1)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
            y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
            pred = paddle.greater_equal(y, x)
            out = paddle.static.nn.cond(pred, true_func, false_func)
            # out is one tensor

        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
        exe = fluid.Executor(place)
        (ret,) = exe.run(main_program, fetch_list=[out.name])
        np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05)

    def test_0d_tensor_as_cond(self):
        """
        pseudocode:

        if 0.23 >= 0.1:
            return 2
        else:
            return -1
        """

        paddle.enable_static()

        def true_func():
            return paddle.full(shape=[3, 3], dtype='int32', fill_value=2)

        def false_func():
            return paddle.full(shape=[3, 3], dtype='int32', fill_value=-1)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            x = paddle.full(shape=[], dtype='float32', fill_value=0.1)
            y = paddle.full(shape=[], dtype='float32', fill_value=0.23)
            pred = paddle.greater_equal(y, x)
            out = paddle.static.nn.cond(pred, true_func, false_func)
            # out is one tensor

        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
        exe = fluid.Executor(place)
        (ret,) = exe.run(main_program, fetch_list=[out.name])
        np.testing.assert_allclose(
            np.asarray(ret), np.full((3, 3), 2, np.int32), rtol=1e-05
        )

    def test_0d_tensor_backward(self):
        """
        pseudocode:

        a = -2.0
        if a >= 0:
            return a
        else:
            return -a
        """

        paddle.enable_static()

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            a = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
            a.stop_gradient = False
            out = paddle.static.nn.cond(a >= 0, lambda: a, lambda: -a)
            append_backward(out)

        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
        exe = fluid.Executor(place)
        ret = exe.run(main_program, fetch_list=[out.name, a.grad_name])
        np.testing.assert_allclose(
            np.asarray(ret[0]), np.array(2.0), rtol=1e-05
        )
        np.testing.assert_allclose(
            np.asarray(ret[1]), np.array(-1.0), rtol=1e-05
        )

180 181 182 183 184 185 186 187 188 189
    def test_return_var_tuple(self):
        """
        pseudocode:

        if True:
            return 1, True
        else:
            return 3, 2
        """

190 191
        paddle.enable_static()

192
        def true_func():
193 194 195
            return layers.fill_constant(
                shape=[1, 2], dtype='int32', value=1
            ), layers.fill_constant(shape=[2, 3], dtype='bool', value=True)
196 197

        def false_func():
198 199 200
            return layers.fill_constant(
                shape=[3, 4], dtype='float32', value=3
            ), layers.fill_constant(shape=[4, 5], dtype='int64', value=2)
201 202 203 204 205

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            pred = layers.fill_constant(shape=[1], dtype='bool', value=True)
206
            out = paddle.static.nn.cond(pred, true_func, false_func)
207 208
            # out is a tuple containing 2 tensors

209 210 211 212 213
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
214 215
        exe = fluid.Executor(place)
        ret = exe.run(main_program, fetch_list=out)
216 217 218 219 220 221
        np.testing.assert_allclose(
            np.asarray(ret[0]), np.full((1, 2), 1, np.int32), rtol=1e-05
        )
        np.testing.assert_allclose(
            np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05
        )
222 223 224 225 226 227 228 229 230 231 232 233

    def test_pass_and_modify_var(self):
        """
        pseudocode:
        for i in range(5):
            a = 7
            if i % 2 == 0:
                a = a * (i + 1)
            else:
                a = a - (i - 1)
        """

234 235
        paddle.enable_static()

236 237 238 239 240 241 242 243 244 245 246 247 248
        def true_func(a, i):
            a = a * (i + 1)
            return a

        def false_func(a, i):
            a = a - (i - 1)
            return a

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            a = layers.fill_constant(shape=[3, 2, 1], dtype='int32', value=7)
            i = fluid.data(name="i", shape=[1], dtype='int32')
249
            pred = (i % 2) == 0
250
            a = paddle.static.nn.cond(
251 252 253 254 255 256 257
                pred, lambda: true_func(a, i), lambda: false_func(a, i)
            )
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
258 259 260
        exe = fluid.Executor(place)
        for feed_i in range(5):
            expected_a = 7 * (feed_i + 1) if feed_i % 2 == 0 else 8 - feed_i
261 262 263 264 265 266 267 268 269 270
            (ret,) = exe.run(
                main_program,
                feed={'i': np.full((1), feed_i, np.int32)},
                fetch_list=[a],
            )
            np.testing.assert_allclose(
                np.asarray(ret),
                np.full((3, 2, 1), expected_a, np.int32),
                rtol=1e-05,
            )
271 272 273 274 275 276 277 278 279 280 281

    def test_return_none(self):
        """
        pseudocode: test doing nothing in branches
        for i in range(5):
            if i % 2 == 0:
                pass
            else:
                pass
        """

282 283
        paddle.enable_static()

284 285 286 287 288 289 290 291 292 293
        def true_func():
            pass

        def false_func():
            return None

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='int32')
294
            pred = (i % 2) == 0
295 296 297
            out1 = paddle.static.nn.cond(pred, true_func, false_func)
            out2 = paddle.static.nn.cond(pred, None, false_func)
            out3 = paddle.static.nn.cond(pred, true_func, None)
298 299 300 301 302
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
303 304 305 306 307 308 309 310 311 312 313 314 315
        exe = fluid.Executor(place)
        for feed_i in range(5):
            # Test that output is None is runnable
            exe.run(main_program, feed={'i': np.full((1), feed_i, np.int32)})
            self.assertIsNone(out1)
            self.assertIsNone(out2)
            self.assertIsNone(out3)

    def test_wrong_structure_exception(self):
        """
        test returning different number of tensors cannot merge into output
        """

316 317
        paddle.enable_static()

318 319 320 321 322 323 324
        def func_return_none():
            return None

        def func_return_one_tensor():
            return layers.fill_constant(shape=[2, 7], dtype='int32', value=3)

        def func_return_two_tensors():
325 326 327
            return layers.fill_constant(
                shape=[3, 1], dtype='int32', value=7
            ), layers.fill_constant(shape=[3, 1], dtype='int32', value=8)
328 329 330 331 332

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='int32')
333
            pred = (i % 2) == 0
334
            with self.assertRaises(TypeError):
335
                out = paddle.static.nn.cond(pred, i, func_return_one_tensor)
336

337
            with self.assertRaises(TypeError):
338 339 340
                out = paddle.static.nn.cond(
                    pred, func_return_one_tensor, np.asarray([3])
                )
341 342

            with self.assertRaises(Exception) as e:
343
                out = paddle.static.nn.cond(
344 345
                    pred, func_return_none, func_return_one_tensor
                )
346
            self.assertTrue(
347 348 349
                "Incompatible return values of true_fn and false_fn in cond"
                in str(e.exception)
            )
350 351

            with self.assertRaises(Exception) as e:
352
                out = paddle.static.nn.cond(
353 354
                    pred, func_return_two_tensors, func_return_none
                )
355
            self.assertTrue(
356 357 358
                "Incompatible return values of true_fn and false_fn in cond"
                in str(e.exception)
            )
359 360

            with self.assertRaises(Exception) as e:
361
                out = paddle.static.nn.cond(
362 363
                    pred, func_return_one_tensor, func_return_two_tensors
                )
364
            self.assertTrue(
365
                "true fn returns 1 vars, but false fn returns 2 vars, which is not equals"
366 367
                in str(e.exception)
            )
368

369
    def test_extremely_simple_net_with_op_in_condition(self):
370
        paddle.enable_static()
371 372 373
        main_program = fluid.Program()
        startup_program = fluid.Program()
        with fluid.program_guard(main_program, startup_program):
374 375 376
            a = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.23
            )
377
            a.stop_gradient = False
378 379 380
            b = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.25
            )
381
            b.stop_gradient = False
382
            out = paddle.static.nn.cond(a - b < -1.0, lambda: a, lambda: b)
383 384
        append_backward(out)

385 386 387 388 389
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
390
        exe = fluid.Executor(place)
391 392 393
        ret = exe.run(
            main_program, fetch_list=[out, b, a.grad_name, b.grad_name]
        )
394 395
        # Note: fill_constant has loss of precision, you have to assertEqual
        # with values doens't lose precision in float-point number.
396 397 398
        self.assertEqual(ret[0][0], ret[1][0])
        self.assertEqual(ret[2][0], 0.0)
        self.assertEqual(ret[3][0], 1.0)
399

400

401 402 403 404 405 406 407 408
class TestCondNestedControlFlow(unittest.TestCase):
    def test_cond_inside_cond(self):
        """
        pseudocode:
        for i in range(1, 10):
            a = 2 * i
            if i < 5:
                if i >= 3:
409
                    return a + a
410 411 412 413 414 415 416 417 418
                else:
                    return a - a
            else:
                if i < 8:
                    return a * a
                else:
                    return a / a
        """

419 420
        paddle.enable_static()

421
        def less_than_branch(i, a):
422
            return paddle.static.nn.cond(
423
                i >= 3.0,
424 425
                lambda: paddle.add(a, a),
                lambda: paddle.subtract(a, a),
426
            )
427 428

        def greater_equal_branch(i, a):
429
            return paddle.static.nn.cond(
430
                i < 8.0,
431 432
                lambda: paddle.multiply(a, a),
                lambda: paddle.divide(a, a),
433
            )
434 435 436 437 438 439

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = fluid.data(name="i", shape=[1], dtype='float32')
            a = 2.0 * i
440
            out = paddle.static.nn.cond(
441 442 443 444
                i < 5.0,
                lambda: less_than_branch(i, a),
                lambda: greater_equal_branch(i, a),
            )
445
            mean = paddle.mean(out)
446 447
            append_backward(mean)

448 449 450 451 452
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
453 454 455 456 457 458 459 460 461
        exe = fluid.Executor(place)
        for feed_i in range(0, 10):
            expected_a = 2.0 * feed_i
            if feed_i < 5:
                expected_ret = expected_a + expected_a if feed_i >= 3 else 0.0
                expected_a_grad = 2.0 if feed_i >= 3 else 0.0
            else:
                expected_ret = expected_a * expected_a if feed_i < 8 else 1.0
                expected_a_grad = 2.0 * expected_a if feed_i < 8 else 0.0
462 463 464 465 466
            ret = exe.run(
                main_program,
                feed={'i': np.full((1), feed_i, np.float32)},
                fetch_list=[out.name, a.grad_name],
            )
467 468 469
            self.assertEqual(ret[0][0], expected_ret)
            self.assertEqual(ret[1][0], expected_a_grad)

470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
    def test_cond_inside_cond_0d_tensor(self):
        """
        pseudocode:
            i = 3.0
            a = 2 * i
            if i < 5:
                if i >= 3:
                    return a + 1
                else:
                    return 1 - a
            else:
                if i < 8:
                    return a * 2
                else:
                    return a / 2
        """

        paddle.enable_static()

        def less_than_branch(i, a):
            return paddle.static.nn.cond(
                i >= 3.0,
                lambda: a + 1,
                lambda: 1 - a,
            )

        def greater_equal_branch(i, a):
            return paddle.static.nn.cond(
                i < 8.0,
                lambda: a * 2,
                lambda: a / 2,
            )

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            i = paddle.full(fill_value=3.0, shape=[], dtype='float32')
            i.stop_gradient = False
            a = 2.0 * i
            out = paddle.static.nn.cond(
                i < 5.0,
                lambda: less_than_branch(i, a),
                lambda: greater_equal_branch(i, a),
            )
            mean = paddle.mean(out)
            append_backward(out)

        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
        exe = fluid.Executor(place)
        ret = exe.run(
            main_program,
            fetch_list=[out.name, i.grad_name],
        )
        np.testing.assert_allclose(
            np.asarray(ret[0]), np.array(7.0), rtol=1e-05
        )
        np.testing.assert_allclose(
            np.asarray(ret[1]), np.array(2.0), rtol=1e-05
        )

534
    def test_cond_op_in_condition(self):
535
        paddle.enable_static()
536 537 538 539
        main_program = fluid.Program()
        startup_program = fluid.Program()

        with fluid.program_guard(main_program, startup_program):
540 541 542
            a = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.23
            )
543
            a.stop_gradient = False
544 545 546
            b = fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=1.24
            )
547
            b.stop_gradient = False
548
            out = paddle.static.nn.cond(
549
                a < b,
550
                lambda: paddle.static.nn.cond(
551
                    a - b < -1.0,
552 553
                    lambda: paddle.add(a, b),
                    lambda: paddle.multiply(a, b),
554
                ),
555
                lambda: paddle.static.nn.cond(
556
                    a == b,
557
                    lambda: paddle.subtract(a, b),
558
                    lambda: paddle.pow(a, b),
559 560
                ),
            )
561 562
            append_backward(out)

563 564 565 566 567
        place = (
            fluid.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
568 569
        exe = fluid.Executor(place)
        ret = exe.run(main_program, fetch_list=[out, a.grad_name, b.grad_name])
570
        # Note: fill_constant has loss of precision, so we assertAlmostEqual.
571 572 573 574
        self.assertAlmostEqual(ret[0][0], 1.5252)
        self.assertAlmostEqual(ret[1][0], 1.24)
        self.assertAlmostEqual(ret[2][0], 1.23)

575

576
class TestCondBackward(unittest.TestCase):
577
    def backward_value_helper(self, cond_func, use_cuda, use_parallel_exe):
578 579 580
        """
        Helper function that compares calculated backward value is close to dy/dx
        """
581
        paddle.enable_static()
582 583 584 585 586 587 588 589 590 591 592
        main_program = Program()
        main_program.random_seed = 123
        startup_program = Program()
        startup_program.random_seed = 123
        with program_guard(main_program, startup_program):
            img = fluid.data(name='image', shape=[-1, 9], dtype='float32')
            img.stop_gradient = False
            label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
            i = fluid.data(name="i", shape=[1], dtype='int32')
            loss = cond_func(i, img, label)
            append_backward(loss)
593
        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
594 595 596
        exe = fluid.Executor(place)
        exe.run(startup_program)

597 598 599
        num_devices = 1
        if use_parallel_exe:
            os.environ['CPU_NUM'] = str(2)
600 601 602 603 604
            exe = fluid.ParallelExecutor(
                use_cuda=use_cuda,
                main_program=main_program,
                loss_name=loss.name,
            )
605 606
            num_devices = exe.device_count

607 608 609
        delta = 0.005
        for feed_i in range(0, 10):
            feed_img = np.random.random(size=[1, 9]).astype(np.float32)
610 611 612
            feed_label = np.random.randint(
                low=0, high=10, size=[1, 1], dtype=np.int64
            )
613 614 615 616
            if use_parallel_exe:
                img_grad, loss_value = exe.run(
                    feed={
                        'i': np.full((num_devices), feed_i, np.int32),
617
                        'image': np.repeat(feed_img, num_devices, axis=0),
618
                        'label': np.repeat(feed_label, num_devices, axis=0),
619
                    },
620 621
                    fetch_list=[img.grad_name, loss.name],
                )
622 623 624 625 626 627
            else:
                img_grad, loss_value = exe.run(
                    main_program,
                    feed={
                        'i': np.full((1), feed_i, np.int32),
                        'image': feed_img,
628
                        'label': feed_label,
629
                    },
630 631
                    fetch_list=[img.grad_name, loss.name],
                )
632

633
            numerical_grad = np.zeros(shape=[num_devices, 9], dtype=np.float32)
634 635 636
            feed_img_delta = np.copy(feed_img)
            for j in range(9):
                feed_img_delta[0][j] = feed_img[0][j] + delta
637
                if use_parallel_exe:
638 639 640 641 642 643 644 645 646 647 648 649 650
                    loss_delta = exe.run(
                        feed={
                            'i': np.full((num_devices), feed_i, np.int32),
                            'image': np.repeat(
                                feed_img_delta, num_devices, axis=0
                            ),
                            'label': np.repeat(feed_label, num_devices, axis=0),
                        },
                        fetch_list=[loss.name],
                    )
                    multi_device_grad = (
                        (loss_delta[0] - loss_value[0]) / delta / num_devices
                    )
651 652 653
                    for d in range(num_devices):
                        numerical_grad[d][j] = multi_device_grad[d]
                else:
654 655 656 657 658 659 660 661 662 663 664 665
                    loss_delta = exe.run(
                        main_program,
                        feed={
                            'i': np.full((1), feed_i, np.int32),
                            'image': feed_img_delta,
                            'label': feed_label,
                        },
                        fetch_list=[loss.name],
                    )
                    numerical_grad[0][j] = (
                        loss_delta[0] - loss_value[0]
                    ) / delta
666
                feed_img_delta[0][j] = feed_img[0][j]
667 668 669
            np.testing.assert_allclose(
                img_grad, numerical_grad, rtol=0.05, atol=0.05
            )
670

671
    def add_optimizer_helper(self, cond_func, use_cuda, use_parallel_exe):
672 673 674 675 676 677 678 679 680 681 682 683 684
        """
        Test that program is runnable when add optimizer
        """
        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            img = fluid.data(name='image', shape=[-1, 784], dtype='float32')
            label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
            i = fluid.data(name="i", shape=[1], dtype='int32')
            loss = cond_func(i, img, label)
            optimizer = fluid.optimizer.SGD(learning_rate=0.1)
            optimizer.minimize(loss)

685
        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
686 687
        exe = fluid.Executor(place)
        exe.run(startup_program)
688 689
        if use_parallel_exe:
            os.environ['CPU_NUM'] = str(2)
690 691 692 693 694
            exe = fluid.ParallelExecutor(
                use_cuda=use_cuda,
                main_program=main_program,
                loss_name=loss.name,
            )
695
            num_devices = exe.device_count
696 697 698

        for feed_i in range(0, 10):
            feed_img = np.random.random(size=[16, 784]).astype(np.float32)
699 700 701
            feed_label = np.random.randint(
                low=0, high=10, size=[16, 1], dtype=np.int64
            )
702
            if use_parallel_exe:
703 704 705 706 707 708 709 710
                exe.run(
                    feed={
                        'i': np.full((num_devices), feed_i, np.int32),
                        'image': np.repeat(feed_img, num_devices, axis=0),
                        'label': np.repeat(feed_label, num_devices, axis=0),
                    },
                    fetch_list=[loss.name],
                )
711
            else:
712 713 714 715 716 717 718 719 720
                exe.run(
                    main_program,
                    feed={
                        'i': np.full((1), feed_i, np.int32),
                        'image': feed_img,
                        'label': feed_label,
                    },
                    fetch_list=[loss],
                )
721 722

    def test_cond_backward(self):
723

724 725
        paddle.enable_static()

726
        def cond_func(i, img, label):
727
            predicate = (i % 2) == 0
728
            return paddle.static.nn.cond(
729 730
                predicate,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
731 732
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
733

734
        for use_parallel_exe in [False, True]:
735 736 737 738 739 740
            if use_parallel_exe and os.name == "nt":
                print(
                    "Skip use_parallel_exe=True in Windows because of flaky test when using PE under old Windows machine"
                )
                continue

741 742 743 744 745 746
            self.backward_value_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
            self.add_optimizer_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
747 748

    def test_half_nested_cond_backward(self):
749
        paddle.enable_static()
750

751
        def branch(i, img, label):
752
            return paddle.static.nn.cond(
753 754
                (i % 2) == 0,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
755 756
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
757 758

        def cond_func_simple_net_at_true(i, img, label):
759
            return paddle.static.nn.cond(
760 761
                i < 5, lambda: branch(i, img, label), lambda: paddle.mean(img)
            )
762 763

        def cond_func_simple_net_at_false(i, img, label):
764
            return paddle.static.nn.cond(
765 766
                i < 5, lambda: paddle.mean(img), lambda: branch(i, img, label)
            )
767

768
        for use_parallel_exe in [False, True]:
769 770 771 772 773 774
            if use_parallel_exe and os.name == "nt":
                print(
                    "Skip use_parallel_exe=True in Windows because of flaky test when using PE under old Windows machine"
                )
                continue

775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794
            self.backward_value_helper(
                cond_func_simple_net_at_true,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
            self.add_optimizer_helper(
                cond_func_simple_net_at_true,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
            self.backward_value_helper(
                cond_func_simple_net_at_false,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
            self.add_optimizer_helper(
                cond_func_simple_net_at_false,
                core.is_compiled_with_cuda(),
                use_parallel_exe,
            )
795 796

    def test_nested_cond_backward(self):
797
        paddle.enable_static()
798

799 800
        def branch(i, img, label, mod_two):
            if mod_two:
801
                predicate = (i % 2) == 0
802
            else:
803
                predicate = (i % 2) != 0
804
            return paddle.static.nn.cond(
805 806
                predicate,
                lambda: simple_fc_net_with_inputs(img, label, class_num=10),
807 808
                lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
            )
809 810

        def cond_func(i, img, label):
811
            return paddle.static.nn.cond(
812 813 814 815
                i < 5,
                lambda: branch(i, img, label, True),
                lambda: branch(i, img, label, False),
            )
816

817
        for use_parallel_exe in [False, True]:
818 819 820 821 822
            if use_parallel_exe and os.name == "nt":
                print(
                    "Skip use_parallel_exe=True in Windows because of flaky test when using PE under old Windows machine"
                )
                continue
823 824 825 826 827 828
            self.backward_value_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
            self.add_optimizer_helper(
                cond_func, core.is_compiled_with_cuda(), use_parallel_exe
            )
829 830


831 832
class TestCondWithError(unittest.TestCase):
    def test_input_type_error(self):
833
        paddle.enable_static()
834 835 836 837 838 839 840 841 842
        main_program = framework.Program()
        startup_program = framework.Program()
        with framework.program_guard(main_program, startup_program):
            pred = fluid.data(name='y', shape=[1], dtype='bool')

            def func():
                return pred

            with self.assertRaises(TypeError):
843
                paddle.static.nn.cond(None, func, func)
844 845

            with self.assertRaises(TypeError):
846
                paddle.static.nn.cond(pred, func, set())
847 848

            with self.assertRaises(TypeError):
849
                paddle.static.nn.cond(pred, set(), func)
850 851

            with self.assertRaises(TypeError):
852
                paddle.static.nn.cond(pred, func, func, set())
853 854


855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893
class TestCondWithDict(unittest.TestCase):
    def test_input_with_dict(self):
        paddle.enable_static()
        main_program = framework.Program()
        startup_program = framework.Program()
        with framework.program_guard(main_program, startup_program):

            def true_func():
                return {
                    '1': paddle.full(shape=[3, 2], dtype='int32', fill_value=1),
                    '2': paddle.full(
                        shape=[2, 3], dtype='bool', fill_value=True
                    ),
                }

            def false_func():
                return {
                    '1': paddle.full(
                        shape=[3, 4], dtype='float32', fill_value=3
                    ),
                    '2': paddle.full(shape=[4, 5], dtype='int64', fill_value=2),
                }

            x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
            y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
            pred = paddle.less_than(x=x, y=y, name=None)
            ret = paddle.static.nn.cond(pred, true_func, false_func)
            self.assertEqual(
                ret['1'].shape,
                (3, -1),
                f"The shape is not correct, expects (3, -1) but gets {ret['1'].shape}.",
            )
            self.assertEqual(
                ret['2'].shape,
                (-1, -1),
                f"The shape is not correct, expects (-1, -1) but gets {ret['2'].shape}.",
            )


894 895
if __name__ == '__main__':
    unittest.main()