control_flow.py 60.6 KB
Newer Older
1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

S
rename  
sneaxiy 已提交
15
from ..wrapped_decorator import signature_safe_contextmanager
D
dzhwinter 已提交
16

17
from .layer_function_generator import templatedoc
18
from .tensor import assign, cast, fill_constant
19
from .. import core
20 21 22 23 24 25 26 27 28
from ..framework import (
    Program,
    Variable,
    Operator,
    _non_static_mode,
    static_only,
    _in_legacy_dygraph,
    in_dygraph_mode,
)
29
from ..layer_helper import LayerHelper, unique_name
30 31 32 33 34 35 36 37 38 39 40
from .utils import (
    assert_same_structure,
    map_structure,
    hold_mutable_vars,
    copy_mutable_vars,
    padding_to_same_structure,
    is_sequence,
    pack_sequence_as,
    flatten,
    to_sequence,
)
Y
yuyang18 已提交
41
import numpy
42
import warnings
L
liym27 已提交
43
from functools import reduce, partial
44 45 46 47 48 49
from ..data_feeder import (
    convert_dtype,
    check_variable_and_dtype,
    check_type,
    check_dtype,
)
50
from ..backward import _infer_var_data_type_shape_
2
201716010711 已提交
51
import paddle
52
from paddle import _C_ops, _legacy_C_ops
D
dzhwinter 已提交
53

Q
QI JUN 已提交
54
__all__ = [
55 56 57 58
    'Switch',
    'StaticRNN',
    'Print',
    'while_loop',
D
dzhwinter 已提交
59 60
]

Y
Yu Yang 已提交
61

62 63
def select_output(input, outputs, mask):
    """
64
    **select_output**
65 66 67 68 69 70 71 72 73 74 75 76 77 78
    This API takes in one input and multiple outputs and an integer mask. It
    selects the output specified by the mask and copy the input to selected
    output. It is useful in control flow.

    Args:
        input(Variable): The input variable
        outputs(tuple|list): The output variables
        mask(Variable): A tensor containing 1 integer number selecting which
            output to be copied with input

    Returns:
        Variable: The outputs variables
    """
    helper = LayerHelper('select_output', **locals())
79 80 81 82
    check_type(input, 'input', (Variable), 'select_output')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_output')
    check_type(outputs, 'outputs', (list, tuple), 'select_output')

83 84 85 86 87
    helper.append_op(
        type='select_output',
        inputs={'X': input, 'Mask': mask},
        outputs={'Out': outputs},
    )
88 89 90
    return outputs


91 92 93 94 95 96 97
def _select_input_infer_shape(first_shape, second_shape):
    """
    This function infer the output shape by following algorithm:
    1. if the dims is different, raise a error.
    2. compare axis one by one:
        if a == b: we set axis to a
        if a != b: we set axis to -1
98
    for compatibility, non declarative mode, we just return second_shape.
99 100 101 102 103 104 105
    """
    if len(first_shape) != len(second_shape):
        warnings.warn(
            f"the input shapes of select_input should have the same rank, but get {first_shape}, {second_shape}"
        )
        return second_shape
    out_shape = list(
106 107
        map(lambda a, b: a if a == b else -1, first_shape, second_shape)
    )
108 109 110
    return out_shape


111 112 113
def select_input(inputs, mask):
    """
    **select_input**
114

115 116 117 118 119 120 121 122 123 124 125 126
    This API takes in multiple inputs and uses an integer mask to select one
    input to output. It is useful in control flow.

    Args:
        inputs(tuple|list): The input variables
        mask(Variable): A tensor containing 1 integer number selecting which
            input to output

    Returns:
        Variable: The selected input variable
    """
    helper = LayerHelper('select_input', **locals())
127 128 129
    check_type(inputs, 'inputs', (list, tuple), 'select_input')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input')

130
    # Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly
131
    # assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"
132

133 134 135
    output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape)
    output_dtype = inputs[1].dtype
    output_type = inputs[1].type
136

137 138 139 140 141 142 143 144
    out = helper.create_variable(
        dtype=output_dtype, shape=output_shape, type=output_type
    )
    helper.append_op(
        type='select_input',
        inputs={'X': inputs, 'Mask': mask},
        outputs={'Out': out},
    )
145 146 147
    return out


148
@static_only
149 150 151 152 153 154 155 156 157 158 159 160
def Print(
    input,
    first_n=-1,
    message=None,
    summarize=20,
    print_tensor_name=True,
    print_tensor_type=True,
    print_tensor_shape=True,
    print_tensor_layout=True,
    print_tensor_lod=True,
    print_phase='both',
):
Y
Yan Chunwei 已提交
161
    '''
162 163
    :api_attr: Static Graph

Y
Yan Chunwei 已提交
164 165 166 167 168 169 170 171 172
    **Print operator**

    This creates a print op that will print when a tensor is accessed.

    Wraps the tensor passed in so that whenever that a tensor is accessed,
    the message `message` is printed, along with the current value of the
    tensor `t`.

    Args:
173 174 175 176 177
        input (Tensor): A Tensor to print.
        first_n (int, optional): Only log `first_n` number of times. Default: -1.
        message (str, optional): A string message to print as a prefix. Default: None.
        summarize (int, optional): Number of elements in the tensor to be print. If
                it's value is -1, then all elements in the tensor will be print.
178 179 180
        print_tensor_name (bool, optional): Print the tensor name. Default: True.
        print_tensor_type (bool, optional): Print the tensor type. Defaultt: True.
        print_tensor_shape (bool, optional): Print the tensor shape. Default: True.
181
        print_tensor_layout (bool, optional): Print the tensor layout. Default: True.
182
        print_tensor_lod (bool, optional): Print the tensor lod. Default: True.
183
        print_phase (str, optional): Which phase to displace, including 'forward',
184
                'backward' and 'both'. Default: 'both'. If set to 'backward', will
185 186
                only print the gradients of input tensor; If set to 'both', will
                both print the input tensor itself and the gradients of input tensor.
Y
Yan Chunwei 已提交
187 188

    Returns:
189
        Tensor: Output tensor.
Y
Yan Chunwei 已提交
190

191
    NOTES:
192 193
        The input and output are two different Tensor, and in the
        following process, you should use the output Tensor but not the input,
194
        otherwise, the print layer doesn't have backward.
Y
Yan Chunwei 已提交
195

Y
Yan Chunwei 已提交
196 197
    Examples:
        .. code-block:: python
198

199 200 201
           import paddle

           paddle.enable_static()
202

203 204 205 206 207 208 209 210 211 212 213 214 215 216
           x = paddle.full(shape=[2, 3], fill_value=3, dtype='int64')
           out = paddle.static.Print(x, message="The content of input layer:")

           main_program = paddle.static.default_main_program()
           exe = paddle.static.Executor(place=paddle.CPUPlace())
           res = exe.run(main_program, fetch_list=[out])
           # Variable: fill_constant_1.tmp_0
           #   - message: The content of input layer:
           #   - lod: {}
           #   - place: CPUPlace
           #   - shape: [2, 3]
           #   - layout: NCHW
           #   - dtype: long
           #   - data: [3 3 3 3 3 3]
Y
Yan Chunwei 已提交
217
    '''
218 219 220 221 222 223
    check_variable_and_dtype(
        input,
        'input',
        ['float32', 'float64', 'int32', 'int64', 'bool'],
        'fluid.layers.Print',
    )
224

225 226
    helper = LayerHelper('print' + "_" + input.name, **locals())
    output = helper.create_variable_for_type_inference(input.dtype)
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
    helper.append_op(
        type='print',
        inputs={'In': input},
        outputs={'Out': output},
        attrs={
            'first_n': first_n,
            'summarize': summarize,
            'message': message or "",
            'print_tensor_name': print_tensor_name,
            'print_tensor_type': print_tensor_type,
            'print_tensor_shape': print_tensor_shape,
            'print_tensor_layout': print_tensor_layout,
            'print_tensor_lod': print_tensor_lod,
            'print_phase': print_phase.upper(),
        },
    )
243
    return output
Y
Yan Chunwei 已提交
244 245


246
# (TODO: Mine) There exists dependency. It will be removed later.
247
class BlockGuard:
Y
Yu Yang 已提交
248
    """
249 250 251 252
    BlockGuard class.

    BlockGuard class is used to create a sub-block in a program by
    using the Python `with` keyword.
Y
Yu Yang 已提交
253 254
    """

255 256
    def __init__(self, main_program):
        if not isinstance(main_program, Program):
Y
Yu Yang 已提交
257
            raise TypeError("BlockGuard takes a program")
258
        self.main_program = main_program
Y
Yu Yang 已提交
259 260

    def __enter__(self):
W
Wu Yi 已提交
261
        self.main_program._create_block()
Y
Yu Yang 已提交
262 263

    def __exit__(self, exc_type, exc_val, exc_tb):
W
Wu Yi 已提交
264
        self.main_program._rollback()
Y
Yu Yang 已提交
265 266 267 268 269
        if exc_type is not None:
            return False  # re-raise exception
        return True


270
# (TODO: Mine) There exists dependency. It will be removed later.
Y
Yang Yang 已提交
271 272 273 274 275
class BlockGuardWithCompletion(BlockGuard):
    """
    BlockGuardWithCompletion class.

    BlockGuardWithCompletion class is used to create an op with a block in a program.
276 277
    """

Y
Yu Yang 已提交
278
    def __init__(self, rnn):
X
Xin Pan 已提交
279
        if not isinstance(rnn, StaticRNN):
X
Xin Pan 已提交
280
            raise TypeError("BlockGuardWithCompletion takes a StaticRNN")
281
        super().__init__(rnn.helper.main_program)
Y
Yu Yang 已提交
282 283 284 285
        self.rnn = rnn

    def __enter__(self):
        self.rnn.status = StaticRNN.IN_RNN_BLOCK
286
        return super().__enter__()
Y
Yu Yang 已提交
287 288

    def __exit__(self, exc_type, exc_val, exc_tb):
Y
Yu Yang 已提交
289 290
        if exc_type is not None:
            return False
Y
Yu Yang 已提交
291
        self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
292
        self.rnn._complete_op()
293
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yu Yang 已提交
294 295


296
class StaticRNNMemoryLink:
Y
Yu Yang 已提交
297
    """
298 299 300 301
    StaticRNNMemoryLink class.

    StaticRNNMemoryLink class is used to create a link between two
    memory cells of a StaticRNN.
Y
yuyang18 已提交
302 303 304 305 306 307 308 309 310


    NOTE: This is a internal data structure of a very low-level API.
    Please use StaticRNN instead.

    Args:
        init(Variable): the initial variable for Memory.
        pre_mem(Variable): the memory variable in previous time step.
        mem(Variable): the memory variable in current time step.
Y
Yu Yang 已提交
311 312 313 314 315 316 317 318
    """

    def __init__(self, init, pre_mem, mem=None):
        self.init = init
        self.pre_mem = pre_mem
        self.mem = mem


319
class StaticRNN:
320
    """
321 322
    :api_attr: Static Graph

323 324
    StaticRNN class.

325 326 327 328 329 330 331
    The StaticRNN can process a batch of sequence data. The first dimension of inputs
    represents sequence length, the length of each input sequence must be equal.
    StaticRNN will unfold sequence into time steps, user needs to define how to process
    each time step during the :code:`with` step.

    Args:
        name (str, optional): Please refer to :ref:`api_guide_Name`, Default None.
C
chengduo 已提交
332 333

    Examples:
334 335
        .. code-block:: python

336
            import paddle
337 338 339 340
            import paddle.fluid as fluid
            import paddle.fluid.layers as layers

            vocab_size, hidden_size=10000, 200
341
            paddle.enable_static()
342 343
            x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
            # create word sequence
344 345 346 347 348
            x_emb = layers.embedding(
                input=x,
                size=[vocab_size, hidden_size],
                dtype='float32',
                is_sparse=False)
349
            # transform batch size to dim 1
350
            x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
351 352 353

            rnn = fluid.layers.StaticRNN()
            with rnn.step():
354
                # mark created x_emb as input, each step process a word
355
                word = rnn.step_input(x_emb)
356
                # create prev memory parameter, batch size comes from word
357 358
                prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
359 360
                # use hidden to update prev
                rnn.update_memory(prev, hidden)
361
                # mark hidden as output
362
                rnn.step_output(hidden)
363
            # get StaticrNN final output
364
            result = rnn()
C
chengduo 已提交
365

366
    """
367

Y
Yu Yang 已提交
368 369 370 371
    BEFORE_RNN_BLOCK = 0
    IN_RNN_BLOCK = 1
    AFTER_RNN_BLOCK = 2

372
    def __init__(self, name=None):
373
        check_type(name, "name", (str, type(None)), "fluid.layers.StaticRNN")
374
        self.helper = LayerHelper("static_rnn", name=name)
Y
Yu Yang 已提交
375 376 377 378 379 380 381 382
        self.memories = {}  # memory map, from pre_mem.name --> MemoryLink
        self.inputs = []  # input variable list in current block
        self.outputs = []  # output variable list in parent block
        self.status = StaticRNN.BEFORE_RNN_BLOCK  # status flag.
        # sequence length, since it is a static RNN, sequence length are fixed.
        self.seq_len = None

    def step(self):
C
chengduo 已提交
383
        """
384 385
        Define operators in each step. step is used in :code:`with` block, OP in :code:`with` block
        will be executed sequence_len times (sequence_len is the length of input)
C
chengduo 已提交
386
        """
Y
Yang Yang 已提交
387
        return BlockGuardWithCompletion(self)
Y
Yu Yang 已提交
388 389 390 391 392

    def _assert_in_rnn_block_(self, method):
        if self.status != StaticRNN.IN_RNN_BLOCK:
            raise ValueError("You must invoke {0} in rnn block".format(method))

393 394 395 396 397 398 399 400 401
    def memory(
        self,
        init=None,
        shape=None,
        batch_ref=None,
        init_value=0.0,
        init_batch_dim_idx=0,
        ref_batch_dim_idx=1,
    ):
402
        """
C
chengduo 已提交
403 404 405
        Create a memory variable for static rnn.
        If the :code:`init` is not None, :code:`memory` will be initialized by
        this Variable. If the :code:`init` is None, :code:`shape` and :code:`batch_ref`
406 407
        must be set, and this function will create a new variable with shape and batch_ref
        to initialize :code:`init` Variable.
C
chengduo 已提交
408

409
        Args:
410
            init(Variable, optional): Tensor used to init memory. If it is not set,
C
chengduo 已提交
411 412
                :code:`shape` and :code:`batch_ref` must be provided.
                Default: None.
413 414 415 416 417 418 419
            shape(list|tuple): When :code:`init` is None use this arg to initialize memory shape.
            NOTE the shape does not contain batch_size. Default: None.
            batch_ref(Variable, optional): When :code:`init` is None, memory's batch size will
            be set as batch_ref's ref_batch_dim_idx value. Default: None.
            init_value(float, optional): When :code:`init` is None, used to init memory's value. Default: 0.0.
            init_batch_dim_idx(int, optional): the batch_size axis of the :code:`init` Variable. Default: 0.
            ref_batch_dim_idx(int, optional): the batch_size axis of the :code:`batch_ref` Variable. Default: 1.
C
chengduo 已提交
420 421

        Returns:
422 423 424 425 426
            Variable: The memory variable.

        Examples 1:
            .. code-block:: python

427
                import paddle
428 429 430 431
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
432
                paddle.enable_static()
433 434 435 436 437 438 439 440
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
441
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
442 443 444 445 446 447 448 449 450 451

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
452 453 454


        Examples 2:
455 456
            .. code-block:: python

457
                import paddle
458 459 460
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers
                vocab_size, hidden_size=10000, 200
461
                paddle.enable_static()
462 463 464 465 466 467 468 469
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
470
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
471 472 473 474 475 476 477 478 479 480
                boot_memory = fluid.layers.data(name='boot', shape=[hidden_size], dtype='float32', lod_level=1)
                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # init memory
                        prev = rnn.memory(init=boot_memory)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # update hidden with prev
                        rnn.update_memory(prev, hidden)
481

482
        """
Y
Yu Yang 已提交
483
        self._assert_in_rnn_block_('memory')
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
        check_type(
            init,
            "init",
            (Variable, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
        check_type(
            shape,
            "shape",
            (list, tuple, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
        check_type(
            batch_ref,
            "batch_ref",
            (Variable, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
Y
Yu Yang 已提交
502
        if init is None:
503
            if shape is None or batch_ref is None:
Y
Yu Yang 已提交
504
                raise ValueError(
505 506
                    "if init is None, memory at least need shape and batch_ref"
                )
507
            parent_block = self._parent_block()
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
            var_name = unique_name.generate_with_ignorable_key(
                "@".join([self.helper.name, "memory_boot"])
            )
            boot_var = parent_block.create_var(
                name=var_name,
                shape=shape,
                dtype=batch_ref.dtype,
                persistable=False,
            )

            parent_block.append_op(
                type="fill_constant_batch_size_like",
                inputs={'Input': [batch_ref]},
                outputs={'Out': [boot_var]},
                attrs={
                    'value': init_value,
                    'shape': boot_var.shape,
                    'dtype': boot_var.dtype,
                    'input_dim_idx': ref_batch_dim_idx,
                    'output_dim_idx': init_batch_dim_idx,
                },
            )
Y
Yu Yang 已提交
530 531 532 533

            return self.memory(init=boot_var)
        else:
            pre_mem = self.helper.create_variable(
534 535 536
                name=unique_name.generate_with_ignorable_key(
                    "@".join([self.helper.name, "mem"])
                ),
F
fengjiayi 已提交
537
                dtype=init.dtype,
538 539 540 541 542
                shape=init.shape,
            )
            self.memories[pre_mem.name] = StaticRNNMemoryLink(
                init=init, pre_mem=pre_mem
            )
Y
Yu Yang 已提交
543 544 545
            return pre_mem

    def step_input(self, x):
C
chengduo 已提交
546 547 548 549 550 551 552 553
        """
        Mark a sequence as a StaticRNN input.

        Args:
            x(Variable): The input sequence, the shape of x
                should be [seq_len, ...].

        Returns:
554 555 556 557 558
            Variable: The current time step data in the input sequence.

        Examples:
            .. code-block:: python

559
                import paddle
560 561 562 563
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
564
                paddle.enable_static()
565 566 567 568 569 570 571 572
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
573
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
574 575 576 577 578 579 580 581 582 583

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
584

C
chengduo 已提交
585
        """
Y
Yu Yang 已提交
586
        self._assert_in_rnn_block_('step_input')
587
        check_type(x, "x", Variable, "fluid.layers.StaticRNN.step_input")
Y
Yu Yang 已提交
588
        if self.seq_len is None:
Y
Yu Yang 已提交
589
            self.seq_len = x.shape[0]
590
        elif x.shape[0] != -1 and self.seq_len != x.shape[0]:
Y
Yu Yang 已提交
591 592
            raise ValueError("Static RNN only take fix seq_len input")

593 594 595
        ipt = self.helper.create_variable(
            name=x.name, dtype=x.dtype, shape=list(x.shape[1:]), type=x.type
        )
Y
Yu Yang 已提交
596 597 598 599
        self.inputs.append(ipt)
        return ipt

    def step_output(self, o):
C
chengduo 已提交
600 601 602 603 604 605 606 607
        """
        Mark a sequence as a StaticRNN output.

        Args:
            o(Variable): The output sequence.

        Returns:
            None.
608 609 610 611

        Examples:
            .. code-block:: python

612
                import paddle
613 614 615 616
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
617
                paddle.enable_static()
618 619 620 621 622 623 624 625
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
626
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
627 628 629 630 631 632 633 634 635 636 637 638 639

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
                        rnn.step_output(hidden)

                result = rnn()
640

C
chengduo 已提交
641
        """
Y
Yu Yang 已提交
642
        self._assert_in_rnn_block_('step_output')
643
        check_type(o, "o", Variable, "fluid.layers.StaticRNN.step_output")
Y
Yu Yang 已提交
644

X
Xin Pan 已提交
645
        tmp_o = self.helper.create_variable_for_type_inference(dtype=o.dtype)
646 647 648 649 650 651
        self.helper.append_op(
            type='rnn_memory_helper',
            inputs={'X': [o]},
            outputs={'Out': tmp_o},
            attrs={'dtype': o.dtype},
        )
Y
Yu Yang 已提交
652

653 654 655 656 657
        out_var = self._parent_block().create_var(
            name=tmp_o.name,
            shape=[self.seq_len] + list(tmp_o.shape),
            dtype=tmp_o.dtype,
        )
Y
Yu Yang 已提交
658 659 660 661

        self.outputs.append(out_var)

    def output(self, *outputs):
C
chengduo 已提交
662 663 664 665
        """
        Mark the StaticRNN output variables.

        Args:
666
            outputs: The output Tensor, can mark multiple variables as output
C
chengduo 已提交
667 668 669

        Returns:
            None
670 671 672 673

        Examples:
            .. code-block:: python

674
                import paddle
675 676 677 678
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
679
                paddle.enable_static()
680 681 682 683 684 685 686 687
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
688
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
689 690 691 692 693 694 695 696 697 698 699 700 701 702

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
                        # mark each step's hidden and word as output
                        rnn.output(hidden, word)

                result = rnn()
C
chengduo 已提交
703
        """
Y
Yu Yang 已提交
704 705 706 707
        for each in outputs:
            self.step_output(each)

    def update_memory(self, mem, var):
C
chengduo 已提交
708
        """
709
        Update the memory from :code:`mem` to :code:`var`.
C
chengduo 已提交
710 711 712

        Args:
            mem(Variable): the memory variable.
713
            var(Variable): the plain variable generated in RNN block, used to update memory.
T
tianshuo78520a 已提交
714
                           var and mem should have same dims and data type.
C
chengduo 已提交
715 716 717

        Returns:
            None
718

C
chengduo 已提交
719
        """
720 721
        check_type(mem, "mem", Variable, "fluid.layers.StaticRNN.update_memory")
        check_type(var, "var", Variable, "fluid.layers.StaticRNN.update_memory")
Y
Yu Yang 已提交
722 723
        self.memories[mem.name].mem = var

724
    def _parent_block(self):
725
        prog = self.helper.main_program
Y
Yu Yang 已提交
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740
        parent_idx = prog.current_block().parent_idx
        assert parent_idx >= 0
        parent_block = prog.block(parent_idx)
        return parent_block

    def __call__(self, *args, **kwargs):
        if self.status != StaticRNN.AFTER_RNN_BLOCK:
            raise ValueError("RNN output can only be retrieved after rnn block")
        if len(self.outputs) == 0:
            raise ValueError("RNN has no output")
        elif len(self.outputs) == 1:
            return self.outputs[0]
        else:
            return self.outputs

741
    def _complete_op(self):
742 743
        main_program = self.helper.main_program
        rnn_block = main_program.current_block()
744
        parent_block = self._parent_block()
Y
Yu Yang 已提交
745 746 747 748 749 750 751 752 753 754 755 756 757 758

        local_inputs = set()

        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for oname in op.output_names:
                for out_var_name in op.output(oname):
                    local_inputs.add(out_var_name)

        for var in self.inputs:
            local_inputs.add(var.name)
        for m in self.memories:
            local_inputs.add(m)

C
chengduo 已提交
759 760 761
        # NOTE(zcd): the params have two categories of variables.
        #   - the variables that are the out of StaticRnn.
        #   - the variables that are the parameters of some layers, for example, conv2d.
Y
Yu Yang 已提交
762 763 764 765 766 767 768 769
        params = list()
        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for iname in op.input_names:
                for in_var_name in op.input(iname):
                    if in_var_name not in local_inputs:
                        params.append(in_var_name)

770 771 772
        parameters = [
            parent_block._find_var_recursive(name) for name in set(params)
        ]
Y
Yu Yang 已提交
773 774

        step_scope = parent_block.create_var(
775 776
            type=core.VarDesc.VarType.STEP_SCOPES
        )
Y
Yu Yang 已提交
777 778 779 780

        inlinks = [parent_block.var(i.name) for i in self.inputs]
        outlinks = self.outputs

C
chengduo 已提交
781
        # NOTE(zcd): the states maybe empty in some case.
Y
Yu Yang 已提交
782 783 784
        boot_memories = []
        pre_memories = []
        memories = []
785
        for _, mem in self.memories.items():
Y
Yu Yang 已提交
786 787
            boot_memories.append(mem.init)
            pre_memories.append(mem.pre_mem.name)
788 789 790
            assert (
                mem.mem is not None
            ), "%s should be updated in every step." % (mem.init.name)
Y
Yu Yang 已提交
791 792
            mem_var = rnn_block.var(mem.mem.name)
            assert isinstance(mem_var, Variable)
X
Xin Pan 已提交
793
            new_mem = self.helper.create_variable_for_type_inference(
794 795 796 797 798 799 800 801
                dtype=mem_var.dtype
            )
            rnn_block.append_op(
                type='rnn_memory_helper',
                inputs={'X': [mem_var]},
                outputs={'Out': [new_mem]},
                attrs={'dtype': mem_var.dtype},
            )
Y
Yu Yang 已提交
802 803 804

            memories.append(new_mem.name)

805 806 807 808 809 810 811 812 813 814 815 816 817 818 819
        parent_block.append_op(
            type='recurrent',
            inputs={
                'inputs': inlinks,
                'initial_states': boot_memories,
                'parameters': parameters,
            },
            outputs={'outputs': outlinks, 'step_scopes': [step_scope]},
            attrs={
                'has_states': len(pre_memories) > 0,
                'ex_states': pre_memories,
                'states': memories,
                'sub_block': rnn_block,
            },
        )
Y
Yu Yang 已提交
820 821


822
# (TODO: Mine) There exists dependency. It will be removed later.
Y
Yang Yang(Tony) 已提交
823 824 825 826
class WhileGuard(BlockGuard):
    def __init__(self, while_op):
        if not isinstance(while_op, While):
            raise TypeError("WhileGuard takes a while op")
827
        super().__init__(while_op.helper.main_program)
Y
Yang Yang(Tony) 已提交
828 829 830 831
        self.while_op = while_op

    def __enter__(self):
        self.while_op.status = While.IN_WHILE_BLOCK
832
        return super().__enter__()
Y
Yang Yang(Tony) 已提交
833 834 835 836 837

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None:
            return False
        self.while_op.status = While.AFTER_WHILE_BLOCK
838
        self.while_op._complete()
839
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yang Yang(Tony) 已提交
840 841


842
# (TODO: Mine) There exists dependency. It will be removed later.
843 844 845
def get_inputs_outputs_in_block(
    current_block, inner_inputs, inner_outputs, helper
):
846 847 848 849 850 851 852 853
    """
    Find inputs and outputs in current control flow block.
    :param current_block: Current control flow block.
    :param inner_inputs: Input var name of ops in current block.
    :param inner_outputs: Output var name of ops in current block.
    :return: inner_inputs, inner_outputs
    """

854 855 856 857 858 859 860 861 862 863 864 865 866
    def is_ignore_vars(op, var_name):
        # NOTE(dev): There are some persistable var created in some non-standard API
        # such as "contrib.layers.shuffle_batch". It create a "Seed" used both in
        # Input and Output. This var shall not be considered as a loop_var in
        # control_flow.
        IGNORE_VAR_NAMES = {"shuffle_batch": ["shuffle_batch_seed"]}
        if op.type in IGNORE_VAR_NAMES:
            var_names = IGNORE_VAR_NAMES[op.type]
            for name in var_names:
                if name in var_name:
                    return True
        return False

867 868 869 870 871 872 873 874
    # Step1: update inner_inputs and inner_outputs
    # NOTE: Here assumes that all variables are input or output of Ops,
    # but some variables are created without appendding a real op.
    # For example, in `arr = create_array(dtype)`, `arr` is not a output of a op.
    for op in current_block.ops:
        assert isinstance(op, Operator)
        for iname in op.input_names:
            for in_var_name in op.input(iname):
875
                if in_var_name not in inner_outputs and not is_ignore_vars(
876 877
                    op, in_var_name
                ):
878 879 880 881 882 883 884 885 886 887 888 889 890 891 892
                    inner_inputs.add(in_var_name)

        for oname in op.output_names:
            for out_var_name in op.output(oname):
                inner_outputs.add(out_var_name)

    # Step2: Remove LOD_TENSOR_ARRAY created in current control flow block.
    remove_inner_inputs = set()
    parent_block = helper.main_program.block(current_block.parent_idx)

    for in_var_name in inner_inputs:
        parent_block_var = parent_block._find_var_recursive(in_var_name)
        current_block_var = None
        if current_block.has_var(in_var_name):
            current_block_var = current_block.var(in_var_name)
893 894 895 896 897
        if (
            not parent_block_var
            and current_block_var
            and current_block_var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
        ):
898 899 900 901 902 903 904
            remove_inner_inputs.add(in_var_name)

    inner_inputs = inner_inputs - remove_inner_inputs

    return inner_inputs, inner_outputs


905
# (TODO: Mine) There exists dependency. It will be removed later.
906
class While:
X
Xin Pan 已提交
907
    """
908
    :api_attr: Static Graph
909

910
    while loop control flow. Repeat while body until cond is False.
X
Xin Pan 已提交
911

912 913 914 915
    Note:
        A new OP :ref:`api_fluid_layers_while_loop` is highly recommended instead of ``While`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_fluid_layers_while_loop` is easier to use and is called with less code but does the same thing as ``While`` .

916 917 918 919 920 921
    Notice:
        Local variables created in ``While`` are similar to that created in while of C++, and cannot be referenced externally.
        As a result, they cannot be obtained through ``fetch_list`` of ``Executor``. If you would like to access the variable
        out of ``while`` , PaddlePaddle provides ``assign`` API to assign local variables to external. Please refer to example
        code 2 or refer to `issue#22724 <https://github.com/PaddlePaddle/Paddle/issues/22724>`_.

X
Xin Pan 已提交
922
    Args:
923
        cond(Variable): A Tensor whose data type is bool controlling whether to continue looping.
G
guofei 已提交
924
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
925
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
X
Xin Pan 已提交
926

927
    Examples 1:
X
Xin Pan 已提交
928
          .. code-block:: python
929

930
            import paddle.fluid as fluid
931 932 933 934 935
            import numpy as np

            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)           # loop counter

            loop_len = fluid.layers.fill_constant(shape=[1],dtype='int64', value=10)    # loop length
936

L
LiYuRio 已提交
937
            cond = paddle.less_than(x=i, y=loop_len)
938
            while_op = fluid.layers.While(cond=cond)
939
            with while_op.block():
940
                i = paddle.increment(x=i, value=1)
L
LiYuRio 已提交
941
                paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
942 943 944 945 946

            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[i])
947 948 949 950 951 952
            print(res) # [array([10])]


    Examples 2:
          .. code-block:: python

L
LiYuRio 已提交
953
            import paddle
954 955 956
            import paddle.fluid as fluid
            import numpy as np

957
            paddle.enable_static()
958 959 960 961 962 963
            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
            loop_len = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
            one = fluid.layers.fill_constant(shape=[1], dtype='float32', value=1)
            data = fluid.data(name='data', shape=[1], dtype='float32')
            sums = fluid.layers.fill_constant(shape=[1], dtype='float32', value=0)  # Define the variable to be obtained ouside of While, which name should be different from the variable inside the While to be obtained

L
LiYuRio 已提交
964
            cond = paddle.less_than(x=i, y=loop_len)
965 966
            while_op = fluid.layers.While(cond=cond)
            with while_op.block():
H
HongyuJia 已提交
967
                sums_tensor = paddle.add(x=data, y=data)
968
                fluid.layers.assign(sums_tensor, sums)  # Update the value of sums_tensor defined in While to the sums which defined outside of While through layers.assign
969
                i = paddle.increment(x=i, value=1)
H
HongyuJia 已提交
970
                data = paddle.add(x=data, y=one)
L
LiYuRio 已提交
971
                paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
972 973 974 975 976 977

            feed_data = np.ones(1).astype('float32')
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())
            res = exe.run(fluid.default_main_program(), feed={'data': feed_data}, fetch_list=sums)
            print(res[0])  # [2.]    # Because the data in While does not update the value outside the While, the value of sums is [2.] after the loop
X
Xin Pan 已提交
978 979
    """

Y
Yang Yang(Tony) 已提交
980 981 982 983
    BEFORE_WHILE_BLOCK = 0
    IN_WHILE_BLOCK = 1
    AFTER_WHILE_BLOCK = 2

C
chengduo 已提交
984
    def __init__(self, cond, is_test=False, name=None):
985
        self.helper = LayerHelper("while", name=name)
Y
Yang Yang(Tony) 已提交
986
        self.status = While.BEFORE_WHILE_BLOCK
987
        check_variable_and_dtype(cond, 'cond', ['bool'], 'fluid.layers.While')
Y
Yang Yang(Tony) 已提交
988
        if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
989
            raise TypeError(
990 991 992 993
                "condition expected shape as [1], but given shape as {0}.".format(
                    list(cond.shape)
                )
            )
Y
Yang Yang(Tony) 已提交
994
        self.cond_var = cond
C
chengduo 已提交
995
        self.is_test = is_test
Y
Yang Yang(Tony) 已提交
996 997 998 999

    def block(self):
        return WhileGuard(self)

1000
    def _complete(self):
Y
Yang Yang(Tony) 已提交
1001 1002
        main_program = self.helper.main_program
        while_block = main_program.current_block()
1003
        parent_block = main_program.block(
1004 1005
            main_program.current_block().parent_idx
        )
Y
Yang Yang(Tony) 已提交
1006 1007 1008

        inner_outputs = {self.cond_var.name}
        x_name_list = set()
1009
        x_name_list, inner_outputs = get_inputs_outputs_in_block(
1010 1011
            while_block, x_name_list, inner_outputs, self.helper
        )
Y
Yang Yang(Tony) 已提交
1012 1013 1014

        out_vars = []
        for inner_out_name in inner_outputs:
X
Xin Pan 已提交
1015 1016 1017
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_vars.append(inner_var)
Y
Yang Yang(Tony) 已提交
1018

1019
        x_name_list |= set(map(lambda x: x.name, out_vars))
1020 1021 1022
        # NOTE(dev): cond_var has been contained in Input('Condition'), so
        # we remove it from Input('X')
        x_name_list -= {self.cond_var.name}
1023

Y
Yang Yang(Tony) 已提交
1024
        step_scope = parent_block.create_var(
1025 1026
            type=core.VarDesc.VarType.STEP_SCOPES
        )
Y
Yang Yang(Tony) 已提交
1027 1028 1029 1030

        parent_block.append_op(
            type='while',
            inputs={
1031 1032 1033 1034 1035
                'X': [
                    parent_block._var_recursive(x_name)
                    for x_name in x_name_list
                ],
                'Condition': [self.cond_var],
1036
            },
1037 1038 1039
            outputs={'Out': out_vars, 'StepScopes': [step_scope]},
            attrs={'sub_block': while_block, "is_test": self.is_test},
        )
Y
Yang Yang(Tony) 已提交
1040 1041


1042
support_ret_buildin_type = (bool, float, int)
1043 1044


1045
# (TODO: Mine) There exists dependency. It will be removed later.
1046
def assign_skip_lod_tensor_array(input, output):
1047
    """
1048
    Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
1049
    """
1050 1051

    def has_shape_diff(x_var, y_var):
1052 1053
        if len(x_var.shape) != len(y_var.shape):
            return True
1054
        for x_dim, y_dim in zip(x_var.shape, y_var.shape):
1055 1056
            if x_dim != y_dim and -1 not in [x_dim, y_dim]:
                return True
1057 1058
        return False

1059
    if not isinstance(input, (Variable, core.VarBase)):
1060
        if isinstance(output, Variable) and isinstance(
1061 1062
            input, support_ret_buildin_type
        ):
1063 1064 1065
            assign(input, output)
        else:
            output = input
1066 1067
        return

1068 1069
    if input.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        main_program = input.block.program
1070
        parent_block = main_program.block(
1071 1072
            main_program.current_block().parent_idx
        )
1073 1074 1075
        if parent_block and not parent_block._find_var_recursive(input.name):
            assign(input, output)
    else:
1076 1077 1078 1079 1080
        if (
            isinstance(output, Variable)
            and isinstance(input, Variable)
            and has_shape_diff(input, output)
        ):
1081
            warnings.warn(
1082 1083 1084 1085
                "In dy2static mode, we attemp to assign a variable with shape {} into a variable with shape{}, which is not always right.".format(
                    input.shape, output.shape
                )
            )
1086
        assign(input, output)
1087 1088


1089
# (TODO: Mine) There exists dependency (jit.dy2static.convert_operators). It will be removed later.
G
guofei 已提交
1090
def while_loop(cond, body, loop_vars, is_test=False, name=None):
G
guofei 已提交
1091
    """
1092 1093
    :api_attr: Static Graph

G
guofei 已提交
1094 1095
    while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.

1096 1097 1098 1099
    Notice:
        Local variables defined in ``body`` cannot be obtained through ``fetch_list`` of ``Executor`` , variables should
        be defined outside ``body`` and placed in ``loop_vars`` for looping, then these variables can be fetched by ``fetch_list`` .

G
guofei 已提交
1100
    Args:
1101
        cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
1102
            as many arguments as ``loop_vars`` .
1103 1104 1105
        body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
            (length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
        loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
G
guofei 已提交
1106
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
G
guofei 已提交
1107 1108
        name(str, optional): Normally there is no need for users to set this property. For more information, please
            refer to :ref:`api_guide_Name`. Default is None.
1109

G
guofei 已提交
1110
    Returns:
C
Chen Long 已提交
1111
        A list or tuple of Tensors or LoDTensorArrays which returned by ``body`` .
G
guofei 已提交
1112 1113 1114 1115

    Examples:
        .. code-block:: python

1116 1117 1118
            import paddle
            paddle.enable_static()

1119 1120
            def cond(i, ten):
                return i < ten
G
guofei 已提交
1121

1122 1123 1124
            def body(i, ten):
                i = i + 1
                return [i, ten]
G
guofei 已提交
1125

C
Chen Long 已提交
1126 1127 1128 1129 1130 1131
            main_program = paddle.static.default_main_program()
            startup_program = paddle.static.default_startup_program()
            with paddle.static.program_guard(main_program, startup_program):
                i = paddle.full(shape=[1], fill_value=0, dtype='int64')     # loop counter
                ten = paddle.full(shape=[1], fill_value=10, dtype='int64')  # loop length
                i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
1132

C
Chen Long 已提交
1133
                exe = paddle.static.Executor(paddle.CPUPlace())
1134
                res = exe.run(main_program, feed={}, fetch_list=[i])
G
guofei 已提交
1135 1136 1137 1138 1139 1140 1141 1142
                print(res) # [array([10])]
    """
    helper = LayerHelper('while_loop', **locals())

    if not callable(cond):
        raise TypeError("cond in while_loop should be callable")
    if not callable(body):
        raise TypeError("body in while_loop should be callable")
1143
    check_type(loop_vars, 'loop_vars', (list, tuple), 'fluid.layers.while_loop')
G
guofei 已提交
1144 1145 1146 1147
    if len(loop_vars) == 0:
        raise ValueError("loop_vars in while_loop should not be empty")

    pre_cond = cond(*loop_vars)
1148 1149 1150
    check_variable_and_dtype(
        pre_cond, 'var of cond returned', ['bool'], 'fluid.layers.while_loop'
    )
G
guofei 已提交
1151 1152
    if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
        raise TypeError(
1153
            "the shape of the variable returned by cond should be [1],"
1154 1155
            "but given shape as {0}.".format(list(pre_cond.shape))
        )
G
guofei 已提交
1156

J
Jiabin Yang 已提交
1157
    if _non_static_mode():
1158
        now_cond = pre_cond.numpy()[0]
1159
        while now_cond:
1160 1161 1162 1163 1164 1165
            output_vars = body(*loop_vars)
            if not isinstance(output_vars, (list, tuple)):
                output_vars = [output_vars]
            if len(output_vars) != len(loop_vars):
                raise ValueError(
                    "body in while_loop should return the same arity "
1166 1167
                    "(length and structure) and types as loop_vars"
                )
1168
            now_cond = cond(*output_vars).numpy()[0]
1169
            map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
1170 1171
        return loop_vars

G
guofei 已提交
1172
    while_loop_block = While(pre_cond, is_test, name)
1173
    has_mutable_vars_in_loop = hold_mutable_vars(loop_vars)
G
guofei 已提交
1174
    with while_loop_block.block():
1175 1176 1177 1178 1179 1180 1181 1182 1183
        # If a variable with mutable type is included in loop_vars, like `dict/list`,
        # modifying it in the body function will cause origin variable to be modified
        # synchronously. This will raise an assignment error out of while block.
        # Here we make a copy of the mutable vars to avoid this problem.
        if has_mutable_vars_in_loop:
            new_loop_vars = copy_mutable_vars(loop_vars)
            output_vars = body(*new_loop_vars)
        else:
            output_vars = body(*loop_vars)
1184 1185
        if not isinstance(output_vars, (list, tuple)):
            output_vars = [output_vars]
1186
        try:
1187
            loop_vars = _deal_with_undefined_var(output_vars, loop_vars)
1188 1189
            assert_same_structure(output_vars, loop_vars, check_types=False)
        except ValueError as e:
1190 1191
            raise ValueError(
                "body in while_loop should return the same arity "
1192 1193
                "(length and structure) as loop_vars: {0}".format(e)
            )
1194
        now_cond = cond(*output_vars)
1195
        map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
G
guofei 已提交
1196 1197 1198 1199
        assign(now_cond, pre_cond)
    return loop_vars


1200
# (TODO: Mine) There exists dependency. It will be removed later.
1201
def _deal_with_undefined_var(output_vars, loop_vars):
1202 1203 1204 1205 1206 1207 1208
    """Deal with undefined var cases, We create undefined variable based on the results of body().
    In Dy2Static, we use undefined var to represent the var created in control flow. This function
    expand the loop_vars and replace original loop_vars.
    1. UndefinedVar = Variable      # create a variable
    2. UndefinedVar = None          # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
    3. UndefinedVar = List(int)     # create a list of variable
    4. UndefinedVar = value         # create a variable
1209
    """
1210
    from paddle.jit.dy2static.utils import (
1211 1212 1213
        UndefinedVar,
        create_undefined_variable,
    )
1214 1215

    def create_var_like(o_var):
1216 1217 1218 1219
        if (
            isinstance(o_var, (Variable,) + support_ret_buildin_type)
            or o_var is None
        ):
1220
            return create_undefined_variable()
1221
        if is_sequence(o_var):
1222
            """
1223 1224 1225
            Create a complex container class inside the body of while, including Python list and python Dict
            """
            return map_structure(lambda x: create_undefined_variable(), o_var)
1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238

    if len(output_vars) != len(loop_vars):
        raise ValueError("The length of loop_vars should be the same.")

    results = []
    for o_var, l_var in zip(output_vars, loop_vars):
        if isinstance(l_var, UndefinedVar) or l_var is None:
            results.append(create_var_like(o_var))
        else:
            results.append(l_var)
    return results


Y
Yu Yang 已提交
1239
class ConditionalBlockGuard(BlockGuard):
F
fengjiayi 已提交
1240
    """
1241 1242 1243
    ConditionalBlockGuard is derived from BlockGuard. It is dedicated for
    holding a ConditionalBlock, and helping users entering and exiting the
    ConditionalBlock via Python's 'with' keyword. However, ConditionalBlockGuard
F
fengjiayi 已提交
1244 1245 1246
    is generally an internal component of IfElse, users should not use it directly.
    """

Y
Yu Yang 已提交
1247
    def __init__(self, block):
1248
        check_type(block, "block", ConditionalBlock, "ConditionalBlockGuard")
1249
        super().__init__(block.helper.main_program)
Y
Yu Yang 已提交
1250 1251 1252
        self.block = block

    def __enter__(self):
1253
        return super().__enter__()
Y
Yu Yang 已提交
1254 1255 1256

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.block.complete()
1257
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yu Yang 已提交
1258 1259


1260
class ConditionalBlock:
Y
Yan Chunwei 已提交
1261 1262 1263 1264 1265 1266 1267 1268
    '''
    **ConditionalBlock**

    ConditionalBlock is an operator that bind a block to a specific condition,
    if the condition matches, the corresponding block will be executed.

    Args:
        inputs (Variable): bool conditions.
T
tianshuo78520a 已提交
1269
        is_scalar_condition (bool): whether the branch is controlled by a scalar.
Y
Yan Chunwei 已提交
1270 1271 1272 1273 1274
        name(str): name of this ConditionalBlock.

    Examples:
        .. code-block:: python

L
LiYuRio 已提交
1275
             import paddle
1276
             import paddle.fluid as fluid
L
LiYuRio 已提交
1277
             cond = paddle.less_than(x=label, y=limit)
Y
Yan Chunwei 已提交
1278 1279 1280 1281 1282 1283 1284 1285 1286 1287
             true_image, false_image = layers.split_lod_tensor(
                 input=image, mask=cond)
             true_cond = layers.ConditionalBlock([true_image])

             with true_cond.block():
                 ...
             with false_cond.block():
                 ...
    '''

1288
    def __init__(self, inputs, is_scalar_condition=False, name=None):
Y
Yu Yang 已提交
1289
        for each_input in inputs:
1290
            check_type(each_input, "input", Variable, "ConditionalBlock")
Y
Yu Yang 已提交
1291
        self.inputs = inputs
1292
        self.is_scalar_condition = is_scalar_condition
1293
        self.helper = LayerHelper('conditional_block', name=name)
Y
Yu Yang 已提交
1294 1295 1296 1297 1298 1299 1300 1301 1302 1303

    def block(self):
        return ConditionalBlockGuard(self)

    def complete(self):
        inside_block = self.helper.main_program.current_block()
        parent_block = self.helper.main_program.block(inside_block.parent_idx)

        intermediate = set()
        params = set()
1304 1305 1306
        params, intermediate = get_inputs_outputs_in_block(
            inside_block, params, intermediate, helper=self.helper
        )
Y
Yu Yang 已提交
1307

1308 1309 1310
        # Todo(liym27) Here assume that all params are in recursive parent block
        # but when minimize() called in control flow, some params may be in
        # conditional grad block
Y
Yu Yang 已提交
1311
        param_list = [
W
Wu Yi 已提交
1312
            parent_block._var_recursive(each_name) for each_name in params
Y
Yu Yang 已提交
1313 1314
        ]

X
Xin Pan 已提交
1315 1316 1317 1318 1319
        out_list = []
        for inner_out_name in intermediate:
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_list.append(inner_var)
Y
Yu Yang 已提交
1320 1321

        step_scope = parent_block.create_var(
1322 1323
            type=core.VarDesc.VarType.STEP_SCOPES
        )
1324
        conditional_block_op = parent_block.append_op(
Y
Yu Yang 已提交
1325 1326
            type='conditional_block',
            inputs={
1327 1328
                'Cond': self.inputs,
                'Input': param_list,
Y
Yu Yang 已提交
1329
            },
1330
            outputs={'Out': out_list, 'Scope': [step_scope]},
1331 1332
            attrs={
                'sub_block': inside_block,
1333 1334 1335
                'is_scalar_condition': self.is_scalar_condition,
            },
        )
1336

1337
        if self.need_append_conditional_block_grad(inside_block):
1338 1339 1340
            self.append_conditional_block_grad(
                parent_block, inside_block, conditional_block_op
            )
1341 1342 1343

    def need_append_conditional_block_grad(self, inside_block):
        grad_sub_block_idx = inside_block.backward_block_idx
1344
        inside_block_idx = inside_block.idx
1345

1346 1347
        # if inside_block have grad_block and grad_block is not itself,
        # we will append conditional block grad.
1348 1349 1350
        return (
            grad_sub_block_idx != -1 and grad_sub_block_idx != inside_block_idx
        )
1351

1352 1353 1354
    def append_conditional_block_grad(
        self, parent_block, inside_block, conditional_block_op
    ):
1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389
        '''
        Append op `conditional_block_grad` manually.
        When `optimizer.minimize/append_backward` is called in Paddle control flow,
        grad ops will be appended before appending op `conditional_block` so that
        op `conditional_block_grad` can't be appended when calling
        `optimizer.minimize/append_backward`. After appending op `conditional_block`,
        `conditional_block_grad` is appended manually.

        Args:
            parent_block (Block): The block that `conditional_block_op` blongs to.
            inside_block (Block): The sub block of `conditional_block_op`.
            conditional_block_op (Operator): The forward op conditional_block.
        '''

        grad_sub_block_idx = inside_block.backward_block_idx
        grad_sub_block = self.helper.main_program.block(grad_sub_block_idx)

        intermediate = set()
        params = set()

        for each_op in grad_sub_block.ops:
            assert isinstance(each_op, Operator)
            for iname in each_op.input_names:
                for in_var_name in each_op.input(iname):
                    if in_var_name not in intermediate:
                        params.add(in_var_name)

            for oname in each_op.output_names:
                for out_var_name in each_op.output(oname):
                    intermediate.add(out_var_name)

        param_list = []
        for inner_input_name in params:
            inner_var = parent_block._find_var_recursive(inner_input_name)
            if inner_var:
1390
                param_list.append(inner_var.name)
1391 1392

        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
1393 1394
            conditional_block_op.desc, set(), [grad_sub_block.desc]
        )
1395 1396 1397 1398 1399 1400 1401 1402 1403

        # append op_desc in grad_op_descs to target_block
        op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
        backward = core.op_proto_and_checker_maker.OpRole.Backward
        new_op_desc = parent_block.desc.append_op()
        new_op_desc.copy_from(grad_op_desc[0])
        new_op_desc._set_attr(op_role_attr_name, backward)
        # set input and output manually
        new_op_desc.set_input('Input', param_list)
1404 1405 1406
        new_op_desc.set_output(
            'Input@GRAD', [param + "@GRAD" for param in param_list]
        )
1407 1408 1409

        new_vars = set()
        for grad_var_name in new_op_desc.output_arg_names():
1410 1411 1412 1413
            if (
                grad_sub_block.desc.has_var_recursive(grad_var_name.encode())
                or grad_var_name == core.empty_var_name()
            ):
1414
                continue
1415
            grad_sub_block.desc.var(grad_var_name.encode())
1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429
            new_vars.add(grad_var_name)
            if grad_var_name not in op_grad_to_var:
                continue

        # infer_shape and infer_type
        new_op_desc.infer_var_type(grad_sub_block.desc)
        new_op_desc.infer_shape(grad_sub_block.desc)

        for arg in new_op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_shape_(arg, grad_sub_block)

        self.helper.main_program._sync_with_cpp()

1430

1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448
def _to_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return [x]
    return to_sequence(x)


def _is_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return False
    return is_sequence(x)


1449
def expand_undefined_var(nest1, nest2, names):
1450 1451 1452 1453
    """TODO: make this function recursively.
    nest1: Var1, (UndefinedVar, [1,2,3])
    nest2: Var2, ([1,2,3,4], UndefinedVar)
    In this case, we should not expand recursively.
1454
    """
1455
    from paddle.jit.dy2static.utils import UndefinedVar
1456
    from paddle.jit.dy2static.return_transformer import (
1457 1458
        RETURN_VALUE_PREFIX,
    )
1459 1460

    def pack_undefined_var_as(seq):
1461 1462 1463
        return pack_sequence_as(
            seq, [UndefinedVar("padding") for i in flatten(seq)]
        )
1464

1465
    def map_fn(n1, n2, name, order):
1466 1467 1468
        if not name.startswith(RETURN_VALUE_PREFIX) and (
            isinstance(n1, UndefinedVar) or n1 is None
        ):
1469 1470 1471 1472 1473 1474
            if n1 is None and n2 is not None:
                if order == 0:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
1475 1476 1477
                            name, type(n1), n1, type(n2), n2
                        )
                    )
1478 1479 1480 1481 1482
                else:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
1483 1484 1485
                            name, type(n2), n2, type(n1), n1
                        )
                    )
1486 1487 1488 1489
            return pack_undefined_var_as(n2)
        return n1

    nest1_out = list(
1490 1491
        map(
            map_fn,
1492 1493 1494 1495
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(names),
            [0 for i in _to_sequence_except_dict(names)],
1496 1497
        )
    )
1498
    nest2_out = list(
1499 1500
        map(
            map_fn,
1501 1502 1503 1504
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(names),
            [1 for i in _to_sequence_except_dict(names)],
1505 1506
        )
    )
1507
    if not _is_sequence_except_dict(nest1):
1508
        nest1_out = nest1_out[0]
1509
    if not _is_sequence_except_dict(nest2):
1510
        nest2_out = nest2_out[0]
1511 1512 1513
    return nest1_out, nest2_out


1514
class Switch:
Q
qiaolongfei 已提交
1515
    """
1516
    :api_attr: Static Graph
Q
qiaolongfei 已提交
1517

1518 1519 1520 1521 1522
    This class is used to implement Switch branch control function.
    Switch branch contains several case branches and one default branch.
    Switch control flow checks whether the case branch conditions are satisfied in turn,
    and only executes the statement after the first case branch that satisfies the conditions.
    If there is no case branch that satisfies the condition,
1523 1524
    only the statement following the default branch is executed.

1525 1526 1527 1528
    Note:
        A new OP :ref:`api_fluid_layers_case` is highly recommended instead of ``Switch`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_fluid_layers_case` is easier to use and is called with less code but does the same thing as ``Switch`` .

1529
    Member Functions:
1530
        case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.
1531

1532 1533 1534 1535 1536
        default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed.

    Case and default functions can only be used inside the scope of Switch, as shown below:

    .. code-block:: python
1537

1538 1539 1540 1541 1542 1543 1544 1545 1546
        '''
        with fluid.layers.Switch() as switch:
            with switch.case(cond1):
                i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=1)
            with switch.case(cond2):
                i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=2)
            with switch.default():
                i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
        '''
Q
qiaolongfei 已提交
1547

1548 1549
    Args:
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
Q
qiaolongfei 已提交
1550 1551 1552

    Examples:
        .. code-block:: python
1553

1554
            import paddle
1555
            import paddle.fluid as fluid
Q
qiaolongfei 已提交
1556

1557
            lr = paddle.static.create_global_var(
Q
qiaolongfei 已提交
1558 1559 1560 1561 1562
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate")
1563
            zero_var = fluid.layers.fill_constant(
1564
                shape=[1], dtype='float32', value=0.0)
1565
            one_var = fluid.layers.fill_constant(
Q
qiaolongfei 已提交
1566
                shape=[1], dtype='float32', value=1.0)
1567
            two_var = fluid.layers.fill_constant(
1568
                shape=[1], dtype='float32', value=2.0)
1569

1570
            global_step = fluid.layers.autoincreased_step_counter(counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)
Q
qiaolongfei 已提交
1571 1572

            with fluid.layers.control_flow.Switch() as switch:
Q
qiaolongfei 已提交
1573
                with switch.case(global_step == zero_var):
1574
                    fluid.layers.assign(input=one_var, output=lr)
Q
qiaolongfei 已提交
1575
                with switch.default():
1576
                    fluid.layers.assign(input=two_var, output=lr)
Q
qiaolongfei 已提交
1577

1578 1579 1580 1581 1582
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[lr])
            print(res) # [array([1.], dtype=float32)]
Q
qiaolongfei 已提交
1583 1584
    """

1585 1586 1587 1588 1589 1590 1591 1592 1593
    def __init__(self, name=None):
        self.helper = LayerHelper('switch', name=name)
        self.inside_scope = False
        self.pre_not_conditions = []

    def case(self, condition):
        if not self.inside_scope:
            raise ValueError("case should be called inside with")

1594
        check_variable_and_dtype(
1595 1596 1597 1598 1599
            condition,
            'condition',
            ['bool'],
            'the member function case of fluid.layers.Switch',
        )
1600

1601 1602
        if len(self.pre_not_conditions) == 0:
            cond_block = ConditionalBlock([condition], is_scalar_condition=True)
2
201716010711 已提交
1603
            not_cond = paddle.logical_not(x=condition)
1604 1605 1606 1607
            self.pre_not_conditions.append(not_cond)
        else:
            pre_cond_num = len(self.pre_not_conditions)
            pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
1608
            new_not_cond = paddle.logical_and(
2
201716010711 已提交
1609
                x=pre_not_cond, y=paddle.logical_not(x=condition)
1610
            )
1611 1612
            self.pre_not_conditions.append(new_not_cond)
            cond_block = ConditionalBlock(
1613
                [paddle.logical_and(x=pre_not_cond, y=condition)],
1614 1615
                is_scalar_condition=True,
            )
1616 1617 1618 1619 1620 1621 1622 1623 1624

        return ConditionalBlockGuard(cond_block)

    def default(self):
        pre_cond_num = len(self.pre_not_conditions)
        if pre_cond_num == 0:
            raise ValueError("there should be at least one condition")
        cond_block = ConditionalBlock(
            [self.pre_not_conditions[pre_cond_num - 1]],
1625 1626
            is_scalar_condition=True,
        )
1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642
        return ConditionalBlockGuard(cond_block)

    def __enter__(self):
        """
        set flag that now is inside switch.block {}
        :return:
        """
        self.inside_scope = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.inside_scope = False
        if exc_type is not None:
            return False  # re-raise exception

        return True