control_flow.py 60.7 KB
Newer Older
1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

S
rename  
sneaxiy 已提交
15
from ..wrapped_decorator import signature_safe_contextmanager
D
dzhwinter 已提交
16

17
from .layer_function_generator import templatedoc
18
from .. import core
19 20 21 22 23 24 25
from ..framework import (
    Program,
    Variable,
    Operator,
    static_only,
    in_dygraph_mode,
)
26
from ..layer_helper import LayerHelper, unique_name
27
from ...utils import (
28 29 30 31 32 33 34 35 36
    assert_same_structure,
    map_structure,
    hold_mutable_vars,
    copy_mutable_vars,
    is_sequence,
    pack_sequence_as,
    flatten,
    to_sequence,
)
Y
yuyang18 已提交
37
import numpy
38
import warnings
L
liym27 已提交
39
from functools import reduce, partial
40 41 42 43 44 45
from ..data_feeder import (
    convert_dtype,
    check_variable_and_dtype,
    check_type,
    check_dtype,
)
46
from ..backward import _infer_var_data_type_shape_
2
201716010711 已提交
47
import paddle
48
from paddle import _C_ops, _legacy_C_ops
D
dzhwinter 已提交
49

Q
QI JUN 已提交
50
__all__ = [
51 52 53 54
    'Switch',
    'StaticRNN',
    'Print',
    'while_loop',
D
dzhwinter 已提交
55 56
]

Y
Yu Yang 已提交
57

58 59
def select_output(input, outputs, mask):
    """
60
    **select_output**
61 62 63 64 65 66 67 68 69 70 71 72 73 74
    This API takes in one input and multiple outputs and an integer mask. It
    selects the output specified by the mask and copy the input to selected
    output. It is useful in control flow.

    Args:
        input(Variable): The input variable
        outputs(tuple|list): The output variables
        mask(Variable): A tensor containing 1 integer number selecting which
            output to be copied with input

    Returns:
        Variable: The outputs variables
    """
    helper = LayerHelper('select_output', **locals())
75 76 77 78
    check_type(input, 'input', (Variable), 'select_output')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_output')
    check_type(outputs, 'outputs', (list, tuple), 'select_output')

79 80 81 82 83
    helper.append_op(
        type='select_output',
        inputs={'X': input, 'Mask': mask},
        outputs={'Out': outputs},
    )
84 85 86
    return outputs


87 88 89 90 91 92 93
def _select_input_infer_shape(first_shape, second_shape):
    """
    This function infer the output shape by following algorithm:
    1. if the dims is different, raise a error.
    2. compare axis one by one:
        if a == b: we set axis to a
        if a != b: we set axis to -1
94
    for compatibility, non declarative mode, we just return second_shape.
95 96 97 98 99 100 101
    """
    if len(first_shape) != len(second_shape):
        warnings.warn(
            f"the input shapes of select_input should have the same rank, but get {first_shape}, {second_shape}"
        )
        return second_shape
    out_shape = list(
102 103
        map(lambda a, b: a if a == b else -1, first_shape, second_shape)
    )
104 105 106
    return out_shape


107 108 109
def select_input(inputs, mask):
    """
    **select_input**
110

111 112 113 114 115 116 117 118 119 120 121 122
    This API takes in multiple inputs and uses an integer mask to select one
    input to output. It is useful in control flow.

    Args:
        inputs(tuple|list): The input variables
        mask(Variable): A tensor containing 1 integer number selecting which
            input to output

    Returns:
        Variable: The selected input variable
    """
    helper = LayerHelper('select_input', **locals())
123 124 125
    check_type(inputs, 'inputs', (list, tuple), 'select_input')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input')

126
    # Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly
127
    # assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"
128

129 130 131
    output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape)
    output_dtype = inputs[1].dtype
    output_type = inputs[1].type
132

133 134 135 136 137 138 139 140
    out = helper.create_variable(
        dtype=output_dtype, shape=output_shape, type=output_type
    )
    helper.append_op(
        type='select_input',
        inputs={'X': inputs, 'Mask': mask},
        outputs={'Out': out},
    )
141 142 143
    return out


144
@static_only
145 146 147 148 149 150 151 152 153 154 155 156
def Print(
    input,
    first_n=-1,
    message=None,
    summarize=20,
    print_tensor_name=True,
    print_tensor_type=True,
    print_tensor_shape=True,
    print_tensor_layout=True,
    print_tensor_lod=True,
    print_phase='both',
):
Y
Yan Chunwei 已提交
157
    '''
158 159
    :api_attr: Static Graph

Y
Yan Chunwei 已提交
160 161 162 163 164 165 166 167 168
    **Print operator**

    This creates a print op that will print when a tensor is accessed.

    Wraps the tensor passed in so that whenever that a tensor is accessed,
    the message `message` is printed, along with the current value of the
    tensor `t`.

    Args:
169 170 171 172 173
        input (Tensor): A Tensor to print.
        first_n (int, optional): Only log `first_n` number of times. Default: -1.
        message (str, optional): A string message to print as a prefix. Default: None.
        summarize (int, optional): Number of elements in the tensor to be print. If
                it's value is -1, then all elements in the tensor will be print.
174 175 176
        print_tensor_name (bool, optional): Print the tensor name. Default: True.
        print_tensor_type (bool, optional): Print the tensor type. Defaultt: True.
        print_tensor_shape (bool, optional): Print the tensor shape. Default: True.
177
        print_tensor_layout (bool, optional): Print the tensor layout. Default: True.
178
        print_tensor_lod (bool, optional): Print the tensor lod. Default: True.
179
        print_phase (str, optional): Which phase to displace, including 'forward',
180
                'backward' and 'both'. Default: 'both'. If set to 'backward', will
181 182
                only print the gradients of input tensor; If set to 'both', will
                both print the input tensor itself and the gradients of input tensor.
Y
Yan Chunwei 已提交
183 184

    Returns:
185
        Tensor: Output tensor.
Y
Yan Chunwei 已提交
186

187
    NOTES:
188 189
        The input and output are two different Tensor, and in the
        following process, you should use the output Tensor but not the input,
190
        otherwise, the print layer doesn't have backward.
Y
Yan Chunwei 已提交
191

Y
Yan Chunwei 已提交
192 193
    Examples:
        .. code-block:: python
194

195 196 197
           import paddle

           paddle.enable_static()
198

199 200 201 202 203 204 205 206 207 208 209 210 211 212
           x = paddle.full(shape=[2, 3], fill_value=3, dtype='int64')
           out = paddle.static.Print(x, message="The content of input layer:")

           main_program = paddle.static.default_main_program()
           exe = paddle.static.Executor(place=paddle.CPUPlace())
           res = exe.run(main_program, fetch_list=[out])
           # Variable: fill_constant_1.tmp_0
           #   - message: The content of input layer:
           #   - lod: {}
           #   - place: CPUPlace
           #   - shape: [2, 3]
           #   - layout: NCHW
           #   - dtype: long
           #   - data: [3 3 3 3 3 3]
Y
Yan Chunwei 已提交
213
    '''
214 215 216 217 218 219
    check_variable_and_dtype(
        input,
        'input',
        ['float32', 'float64', 'int32', 'int64', 'bool'],
        'fluid.layers.Print',
    )
220

221 222
    helper = LayerHelper('print' + "_" + input.name, **locals())
    output = helper.create_variable_for_type_inference(input.dtype)
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    helper.append_op(
        type='print',
        inputs={'In': input},
        outputs={'Out': output},
        attrs={
            'first_n': first_n,
            'summarize': summarize,
            'message': message or "",
            'print_tensor_name': print_tensor_name,
            'print_tensor_type': print_tensor_type,
            'print_tensor_shape': print_tensor_shape,
            'print_tensor_layout': print_tensor_layout,
            'print_tensor_lod': print_tensor_lod,
            'print_phase': print_phase.upper(),
        },
    )
239
    return output
Y
Yan Chunwei 已提交
240 241


242
# (TODO: Mine) There exists dependency. It will be removed later.
243
class BlockGuard:
Y
Yu Yang 已提交
244
    """
245 246 247 248
    BlockGuard class.

    BlockGuard class is used to create a sub-block in a program by
    using the Python `with` keyword.
Y
Yu Yang 已提交
249 250
    """

251 252
    def __init__(self, main_program):
        if not isinstance(main_program, Program):
Y
Yu Yang 已提交
253
            raise TypeError("BlockGuard takes a program")
254
        self.main_program = main_program
Y
Yu Yang 已提交
255 256

    def __enter__(self):
W
Wu Yi 已提交
257
        self.main_program._create_block()
Y
Yu Yang 已提交
258 259

    def __exit__(self, exc_type, exc_val, exc_tb):
W
Wu Yi 已提交
260
        self.main_program._rollback()
Y
Yu Yang 已提交
261 262 263 264 265
        if exc_type is not None:
            return False  # re-raise exception
        return True


266
# (TODO: Mine) There exists dependency. It will be removed later.
Y
Yang Yang 已提交
267 268 269 270 271
class BlockGuardWithCompletion(BlockGuard):
    """
    BlockGuardWithCompletion class.

    BlockGuardWithCompletion class is used to create an op with a block in a program.
272 273
    """

Y
Yu Yang 已提交
274
    def __init__(self, rnn):
X
Xin Pan 已提交
275
        if not isinstance(rnn, StaticRNN):
X
Xin Pan 已提交
276
            raise TypeError("BlockGuardWithCompletion takes a StaticRNN")
277
        super().__init__(rnn.helper.main_program)
Y
Yu Yang 已提交
278 279 280 281
        self.rnn = rnn

    def __enter__(self):
        self.rnn.status = StaticRNN.IN_RNN_BLOCK
282
        return super().__enter__()
Y
Yu Yang 已提交
283 284

    def __exit__(self, exc_type, exc_val, exc_tb):
Y
Yu Yang 已提交
285 286
        if exc_type is not None:
            return False
Y
Yu Yang 已提交
287
        self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
288
        self.rnn._complete_op()
289
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yu Yang 已提交
290 291


292
class StaticRNNMemoryLink:
Y
Yu Yang 已提交
293
    """
294 295 296 297
    StaticRNNMemoryLink class.

    StaticRNNMemoryLink class is used to create a link between two
    memory cells of a StaticRNN.
Y
yuyang18 已提交
298 299 300 301 302 303 304 305 306


    NOTE: This is a internal data structure of a very low-level API.
    Please use StaticRNN instead.

    Args:
        init(Variable): the initial variable for Memory.
        pre_mem(Variable): the memory variable in previous time step.
        mem(Variable): the memory variable in current time step.
Y
Yu Yang 已提交
307 308 309 310 311 312 313 314
    """

    def __init__(self, init, pre_mem, mem=None):
        self.init = init
        self.pre_mem = pre_mem
        self.mem = mem


315
class StaticRNN:
316
    """
317 318
    :api_attr: Static Graph

319 320
    StaticRNN class.

321 322 323 324 325 326 327
    The StaticRNN can process a batch of sequence data. The first dimension of inputs
    represents sequence length, the length of each input sequence must be equal.
    StaticRNN will unfold sequence into time steps, user needs to define how to process
    each time step during the :code:`with` step.

    Args:
        name (str, optional): Please refer to :ref:`api_guide_Name`, Default None.
C
chengduo 已提交
328 329

    Examples:
330 331
        .. code-block:: python

332
            import paddle
333 334 335 336
            import paddle.fluid as fluid
            import paddle.fluid.layers as layers

            vocab_size, hidden_size=10000, 200
337
            paddle.enable_static()
338
            x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
339
            # create word sequence
340 341 342 343 344
            x_emb = layers.embedding(
                input=x,
                size=[vocab_size, hidden_size],
                dtype='float32',
                is_sparse=False)
345
            # transform batch size to dim 1
346
            x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
347 348 349

            rnn = fluid.layers.StaticRNN()
            with rnn.step():
350
                # mark created x_emb as input, each step process a word
351
                word = rnn.step_input(x_emb)
352
                # create prev memory parameter, batch size comes from word
353
                prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
C
Charles-hit 已提交
354
                hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
355 356
                # use hidden to update prev
                rnn.update_memory(prev, hidden)
357
                # mark hidden as output
358
                rnn.step_output(hidden)
359
            # get StaticrNN final output
360
            result = rnn()
C
chengduo 已提交
361

362
    """
363

Y
Yu Yang 已提交
364 365 366 367
    BEFORE_RNN_BLOCK = 0
    IN_RNN_BLOCK = 1
    AFTER_RNN_BLOCK = 2

368
    def __init__(self, name=None):
369
        check_type(name, "name", (str, type(None)), "fluid.layers.StaticRNN")
370
        self.helper = LayerHelper("static_rnn", name=name)
Y
Yu Yang 已提交
371 372 373 374 375 376 377 378
        self.memories = {}  # memory map, from pre_mem.name --> MemoryLink
        self.inputs = []  # input variable list in current block
        self.outputs = []  # output variable list in parent block
        self.status = StaticRNN.BEFORE_RNN_BLOCK  # status flag.
        # sequence length, since it is a static RNN, sequence length are fixed.
        self.seq_len = None

    def step(self):
C
chengduo 已提交
379
        """
380 381
        Define operators in each step. step is used in :code:`with` block, OP in :code:`with` block
        will be executed sequence_len times (sequence_len is the length of input)
C
chengduo 已提交
382
        """
Y
Yang Yang 已提交
383
        return BlockGuardWithCompletion(self)
Y
Yu Yang 已提交
384 385 386 387 388

    def _assert_in_rnn_block_(self, method):
        if self.status != StaticRNN.IN_RNN_BLOCK:
            raise ValueError("You must invoke {0} in rnn block".format(method))

389 390 391 392 393 394 395 396 397
    def memory(
        self,
        init=None,
        shape=None,
        batch_ref=None,
        init_value=0.0,
        init_batch_dim_idx=0,
        ref_batch_dim_idx=1,
    ):
398
        """
C
chengduo 已提交
399 400 401
        Create a memory variable for static rnn.
        If the :code:`init` is not None, :code:`memory` will be initialized by
        this Variable. If the :code:`init` is None, :code:`shape` and :code:`batch_ref`
402 403
        must be set, and this function will create a new variable with shape and batch_ref
        to initialize :code:`init` Variable.
C
chengduo 已提交
404

405
        Args:
406
            init(Variable, optional): Tensor used to init memory. If it is not set,
C
chengduo 已提交
407 408
                :code:`shape` and :code:`batch_ref` must be provided.
                Default: None.
409 410 411 412 413 414 415
            shape(list|tuple): When :code:`init` is None use this arg to initialize memory shape.
            NOTE the shape does not contain batch_size. Default: None.
            batch_ref(Variable, optional): When :code:`init` is None, memory's batch size will
            be set as batch_ref's ref_batch_dim_idx value. Default: None.
            init_value(float, optional): When :code:`init` is None, used to init memory's value. Default: 0.0.
            init_batch_dim_idx(int, optional): the batch_size axis of the :code:`init` Variable. Default: 0.
            ref_batch_dim_idx(int, optional): the batch_size axis of the :code:`batch_ref` Variable. Default: 1.
C
chengduo 已提交
416 417

        Returns:
418 419 420 421 422
            Variable: The memory variable.

        Examples 1:
            .. code-block:: python

423
                import paddle
424 425 426 427
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
428
                paddle.enable_static()
429
                x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
430 431 432 433 434 435 436
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
437
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
438 439 440 441 442 443 444

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
C
Charles-hit 已提交
445
                        hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
446 447
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
448 449 450


        Examples 2:
451 452
            .. code-block:: python

453
                import paddle
454 455 456
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers
                vocab_size, hidden_size=10000, 200
457
                paddle.enable_static()
458
                x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
459 460 461 462 463 464 465
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
466
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
G
GGBond8488 已提交
467
                boot_memory = paddle.static.data(name='boot', shape=[-1, hidden_size], dtype='float32', lod_level=1)
468 469 470 471 472 473
                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # init memory
                        prev = rnn.memory(init=boot_memory)
C
Charles-hit 已提交
474
                        hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
475 476
                        # update hidden with prev
                        rnn.update_memory(prev, hidden)
477

478
        """
Y
Yu Yang 已提交
479
        self._assert_in_rnn_block_('memory')
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
        check_type(
            init,
            "init",
            (Variable, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
        check_type(
            shape,
            "shape",
            (list, tuple, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
        check_type(
            batch_ref,
            "batch_ref",
            (Variable, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
Y
Yu Yang 已提交
498
        if init is None:
499
            if shape is None or batch_ref is None:
Y
Yu Yang 已提交
500
                raise ValueError(
501 502
                    "if init is None, memory at least need shape and batch_ref"
                )
503
            parent_block = self._parent_block()
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
            var_name = unique_name.generate_with_ignorable_key(
                "@".join([self.helper.name, "memory_boot"])
            )
            boot_var = parent_block.create_var(
                name=var_name,
                shape=shape,
                dtype=batch_ref.dtype,
                persistable=False,
            )

            parent_block.append_op(
                type="fill_constant_batch_size_like",
                inputs={'Input': [batch_ref]},
                outputs={'Out': [boot_var]},
                attrs={
                    'value': init_value,
                    'shape': boot_var.shape,
                    'dtype': boot_var.dtype,
                    'input_dim_idx': ref_batch_dim_idx,
                    'output_dim_idx': init_batch_dim_idx,
                },
            )
Y
Yu Yang 已提交
526 527 528 529

            return self.memory(init=boot_var)
        else:
            pre_mem = self.helper.create_variable(
530 531 532
                name=unique_name.generate_with_ignorable_key(
                    "@".join([self.helper.name, "mem"])
                ),
F
fengjiayi 已提交
533
                dtype=init.dtype,
534 535 536 537 538
                shape=init.shape,
            )
            self.memories[pre_mem.name] = StaticRNNMemoryLink(
                init=init, pre_mem=pre_mem
            )
Y
Yu Yang 已提交
539 540 541
            return pre_mem

    def step_input(self, x):
C
chengduo 已提交
542 543 544 545 546 547 548 549
        """
        Mark a sequence as a StaticRNN input.

        Args:
            x(Variable): The input sequence, the shape of x
                should be [seq_len, ...].

        Returns:
550 551 552 553 554
            Variable: The current time step data in the input sequence.

        Examples:
            .. code-block:: python

555
                import paddle
556 557 558 559
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
560
                paddle.enable_static()
561
                x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
562 563 564 565 566 567 568
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
569
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
570 571 572 573 574 575 576

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
C
Charles-hit 已提交
577
                        hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
578 579
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
580

C
chengduo 已提交
581
        """
Y
Yu Yang 已提交
582
        self._assert_in_rnn_block_('step_input')
583
        check_type(x, "x", Variable, "fluid.layers.StaticRNN.step_input")
Y
Yu Yang 已提交
584
        if self.seq_len is None:
Y
Yu Yang 已提交
585
            self.seq_len = x.shape[0]
586
        elif x.shape[0] != -1 and self.seq_len != x.shape[0]:
Y
Yu Yang 已提交
587 588
            raise ValueError("Static RNN only take fix seq_len input")

589 590 591
        ipt = self.helper.create_variable(
            name=x.name, dtype=x.dtype, shape=list(x.shape[1:]), type=x.type
        )
Y
Yu Yang 已提交
592 593 594 595
        self.inputs.append(ipt)
        return ipt

    def step_output(self, o):
C
chengduo 已提交
596 597 598 599 600 601 602 603
        """
        Mark a sequence as a StaticRNN output.

        Args:
            o(Variable): The output sequence.

        Returns:
            None.
604 605 606 607

        Examples:
            .. code-block:: python

608
                import paddle
609 610 611 612
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
613
                paddle.enable_static()
614
                x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
615 616 617 618 619 620 621
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
622
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
623 624 625 626 627 628 629

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
C
Charles-hit 已提交
630
                        hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
631 632 633 634 635
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
                        rnn.step_output(hidden)

                result = rnn()
636

C
chengduo 已提交
637
        """
Y
Yu Yang 已提交
638
        self._assert_in_rnn_block_('step_output')
639
        check_type(o, "o", Variable, "fluid.layers.StaticRNN.step_output")
Y
Yu Yang 已提交
640

X
Xin Pan 已提交
641
        tmp_o = self.helper.create_variable_for_type_inference(dtype=o.dtype)
642 643 644 645 646 647
        self.helper.append_op(
            type='rnn_memory_helper',
            inputs={'X': [o]},
            outputs={'Out': tmp_o},
            attrs={'dtype': o.dtype},
        )
Y
Yu Yang 已提交
648

649 650 651 652 653
        out_var = self._parent_block().create_var(
            name=tmp_o.name,
            shape=[self.seq_len] + list(tmp_o.shape),
            dtype=tmp_o.dtype,
        )
Y
Yu Yang 已提交
654 655 656 657

        self.outputs.append(out_var)

    def output(self, *outputs):
C
chengduo 已提交
658 659 660 661
        """
        Mark the StaticRNN output variables.

        Args:
662
            outputs: The output Tensor, can mark multiple variables as output
C
chengduo 已提交
663 664 665

        Returns:
            None
666 667 668 669

        Examples:
            .. code-block:: python

670
                import paddle
671 672 673 674
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
675
                paddle.enable_static()
676
                x = paddle.static.data(name="x", shape=[None, 1, 1], dtype='int64')
677 678 679 680 681 682 683
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
684
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
685 686 687 688 689 690 691

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
C
Charles-hit 已提交
692
                        hidden = paddle.static.nn.fc(x=[word, prev], size=hidden_size, activation='relu')
693 694 695 696 697 698
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
                        # mark each step's hidden and word as output
                        rnn.output(hidden, word)

                result = rnn()
C
chengduo 已提交
699
        """
Y
Yu Yang 已提交
700 701 702 703
        for each in outputs:
            self.step_output(each)

    def update_memory(self, mem, var):
C
chengduo 已提交
704
        """
705
        Update the memory from :code:`mem` to :code:`var`.
C
chengduo 已提交
706 707 708

        Args:
            mem(Variable): the memory variable.
709
            var(Variable): the plain variable generated in RNN block, used to update memory.
T
tianshuo78520a 已提交
710
                           var and mem should have same dims and data type.
C
chengduo 已提交
711 712 713

        Returns:
            None
714

C
chengduo 已提交
715
        """
716 717
        check_type(mem, "mem", Variable, "fluid.layers.StaticRNN.update_memory")
        check_type(var, "var", Variable, "fluid.layers.StaticRNN.update_memory")
Y
Yu Yang 已提交
718 719
        self.memories[mem.name].mem = var

720
    def _parent_block(self):
721
        prog = self.helper.main_program
Y
Yu Yang 已提交
722 723 724 725 726 727 728 729 730 731 732 733 734 735 736
        parent_idx = prog.current_block().parent_idx
        assert parent_idx >= 0
        parent_block = prog.block(parent_idx)
        return parent_block

    def __call__(self, *args, **kwargs):
        if self.status != StaticRNN.AFTER_RNN_BLOCK:
            raise ValueError("RNN output can only be retrieved after rnn block")
        if len(self.outputs) == 0:
            raise ValueError("RNN has no output")
        elif len(self.outputs) == 1:
            return self.outputs[0]
        else:
            return self.outputs

737
    def _complete_op(self):
738 739
        main_program = self.helper.main_program
        rnn_block = main_program.current_block()
740
        parent_block = self._parent_block()
Y
Yu Yang 已提交
741 742 743 744 745 746 747 748 749 750 751 752 753 754

        local_inputs = set()

        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for oname in op.output_names:
                for out_var_name in op.output(oname):
                    local_inputs.add(out_var_name)

        for var in self.inputs:
            local_inputs.add(var.name)
        for m in self.memories:
            local_inputs.add(m)

C
chengduo 已提交
755 756 757
        # NOTE(zcd): the params have two categories of variables.
        #   - the variables that are the out of StaticRnn.
        #   - the variables that are the parameters of some layers, for example, conv2d.
Y
Yu Yang 已提交
758 759 760 761 762 763 764 765
        params = list()
        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for iname in op.input_names:
                for in_var_name in op.input(iname):
                    if in_var_name not in local_inputs:
                        params.append(in_var_name)

766 767 768
        parameters = [
            parent_block._find_var_recursive(name) for name in set(params)
        ]
Y
Yu Yang 已提交
769 770

        step_scope = parent_block.create_var(
771 772
            type=core.VarDesc.VarType.STEP_SCOPES
        )
Y
Yu Yang 已提交
773 774 775 776

        inlinks = [parent_block.var(i.name) for i in self.inputs]
        outlinks = self.outputs

C
chengduo 已提交
777
        # NOTE(zcd): the states maybe empty in some case.
Y
Yu Yang 已提交
778 779 780
        boot_memories = []
        pre_memories = []
        memories = []
781
        for _, mem in self.memories.items():
Y
Yu Yang 已提交
782 783
            boot_memories.append(mem.init)
            pre_memories.append(mem.pre_mem.name)
784 785 786
            assert (
                mem.mem is not None
            ), "%s should be updated in every step." % (mem.init.name)
Y
Yu Yang 已提交
787 788
            mem_var = rnn_block.var(mem.mem.name)
            assert isinstance(mem_var, Variable)
X
Xin Pan 已提交
789
            new_mem = self.helper.create_variable_for_type_inference(
790 791 792 793 794 795 796 797
                dtype=mem_var.dtype
            )
            rnn_block.append_op(
                type='rnn_memory_helper',
                inputs={'X': [mem_var]},
                outputs={'Out': [new_mem]},
                attrs={'dtype': mem_var.dtype},
            )
Y
Yu Yang 已提交
798 799 800

            memories.append(new_mem.name)

801 802 803 804 805 806 807 808 809 810 811 812 813 814 815
        parent_block.append_op(
            type='recurrent',
            inputs={
                'inputs': inlinks,
                'initial_states': boot_memories,
                'parameters': parameters,
            },
            outputs={'outputs': outlinks, 'step_scopes': [step_scope]},
            attrs={
                'has_states': len(pre_memories) > 0,
                'ex_states': pre_memories,
                'states': memories,
                'sub_block': rnn_block,
            },
        )
Y
Yu Yang 已提交
816 817


818
# (TODO: Mine) There exists dependency. It will be removed later.
Y
Yang Yang(Tony) 已提交
819 820 821 822
class WhileGuard(BlockGuard):
    def __init__(self, while_op):
        if not isinstance(while_op, While):
            raise TypeError("WhileGuard takes a while op")
823
        super().__init__(while_op.helper.main_program)
Y
Yang Yang(Tony) 已提交
824 825 826 827
        self.while_op = while_op

    def __enter__(self):
        self.while_op.status = While.IN_WHILE_BLOCK
828
        return super().__enter__()
Y
Yang Yang(Tony) 已提交
829 830 831 832 833

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None:
            return False
        self.while_op.status = While.AFTER_WHILE_BLOCK
834
        self.while_op._complete()
835
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yang Yang(Tony) 已提交
836 837


838
# (TODO: Mine) There exists dependency. It will be removed later.
839 840 841
def get_inputs_outputs_in_block(
    current_block, inner_inputs, inner_outputs, helper
):
842 843 844 845 846 847 848 849
    """
    Find inputs and outputs in current control flow block.
    :param current_block: Current control flow block.
    :param inner_inputs: Input var name of ops in current block.
    :param inner_outputs: Output var name of ops in current block.
    :return: inner_inputs, inner_outputs
    """

850 851 852 853 854 855 856 857 858 859 860 861 862
    def is_ignore_vars(op, var_name):
        # NOTE(dev): There are some persistable var created in some non-standard API
        # such as "contrib.layers.shuffle_batch". It create a "Seed" used both in
        # Input and Output. This var shall not be considered as a loop_var in
        # control_flow.
        IGNORE_VAR_NAMES = {"shuffle_batch": ["shuffle_batch_seed"]}
        if op.type in IGNORE_VAR_NAMES:
            var_names = IGNORE_VAR_NAMES[op.type]
            for name in var_names:
                if name in var_name:
                    return True
        return False

863 864 865 866 867 868 869 870
    # Step1: update inner_inputs and inner_outputs
    # NOTE: Here assumes that all variables are input or output of Ops,
    # but some variables are created without appendding a real op.
    # For example, in `arr = create_array(dtype)`, `arr` is not a output of a op.
    for op in current_block.ops:
        assert isinstance(op, Operator)
        for iname in op.input_names:
            for in_var_name in op.input(iname):
871
                if in_var_name not in inner_outputs and not is_ignore_vars(
872 873
                    op, in_var_name
                ):
874 875 876 877 878 879 880 881 882 883 884 885 886 887 888
                    inner_inputs.add(in_var_name)

        for oname in op.output_names:
            for out_var_name in op.output(oname):
                inner_outputs.add(out_var_name)

    # Step2: Remove LOD_TENSOR_ARRAY created in current control flow block.
    remove_inner_inputs = set()
    parent_block = helper.main_program.block(current_block.parent_idx)

    for in_var_name in inner_inputs:
        parent_block_var = parent_block._find_var_recursive(in_var_name)
        current_block_var = None
        if current_block.has_var(in_var_name):
            current_block_var = current_block.var(in_var_name)
889 890 891 892 893
        if (
            not parent_block_var
            and current_block_var
            and current_block_var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
        ):
894 895 896 897 898 899 900
            remove_inner_inputs.add(in_var_name)

    inner_inputs = inner_inputs - remove_inner_inputs

    return inner_inputs, inner_outputs


901
# (TODO: Mine) There exists dependency. It will be removed later.
902
class While:
X
Xin Pan 已提交
903
    """
904
    :api_attr: Static Graph
905

906
    while loop control flow. Repeat while body until cond is False.
X
Xin Pan 已提交
907

908 909 910 911
    Note:
        A new OP :ref:`api_fluid_layers_while_loop` is highly recommended instead of ``While`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_fluid_layers_while_loop` is easier to use and is called with less code but does the same thing as ``While`` .

912 913 914 915 916 917
    Notice:
        Local variables created in ``While`` are similar to that created in while of C++, and cannot be referenced externally.
        As a result, they cannot be obtained through ``fetch_list`` of ``Executor``. If you would like to access the variable
        out of ``while`` , PaddlePaddle provides ``assign`` API to assign local variables to external. Please refer to example
        code 2 or refer to `issue#22724 <https://github.com/PaddlePaddle/Paddle/issues/22724>`_.

X
Xin Pan 已提交
918
    Args:
919
        cond(Variable): A Tensor whose data type is bool controlling whether to continue looping.
G
guofei 已提交
920
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
921
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
X
Xin Pan 已提交
922

923
    Examples 1:
X
Xin Pan 已提交
924
          .. code-block:: python
925

926
            import paddle.fluid as fluid
927
            import paddle
928 929
            import numpy as np

930
            i = paddle.full(shape=[1], dtype='int64', fill_value=0)           # loop counter
931

932
            loop_len = paddle.full(shape=[1],dtype='int64', fill_value=10)    # loop length
933

L
LiYuRio 已提交
934
            cond = paddle.less_than(x=i, y=loop_len)
935
            while_op = fluid.layers.While(cond=cond)
936
            with while_op.block():
937
                i = paddle.increment(x=i, value=1)
L
LiYuRio 已提交
938
                paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
939 940 941 942 943

            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[i])
944 945 946 947 948 949
            print(res) # [array([10])]


    Examples 2:
          .. code-block:: python

L
LiYuRio 已提交
950
            import paddle
951 952 953
            import paddle.fluid as fluid
            import numpy as np

954
            paddle.enable_static()
955 956 957
            i = paddle.full(shape=[1], dtype='int64', fill_value=0)
            loop_len = paddle.full(shape=[1], dtype='int64', fill_value=10)
            one = paddle.full(shape=[1], dtype='float32', fill_value=1)
958
            data = paddle.static.data(name='data', shape=[1], dtype='float32')
959
            sums = paddle.full(shape=[1], dtype='float32', fill_value=0)  # Define the variable to be obtained ouside of While, which name should be different from the variable inside the While to be obtained
960

L
LiYuRio 已提交
961
            cond = paddle.less_than(x=i, y=loop_len)
962 963
            while_op = fluid.layers.While(cond=cond)
            with while_op.block():
H
HongyuJia 已提交
964
                sums_tensor = paddle.add(x=data, y=data)
965
                fluid.layers.assign(sums_tensor, sums)  # Update the value of sums_tensor defined in While to the sums which defined outside of While through layers.assign
966
                i = paddle.increment(x=i, value=1)
H
HongyuJia 已提交
967
                data = paddle.add(x=data, y=one)
L
LiYuRio 已提交
968
                paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
969 970 971 972 973 974

            feed_data = np.ones(1).astype('float32')
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())
            res = exe.run(fluid.default_main_program(), feed={'data': feed_data}, fetch_list=sums)
            print(res[0])  # [2.]    # Because the data in While does not update the value outside the While, the value of sums is [2.] after the loop
X
Xin Pan 已提交
975 976
    """

Y
Yang Yang(Tony) 已提交
977 978 979 980
    BEFORE_WHILE_BLOCK = 0
    IN_WHILE_BLOCK = 1
    AFTER_WHILE_BLOCK = 2

C
chengduo 已提交
981
    def __init__(self, cond, is_test=False, name=None):
982
        self.helper = LayerHelper("while", name=name)
Y
Yang Yang(Tony) 已提交
983
        self.status = While.BEFORE_WHILE_BLOCK
984
        check_variable_and_dtype(cond, 'cond', ['bool'], 'fluid.layers.While')
Y
Yang Yang(Tony) 已提交
985
        if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
986
            raise TypeError(
987 988 989 990
                "condition expected shape as [1], but given shape as {0}.".format(
                    list(cond.shape)
                )
            )
Y
Yang Yang(Tony) 已提交
991
        self.cond_var = cond
C
chengduo 已提交
992
        self.is_test = is_test
Y
Yang Yang(Tony) 已提交
993 994 995 996

    def block(self):
        return WhileGuard(self)

997
    def _complete(self):
Y
Yang Yang(Tony) 已提交
998 999
        main_program = self.helper.main_program
        while_block = main_program.current_block()
1000
        parent_block = main_program.block(
1001 1002
            main_program.current_block().parent_idx
        )
Y
Yang Yang(Tony) 已提交
1003 1004 1005

        inner_outputs = {self.cond_var.name}
        x_name_list = set()
1006
        x_name_list, inner_outputs = get_inputs_outputs_in_block(
1007 1008
            while_block, x_name_list, inner_outputs, self.helper
        )
Y
Yang Yang(Tony) 已提交
1009 1010 1011

        out_vars = []
        for inner_out_name in inner_outputs:
X
Xin Pan 已提交
1012 1013 1014
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_vars.append(inner_var)
Y
Yang Yang(Tony) 已提交
1015

1016
        x_name_list |= set(map(lambda x: x.name, out_vars))
1017 1018 1019
        # NOTE(dev): cond_var has been contained in Input('Condition'), so
        # we remove it from Input('X')
        x_name_list -= {self.cond_var.name}
1020

Y
Yang Yang(Tony) 已提交
1021
        step_scope = parent_block.create_var(
1022 1023
            type=core.VarDesc.VarType.STEP_SCOPES
        )
Y
Yang Yang(Tony) 已提交
1024 1025 1026 1027

        parent_block.append_op(
            type='while',
            inputs={
1028 1029 1030 1031 1032
                'X': [
                    parent_block._var_recursive(x_name)
                    for x_name in x_name_list
                ],
                'Condition': [self.cond_var],
1033
            },
1034 1035 1036
            outputs={'Out': out_vars, 'StepScopes': [step_scope]},
            attrs={'sub_block': while_block, "is_test": self.is_test},
        )
Y
Yang Yang(Tony) 已提交
1037 1038


1039
support_ret_buildin_type = (bool, float, int)
1040 1041


1042
# (TODO: Mine) There exists dependency. It will be removed later.
1043
def assign_skip_lod_tensor_array(input, output):
1044
    """
1045
    Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
1046
    """
1047 1048

    def has_shape_diff(x_var, y_var):
1049 1050
        if len(x_var.shape) != len(y_var.shape):
            return True
1051
        for x_dim, y_dim in zip(x_var.shape, y_var.shape):
1052 1053
            if x_dim != y_dim and -1 not in [x_dim, y_dim]:
                return True
1054 1055
        return False

W
wanghuancoder 已提交
1056
    if not isinstance(input, (Variable, core.eager.Tensor)):
1057
        if isinstance(output, Variable) and isinstance(
1058 1059
            input, support_ret_buildin_type
        ):
1060
            paddle.assign(input, output)
1061 1062
        else:
            output = input
1063 1064
        return

1065 1066
    if input.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        main_program = input.block.program
1067
        parent_block = main_program.block(
1068 1069
            main_program.current_block().parent_idx
        )
1070
        if parent_block and not parent_block._find_var_recursive(input.name):
1071
            paddle.assign(input, output)
1072
    else:
1073 1074 1075 1076 1077
        if (
            isinstance(output, Variable)
            and isinstance(input, Variable)
            and has_shape_diff(input, output)
        ):
1078
            warnings.warn(
1079 1080 1081 1082
                "In dy2static mode, we attemp to assign a variable with shape {} into a variable with shape{}, which is not always right.".format(
                    input.shape, output.shape
                )
            )
1083
        paddle.assign(input, output)
1084 1085


1086
# (TODO: Mine) There exists dependency (jit.dy2static.convert_operators). It will be removed later.
G
guofei 已提交
1087
def while_loop(cond, body, loop_vars, is_test=False, name=None):
G
guofei 已提交
1088
    """
1089 1090
    :api_attr: Static Graph

G
guofei 已提交
1091 1092
    while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.

1093 1094 1095 1096
    Notice:
        Local variables defined in ``body`` cannot be obtained through ``fetch_list`` of ``Executor`` , variables should
        be defined outside ``body`` and placed in ``loop_vars`` for looping, then these variables can be fetched by ``fetch_list`` .

G
guofei 已提交
1097
    Args:
1098
        cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
1099
            as many arguments as ``loop_vars`` .
1100 1101 1102
        body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
            (length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
        loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
G
guofei 已提交
1103
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
G
guofei 已提交
1104 1105
        name(str, optional): Normally there is no need for users to set this property. For more information, please
            refer to :ref:`api_guide_Name`. Default is None.
1106

G
guofei 已提交
1107
    Returns:
C
Chen Long 已提交
1108
        A list or tuple of Tensors or LoDTensorArrays which returned by ``body`` .
G
guofei 已提交
1109 1110 1111 1112

    Examples:
        .. code-block:: python

1113 1114 1115
            import paddle
            paddle.enable_static()

1116 1117
            def cond(i, ten):
                return i < ten
G
guofei 已提交
1118

1119 1120 1121
            def body(i, ten):
                i = i + 1
                return [i, ten]
G
guofei 已提交
1122

C
Chen Long 已提交
1123 1124 1125 1126 1127 1128
            main_program = paddle.static.default_main_program()
            startup_program = paddle.static.default_startup_program()
            with paddle.static.program_guard(main_program, startup_program):
                i = paddle.full(shape=[1], fill_value=0, dtype='int64')     # loop counter
                ten = paddle.full(shape=[1], fill_value=10, dtype='int64')  # loop length
                i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
1129

C
Chen Long 已提交
1130
                exe = paddle.static.Executor(paddle.CPUPlace())
1131
                res = exe.run(main_program, feed={}, fetch_list=[i])
G
guofei 已提交
1132 1133 1134 1135 1136 1137 1138 1139
                print(res) # [array([10])]
    """
    helper = LayerHelper('while_loop', **locals())

    if not callable(cond):
        raise TypeError("cond in while_loop should be callable")
    if not callable(body):
        raise TypeError("body in while_loop should be callable")
1140
    check_type(loop_vars, 'loop_vars', (list, tuple), 'fluid.layers.while_loop')
G
guofei 已提交
1141 1142 1143 1144
    if len(loop_vars) == 0:
        raise ValueError("loop_vars in while_loop should not be empty")

    pre_cond = cond(*loop_vars)
1145

G
guofei 已提交
1146 1147
    if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
        raise TypeError(
1148
            "the shape of the variable returned by cond should be [1],"
1149 1150
            "but given shape as {0}.".format(list(pre_cond.shape))
        )
G
guofei 已提交
1151

姜永久 已提交
1152
    if in_dygraph_mode():
1153
        now_cond = pre_cond.item()
1154
        while now_cond:
1155 1156 1157 1158 1159 1160
            output_vars = body(*loop_vars)
            if not isinstance(output_vars, (list, tuple)):
                output_vars = [output_vars]
            if len(output_vars) != len(loop_vars):
                raise ValueError(
                    "body in while_loop should return the same arity "
1161 1162
                    "(length and structure) and types as loop_vars"
                )
1163
            now_cond = cond(*output_vars).item()
1164
            map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
1165
        return loop_vars
姜永久 已提交
1166
    else:
1167 1168 1169 1170 1171 1172
        check_variable_and_dtype(
            pre_cond,
            'var of cond returned',
            ['bool'],
            'fluid.layers.while_loop',
        )
姜永久 已提交
1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196
        while_loop_block = While(pre_cond, is_test, name)
        has_mutable_vars_in_loop = hold_mutable_vars(loop_vars)
        with while_loop_block.block():
            # If a variable with mutable type is included in loop_vars, like `dict/list`,
            # modifying it in the body function will cause origin variable to be modified
            # synchronously. This will raise an assignment error out of while block.
            # Here we make a copy of the mutable vars to avoid this problem.
            if has_mutable_vars_in_loop:
                new_loop_vars = copy_mutable_vars(loop_vars)
                output_vars = body(*new_loop_vars)
            else:
                output_vars = body(*loop_vars)
            if not isinstance(output_vars, (list, tuple)):
                output_vars = [output_vars]
            try:
                loop_vars = _deal_with_undefined_var(output_vars, loop_vars)
                assert_same_structure(output_vars, loop_vars, check_types=False)
            except ValueError as e:
                raise ValueError(
                    "body in while_loop should return the same arity "
                    "(length and structure) as loop_vars: {0}".format(e)
                )
            now_cond = cond(*output_vars)
            map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
1197
            paddle.assign(now_cond, pre_cond)
姜永久 已提交
1198
        return loop_vars
G
guofei 已提交
1199 1200


1201
# (TODO: Mine) There exists dependency. It will be removed later.
1202
def _deal_with_undefined_var(output_vars, loop_vars):
1203 1204 1205 1206 1207 1208 1209
    """Deal with undefined var cases, We create undefined variable based on the results of body().
    In Dy2Static, we use undefined var to represent the var created in control flow. This function
    expand the loop_vars and replace original loop_vars.
    1. UndefinedVar = Variable      # create a variable
    2. UndefinedVar = None          # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
    3. UndefinedVar = List(int)     # create a list of variable
    4. UndefinedVar = value         # create a variable
1210
    """
1211
    from paddle.jit.dy2static.utils import (
1212 1213 1214
        UndefinedVar,
        create_undefined_variable,
    )
1215 1216

    def create_var_like(o_var):
1217 1218 1219 1220
        if (
            isinstance(o_var, (Variable,) + support_ret_buildin_type)
            or o_var is None
        ):
1221
            return create_undefined_variable()
1222
        if is_sequence(o_var):
1223
            """
1224 1225 1226
            Create a complex container class inside the body of while, including Python list and python Dict
            """
            return map_structure(lambda x: create_undefined_variable(), o_var)
1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239

    if len(output_vars) != len(loop_vars):
        raise ValueError("The length of loop_vars should be the same.")

    results = []
    for o_var, l_var in zip(output_vars, loop_vars):
        if isinstance(l_var, UndefinedVar) or l_var is None:
            results.append(create_var_like(o_var))
        else:
            results.append(l_var)
    return results


Y
Yu Yang 已提交
1240
class ConditionalBlockGuard(BlockGuard):
F
fengjiayi 已提交
1241
    """
1242 1243 1244
    ConditionalBlockGuard is derived from BlockGuard. It is dedicated for
    holding a ConditionalBlock, and helping users entering and exiting the
    ConditionalBlock via Python's 'with' keyword. However, ConditionalBlockGuard
F
fengjiayi 已提交
1245 1246 1247
    is generally an internal component of IfElse, users should not use it directly.
    """

Y
Yu Yang 已提交
1248
    def __init__(self, block):
1249
        check_type(block, "block", ConditionalBlock, "ConditionalBlockGuard")
1250
        super().__init__(block.helper.main_program)
Y
Yu Yang 已提交
1251 1252 1253
        self.block = block

    def __enter__(self):
1254
        return super().__enter__()
Y
Yu Yang 已提交
1255 1256 1257

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.block.complete()
1258
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yu Yang 已提交
1259 1260


1261
class ConditionalBlock:
Y
Yan Chunwei 已提交
1262 1263 1264 1265 1266 1267 1268 1269
    '''
    **ConditionalBlock**

    ConditionalBlock is an operator that bind a block to a specific condition,
    if the condition matches, the corresponding block will be executed.

    Args:
        inputs (Variable): bool conditions.
T
tianshuo78520a 已提交
1270
        is_scalar_condition (bool): whether the branch is controlled by a scalar.
Y
Yan Chunwei 已提交
1271 1272 1273 1274 1275
        name(str): name of this ConditionalBlock.

    Examples:
        .. code-block:: python

L
LiYuRio 已提交
1276
             import paddle
1277
             import paddle.fluid as fluid
L
LiYuRio 已提交
1278
             cond = paddle.less_than(x=label, y=limit)
Y
Yan Chunwei 已提交
1279 1280 1281 1282 1283 1284 1285 1286 1287 1288
             true_image, false_image = layers.split_lod_tensor(
                 input=image, mask=cond)
             true_cond = layers.ConditionalBlock([true_image])

             with true_cond.block():
                 ...
             with false_cond.block():
                 ...
    '''

1289
    def __init__(self, inputs, is_scalar_condition=False, name=None):
Y
Yu Yang 已提交
1290
        for each_input in inputs:
1291
            check_type(each_input, "input", Variable, "ConditionalBlock")
Y
Yu Yang 已提交
1292
        self.inputs = inputs
1293
        self.is_scalar_condition = is_scalar_condition
1294
        self.helper = LayerHelper('conditional_block', name=name)
Y
Yu Yang 已提交
1295 1296 1297 1298 1299 1300 1301 1302 1303 1304

    def block(self):
        return ConditionalBlockGuard(self)

    def complete(self):
        inside_block = self.helper.main_program.current_block()
        parent_block = self.helper.main_program.block(inside_block.parent_idx)

        intermediate = set()
        params = set()
1305 1306 1307
        params, intermediate = get_inputs_outputs_in_block(
            inside_block, params, intermediate, helper=self.helper
        )
Y
Yu Yang 已提交
1308

1309 1310 1311
        # Todo(liym27) Here assume that all params are in recursive parent block
        # but when minimize() called in control flow, some params may be in
        # conditional grad block
Y
Yu Yang 已提交
1312
        param_list = [
W
Wu Yi 已提交
1313
            parent_block._var_recursive(each_name) for each_name in params
Y
Yu Yang 已提交
1314 1315
        ]

X
Xin Pan 已提交
1316 1317 1318 1319 1320
        out_list = []
        for inner_out_name in intermediate:
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_list.append(inner_var)
Y
Yu Yang 已提交
1321 1322

        step_scope = parent_block.create_var(
1323 1324
            type=core.VarDesc.VarType.STEP_SCOPES
        )
1325
        conditional_block_op = parent_block.append_op(
Y
Yu Yang 已提交
1326 1327
            type='conditional_block',
            inputs={
1328 1329
                'Cond': self.inputs,
                'Input': param_list,
Y
Yu Yang 已提交
1330
            },
1331
            outputs={'Out': out_list, 'Scope': [step_scope]},
1332 1333
            attrs={
                'sub_block': inside_block,
1334 1335 1336
                'is_scalar_condition': self.is_scalar_condition,
            },
        )
1337

1338
        if self.need_append_conditional_block_grad(inside_block):
1339 1340 1341
            self.append_conditional_block_grad(
                parent_block, inside_block, conditional_block_op
            )
1342 1343 1344

    def need_append_conditional_block_grad(self, inside_block):
        grad_sub_block_idx = inside_block.backward_block_idx
1345
        inside_block_idx = inside_block.idx
1346

1347 1348
        # if inside_block have grad_block and grad_block is not itself,
        # we will append conditional block grad.
1349 1350 1351
        return (
            grad_sub_block_idx != -1 and grad_sub_block_idx != inside_block_idx
        )
1352

1353 1354 1355
    def append_conditional_block_grad(
        self, parent_block, inside_block, conditional_block_op
    ):
1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390
        '''
        Append op `conditional_block_grad` manually.
        When `optimizer.minimize/append_backward` is called in Paddle control flow,
        grad ops will be appended before appending op `conditional_block` so that
        op `conditional_block_grad` can't be appended when calling
        `optimizer.minimize/append_backward`. After appending op `conditional_block`,
        `conditional_block_grad` is appended manually.

        Args:
            parent_block (Block): The block that `conditional_block_op` blongs to.
            inside_block (Block): The sub block of `conditional_block_op`.
            conditional_block_op (Operator): The forward op conditional_block.
        '''

        grad_sub_block_idx = inside_block.backward_block_idx
        grad_sub_block = self.helper.main_program.block(grad_sub_block_idx)

        intermediate = set()
        params = set()

        for each_op in grad_sub_block.ops:
            assert isinstance(each_op, Operator)
            for iname in each_op.input_names:
                for in_var_name in each_op.input(iname):
                    if in_var_name not in intermediate:
                        params.add(in_var_name)

            for oname in each_op.output_names:
                for out_var_name in each_op.output(oname):
                    intermediate.add(out_var_name)

        param_list = []
        for inner_input_name in params:
            inner_var = parent_block._find_var_recursive(inner_input_name)
            if inner_var:
1391
                param_list.append(inner_var.name)
1392 1393

        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
1394 1395
            conditional_block_op.desc, set(), [grad_sub_block.desc]
        )
1396 1397 1398 1399 1400 1401 1402 1403 1404

        # append op_desc in grad_op_descs to target_block
        op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
        backward = core.op_proto_and_checker_maker.OpRole.Backward
        new_op_desc = parent_block.desc.append_op()
        new_op_desc.copy_from(grad_op_desc[0])
        new_op_desc._set_attr(op_role_attr_name, backward)
        # set input and output manually
        new_op_desc.set_input('Input', param_list)
1405 1406 1407
        new_op_desc.set_output(
            'Input@GRAD', [param + "@GRAD" for param in param_list]
        )
1408 1409 1410

        new_vars = set()
        for grad_var_name in new_op_desc.output_arg_names():
1411 1412 1413 1414
            if (
                grad_sub_block.desc.has_var_recursive(grad_var_name.encode())
                or grad_var_name == core.empty_var_name()
            ):
1415
                continue
1416
            grad_sub_block.desc.var(grad_var_name.encode())
1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430
            new_vars.add(grad_var_name)
            if grad_var_name not in op_grad_to_var:
                continue

        # infer_shape and infer_type
        new_op_desc.infer_var_type(grad_sub_block.desc)
        new_op_desc.infer_shape(grad_sub_block.desc)

        for arg in new_op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_shape_(arg, grad_sub_block)

        self.helper.main_program._sync_with_cpp()

1431

1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449
def _to_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return [x]
    return to_sequence(x)


def _is_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return False
    return is_sequence(x)


1450
def expand_undefined_var(nest1, nest2, names):
1451 1452 1453 1454
    """TODO: make this function recursively.
    nest1: Var1, (UndefinedVar, [1,2,3])
    nest2: Var2, ([1,2,3,4], UndefinedVar)
    In this case, we should not expand recursively.
1455
    """
1456
    from paddle.jit.dy2static.utils import UndefinedVar
1457
    from paddle.jit.dy2static.return_transformer import (
1458 1459
        RETURN_VALUE_PREFIX,
    )
1460 1461

    def pack_undefined_var_as(seq):
1462 1463 1464
        return pack_sequence_as(
            seq, [UndefinedVar("padding") for i in flatten(seq)]
        )
1465

1466
    def map_fn(n1, n2, name, order):
1467 1468 1469
        if not name.startswith(RETURN_VALUE_PREFIX) and (
            isinstance(n1, UndefinedVar) or n1 is None
        ):
1470 1471 1472 1473 1474 1475
            if n1 is None and n2 is not None:
                if order == 0:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
1476 1477 1478
                            name, type(n1), n1, type(n2), n2
                        )
                    )
1479 1480 1481 1482 1483
                else:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
1484 1485 1486
                            name, type(n2), n2, type(n1), n1
                        )
                    )
1487 1488 1489 1490
            return pack_undefined_var_as(n2)
        return n1

    nest1_out = list(
1491 1492
        map(
            map_fn,
1493 1494 1495 1496
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(names),
            [0 for i in _to_sequence_except_dict(names)],
1497 1498
        )
    )
1499
    nest2_out = list(
1500 1501
        map(
            map_fn,
1502 1503 1504 1505
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(names),
            [1 for i in _to_sequence_except_dict(names)],
1506 1507
        )
    )
1508
    if not _is_sequence_except_dict(nest1):
1509
        nest1_out = nest1_out[0]
1510
    if not _is_sequence_except_dict(nest2):
1511
        nest2_out = nest2_out[0]
1512 1513 1514
    return nest1_out, nest2_out


1515
class Switch:
Q
qiaolongfei 已提交
1516
    """
1517
    :api_attr: Static Graph
Q
qiaolongfei 已提交
1518

1519 1520 1521 1522 1523
    This class is used to implement Switch branch control function.
    Switch branch contains several case branches and one default branch.
    Switch control flow checks whether the case branch conditions are satisfied in turn,
    and only executes the statement after the first case branch that satisfies the conditions.
    If there is no case branch that satisfies the condition,
1524 1525
    only the statement following the default branch is executed.

1526 1527 1528 1529
    Note:
        A new OP :ref:`api_fluid_layers_case` is highly recommended instead of ``Switch`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_fluid_layers_case` is easier to use and is called with less code but does the same thing as ``Switch`` .

1530
    Member Functions:
1531
        case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.
1532

1533 1534 1535 1536 1537
        default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed.

    Case and default functions can only be used inside the scope of Switch, as shown below:

    .. code-block:: python
1538

1539
        '''
1540 1541
        import paddle
        import paddle.fluid as fluid
1542 1543
        with fluid.layers.Switch() as switch:
            with switch.case(cond1):
1544
                i = paddle.full(shape=[1], dtype='int64', fill_value=1)
1545
            with switch.case(cond2):
1546
                i = paddle.full(shape=[1], dtype='int64', fill_value=2)
1547
            with switch.default():
1548
                i = paddle.full(shape=[1], dtype='int64', fill_value=0)
1549
        '''
Q
qiaolongfei 已提交
1550

1551 1552
    Args:
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
Q
qiaolongfei 已提交
1553 1554 1555

    Examples:
        .. code-block:: python
1556

1557
            import paddle
1558
            import paddle.fluid as fluid
Q
qiaolongfei 已提交
1559

1560
            lr = paddle.static.create_global_var(
Q
qiaolongfei 已提交
1561 1562 1563 1564 1565
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate")
1566 1567 1568 1569 1570 1571
            zero_var = paddle.full(
                shape=[1], dtype='float32', fill_value=0.0)
            one_var = paddle.full(
                shape=[1], dtype='float32', fill_value=1.0)
            two_var = paddle.full(
                shape=[1], dtype='float32', fill_value=2.0)
1572

1573
            global_step = fluid.layers.autoincreased_step_counter(counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)
Q
qiaolongfei 已提交
1574 1575

            with fluid.layers.control_flow.Switch() as switch:
Q
qiaolongfei 已提交
1576
                with switch.case(global_step == zero_var):
1577
                    paddle.assign(input=one_var, output=lr)
Q
qiaolongfei 已提交
1578
                with switch.default():
1579
                    paddle.assign(input=two_var, output=lr)
Q
qiaolongfei 已提交
1580

1581 1582 1583 1584 1585
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[lr])
            print(res) # [array([1.], dtype=float32)]
Q
qiaolongfei 已提交
1586 1587
    """

1588 1589 1590 1591 1592 1593 1594 1595 1596
    def __init__(self, name=None):
        self.helper = LayerHelper('switch', name=name)
        self.inside_scope = False
        self.pre_not_conditions = []

    def case(self, condition):
        if not self.inside_scope:
            raise ValueError("case should be called inside with")

1597
        check_variable_and_dtype(
1598 1599 1600 1601 1602
            condition,
            'condition',
            ['bool'],
            'the member function case of fluid.layers.Switch',
        )
1603

1604 1605
        if len(self.pre_not_conditions) == 0:
            cond_block = ConditionalBlock([condition], is_scalar_condition=True)
2
201716010711 已提交
1606
            not_cond = paddle.logical_not(x=condition)
1607 1608 1609 1610
            self.pre_not_conditions.append(not_cond)
        else:
            pre_cond_num = len(self.pre_not_conditions)
            pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
1611
            new_not_cond = paddle.logical_and(
2
201716010711 已提交
1612
                x=pre_not_cond, y=paddle.logical_not(x=condition)
1613
            )
1614 1615
            self.pre_not_conditions.append(new_not_cond)
            cond_block = ConditionalBlock(
1616
                [paddle.logical_and(x=pre_not_cond, y=condition)],
1617 1618
                is_scalar_condition=True,
            )
1619 1620 1621 1622 1623 1624 1625 1626 1627

        return ConditionalBlockGuard(cond_block)

    def default(self):
        pre_cond_num = len(self.pre_not_conditions)
        if pre_cond_num == 0:
            raise ValueError("there should be at least one condition")
        cond_block = ConditionalBlock(
            [self.pre_not_conditions[pre_cond_num - 1]],
1628 1629
            is_scalar_condition=True,
        )
1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645
        return ConditionalBlockGuard(cond_block)

    def __enter__(self):
        """
        set flag that now is inside switch.block {}
        :return:
        """
        self.inside_scope = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.inside_scope = False
        if exc_type is not None:
            return False  # re-raise exception

        return True