control_flow.py 72.6 KB
Newer Older
1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

S
rename  
sneaxiy 已提交
15
from ..wrapped_decorator import signature_safe_contextmanager
D
dzhwinter 已提交
16

17
from .layer_function_generator import templatedoc
18
from .tensor import assign, cast, fill_constant
19
from .. import core
20 21 22 23 24 25 26 27 28
from ..framework import (
    Program,
    Variable,
    Operator,
    _non_static_mode,
    static_only,
    _in_legacy_dygraph,
    in_dygraph_mode,
)
29
from ..layer_helper import LayerHelper, unique_name
30 31 32 33 34 35 36 37 38 39 40
from .utils import (
    assert_same_structure,
    map_structure,
    hold_mutable_vars,
    copy_mutable_vars,
    padding_to_same_structure,
    is_sequence,
    pack_sequence_as,
    flatten,
    to_sequence,
)
Y
yuyang18 已提交
41
import numpy
42
import warnings
L
liym27 已提交
43
from functools import reduce, partial
44 45 46 47 48 49
from ..data_feeder import (
    convert_dtype,
    check_variable_and_dtype,
    check_type,
    check_dtype,
)
50
from ..backward import _infer_var_data_type_shape_
2
201716010711 已提交
51
import paddle
52
from paddle import _C_ops, _legacy_C_ops
D
dzhwinter 已提交
53

Q
QI JUN 已提交
54
__all__ = [
55 56 57 58 59 60
    'Switch',
    'array_write',
    'array_read',
    'StaticRNN',
    'Print',
    'while_loop',
D
dzhwinter 已提交
61 62
]

Y
Yu Yang 已提交
63

64 65
def select_output(input, outputs, mask):
    """
66
    **select_output**
67 68 69 70 71 72 73 74 75 76 77 78 79 80
    This API takes in one input and multiple outputs and an integer mask. It
    selects the output specified by the mask and copy the input to selected
    output. It is useful in control flow.

    Args:
        input(Variable): The input variable
        outputs(tuple|list): The output variables
        mask(Variable): A tensor containing 1 integer number selecting which
            output to be copied with input

    Returns:
        Variable: The outputs variables
    """
    helper = LayerHelper('select_output', **locals())
81 82 83 84
    check_type(input, 'input', (Variable), 'select_output')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_output')
    check_type(outputs, 'outputs', (list, tuple), 'select_output')

85 86 87 88 89
    helper.append_op(
        type='select_output',
        inputs={'X': input, 'Mask': mask},
        outputs={'Out': outputs},
    )
90 91 92
    return outputs


93 94 95 96 97 98 99
def _select_input_infer_shape(first_shape, second_shape):
    """
    This function infer the output shape by following algorithm:
    1. if the dims is different, raise a error.
    2. compare axis one by one:
        if a == b: we set axis to a
        if a != b: we set axis to -1
100
    for compatibility, non declarative mode, we just return second_shape.
101 102 103 104 105 106 107
    """
    if len(first_shape) != len(second_shape):
        warnings.warn(
            f"the input shapes of select_input should have the same rank, but get {first_shape}, {second_shape}"
        )
        return second_shape
    out_shape = list(
108 109
        map(lambda a, b: a if a == b else -1, first_shape, second_shape)
    )
110 111 112
    return out_shape


113 114 115
def select_input(inputs, mask):
    """
    **select_input**
116

117 118 119 120 121 122 123 124 125 126 127 128
    This API takes in multiple inputs and uses an integer mask to select one
    input to output. It is useful in control flow.

    Args:
        inputs(tuple|list): The input variables
        mask(Variable): A tensor containing 1 integer number selecting which
            input to output

    Returns:
        Variable: The selected input variable
    """
    helper = LayerHelper('select_input', **locals())
129 130 131
    check_type(inputs, 'inputs', (list, tuple), 'select_input')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input')

132
    # Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly
133
    # assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"
134

135 136 137
    output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape)
    output_dtype = inputs[1].dtype
    output_type = inputs[1].type
138

139 140 141 142 143 144 145 146
    out = helper.create_variable(
        dtype=output_dtype, shape=output_shape, type=output_type
    )
    helper.append_op(
        type='select_input',
        inputs={'X': inputs, 'Mask': mask},
        outputs={'Out': out},
    )
147 148 149
    return out


150
def split_lod_tensor(input, mask, level=0):
151 152 153 154
    """
    This function takes in an input that contains the complete lod information,
    and takes in a mask which is used to mask certain parts of the input.
    The output is the true branch and the false branch with the mask applied to
Q
qiaolongfei 已提交
155 156
    the input at a certain level in the tensor. Mainly used in IfElse to split
    data into two parts.
157 158

    Args:
159
        input(Variable|tuple|list|None): The input tensor that contains complete
160
                                lod information needed to construct the output.
161
        mask(Variable|list): A bool column vector which masks the input.
Q
qiaolongfei 已提交
162
        level(int): The specific lod level to split.
163 164

    Returns:
Q
qiaolongfei 已提交
165 166 167 168
        tuple(Variable, Variable):
        The true branch of tensor as per the mask applied to input.

        The false branch of tensor as per the mask applied to input.
169 170 171 172

    Examples:
        .. code-block:: python

173
          import paddle.fluid as fluid
Q
qiaolongfei 已提交
174
          x = fluid.layers.data(name='x', shape=[1])
175 176
          x.persistable = True

Q
qiaolongfei 已提交
177
          y = fluid.layers.data(name='y', shape=[1])
178 179
          y.persistable = True

Q
qiaolongfei 已提交
180
          out_true, out_false = fluid.layers.split_lod_tensor(
181
                input=x, mask=y, level=level)
182

183
    """
184 185 186 187 188 189
    check_type(
        input,
        'input',
        (Variable, list, tuple, type(None)),
        'fluid.layers.split_lod_tensor',
    )
190 191
    check_type(mask, 'mask', (Variable, list), 'fluid.layers.split_lod_tensor')
    check_type(level, 'level', int, 'fluid.layers.split_lod_tensor')
192
    helper = LayerHelper('split_lod_tensor', **locals())
X
Xin Pan 已提交
193 194
    out_true = helper.create_variable_for_type_inference(dtype=input.dtype)
    out_false = helper.create_variable_for_type_inference(dtype=input.dtype)
195 196 197 198 199 200 201 202 203
    helper.append_op(
        type='split_lod_tensor',
        inputs={
            'X': input,
            'Mask': mask,
        },
        outputs={'OutTrue': out_true, 'OutFalse': out_false},
        attrs={'level': level},
    )
204 205 206
    return out_true, out_false


207
def merge_lod_tensor(in_true, in_false, x, mask, level=0):
208 209 210 211 212
    """
    **merge_lod_tensor**

    This function takes in an input :math:`x`, the True branch, the False
    branch and a binary :math:`mask`. Using this information, this function
Q
qiaolongfei 已提交
213 214 215
    merges the True and False branches of the tensor into a single tensor as
    output at a certain lod level indicated by :math:`level`. Used in IfElse
    to merge the output if True block and False Block.
216 217

    Args:
218 219 220
        in_true(Variable|tuple|list|None): The True branch to be merged.
        in_false(Variable|tuple|list|None): The False branch to be merged.
        x(Variable|tuple|list|None): The input tensor that contains complete
221
                            lod information needed to construct the output.
222
        mask(Variable|list): A bool column vector which masks the input.
Q
qiaolongfei 已提交
223
        level(int): The specific lod level to merge.
224 225 226 227 228 229 230

    Returns:
        Variable: The merged output tensor.

    Examples:
        .. code-block:: python

231
          import paddle.fluid as fluid
232 233 234 235 236 237 238 239 240 241 242 243
          x = layers.data(
                      name='x', shape=[1], dtype='float32', stop_gradient=False)
          y = layers.data(
                name='y', shape=[1], dtype='bool', stop_gradient=False)

          level = 0

          out_true, out_false = layers.split_lod_tensor(
                input=x, mask=y, level=level)
          out = layers.merge_lod_tensor(
                in_true=out_true, in_false=out_false, mask=y, x=x, level=level)
    """
244
    helper = LayerHelper('merge_lod_tensor', **locals())
245 246 247 248 249 250
    check_type(
        x,
        'x',
        (Variable, list, tuple, type(None)),
        'fluid.layers.merge_lod_tensor',
    )
251
    check_type(mask, 'mask', (Variable, list), 'fluid.layers.merge_lod_tensor')
252 253 254 255 256 257 258 259 260 261 262 263
    check_type(
        in_true,
        'in_true',
        (Variable, list, tuple, type(None)),
        'fluid.layers.merge_lod_tensor',
    )
    check_type(
        in_false,
        'in_false',
        (Variable, list, tuple, type(None)),
        'fluid.layers.merge_lod_tensor',
    )
X
Xin Pan 已提交
264
    out = helper.create_variable_for_type_inference(dtype=in_true.dtype)
265 266 267 268 269 270
    helper.append_op(
        type='merge_lod_tensor',
        inputs={'X': x, 'Mask': mask, 'InTrue': in_true, 'InFalse': in_false},
        outputs={'Out': out},
        attrs={'level': level},
    )
271 272 273
    return out


274
@static_only
275 276 277 278 279 280 281 282 283 284 285 286
def Print(
    input,
    first_n=-1,
    message=None,
    summarize=20,
    print_tensor_name=True,
    print_tensor_type=True,
    print_tensor_shape=True,
    print_tensor_layout=True,
    print_tensor_lod=True,
    print_phase='both',
):
Y
Yan Chunwei 已提交
287
    '''
288 289
    :api_attr: Static Graph

Y
Yan Chunwei 已提交
290 291 292 293 294 295 296 297 298
    **Print operator**

    This creates a print op that will print when a tensor is accessed.

    Wraps the tensor passed in so that whenever that a tensor is accessed,
    the message `message` is printed, along with the current value of the
    tensor `t`.

    Args:
Y
yangyaming 已提交
299
        input (Variable): A Tensor to print.
300
        summarize (int): Number of elements in the tensor to be print. If it's
T
tianshuo78520a 已提交
301
                value is -1, then all elements in the tensor will be print.
Y
yangyaming 已提交
302 303
        message (str): A string message to print as a prefix.
        first_n (int): Only log `first_n` number of times.
304 305 306
        print_tensor_name (bool, optional): Print the tensor name. Default: True.
        print_tensor_type (bool, optional): Print the tensor type. Defaultt: True.
        print_tensor_shape (bool, optional): Print the tensor shape. Default: True.
307
        print_tensor_layout (bool, optional): Print the tensor layout. Default: True.
308
        print_tensor_lod (bool, optional): Print the tensor lod. Default: True.
309
        print_phase (str): Which phase to displace, including 'forward',
310
                'backward' and 'both'. Default: 'both'. If set to 'backward', will
311 312
                only print the gradients of input tensor; If set to 'both', will
                both print the input tensor itself and the gradients of input tensor.
Y
Yan Chunwei 已提交
313 314

    Returns:
315
        Variable: Output tensor.
Y
Yan Chunwei 已提交
316

317 318 319 320
    NOTES:
        The input and output are two different variables, and in the
        following process, you should use the output variable but not the input,
        otherwise, the print layer doesn't have backward.
Y
Yan Chunwei 已提交
321

Y
Yan Chunwei 已提交
322 323
    Examples:
        .. code-block:: python
324

325 326 327
           import paddle

           paddle.enable_static()
328

329 330 331 332 333 334 335 336 337 338 339 340 341 342
           x = paddle.full(shape=[2, 3], fill_value=3, dtype='int64')
           out = paddle.static.Print(x, message="The content of input layer:")

           main_program = paddle.static.default_main_program()
           exe = paddle.static.Executor(place=paddle.CPUPlace())
           res = exe.run(main_program, fetch_list=[out])
           # Variable: fill_constant_1.tmp_0
           #   - message: The content of input layer:
           #   - lod: {}
           #   - place: CPUPlace
           #   - shape: [2, 3]
           #   - layout: NCHW
           #   - dtype: long
           #   - data: [3 3 3 3 3 3]
Y
Yan Chunwei 已提交
343
    '''
344 345 346 347 348 349
    check_variable_and_dtype(
        input,
        'input',
        ['float32', 'float64', 'int32', 'int64', 'bool'],
        'fluid.layers.Print',
    )
350

351 352
    helper = LayerHelper('print' + "_" + input.name, **locals())
    output = helper.create_variable_for_type_inference(input.dtype)
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
    helper.append_op(
        type='print',
        inputs={'In': input},
        outputs={'Out': output},
        attrs={
            'first_n': first_n,
            'summarize': summarize,
            'message': message or "",
            'print_tensor_name': print_tensor_name,
            'print_tensor_type': print_tensor_type,
            'print_tensor_shape': print_tensor_shape,
            'print_tensor_layout': print_tensor_layout,
            'print_tensor_lod': print_tensor_lod,
            'print_phase': print_phase.upper(),
        },
    )
369
    return output
Y
Yan Chunwei 已提交
370 371


372
# (TODO: Mine) There exists dependency. It will be removed later.
373
class BlockGuard:
Y
Yu Yang 已提交
374
    """
375 376 377 378
    BlockGuard class.

    BlockGuard class is used to create a sub-block in a program by
    using the Python `with` keyword.
Y
Yu Yang 已提交
379 380
    """

381 382
    def __init__(self, main_program):
        if not isinstance(main_program, Program):
Y
Yu Yang 已提交
383
            raise TypeError("BlockGuard takes a program")
384
        self.main_program = main_program
Y
Yu Yang 已提交
385 386

    def __enter__(self):
W
Wu Yi 已提交
387
        self.main_program._create_block()
Y
Yu Yang 已提交
388 389

    def __exit__(self, exc_type, exc_val, exc_tb):
W
Wu Yi 已提交
390
        self.main_program._rollback()
Y
Yu Yang 已提交
391 392 393 394 395
        if exc_type is not None:
            return False  # re-raise exception
        return True


396
# (TODO: Mine) There exists dependency. It will be removed later.
Y
Yang Yang 已提交
397 398 399 400 401
class BlockGuardWithCompletion(BlockGuard):
    """
    BlockGuardWithCompletion class.

    BlockGuardWithCompletion class is used to create an op with a block in a program.
402 403
    """

Y
Yu Yang 已提交
404
    def __init__(self, rnn):
X
Xin Pan 已提交
405
        if not isinstance(rnn, StaticRNN):
X
Xin Pan 已提交
406
            raise TypeError("BlockGuardWithCompletion takes a StaticRNN")
407
        super().__init__(rnn.helper.main_program)
Y
Yu Yang 已提交
408 409 410 411
        self.rnn = rnn

    def __enter__(self):
        self.rnn.status = StaticRNN.IN_RNN_BLOCK
412
        return super().__enter__()
Y
Yu Yang 已提交
413 414

    def __exit__(self, exc_type, exc_val, exc_tb):
Y
Yu Yang 已提交
415 416
        if exc_type is not None:
            return False
Y
Yu Yang 已提交
417
        self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
418
        self.rnn._complete_op()
419
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yu Yang 已提交
420 421


422
class StaticRNNMemoryLink:
Y
Yu Yang 已提交
423
    """
424 425 426 427
    StaticRNNMemoryLink class.

    StaticRNNMemoryLink class is used to create a link between two
    memory cells of a StaticRNN.
Y
yuyang18 已提交
428 429 430 431 432 433 434 435 436


    NOTE: This is a internal data structure of a very low-level API.
    Please use StaticRNN instead.

    Args:
        init(Variable): the initial variable for Memory.
        pre_mem(Variable): the memory variable in previous time step.
        mem(Variable): the memory variable in current time step.
Y
Yu Yang 已提交
437 438 439 440 441 442 443 444
    """

    def __init__(self, init, pre_mem, mem=None):
        self.init = init
        self.pre_mem = pre_mem
        self.mem = mem


445
class StaticRNN:
446
    """
447 448
    :api_attr: Static Graph

449 450
    StaticRNN class.

451 452 453 454 455 456 457
    The StaticRNN can process a batch of sequence data. The first dimension of inputs
    represents sequence length, the length of each input sequence must be equal.
    StaticRNN will unfold sequence into time steps, user needs to define how to process
    each time step during the :code:`with` step.

    Args:
        name (str, optional): Please refer to :ref:`api_guide_Name`, Default None.
C
chengduo 已提交
458 459

    Examples:
460 461
        .. code-block:: python

462
            import paddle
463 464 465 466
            import paddle.fluid as fluid
            import paddle.fluid.layers as layers

            vocab_size, hidden_size=10000, 200
467
            paddle.enable_static()
468 469
            x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
            # create word sequence
470 471 472 473 474
            x_emb = layers.embedding(
                input=x,
                size=[vocab_size, hidden_size],
                dtype='float32',
                is_sparse=False)
475
            # transform batch size to dim 1
476
            x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
477 478 479

            rnn = fluid.layers.StaticRNN()
            with rnn.step():
480
                # mark created x_emb as input, each step process a word
481
                word = rnn.step_input(x_emb)
482
                # create prev memory parameter, batch size comes from word
483 484
                prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
485 486
                # use hidden to update prev
                rnn.update_memory(prev, hidden)
487
                # mark hidden as output
488
                rnn.step_output(hidden)
489
            # get StaticrNN final output
490
            result = rnn()
C
chengduo 已提交
491

492
    """
493

Y
Yu Yang 已提交
494 495 496 497
    BEFORE_RNN_BLOCK = 0
    IN_RNN_BLOCK = 1
    AFTER_RNN_BLOCK = 2

498
    def __init__(self, name=None):
499
        check_type(name, "name", (str, type(None)), "fluid.layers.StaticRNN")
500
        self.helper = LayerHelper("static_rnn", name=name)
Y
Yu Yang 已提交
501 502 503 504 505 506 507 508
        self.memories = {}  # memory map, from pre_mem.name --> MemoryLink
        self.inputs = []  # input variable list in current block
        self.outputs = []  # output variable list in parent block
        self.status = StaticRNN.BEFORE_RNN_BLOCK  # status flag.
        # sequence length, since it is a static RNN, sequence length are fixed.
        self.seq_len = None

    def step(self):
C
chengduo 已提交
509
        """
510 511
        Define operators in each step. step is used in :code:`with` block, OP in :code:`with` block
        will be executed sequence_len times (sequence_len is the length of input)
C
chengduo 已提交
512
        """
Y
Yang Yang 已提交
513
        return BlockGuardWithCompletion(self)
Y
Yu Yang 已提交
514 515 516 517 518

    def _assert_in_rnn_block_(self, method):
        if self.status != StaticRNN.IN_RNN_BLOCK:
            raise ValueError("You must invoke {0} in rnn block".format(method))

519 520 521 522 523 524 525 526 527
    def memory(
        self,
        init=None,
        shape=None,
        batch_ref=None,
        init_value=0.0,
        init_batch_dim_idx=0,
        ref_batch_dim_idx=1,
    ):
528
        """
C
chengduo 已提交
529 530 531
        Create a memory variable for static rnn.
        If the :code:`init` is not None, :code:`memory` will be initialized by
        this Variable. If the :code:`init` is None, :code:`shape` and :code:`batch_ref`
532 533
        must be set, and this function will create a new variable with shape and batch_ref
        to initialize :code:`init` Variable.
C
chengduo 已提交
534

535
        Args:
536
            init(Variable, optional): Tensor used to init memory. If it is not set,
C
chengduo 已提交
537 538
                :code:`shape` and :code:`batch_ref` must be provided.
                Default: None.
539 540 541 542 543 544 545
            shape(list|tuple): When :code:`init` is None use this arg to initialize memory shape.
            NOTE the shape does not contain batch_size. Default: None.
            batch_ref(Variable, optional): When :code:`init` is None, memory's batch size will
            be set as batch_ref's ref_batch_dim_idx value. Default: None.
            init_value(float, optional): When :code:`init` is None, used to init memory's value. Default: 0.0.
            init_batch_dim_idx(int, optional): the batch_size axis of the :code:`init` Variable. Default: 0.
            ref_batch_dim_idx(int, optional): the batch_size axis of the :code:`batch_ref` Variable. Default: 1.
C
chengduo 已提交
546 547

        Returns:
548 549 550 551 552
            Variable: The memory variable.

        Examples 1:
            .. code-block:: python

553
                import paddle
554 555 556 557
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
558
                paddle.enable_static()
559 560 561 562 563 564 565 566
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
567
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
568 569 570 571 572 573 574 575 576 577

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
578 579 580


        Examples 2:
581 582
            .. code-block:: python

583
                import paddle
584 585 586
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers
                vocab_size, hidden_size=10000, 200
587
                paddle.enable_static()
588 589 590 591 592 593 594 595
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
596
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
597 598 599 600 601 602 603 604 605 606
                boot_memory = fluid.layers.data(name='boot', shape=[hidden_size], dtype='float32', lod_level=1)
                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # init memory
                        prev = rnn.memory(init=boot_memory)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # update hidden with prev
                        rnn.update_memory(prev, hidden)
607

608
        """
Y
Yu Yang 已提交
609
        self._assert_in_rnn_block_('memory')
610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
        check_type(
            init,
            "init",
            (Variable, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
        check_type(
            shape,
            "shape",
            (list, tuple, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
        check_type(
            batch_ref,
            "batch_ref",
            (Variable, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
Y
Yu Yang 已提交
628
        if init is None:
629
            if shape is None or batch_ref is None:
Y
Yu Yang 已提交
630
                raise ValueError(
631 632
                    "if init is None, memory at least need shape and batch_ref"
                )
633
            parent_block = self._parent_block()
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
            var_name = unique_name.generate_with_ignorable_key(
                "@".join([self.helper.name, "memory_boot"])
            )
            boot_var = parent_block.create_var(
                name=var_name,
                shape=shape,
                dtype=batch_ref.dtype,
                persistable=False,
            )

            parent_block.append_op(
                type="fill_constant_batch_size_like",
                inputs={'Input': [batch_ref]},
                outputs={'Out': [boot_var]},
                attrs={
                    'value': init_value,
                    'shape': boot_var.shape,
                    'dtype': boot_var.dtype,
                    'input_dim_idx': ref_batch_dim_idx,
                    'output_dim_idx': init_batch_dim_idx,
                },
            )
Y
Yu Yang 已提交
656 657 658 659

            return self.memory(init=boot_var)
        else:
            pre_mem = self.helper.create_variable(
660 661 662
                name=unique_name.generate_with_ignorable_key(
                    "@".join([self.helper.name, "mem"])
                ),
F
fengjiayi 已提交
663
                dtype=init.dtype,
664 665 666 667 668
                shape=init.shape,
            )
            self.memories[pre_mem.name] = StaticRNNMemoryLink(
                init=init, pre_mem=pre_mem
            )
Y
Yu Yang 已提交
669 670 671
            return pre_mem

    def step_input(self, x):
C
chengduo 已提交
672 673 674 675 676 677 678 679
        """
        Mark a sequence as a StaticRNN input.

        Args:
            x(Variable): The input sequence, the shape of x
                should be [seq_len, ...].

        Returns:
680 681 682 683 684
            Variable: The current time step data in the input sequence.

        Examples:
            .. code-block:: python

685
                import paddle
686 687 688 689
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
690
                paddle.enable_static()
691 692 693 694 695 696 697 698
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
699
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
700 701 702 703 704 705 706 707 708 709

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
710

C
chengduo 已提交
711
        """
Y
Yu Yang 已提交
712
        self._assert_in_rnn_block_('step_input')
713
        check_type(x, "x", Variable, "fluid.layers.StaticRNN.step_input")
Y
Yu Yang 已提交
714
        if self.seq_len is None:
Y
Yu Yang 已提交
715
            self.seq_len = x.shape[0]
716
        elif x.shape[0] != -1 and self.seq_len != x.shape[0]:
Y
Yu Yang 已提交
717 718
            raise ValueError("Static RNN only take fix seq_len input")

719 720 721
        ipt = self.helper.create_variable(
            name=x.name, dtype=x.dtype, shape=list(x.shape[1:]), type=x.type
        )
Y
Yu Yang 已提交
722 723 724 725
        self.inputs.append(ipt)
        return ipt

    def step_output(self, o):
C
chengduo 已提交
726 727 728 729 730 731 732 733
        """
        Mark a sequence as a StaticRNN output.

        Args:
            o(Variable): The output sequence.

        Returns:
            None.
734 735 736 737

        Examples:
            .. code-block:: python

738
                import paddle
739 740 741 742
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
743
                paddle.enable_static()
744 745 746 747 748 749 750 751
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
752
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
753 754 755 756 757 758 759 760 761 762 763 764 765

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
                        rnn.step_output(hidden)

                result = rnn()
766

C
chengduo 已提交
767
        """
Y
Yu Yang 已提交
768
        self._assert_in_rnn_block_('step_output')
769
        check_type(o, "o", Variable, "fluid.layers.StaticRNN.step_output")
Y
Yu Yang 已提交
770

X
Xin Pan 已提交
771
        tmp_o = self.helper.create_variable_for_type_inference(dtype=o.dtype)
772 773 774 775 776 777
        self.helper.append_op(
            type='rnn_memory_helper',
            inputs={'X': [o]},
            outputs={'Out': tmp_o},
            attrs={'dtype': o.dtype},
        )
Y
Yu Yang 已提交
778

779 780 781 782 783
        out_var = self._parent_block().create_var(
            name=tmp_o.name,
            shape=[self.seq_len] + list(tmp_o.shape),
            dtype=tmp_o.dtype,
        )
Y
Yu Yang 已提交
784 785 786 787

        self.outputs.append(out_var)

    def output(self, *outputs):
C
chengduo 已提交
788 789 790 791
        """
        Mark the StaticRNN output variables.

        Args:
792
            outputs: The output Tensor, can mark multiple variables as output
C
chengduo 已提交
793 794 795

        Returns:
            None
796 797 798 799

        Examples:
            .. code-block:: python

800
                import paddle
801 802 803 804
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
805
                paddle.enable_static()
806 807 808 809 810 811 812 813
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
814
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
815 816 817 818 819 820 821 822 823 824 825 826 827 828

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
                        # mark each step's hidden and word as output
                        rnn.output(hidden, word)

                result = rnn()
C
chengduo 已提交
829
        """
Y
Yu Yang 已提交
830 831 832 833
        for each in outputs:
            self.step_output(each)

    def update_memory(self, mem, var):
C
chengduo 已提交
834
        """
835
        Update the memory from :code:`mem` to :code:`var`.
C
chengduo 已提交
836 837 838

        Args:
            mem(Variable): the memory variable.
839
            var(Variable): the plain variable generated in RNN block, used to update memory.
T
tianshuo78520a 已提交
840
                           var and mem should have same dims and data type.
C
chengduo 已提交
841 842 843

        Returns:
            None
844

C
chengduo 已提交
845
        """
846 847
        check_type(mem, "mem", Variable, "fluid.layers.StaticRNN.update_memory")
        check_type(var, "var", Variable, "fluid.layers.StaticRNN.update_memory")
Y
Yu Yang 已提交
848 849
        self.memories[mem.name].mem = var

850
    def _parent_block(self):
851
        prog = self.helper.main_program
Y
Yu Yang 已提交
852 853 854 855 856 857 858 859 860 861 862 863 864 865 866
        parent_idx = prog.current_block().parent_idx
        assert parent_idx >= 0
        parent_block = prog.block(parent_idx)
        return parent_block

    def __call__(self, *args, **kwargs):
        if self.status != StaticRNN.AFTER_RNN_BLOCK:
            raise ValueError("RNN output can only be retrieved after rnn block")
        if len(self.outputs) == 0:
            raise ValueError("RNN has no output")
        elif len(self.outputs) == 1:
            return self.outputs[0]
        else:
            return self.outputs

867
    def _complete_op(self):
868 869
        main_program = self.helper.main_program
        rnn_block = main_program.current_block()
870
        parent_block = self._parent_block()
Y
Yu Yang 已提交
871 872 873 874 875 876 877 878 879 880 881 882 883 884

        local_inputs = set()

        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for oname in op.output_names:
                for out_var_name in op.output(oname):
                    local_inputs.add(out_var_name)

        for var in self.inputs:
            local_inputs.add(var.name)
        for m in self.memories:
            local_inputs.add(m)

C
chengduo 已提交
885 886 887
        # NOTE(zcd): the params have two categories of variables.
        #   - the variables that are the out of StaticRnn.
        #   - the variables that are the parameters of some layers, for example, conv2d.
Y
Yu Yang 已提交
888 889 890 891 892 893 894 895
        params = list()
        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for iname in op.input_names:
                for in_var_name in op.input(iname):
                    if in_var_name not in local_inputs:
                        params.append(in_var_name)

896 897 898
        parameters = [
            parent_block._find_var_recursive(name) for name in set(params)
        ]
Y
Yu Yang 已提交
899 900

        step_scope = parent_block.create_var(
901 902
            type=core.VarDesc.VarType.STEP_SCOPES
        )
Y
Yu Yang 已提交
903 904 905 906

        inlinks = [parent_block.var(i.name) for i in self.inputs]
        outlinks = self.outputs

C
chengduo 已提交
907
        # NOTE(zcd): the states maybe empty in some case.
Y
Yu Yang 已提交
908 909 910
        boot_memories = []
        pre_memories = []
        memories = []
911
        for _, mem in self.memories.items():
Y
Yu Yang 已提交
912 913
            boot_memories.append(mem.init)
            pre_memories.append(mem.pre_mem.name)
914 915 916
            assert (
                mem.mem is not None
            ), "%s should be updated in every step." % (mem.init.name)
Y
Yu Yang 已提交
917 918
            mem_var = rnn_block.var(mem.mem.name)
            assert isinstance(mem_var, Variable)
X
Xin Pan 已提交
919
            new_mem = self.helper.create_variable_for_type_inference(
920 921 922 923 924 925 926 927
                dtype=mem_var.dtype
            )
            rnn_block.append_op(
                type='rnn_memory_helper',
                inputs={'X': [mem_var]},
                outputs={'Out': [new_mem]},
                attrs={'dtype': mem_var.dtype},
            )
Y
Yu Yang 已提交
928 929 930

            memories.append(new_mem.name)

931 932 933 934 935 936 937 938 939 940 941 942 943 944 945
        parent_block.append_op(
            type='recurrent',
            inputs={
                'inputs': inlinks,
                'initial_states': boot_memories,
                'parameters': parameters,
            },
            outputs={'outputs': outlinks, 'step_scopes': [step_scope]},
            attrs={
                'has_states': len(pre_memories) > 0,
                'ex_states': pre_memories,
                'states': memories,
                'sub_block': rnn_block,
            },
        )
Y
Yu Yang 已提交
946 947


948
# (TODO: Mine) There exists dependency. It will be removed later.
Y
Yang Yang(Tony) 已提交
949 950 951 952
class WhileGuard(BlockGuard):
    def __init__(self, while_op):
        if not isinstance(while_op, While):
            raise TypeError("WhileGuard takes a while op")
953
        super().__init__(while_op.helper.main_program)
Y
Yang Yang(Tony) 已提交
954 955 956 957
        self.while_op = while_op

    def __enter__(self):
        self.while_op.status = While.IN_WHILE_BLOCK
958
        return super().__enter__()
Y
Yang Yang(Tony) 已提交
959 960 961 962 963

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None:
            return False
        self.while_op.status = While.AFTER_WHILE_BLOCK
964
        self.while_op._complete()
965
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yang Yang(Tony) 已提交
966 967


968
# (TODO: Mine) There exists dependency. It will be removed later.
969 970 971
def get_inputs_outputs_in_block(
    current_block, inner_inputs, inner_outputs, helper
):
972 973 974 975 976 977 978 979
    """
    Find inputs and outputs in current control flow block.
    :param current_block: Current control flow block.
    :param inner_inputs: Input var name of ops in current block.
    :param inner_outputs: Output var name of ops in current block.
    :return: inner_inputs, inner_outputs
    """

980 981 982 983 984 985 986 987 988 989 990 991 992
    def is_ignore_vars(op, var_name):
        # NOTE(dev): There are some persistable var created in some non-standard API
        # such as "contrib.layers.shuffle_batch". It create a "Seed" used both in
        # Input and Output. This var shall not be considered as a loop_var in
        # control_flow.
        IGNORE_VAR_NAMES = {"shuffle_batch": ["shuffle_batch_seed"]}
        if op.type in IGNORE_VAR_NAMES:
            var_names = IGNORE_VAR_NAMES[op.type]
            for name in var_names:
                if name in var_name:
                    return True
        return False

993 994 995 996 997 998 999 1000
    # Step1: update inner_inputs and inner_outputs
    # NOTE: Here assumes that all variables are input or output of Ops,
    # but some variables are created without appendding a real op.
    # For example, in `arr = create_array(dtype)`, `arr` is not a output of a op.
    for op in current_block.ops:
        assert isinstance(op, Operator)
        for iname in op.input_names:
            for in_var_name in op.input(iname):
1001
                if in_var_name not in inner_outputs and not is_ignore_vars(
1002 1003
                    op, in_var_name
                ):
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018
                    inner_inputs.add(in_var_name)

        for oname in op.output_names:
            for out_var_name in op.output(oname):
                inner_outputs.add(out_var_name)

    # Step2: Remove LOD_TENSOR_ARRAY created in current control flow block.
    remove_inner_inputs = set()
    parent_block = helper.main_program.block(current_block.parent_idx)

    for in_var_name in inner_inputs:
        parent_block_var = parent_block._find_var_recursive(in_var_name)
        current_block_var = None
        if current_block.has_var(in_var_name):
            current_block_var = current_block.var(in_var_name)
1019 1020 1021 1022 1023
        if (
            not parent_block_var
            and current_block_var
            and current_block_var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
        ):
1024 1025 1026 1027 1028 1029 1030
            remove_inner_inputs.add(in_var_name)

    inner_inputs = inner_inputs - remove_inner_inputs

    return inner_inputs, inner_outputs


1031
# (TODO: Mine) There exists dependency. It will be removed later.
1032
class While:
X
Xin Pan 已提交
1033
    """
1034
    :api_attr: Static Graph
1035

1036
    while loop control flow. Repeat while body until cond is False.
X
Xin Pan 已提交
1037

1038 1039 1040 1041
    Note:
        A new OP :ref:`api_fluid_layers_while_loop` is highly recommended instead of ``While`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_fluid_layers_while_loop` is easier to use and is called with less code but does the same thing as ``While`` .

1042 1043 1044 1045 1046 1047
    Notice:
        Local variables created in ``While`` are similar to that created in while of C++, and cannot be referenced externally.
        As a result, they cannot be obtained through ``fetch_list`` of ``Executor``. If you would like to access the variable
        out of ``while`` , PaddlePaddle provides ``assign`` API to assign local variables to external. Please refer to example
        code 2 or refer to `issue#22724 <https://github.com/PaddlePaddle/Paddle/issues/22724>`_.

X
Xin Pan 已提交
1048
    Args:
1049
        cond(Variable): A Tensor whose data type is bool controlling whether to continue looping.
G
guofei 已提交
1050
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
1051
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
X
Xin Pan 已提交
1052

1053
    Examples 1:
X
Xin Pan 已提交
1054
          .. code-block:: python
1055

1056
            import paddle.fluid as fluid
1057 1058 1059 1060 1061
            import numpy as np

            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)           # loop counter

            loop_len = fluid.layers.fill_constant(shape=[1],dtype='int64', value=10)    # loop length
1062

L
LiYuRio 已提交
1063
            cond = paddle.less_than(x=i, y=loop_len)
1064
            while_op = fluid.layers.While(cond=cond)
1065
            with while_op.block():
1066
                i = paddle.increment(x=i, value=1)
L
LiYuRio 已提交
1067
                paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
1068 1069 1070 1071 1072

            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[i])
1073 1074 1075 1076 1077 1078
            print(res) # [array([10])]


    Examples 2:
          .. code-block:: python

L
LiYuRio 已提交
1079
            import paddle
1080 1081 1082
            import paddle.fluid as fluid
            import numpy as np

1083
            paddle.enable_static()
1084 1085 1086 1087 1088 1089
            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
            loop_len = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
            one = fluid.layers.fill_constant(shape=[1], dtype='float32', value=1)
            data = fluid.data(name='data', shape=[1], dtype='float32')
            sums = fluid.layers.fill_constant(shape=[1], dtype='float32', value=0)  # Define the variable to be obtained ouside of While, which name should be different from the variable inside the While to be obtained

L
LiYuRio 已提交
1090
            cond = paddle.less_than(x=i, y=loop_len)
1091 1092 1093 1094
            while_op = fluid.layers.While(cond=cond)
            with while_op.block():
                sums_tensor = fluid.layers.elementwise_add(x=data, y=data)
                fluid.layers.assign(sums_tensor, sums)  # Update the value of sums_tensor defined in While to the sums which defined outside of While through layers.assign
1095
                i = paddle.increment(x=i, value=1)
1096
                data = fluid.layers.elementwise_add(x=data, y=one)
L
LiYuRio 已提交
1097
                paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
1098 1099 1100 1101 1102 1103

            feed_data = np.ones(1).astype('float32')
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())
            res = exe.run(fluid.default_main_program(), feed={'data': feed_data}, fetch_list=sums)
            print(res[0])  # [2.]    # Because the data in While does not update the value outside the While, the value of sums is [2.] after the loop
X
Xin Pan 已提交
1104 1105
    """

Y
Yang Yang(Tony) 已提交
1106 1107 1108 1109
    BEFORE_WHILE_BLOCK = 0
    IN_WHILE_BLOCK = 1
    AFTER_WHILE_BLOCK = 2

C
chengduo 已提交
1110
    def __init__(self, cond, is_test=False, name=None):
1111
        self.helper = LayerHelper("while", name=name)
Y
Yang Yang(Tony) 已提交
1112
        self.status = While.BEFORE_WHILE_BLOCK
1113
        check_variable_and_dtype(cond, 'cond', ['bool'], 'fluid.layers.While')
Y
Yang Yang(Tony) 已提交
1114
        if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
1115
            raise TypeError(
1116 1117 1118 1119
                "condition expected shape as [1], but given shape as {0}.".format(
                    list(cond.shape)
                )
            )
Y
Yang Yang(Tony) 已提交
1120
        self.cond_var = cond
C
chengduo 已提交
1121
        self.is_test = is_test
Y
Yang Yang(Tony) 已提交
1122 1123 1124 1125

    def block(self):
        return WhileGuard(self)

1126
    def _complete(self):
Y
Yang Yang(Tony) 已提交
1127 1128
        main_program = self.helper.main_program
        while_block = main_program.current_block()
1129
        parent_block = main_program.block(
1130 1131
            main_program.current_block().parent_idx
        )
Y
Yang Yang(Tony) 已提交
1132 1133 1134

        inner_outputs = {self.cond_var.name}
        x_name_list = set()
1135
        x_name_list, inner_outputs = get_inputs_outputs_in_block(
1136 1137
            while_block, x_name_list, inner_outputs, self.helper
        )
Y
Yang Yang(Tony) 已提交
1138 1139 1140

        out_vars = []
        for inner_out_name in inner_outputs:
X
Xin Pan 已提交
1141 1142 1143
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_vars.append(inner_var)
Y
Yang Yang(Tony) 已提交
1144

1145
        x_name_list |= set(map(lambda x: x.name, out_vars))
1146 1147 1148
        # NOTE(dev): cond_var has been contained in Input('Condition'), so
        # we remove it from Input('X')
        x_name_list -= {self.cond_var.name}
1149

Y
Yang Yang(Tony) 已提交
1150
        step_scope = parent_block.create_var(
1151 1152
            type=core.VarDesc.VarType.STEP_SCOPES
        )
Y
Yang Yang(Tony) 已提交
1153 1154 1155 1156

        parent_block.append_op(
            type='while',
            inputs={
1157 1158 1159 1160 1161
                'X': [
                    parent_block._var_recursive(x_name)
                    for x_name in x_name_list
                ],
                'Condition': [self.cond_var],
1162
            },
1163 1164 1165
            outputs={'Out': out_vars, 'StepScopes': [step_scope]},
            attrs={'sub_block': while_block, "is_test": self.is_test},
        )
Y
Yang Yang(Tony) 已提交
1166 1167


1168
support_ret_buildin_type = (bool, float, int)
1169 1170


1171
# (TODO: Mine) There exists dependency. It will be removed later.
1172
def assign_skip_lod_tensor_array(input, output):
1173
    """
1174
    Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
1175
    """
1176 1177

    def has_shape_diff(x_var, y_var):
1178 1179
        if len(x_var.shape) != len(y_var.shape):
            return True
1180
        for x_dim, y_dim in zip(x_var.shape, y_var.shape):
1181 1182
            if x_dim != y_dim and -1 not in [x_dim, y_dim]:
                return True
1183 1184
        return False

1185
    if not isinstance(input, (Variable, core.VarBase)):
1186
        if isinstance(output, Variable) and isinstance(
1187 1188
            input, support_ret_buildin_type
        ):
1189 1190 1191
            assign(input, output)
        else:
            output = input
1192 1193
        return

1194 1195
    if input.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        main_program = input.block.program
1196
        parent_block = main_program.block(
1197 1198
            main_program.current_block().parent_idx
        )
1199 1200 1201
        if parent_block and not parent_block._find_var_recursive(input.name):
            assign(input, output)
    else:
1202 1203 1204 1205 1206
        if (
            isinstance(output, Variable)
            and isinstance(input, Variable)
            and has_shape_diff(input, output)
        ):
1207
            warnings.warn(
1208 1209 1210 1211
                "In dy2static mode, we attemp to assign a variable with shape {} into a variable with shape{}, which is not always right.".format(
                    input.shape, output.shape
                )
            )
1212
        assign(input, output)
1213 1214


1215
# (TODO: Mine) There exists dependency (jit.dy2static.convert_operators). It will be removed later.
G
guofei 已提交
1216
def while_loop(cond, body, loop_vars, is_test=False, name=None):
G
guofei 已提交
1217
    """
1218 1219
    :api_attr: Static Graph

G
guofei 已提交
1220 1221
    while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.

1222 1223 1224 1225
    Notice:
        Local variables defined in ``body`` cannot be obtained through ``fetch_list`` of ``Executor`` , variables should
        be defined outside ``body`` and placed in ``loop_vars`` for looping, then these variables can be fetched by ``fetch_list`` .

G
guofei 已提交
1226
    Args:
1227
        cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
1228
            as many arguments as ``loop_vars`` .
1229 1230 1231
        body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
            (length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
        loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
G
guofei 已提交
1232
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
G
guofei 已提交
1233 1234
        name(str, optional): Normally there is no need for users to set this property. For more information, please
            refer to :ref:`api_guide_Name`. Default is None.
1235

G
guofei 已提交
1236
    Returns:
C
Chen Long 已提交
1237
        A list or tuple of Tensors or LoDTensorArrays which returned by ``body`` .
G
guofei 已提交
1238 1239 1240 1241

    Examples:
        .. code-block:: python

1242 1243 1244
            import paddle
            paddle.enable_static()

1245 1246
            def cond(i, ten):
                return i < ten
G
guofei 已提交
1247

1248 1249 1250
            def body(i, ten):
                i = i + 1
                return [i, ten]
G
guofei 已提交
1251

C
Chen Long 已提交
1252 1253 1254 1255 1256 1257
            main_program = paddle.static.default_main_program()
            startup_program = paddle.static.default_startup_program()
            with paddle.static.program_guard(main_program, startup_program):
                i = paddle.full(shape=[1], fill_value=0, dtype='int64')     # loop counter
                ten = paddle.full(shape=[1], fill_value=10, dtype='int64')  # loop length
                i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
1258

C
Chen Long 已提交
1259
                exe = paddle.static.Executor(paddle.CPUPlace())
1260
                res = exe.run(main_program, feed={}, fetch_list=[i])
G
guofei 已提交
1261 1262 1263 1264 1265 1266 1267 1268
                print(res) # [array([10])]
    """
    helper = LayerHelper('while_loop', **locals())

    if not callable(cond):
        raise TypeError("cond in while_loop should be callable")
    if not callable(body):
        raise TypeError("body in while_loop should be callable")
1269
    check_type(loop_vars, 'loop_vars', (list, tuple), 'fluid.layers.while_loop')
G
guofei 已提交
1270 1271 1272 1273
    if len(loop_vars) == 0:
        raise ValueError("loop_vars in while_loop should not be empty")

    pre_cond = cond(*loop_vars)
1274 1275 1276
    check_variable_and_dtype(
        pre_cond, 'var of cond returned', ['bool'], 'fluid.layers.while_loop'
    )
G
guofei 已提交
1277 1278
    if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
        raise TypeError(
1279
            "the shape of the variable returned by cond should be [1],"
1280 1281
            "but given shape as {0}.".format(list(pre_cond.shape))
        )
G
guofei 已提交
1282

J
Jiabin Yang 已提交
1283
    if _non_static_mode():
1284
        now_cond = pre_cond.numpy()[0]
1285
        while now_cond:
1286 1287 1288 1289 1290 1291
            output_vars = body(*loop_vars)
            if not isinstance(output_vars, (list, tuple)):
                output_vars = [output_vars]
            if len(output_vars) != len(loop_vars):
                raise ValueError(
                    "body in while_loop should return the same arity "
1292 1293
                    "(length and structure) and types as loop_vars"
                )
1294
            now_cond = cond(*output_vars).numpy()[0]
1295
            map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
1296 1297
        return loop_vars

G
guofei 已提交
1298
    while_loop_block = While(pre_cond, is_test, name)
1299
    has_mutable_vars_in_loop = hold_mutable_vars(loop_vars)
G
guofei 已提交
1300
    with while_loop_block.block():
1301 1302 1303 1304 1305 1306 1307 1308 1309
        # If a variable with mutable type is included in loop_vars, like `dict/list`,
        # modifying it in the body function will cause origin variable to be modified
        # synchronously. This will raise an assignment error out of while block.
        # Here we make a copy of the mutable vars to avoid this problem.
        if has_mutable_vars_in_loop:
            new_loop_vars = copy_mutable_vars(loop_vars)
            output_vars = body(*new_loop_vars)
        else:
            output_vars = body(*loop_vars)
1310 1311
        if not isinstance(output_vars, (list, tuple)):
            output_vars = [output_vars]
1312
        try:
1313
            loop_vars = _deal_with_undefined_var(output_vars, loop_vars)
1314 1315
            assert_same_structure(output_vars, loop_vars, check_types=False)
        except ValueError as e:
1316 1317
            raise ValueError(
                "body in while_loop should return the same arity "
1318 1319
                "(length and structure) as loop_vars: {0}".format(e)
            )
1320
        now_cond = cond(*output_vars)
1321
        map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
G
guofei 已提交
1322 1323 1324 1325
        assign(now_cond, pre_cond)
    return loop_vars


1326
# (TODO: Mine) There exists dependency. It will be removed later.
1327
def _deal_with_undefined_var(output_vars, loop_vars):
1328 1329 1330 1331 1332 1333 1334
    """Deal with undefined var cases, We create undefined variable based on the results of body().
    In Dy2Static, we use undefined var to represent the var created in control flow. This function
    expand the loop_vars and replace original loop_vars.
    1. UndefinedVar = Variable      # create a variable
    2. UndefinedVar = None          # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
    3. UndefinedVar = List(int)     # create a list of variable
    4. UndefinedVar = value         # create a variable
1335
    """
1336
    from paddle.jit.dy2static.utils import (
1337 1338 1339
        UndefinedVar,
        create_undefined_variable,
    )
1340 1341

    def create_var_like(o_var):
1342 1343 1344 1345
        if (
            isinstance(o_var, (Variable,) + support_ret_buildin_type)
            or o_var is None
        ):
1346
            return create_undefined_variable()
1347
        if is_sequence(o_var):
1348
            """
1349 1350 1351
            Create a complex container class inside the body of while, including Python list and python Dict
            """
            return map_structure(lambda x: create_undefined_variable(), o_var)
1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364

    if len(output_vars) != len(loop_vars):
        raise ValueError("The length of loop_vars should be the same.")

    results = []
    for o_var, l_var in zip(output_vars, loop_vars):
        if isinstance(l_var, UndefinedVar) or l_var is None:
            results.append(create_var_like(o_var))
        else:
            results.append(l_var)
    return results


1365
def array_write(x, i, array=None):
1366
    """
1367 1368 1369 1370
    This OP writes the input ``x`` into the i-th position of the ``array``
    :ref:`api_fluid_LoDTensorArray` and returns the modified array.
    If ``array`` is none, a new LoDTensorArray will be created and returned.
    This OP is often used together with :ref:`api_fluid_layers_array_read` OP.
1371 1372

    Args:
1373 1374 1375 1376
        x (Variable): The input data to be written into array. It's multi-dimensional
            Tensor or LoDTensor. Data type: float32, float64, int32, int64.
        i (Variable): 1-D Tensor with shape [1], which represents the position into which
            ``x`` is written. Data type: int64.
1377 1378
        array (LoDTensorArray, optional): The LoDTensorArray into which ``x`` is written.
            The default value is None, when a new LoDTensorArray will be created and returned
1379
            as a result.
1380

1381
    Returns:
1382
        Variable: The input ``array`` after ``x`` is written into.
1383 1384

    Examples:
D
dzhwinter 已提交
1385
        .. code-block:: python
1386

1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409
            import paddle.fluid as fluid
            tmp = fluid.layers.fill_constant(shape=[3, 2], dtype='int64', value=5)
            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
            # Write tmp into the position of arr with subscript 10 and return arr.
            arr = fluid.layers.array_write(tmp, i=i)

            # Now, arr is a LoDTensorArray with length 11. We can use array_read OP to read
            # the data at subscript 10 and print it out.
            item = fluid.layers.array_read(arr, i=i)
            input = fluid.layers.Print(item, message="The content of i-th LoDTensor:")
            main_program = fluid.default_main_program()
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(main_program)

            # The printed result is:
            # 1570533133    The content of i-th LoDTensor:  The place is:CPUPlace
            # Tensor[array_read_0.tmp_0]
            #    shape: [3,2,]
            #    dtype: l
            #    data: 5,5,5,5,5,5,

            # the output is 2-D Tensor with shape [3,2], which is tmp above.
            # dtype is the corresponding C++ data type, which may vary in different environments.
1410 1411
            # Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
            #       so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
1412 1413
            #       and '__int64' on Windows. They both represent 64-bit integer variables.

1414
    """
J
Jiabin Yang 已提交
1415
    if _non_static_mode():
1416 1417 1418 1419 1420 1421 1422 1423 1424
        assert isinstance(
            x, Variable
        ), "The input data 'x' in array_write must be Variable in dygraph mode"
        assert isinstance(
            i, Variable
        ), "The index 'i' in array_write must be Variable in dygraph mode"
        assert i.shape == [
            1
        ], "The shape of index 'i' should be [1] in dygraph mode"
1425
        i = i.numpy().item(0)
1426
        if array is None:
1427
            array = paddle.tensor.create_array(x.dtype)
1428
        assert isinstance(
1429 1430
            array, list
        ), "The 'array' in array_write must be a list in dygraph mode"
1431 1432 1433 1434 1435 1436 1437 1438 1439
        assert i <= len(
            array
        ), "The index 'i' should not be greater than the length of 'array' in dygraph mode"
        if i < len(array):
            array[i] = x
        else:
            array.append(x)
        return array

1440 1441
    check_variable_and_dtype(i, 'i', ['int64'], 'array_write')
    check_type(x, 'x', (Variable), 'array_write')
Y
Yu Yang 已提交
1442
    helper = LayerHelper('array_write', **locals())
1443
    if array is not None:
1444 1445 1446 1447
        if (
            not isinstance(array, Variable)
            or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY
        ):
1448
            raise TypeError(
1449 1450
                "array should be tensor array vairable in array_write Op"
            )
Y
Yu Yang 已提交
1451 1452 1453 1454
    if array is None:
        array = helper.create_variable(
            name="{0}.out".format(helper.name),
            type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
1455 1456 1457 1458 1459 1460 1461
            dtype=x.dtype,
        )
    helper.append_op(
        type='write_to_array',
        inputs={'X': [x], 'I': [i]},
        outputs={'Out': [array]},
    )
Y
Yu Yang 已提交
1462 1463 1464
    return array


1465
def array_read(array, i):
1466
    """
1467
    This OP is used to read data at the specified position from the input array
1468
    :ref:`api_fluid_LoDTensorArray` . ``array`` is the input array and ``i``
1469
    is the specified read position. This OP is often used together with
1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481
    :ref:`api_fluid_layers_array_write` OP.

    Case 1:
    ::
        Input:
            The shape of first three tensors are [1], and that of the last one is [1,2]:
                array = ([0.6], [0.1], [0.3], [0.4, 0.2])
            And:
                i = [3]

        Output:
            output = [0.4, 0.2]
1482

K
kavyasrinet 已提交
1483
    Args:
1484 1485 1486
        array (LoDTensorArray): The input LoDTensorArray.
        i (Variable): 1-D Tensor, whose shape is [1] and dtype is int64. It represents the
            specified read position of ``array``.
1487

K
kavyasrinet 已提交
1488
    Returns:
1489
        Variable: The LoDTensor or Tensor that is read at the specified position of ``array``.
1490

K
kavyasrinet 已提交
1491
    Examples:
1492 1493
        .. code-block:: python

1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521
            # First we're going to create a LoDTensorArray, then we're going to write the Tensor into
            # the specified position, and finally we're going to read the Tensor at that position.
            import paddle.fluid as fluid
            arr = fluid.layers.create_array(dtype='float32')
            tmp = fluid.layers.fill_constant(shape=[3, 2], dtype='int64', value=5)
            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
            # tmp is the Tensor with shape [3,2], and if we write it into the position with subscript 10
            # of the empty-array: arr, then the length of arr becomes 11.
            arr = fluid.layers.array_write(tmp, i, array=arr)
            # Read the data of the position with subscript 10.
            item = fluid.layers.array_read(arr, i)

            # You can print out the data via executor.
            input = fluid.layers.Print(item, message="The LoDTensor of the i-th position:")
            main_program = fluid.default_main_program()
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(main_program)

            # The printed result is:

            # 1569588169  The LoDTensor of the i-th position: The place is:CPUPlace
            # Tensor[array_read_0.tmp_0]
            #    shape: [3,2,]
            #    dtype: l
            #    data: 5,5,5,5,5,5,

            # the output is 2-D Tensor with shape [3,2].
            # dtype is the corresponding C++ data type, which may vary in different environments.
1522 1523
            # Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
            #       so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
1524
            #       and '__int64' on Windows. They both represent 64-bit integer variables.
1525
    """
J
Jiabin Yang 已提交
1526
    if _non_static_mode():
1527
        assert isinstance(
1528 1529
            array, list
        ), "The 'array' in array_read must be list in dygraph mode"
1530 1531 1532 1533 1534 1535
        assert isinstance(
            i, Variable
        ), "The index 'i' in array_read must be Variable in dygraph mode"
        assert i.shape == [
            1
        ], "The shape of index 'i' should be [1] in dygraph mode"
1536
        i = i.numpy().item(0)
1537 1538
        return array[i]

1539
    check_variable_and_dtype(i, 'i', ['int64'], 'array_read')
Y
Yu Yang 已提交
1540
    helper = LayerHelper('array_read', **locals())
1541 1542 1543 1544
    if (
        not isinstance(array, Variable)
        or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY
    ):
Y
Yu Yang 已提交
1545
        raise TypeError("array should be tensor array vairable")
X
Xin Pan 已提交
1546
    out = helper.create_variable_for_type_inference(dtype=array.dtype)
1547 1548 1549 1550 1551
    helper.append_op(
        type='read_from_array',
        inputs={'X': [array], 'I': [i]},
        outputs={'Out': [out]},
    )
Y
Yu Yang 已提交
1552
    return out
Y
Yang Yu 已提交
1553 1554


Y
Yu Yang 已提交
1555
class ConditionalBlockGuard(BlockGuard):
F
fengjiayi 已提交
1556
    """
1557 1558 1559
    ConditionalBlockGuard is derived from BlockGuard. It is dedicated for
    holding a ConditionalBlock, and helping users entering and exiting the
    ConditionalBlock via Python's 'with' keyword. However, ConditionalBlockGuard
F
fengjiayi 已提交
1560 1561 1562
    is generally an internal component of IfElse, users should not use it directly.
    """

Y
Yu Yang 已提交
1563
    def __init__(self, block):
1564
        check_type(block, "block", ConditionalBlock, "ConditionalBlockGuard")
1565
        super().__init__(block.helper.main_program)
Y
Yu Yang 已提交
1566 1567 1568
        self.block = block

    def __enter__(self):
1569
        return super().__enter__()
Y
Yu Yang 已提交
1570 1571 1572

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.block.complete()
1573
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yu Yang 已提交
1574 1575


1576
class ConditionalBlock:
Y
Yan Chunwei 已提交
1577 1578 1579 1580 1581 1582 1583 1584
    '''
    **ConditionalBlock**

    ConditionalBlock is an operator that bind a block to a specific condition,
    if the condition matches, the corresponding block will be executed.

    Args:
        inputs (Variable): bool conditions.
T
tianshuo78520a 已提交
1585
        is_scalar_condition (bool): whether the branch is controlled by a scalar.
Y
Yan Chunwei 已提交
1586 1587 1588 1589 1590
        name(str): name of this ConditionalBlock.

    Examples:
        .. code-block:: python

L
LiYuRio 已提交
1591
             import paddle
1592
             import paddle.fluid as fluid
L
LiYuRio 已提交
1593
             cond = paddle.less_than(x=label, y=limit)
Y
Yan Chunwei 已提交
1594 1595 1596 1597 1598 1599 1600 1601 1602 1603
             true_image, false_image = layers.split_lod_tensor(
                 input=image, mask=cond)
             true_cond = layers.ConditionalBlock([true_image])

             with true_cond.block():
                 ...
             with false_cond.block():
                 ...
    '''

1604
    def __init__(self, inputs, is_scalar_condition=False, name=None):
Y
Yu Yang 已提交
1605
        for each_input in inputs:
1606
            check_type(each_input, "input", Variable, "ConditionalBlock")
Y
Yu Yang 已提交
1607
        self.inputs = inputs
1608
        self.is_scalar_condition = is_scalar_condition
1609
        self.helper = LayerHelper('conditional_block', name=name)
Y
Yu Yang 已提交
1610 1611 1612 1613 1614 1615 1616 1617 1618 1619

    def block(self):
        return ConditionalBlockGuard(self)

    def complete(self):
        inside_block = self.helper.main_program.current_block()
        parent_block = self.helper.main_program.block(inside_block.parent_idx)

        intermediate = set()
        params = set()
1620 1621 1622
        params, intermediate = get_inputs_outputs_in_block(
            inside_block, params, intermediate, helper=self.helper
        )
Y
Yu Yang 已提交
1623

1624 1625 1626
        # Todo(liym27) Here assume that all params are in recursive parent block
        # but when minimize() called in control flow, some params may be in
        # conditional grad block
Y
Yu Yang 已提交
1627
        param_list = [
W
Wu Yi 已提交
1628
            parent_block._var_recursive(each_name) for each_name in params
Y
Yu Yang 已提交
1629 1630
        ]

X
Xin Pan 已提交
1631 1632 1633 1634 1635
        out_list = []
        for inner_out_name in intermediate:
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_list.append(inner_var)
Y
Yu Yang 已提交
1636 1637

        step_scope = parent_block.create_var(
1638 1639
            type=core.VarDesc.VarType.STEP_SCOPES
        )
1640
        conditional_block_op = parent_block.append_op(
Y
Yu Yang 已提交
1641 1642
            type='conditional_block',
            inputs={
1643 1644
                'Cond': self.inputs,
                'Input': param_list,
Y
Yu Yang 已提交
1645
            },
1646
            outputs={'Out': out_list, 'Scope': [step_scope]},
1647 1648
            attrs={
                'sub_block': inside_block,
1649 1650 1651
                'is_scalar_condition': self.is_scalar_condition,
            },
        )
1652

1653
        if self.need_append_conditional_block_grad(inside_block):
1654 1655 1656
            self.append_conditional_block_grad(
                parent_block, inside_block, conditional_block_op
            )
1657 1658 1659

    def need_append_conditional_block_grad(self, inside_block):
        grad_sub_block_idx = inside_block.backward_block_idx
1660
        inside_block_idx = inside_block.idx
1661

1662 1663
        # if inside_block have grad_block and grad_block is not itself,
        # we will append conditional block grad.
1664 1665 1666
        return (
            grad_sub_block_idx != -1 and grad_sub_block_idx != inside_block_idx
        )
1667

1668 1669 1670
    def append_conditional_block_grad(
        self, parent_block, inside_block, conditional_block_op
    ):
1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705
        '''
        Append op `conditional_block_grad` manually.
        When `optimizer.minimize/append_backward` is called in Paddle control flow,
        grad ops will be appended before appending op `conditional_block` so that
        op `conditional_block_grad` can't be appended when calling
        `optimizer.minimize/append_backward`. After appending op `conditional_block`,
        `conditional_block_grad` is appended manually.

        Args:
            parent_block (Block): The block that `conditional_block_op` blongs to.
            inside_block (Block): The sub block of `conditional_block_op`.
            conditional_block_op (Operator): The forward op conditional_block.
        '''

        grad_sub_block_idx = inside_block.backward_block_idx
        grad_sub_block = self.helper.main_program.block(grad_sub_block_idx)

        intermediate = set()
        params = set()

        for each_op in grad_sub_block.ops:
            assert isinstance(each_op, Operator)
            for iname in each_op.input_names:
                for in_var_name in each_op.input(iname):
                    if in_var_name not in intermediate:
                        params.add(in_var_name)

            for oname in each_op.output_names:
                for out_var_name in each_op.output(oname):
                    intermediate.add(out_var_name)

        param_list = []
        for inner_input_name in params:
            inner_var = parent_block._find_var_recursive(inner_input_name)
            if inner_var:
1706
                param_list.append(inner_var.name)
1707 1708

        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
1709 1710
            conditional_block_op.desc, set(), [grad_sub_block.desc]
        )
1711 1712 1713 1714 1715 1716 1717 1718 1719

        # append op_desc in grad_op_descs to target_block
        op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
        backward = core.op_proto_and_checker_maker.OpRole.Backward
        new_op_desc = parent_block.desc.append_op()
        new_op_desc.copy_from(grad_op_desc[0])
        new_op_desc._set_attr(op_role_attr_name, backward)
        # set input and output manually
        new_op_desc.set_input('Input', param_list)
1720 1721 1722
        new_op_desc.set_output(
            'Input@GRAD', [param + "@GRAD" for param in param_list]
        )
1723 1724 1725

        new_vars = set()
        for grad_var_name in new_op_desc.output_arg_names():
1726 1727 1728 1729
            if (
                grad_sub_block.desc.has_var_recursive(grad_var_name.encode())
                or grad_var_name == core.empty_var_name()
            ):
1730
                continue
1731
            grad_sub_block.desc.var(grad_var_name.encode())
1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745
            new_vars.add(grad_var_name)
            if grad_var_name not in op_grad_to_var:
                continue

        # infer_shape and infer_type
        new_op_desc.infer_var_type(grad_sub_block.desc)
        new_op_desc.infer_shape(grad_sub_block.desc)

        for arg in new_op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_shape_(arg, grad_sub_block)

        self.helper.main_program._sync_with_cpp()

1746

1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764
def _to_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return [x]
    return to_sequence(x)


def _is_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return False
    return is_sequence(x)


1765
def expand_undefined_var(nest1, nest2, names):
1766 1767 1768 1769
    """TODO: make this function recursively.
    nest1: Var1, (UndefinedVar, [1,2,3])
    nest2: Var2, ([1,2,3,4], UndefinedVar)
    In this case, we should not expand recursively.
1770
    """
1771
    from paddle.jit.dy2static.utils import UndefinedVar
1772
    from paddle.jit.dy2static.return_transformer import (
1773 1774
        RETURN_VALUE_PREFIX,
    )
1775 1776

    def pack_undefined_var_as(seq):
1777 1778 1779
        return pack_sequence_as(
            seq, [UndefinedVar("padding") for i in flatten(seq)]
        )
1780

1781
    def map_fn(n1, n2, name, order):
1782 1783 1784
        if not name.startswith(RETURN_VALUE_PREFIX) and (
            isinstance(n1, UndefinedVar) or n1 is None
        ):
1785 1786 1787 1788 1789 1790
            if n1 is None and n2 is not None:
                if order == 0:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
1791 1792 1793
                            name, type(n1), n1, type(n2), n2
                        )
                    )
1794 1795 1796 1797 1798
                else:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
1799 1800 1801
                            name, type(n2), n2, type(n1), n1
                        )
                    )
1802 1803 1804 1805
            return pack_undefined_var_as(n2)
        return n1

    nest1_out = list(
1806 1807
        map(
            map_fn,
1808 1809 1810 1811
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(names),
            [0 for i in _to_sequence_except_dict(names)],
1812 1813
        )
    )
1814
    nest2_out = list(
1815 1816
        map(
            map_fn,
1817 1818 1819 1820
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(names),
            [1 for i in _to_sequence_except_dict(names)],
1821 1822
        )
    )
1823
    if not _is_sequence_except_dict(nest1):
1824
        nest1_out = nest1_out[0]
1825
    if not _is_sequence_except_dict(nest2):
1826
        nest2_out = nest2_out[0]
1827 1828 1829
    return nest1_out, nest2_out


1830
class Switch:
Q
qiaolongfei 已提交
1831
    """
1832
    :api_attr: Static Graph
Q
qiaolongfei 已提交
1833

1834 1835 1836 1837 1838
    This class is used to implement Switch branch control function.
    Switch branch contains several case branches and one default branch.
    Switch control flow checks whether the case branch conditions are satisfied in turn,
    and only executes the statement after the first case branch that satisfies the conditions.
    If there is no case branch that satisfies the condition,
1839 1840
    only the statement following the default branch is executed.

1841 1842 1843 1844
    Note:
        A new OP :ref:`api_fluid_layers_case` is highly recommended instead of ``Switch`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_fluid_layers_case` is easier to use and is called with less code but does the same thing as ``Switch`` .

1845
    Member Functions:
1846
        case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.
1847

1848 1849 1850 1851 1852
        default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed.

    Case and default functions can only be used inside the scope of Switch, as shown below:

    .. code-block:: python
1853

1854 1855 1856 1857 1858 1859 1860 1861 1862
        '''
        with fluid.layers.Switch() as switch:
            with switch.case(cond1):
                i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=1)
            with switch.case(cond2):
                i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=2)
            with switch.default():
                i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
        '''
Q
qiaolongfei 已提交
1863

1864 1865
    Args:
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
Q
qiaolongfei 已提交
1866 1867 1868

    Examples:
        .. code-block:: python
1869

1870
            import paddle
1871
            import paddle.fluid as fluid
Q
qiaolongfei 已提交
1872

1873
            lr = paddle.static.create_global_var(
Q
qiaolongfei 已提交
1874 1875 1876 1877 1878
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate")
1879
            zero_var = fluid.layers.fill_constant(
1880
                shape=[1], dtype='float32', value=0.0)
1881
            one_var = fluid.layers.fill_constant(
Q
qiaolongfei 已提交
1882
                shape=[1], dtype='float32', value=1.0)
1883
            two_var = fluid.layers.fill_constant(
1884
                shape=[1], dtype='float32', value=2.0)
1885

1886
            global_step = fluid.layers.autoincreased_step_counter(counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)
Q
qiaolongfei 已提交
1887 1888

            with fluid.layers.control_flow.Switch() as switch:
Q
qiaolongfei 已提交
1889
                with switch.case(global_step == zero_var):
1890
                    fluid.layers.assign(input=one_var, output=lr)
Q
qiaolongfei 已提交
1891
                with switch.default():
1892
                    fluid.layers.assign(input=two_var, output=lr)
Q
qiaolongfei 已提交
1893

1894 1895 1896 1897 1898
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[lr])
            print(res) # [array([1.], dtype=float32)]
Q
qiaolongfei 已提交
1899 1900
    """

1901 1902 1903 1904 1905 1906 1907 1908 1909
    def __init__(self, name=None):
        self.helper = LayerHelper('switch', name=name)
        self.inside_scope = False
        self.pre_not_conditions = []

    def case(self, condition):
        if not self.inside_scope:
            raise ValueError("case should be called inside with")

1910
        check_variable_and_dtype(
1911 1912 1913 1914 1915
            condition,
            'condition',
            ['bool'],
            'the member function case of fluid.layers.Switch',
        )
1916

1917 1918
        if len(self.pre_not_conditions) == 0:
            cond_block = ConditionalBlock([condition], is_scalar_condition=True)
2
201716010711 已提交
1919
            not_cond = paddle.logical_not(x=condition)
1920 1921 1922 1923
            self.pre_not_conditions.append(not_cond)
        else:
            pre_cond_num = len(self.pre_not_conditions)
            pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
1924
            new_not_cond = paddle.logical_and(
2
201716010711 已提交
1925
                x=pre_not_cond, y=paddle.logical_not(x=condition)
1926
            )
1927 1928
            self.pre_not_conditions.append(new_not_cond)
            cond_block = ConditionalBlock(
1929
                [paddle.logical_and(x=pre_not_cond, y=condition)],
1930 1931
                is_scalar_condition=True,
            )
1932 1933 1934 1935 1936 1937 1938 1939 1940

        return ConditionalBlockGuard(cond_block)

    def default(self):
        pre_cond_num = len(self.pre_not_conditions)
        if pre_cond_num == 0:
            raise ValueError("there should be at least one condition")
        cond_block = ConditionalBlock(
            [self.pre_not_conditions[pre_cond_num - 1]],
1941 1942
            is_scalar_condition=True,
        )
1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958
        return ConditionalBlockGuard(cond_block)

    def __enter__(self):
        """
        set flag that now is inside switch.block {}
        :return:
        """
        self.inside_scope = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.inside_scope = False
        if exc_type is not None:
            return False  # re-raise exception

        return True