control_flow.py 64.8 KB
Newer Older
1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

S
rename  
sneaxiy 已提交
15
from ..wrapped_decorator import signature_safe_contextmanager
D
dzhwinter 已提交
16

17
from .layer_function_generator import templatedoc
18
from .tensor import assign, cast, fill_constant
19
from .. import core
20 21 22 23 24 25 26 27 28
from ..framework import (
    Program,
    Variable,
    Operator,
    _non_static_mode,
    static_only,
    _in_legacy_dygraph,
    in_dygraph_mode,
)
29
from ..layer_helper import LayerHelper, unique_name
30 31 32 33 34 35 36 37 38 39 40
from .utils import (
    assert_same_structure,
    map_structure,
    hold_mutable_vars,
    copy_mutable_vars,
    padding_to_same_structure,
    is_sequence,
    pack_sequence_as,
    flatten,
    to_sequence,
)
Y
yuyang18 已提交
41
import numpy
42
import warnings
L
liym27 已提交
43
from functools import reduce, partial
44 45 46 47 48 49
from ..data_feeder import (
    convert_dtype,
    check_variable_and_dtype,
    check_type,
    check_dtype,
)
50
from ..backward import _infer_var_data_type_shape_
2
201716010711 已提交
51
import paddle
52
from paddle import _C_ops, _legacy_C_ops
D
dzhwinter 已提交
53

Q
QI JUN 已提交
54
__all__ = [
55 56 57 58
    'Switch',
    'StaticRNN',
    'Print',
    'while_loop',
D
dzhwinter 已提交
59 60
]

Y
Yu Yang 已提交
61

62 63
def select_output(input, outputs, mask):
    """
64
    **select_output**
65 66 67 68 69 70 71 72 73 74 75 76 77 78
    This API takes in one input and multiple outputs and an integer mask. It
    selects the output specified by the mask and copy the input to selected
    output. It is useful in control flow.

    Args:
        input(Variable): The input variable
        outputs(tuple|list): The output variables
        mask(Variable): A tensor containing 1 integer number selecting which
            output to be copied with input

    Returns:
        Variable: The outputs variables
    """
    helper = LayerHelper('select_output', **locals())
79 80 81 82
    check_type(input, 'input', (Variable), 'select_output')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_output')
    check_type(outputs, 'outputs', (list, tuple), 'select_output')

83 84 85 86 87
    helper.append_op(
        type='select_output',
        inputs={'X': input, 'Mask': mask},
        outputs={'Out': outputs},
    )
88 89 90
    return outputs


91 92 93 94 95 96 97
def _select_input_infer_shape(first_shape, second_shape):
    """
    This function infer the output shape by following algorithm:
    1. if the dims is different, raise a error.
    2. compare axis one by one:
        if a == b: we set axis to a
        if a != b: we set axis to -1
98
    for compatibility, non declarative mode, we just return second_shape.
99 100 101 102 103 104 105
    """
    if len(first_shape) != len(second_shape):
        warnings.warn(
            f"the input shapes of select_input should have the same rank, but get {first_shape}, {second_shape}"
        )
        return second_shape
    out_shape = list(
106 107
        map(lambda a, b: a if a == b else -1, first_shape, second_shape)
    )
108 109 110
    return out_shape


111 112 113
def select_input(inputs, mask):
    """
    **select_input**
114

115 116 117 118 119 120 121 122 123 124 125 126
    This API takes in multiple inputs and uses an integer mask to select one
    input to output. It is useful in control flow.

    Args:
        inputs(tuple|list): The input variables
        mask(Variable): A tensor containing 1 integer number selecting which
            input to output

    Returns:
        Variable: The selected input variable
    """
    helper = LayerHelper('select_input', **locals())
127 128 129
    check_type(inputs, 'inputs', (list, tuple), 'select_input')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input')

130
    # Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly
131
    # assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"
132

133 134 135
    output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape)
    output_dtype = inputs[1].dtype
    output_type = inputs[1].type
136

137 138 139 140 141 142 143 144
    out = helper.create_variable(
        dtype=output_dtype, shape=output_shape, type=output_type
    )
    helper.append_op(
        type='select_input',
        inputs={'X': inputs, 'Mask': mask},
        outputs={'Out': out},
    )
145 146 147
    return out


148
def split_lod_tensor(input, mask, level=0):
149 150 151 152
    """
    This function takes in an input that contains the complete lod information,
    and takes in a mask which is used to mask certain parts of the input.
    The output is the true branch and the false branch with the mask applied to
Q
qiaolongfei 已提交
153 154
    the input at a certain level in the tensor. Mainly used in IfElse to split
    data into two parts.
155 156

    Args:
157
        input(Variable|tuple|list|None): The input tensor that contains complete
158
                                lod information needed to construct the output.
159
        mask(Variable|list): A bool column vector which masks the input.
Q
qiaolongfei 已提交
160
        level(int): The specific lod level to split.
161 162

    Returns:
Q
qiaolongfei 已提交
163 164 165 166
        tuple(Variable, Variable):
        The true branch of tensor as per the mask applied to input.

        The false branch of tensor as per the mask applied to input.
167 168 169 170

    Examples:
        .. code-block:: python

171
          import paddle.fluid as fluid
Q
qiaolongfei 已提交
172
          x = fluid.layers.data(name='x', shape=[1])
173 174
          x.persistable = True

Q
qiaolongfei 已提交
175
          y = fluid.layers.data(name='y', shape=[1])
176 177
          y.persistable = True

Q
qiaolongfei 已提交
178
          out_true, out_false = fluid.layers.split_lod_tensor(
179
                input=x, mask=y, level=level)
180

181
    """
182 183 184 185 186 187
    check_type(
        input,
        'input',
        (Variable, list, tuple, type(None)),
        'fluid.layers.split_lod_tensor',
    )
188 189
    check_type(mask, 'mask', (Variable, list), 'fluid.layers.split_lod_tensor')
    check_type(level, 'level', int, 'fluid.layers.split_lod_tensor')
190
    helper = LayerHelper('split_lod_tensor', **locals())
X
Xin Pan 已提交
191 192
    out_true = helper.create_variable_for_type_inference(dtype=input.dtype)
    out_false = helper.create_variable_for_type_inference(dtype=input.dtype)
193 194 195 196 197 198 199 200 201
    helper.append_op(
        type='split_lod_tensor',
        inputs={
            'X': input,
            'Mask': mask,
        },
        outputs={'OutTrue': out_true, 'OutFalse': out_false},
        attrs={'level': level},
    )
202 203 204
    return out_true, out_false


205
def merge_lod_tensor(in_true, in_false, x, mask, level=0):
206 207 208 209 210
    """
    **merge_lod_tensor**

    This function takes in an input :math:`x`, the True branch, the False
    branch and a binary :math:`mask`. Using this information, this function
Q
qiaolongfei 已提交
211 212 213
    merges the True and False branches of the tensor into a single tensor as
    output at a certain lod level indicated by :math:`level`. Used in IfElse
    to merge the output if True block and False Block.
214 215

    Args:
216 217 218
        in_true(Variable|tuple|list|None): The True branch to be merged.
        in_false(Variable|tuple|list|None): The False branch to be merged.
        x(Variable|tuple|list|None): The input tensor that contains complete
219
                            lod information needed to construct the output.
220
        mask(Variable|list): A bool column vector which masks the input.
Q
qiaolongfei 已提交
221
        level(int): The specific lod level to merge.
222 223 224 225 226 227 228

    Returns:
        Variable: The merged output tensor.

    Examples:
        .. code-block:: python

229
          import paddle.fluid as fluid
230 231 232 233 234 235 236 237 238 239 240 241
          x = layers.data(
                      name='x', shape=[1], dtype='float32', stop_gradient=False)
          y = layers.data(
                name='y', shape=[1], dtype='bool', stop_gradient=False)

          level = 0

          out_true, out_false = layers.split_lod_tensor(
                input=x, mask=y, level=level)
          out = layers.merge_lod_tensor(
                in_true=out_true, in_false=out_false, mask=y, x=x, level=level)
    """
242
    helper = LayerHelper('merge_lod_tensor', **locals())
243 244 245 246 247 248
    check_type(
        x,
        'x',
        (Variable, list, tuple, type(None)),
        'fluid.layers.merge_lod_tensor',
    )
249
    check_type(mask, 'mask', (Variable, list), 'fluid.layers.merge_lod_tensor')
250 251 252 253 254 255 256 257 258 259 260 261
    check_type(
        in_true,
        'in_true',
        (Variable, list, tuple, type(None)),
        'fluid.layers.merge_lod_tensor',
    )
    check_type(
        in_false,
        'in_false',
        (Variable, list, tuple, type(None)),
        'fluid.layers.merge_lod_tensor',
    )
X
Xin Pan 已提交
262
    out = helper.create_variable_for_type_inference(dtype=in_true.dtype)
263 264 265 266 267 268
    helper.append_op(
        type='merge_lod_tensor',
        inputs={'X': x, 'Mask': mask, 'InTrue': in_true, 'InFalse': in_false},
        outputs={'Out': out},
        attrs={'level': level},
    )
269 270 271
    return out


272
@static_only
273 274 275 276 277 278 279 280 281 282 283 284
def Print(
    input,
    first_n=-1,
    message=None,
    summarize=20,
    print_tensor_name=True,
    print_tensor_type=True,
    print_tensor_shape=True,
    print_tensor_layout=True,
    print_tensor_lod=True,
    print_phase='both',
):
Y
Yan Chunwei 已提交
285
    '''
286 287
    :api_attr: Static Graph

Y
Yan Chunwei 已提交
288 289 290 291 292 293 294 295 296
    **Print operator**

    This creates a print op that will print when a tensor is accessed.

    Wraps the tensor passed in so that whenever that a tensor is accessed,
    the message `message` is printed, along with the current value of the
    tensor `t`.

    Args:
Y
yangyaming 已提交
297
        input (Variable): A Tensor to print.
298
        summarize (int): Number of elements in the tensor to be print. If it's
T
tianshuo78520a 已提交
299
                value is -1, then all elements in the tensor will be print.
Y
yangyaming 已提交
300 301
        message (str): A string message to print as a prefix.
        first_n (int): Only log `first_n` number of times.
302 303 304
        print_tensor_name (bool, optional): Print the tensor name. Default: True.
        print_tensor_type (bool, optional): Print the tensor type. Defaultt: True.
        print_tensor_shape (bool, optional): Print the tensor shape. Default: True.
305
        print_tensor_layout (bool, optional): Print the tensor layout. Default: True.
306
        print_tensor_lod (bool, optional): Print the tensor lod. Default: True.
307
        print_phase (str): Which phase to displace, including 'forward',
308
                'backward' and 'both'. Default: 'both'. If set to 'backward', will
309 310
                only print the gradients of input tensor; If set to 'both', will
                both print the input tensor itself and the gradients of input tensor.
Y
Yan Chunwei 已提交
311 312

    Returns:
313
        Variable: Output tensor.
Y
Yan Chunwei 已提交
314

315 316 317 318
    NOTES:
        The input and output are two different variables, and in the
        following process, you should use the output variable but not the input,
        otherwise, the print layer doesn't have backward.
Y
Yan Chunwei 已提交
319

Y
Yan Chunwei 已提交
320 321
    Examples:
        .. code-block:: python
322

323 324 325
           import paddle

           paddle.enable_static()
326

327 328 329 330 331 332 333 334 335 336 337 338 339 340
           x = paddle.full(shape=[2, 3], fill_value=3, dtype='int64')
           out = paddle.static.Print(x, message="The content of input layer:")

           main_program = paddle.static.default_main_program()
           exe = paddle.static.Executor(place=paddle.CPUPlace())
           res = exe.run(main_program, fetch_list=[out])
           # Variable: fill_constant_1.tmp_0
           #   - message: The content of input layer:
           #   - lod: {}
           #   - place: CPUPlace
           #   - shape: [2, 3]
           #   - layout: NCHW
           #   - dtype: long
           #   - data: [3 3 3 3 3 3]
Y
Yan Chunwei 已提交
341
    '''
342 343 344 345 346 347
    check_variable_and_dtype(
        input,
        'input',
        ['float32', 'float64', 'int32', 'int64', 'bool'],
        'fluid.layers.Print',
    )
348

349 350
    helper = LayerHelper('print' + "_" + input.name, **locals())
    output = helper.create_variable_for_type_inference(input.dtype)
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
    helper.append_op(
        type='print',
        inputs={'In': input},
        outputs={'Out': output},
        attrs={
            'first_n': first_n,
            'summarize': summarize,
            'message': message or "",
            'print_tensor_name': print_tensor_name,
            'print_tensor_type': print_tensor_type,
            'print_tensor_shape': print_tensor_shape,
            'print_tensor_layout': print_tensor_layout,
            'print_tensor_lod': print_tensor_lod,
            'print_phase': print_phase.upper(),
        },
    )
367
    return output
Y
Yan Chunwei 已提交
368 369


370
# (TODO: Mine) There exists dependency. It will be removed later.
371
class BlockGuard:
Y
Yu Yang 已提交
372
    """
373 374 375 376
    BlockGuard class.

    BlockGuard class is used to create a sub-block in a program by
    using the Python `with` keyword.
Y
Yu Yang 已提交
377 378
    """

379 380
    def __init__(self, main_program):
        if not isinstance(main_program, Program):
Y
Yu Yang 已提交
381
            raise TypeError("BlockGuard takes a program")
382
        self.main_program = main_program
Y
Yu Yang 已提交
383 384

    def __enter__(self):
W
Wu Yi 已提交
385
        self.main_program._create_block()
Y
Yu Yang 已提交
386 387

    def __exit__(self, exc_type, exc_val, exc_tb):
W
Wu Yi 已提交
388
        self.main_program._rollback()
Y
Yu Yang 已提交
389 390 391 392 393
        if exc_type is not None:
            return False  # re-raise exception
        return True


394
# (TODO: Mine) There exists dependency. It will be removed later.
Y
Yang Yang 已提交
395 396 397 398 399
class BlockGuardWithCompletion(BlockGuard):
    """
    BlockGuardWithCompletion class.

    BlockGuardWithCompletion class is used to create an op with a block in a program.
400 401
    """

Y
Yu Yang 已提交
402
    def __init__(self, rnn):
X
Xin Pan 已提交
403
        if not isinstance(rnn, StaticRNN):
X
Xin Pan 已提交
404
            raise TypeError("BlockGuardWithCompletion takes a StaticRNN")
405
        super().__init__(rnn.helper.main_program)
Y
Yu Yang 已提交
406 407 408 409
        self.rnn = rnn

    def __enter__(self):
        self.rnn.status = StaticRNN.IN_RNN_BLOCK
410
        return super().__enter__()
Y
Yu Yang 已提交
411 412

    def __exit__(self, exc_type, exc_val, exc_tb):
Y
Yu Yang 已提交
413 414
        if exc_type is not None:
            return False
Y
Yu Yang 已提交
415
        self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
416
        self.rnn._complete_op()
417
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yu Yang 已提交
418 419


420
class StaticRNNMemoryLink:
Y
Yu Yang 已提交
421
    """
422 423 424 425
    StaticRNNMemoryLink class.

    StaticRNNMemoryLink class is used to create a link between two
    memory cells of a StaticRNN.
Y
yuyang18 已提交
426 427 428 429 430 431 432 433 434


    NOTE: This is a internal data structure of a very low-level API.
    Please use StaticRNN instead.

    Args:
        init(Variable): the initial variable for Memory.
        pre_mem(Variable): the memory variable in previous time step.
        mem(Variable): the memory variable in current time step.
Y
Yu Yang 已提交
435 436 437 438 439 440 441 442
    """

    def __init__(self, init, pre_mem, mem=None):
        self.init = init
        self.pre_mem = pre_mem
        self.mem = mem


443
class StaticRNN:
444
    """
445 446
    :api_attr: Static Graph

447 448
    StaticRNN class.

449 450 451 452 453 454 455
    The StaticRNN can process a batch of sequence data. The first dimension of inputs
    represents sequence length, the length of each input sequence must be equal.
    StaticRNN will unfold sequence into time steps, user needs to define how to process
    each time step during the :code:`with` step.

    Args:
        name (str, optional): Please refer to :ref:`api_guide_Name`, Default None.
C
chengduo 已提交
456 457

    Examples:
458 459
        .. code-block:: python

460
            import paddle
461 462 463 464
            import paddle.fluid as fluid
            import paddle.fluid.layers as layers

            vocab_size, hidden_size=10000, 200
465
            paddle.enable_static()
466 467
            x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
            # create word sequence
468 469 470 471 472
            x_emb = layers.embedding(
                input=x,
                size=[vocab_size, hidden_size],
                dtype='float32',
                is_sparse=False)
473
            # transform batch size to dim 1
474
            x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
475 476 477

            rnn = fluid.layers.StaticRNN()
            with rnn.step():
478
                # mark created x_emb as input, each step process a word
479
                word = rnn.step_input(x_emb)
480
                # create prev memory parameter, batch size comes from word
481 482
                prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
483 484
                # use hidden to update prev
                rnn.update_memory(prev, hidden)
485
                # mark hidden as output
486
                rnn.step_output(hidden)
487
            # get StaticrNN final output
488
            result = rnn()
C
chengduo 已提交
489

490
    """
491

Y
Yu Yang 已提交
492 493 494 495
    BEFORE_RNN_BLOCK = 0
    IN_RNN_BLOCK = 1
    AFTER_RNN_BLOCK = 2

496
    def __init__(self, name=None):
497
        check_type(name, "name", (str, type(None)), "fluid.layers.StaticRNN")
498
        self.helper = LayerHelper("static_rnn", name=name)
Y
Yu Yang 已提交
499 500 501 502 503 504 505 506
        self.memories = {}  # memory map, from pre_mem.name --> MemoryLink
        self.inputs = []  # input variable list in current block
        self.outputs = []  # output variable list in parent block
        self.status = StaticRNN.BEFORE_RNN_BLOCK  # status flag.
        # sequence length, since it is a static RNN, sequence length are fixed.
        self.seq_len = None

    def step(self):
C
chengduo 已提交
507
        """
508 509
        Define operators in each step. step is used in :code:`with` block, OP in :code:`with` block
        will be executed sequence_len times (sequence_len is the length of input)
C
chengduo 已提交
510
        """
Y
Yang Yang 已提交
511
        return BlockGuardWithCompletion(self)
Y
Yu Yang 已提交
512 513 514 515 516

    def _assert_in_rnn_block_(self, method):
        if self.status != StaticRNN.IN_RNN_BLOCK:
            raise ValueError("You must invoke {0} in rnn block".format(method))

517 518 519 520 521 522 523 524 525
    def memory(
        self,
        init=None,
        shape=None,
        batch_ref=None,
        init_value=0.0,
        init_batch_dim_idx=0,
        ref_batch_dim_idx=1,
    ):
526
        """
C
chengduo 已提交
527 528 529
        Create a memory variable for static rnn.
        If the :code:`init` is not None, :code:`memory` will be initialized by
        this Variable. If the :code:`init` is None, :code:`shape` and :code:`batch_ref`
530 531
        must be set, and this function will create a new variable with shape and batch_ref
        to initialize :code:`init` Variable.
C
chengduo 已提交
532

533
        Args:
534
            init(Variable, optional): Tensor used to init memory. If it is not set,
C
chengduo 已提交
535 536
                :code:`shape` and :code:`batch_ref` must be provided.
                Default: None.
537 538 539 540 541 542 543
            shape(list|tuple): When :code:`init` is None use this arg to initialize memory shape.
            NOTE the shape does not contain batch_size. Default: None.
            batch_ref(Variable, optional): When :code:`init` is None, memory's batch size will
            be set as batch_ref's ref_batch_dim_idx value. Default: None.
            init_value(float, optional): When :code:`init` is None, used to init memory's value. Default: 0.0.
            init_batch_dim_idx(int, optional): the batch_size axis of the :code:`init` Variable. Default: 0.
            ref_batch_dim_idx(int, optional): the batch_size axis of the :code:`batch_ref` Variable. Default: 1.
C
chengduo 已提交
544 545

        Returns:
546 547 548 549 550
            Variable: The memory variable.

        Examples 1:
            .. code-block:: python

551
                import paddle
552 553 554 555
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
556
                paddle.enable_static()
557 558 559 560 561 562 563 564
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
565
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
566 567 568 569 570 571 572 573 574 575

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
576 577 578


        Examples 2:
579 580
            .. code-block:: python

581
                import paddle
582 583 584
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers
                vocab_size, hidden_size=10000, 200
585
                paddle.enable_static()
586 587 588 589 590 591 592 593
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
594
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
595 596 597 598 599 600 601 602 603 604
                boot_memory = fluid.layers.data(name='boot', shape=[hidden_size], dtype='float32', lod_level=1)
                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # init memory
                        prev = rnn.memory(init=boot_memory)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # update hidden with prev
                        rnn.update_memory(prev, hidden)
605

606
        """
Y
Yu Yang 已提交
607
        self._assert_in_rnn_block_('memory')
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
        check_type(
            init,
            "init",
            (Variable, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
        check_type(
            shape,
            "shape",
            (list, tuple, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
        check_type(
            batch_ref,
            "batch_ref",
            (Variable, type(None)),
            "fluid.layers.StaticRNN.memory",
        )
Y
Yu Yang 已提交
626
        if init is None:
627
            if shape is None or batch_ref is None:
Y
Yu Yang 已提交
628
                raise ValueError(
629 630
                    "if init is None, memory at least need shape and batch_ref"
                )
631
            parent_block = self._parent_block()
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653
            var_name = unique_name.generate_with_ignorable_key(
                "@".join([self.helper.name, "memory_boot"])
            )
            boot_var = parent_block.create_var(
                name=var_name,
                shape=shape,
                dtype=batch_ref.dtype,
                persistable=False,
            )

            parent_block.append_op(
                type="fill_constant_batch_size_like",
                inputs={'Input': [batch_ref]},
                outputs={'Out': [boot_var]},
                attrs={
                    'value': init_value,
                    'shape': boot_var.shape,
                    'dtype': boot_var.dtype,
                    'input_dim_idx': ref_batch_dim_idx,
                    'output_dim_idx': init_batch_dim_idx,
                },
            )
Y
Yu Yang 已提交
654 655 656 657

            return self.memory(init=boot_var)
        else:
            pre_mem = self.helper.create_variable(
658 659 660
                name=unique_name.generate_with_ignorable_key(
                    "@".join([self.helper.name, "mem"])
                ),
F
fengjiayi 已提交
661
                dtype=init.dtype,
662 663 664 665 666
                shape=init.shape,
            )
            self.memories[pre_mem.name] = StaticRNNMemoryLink(
                init=init, pre_mem=pre_mem
            )
Y
Yu Yang 已提交
667 668 669
            return pre_mem

    def step_input(self, x):
C
chengduo 已提交
670 671 672 673 674 675 676 677
        """
        Mark a sequence as a StaticRNN input.

        Args:
            x(Variable): The input sequence, the shape of x
                should be [seq_len, ...].

        Returns:
678 679 680 681 682
            Variable: The current time step data in the input sequence.

        Examples:
            .. code-block:: python

683
                import paddle
684 685 686 687
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
688
                paddle.enable_static()
689 690 691 692 693 694 695 696
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
697
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
698 699 700 701 702 703 704 705 706 707

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
708

C
chengduo 已提交
709
        """
Y
Yu Yang 已提交
710
        self._assert_in_rnn_block_('step_input')
711
        check_type(x, "x", Variable, "fluid.layers.StaticRNN.step_input")
Y
Yu Yang 已提交
712
        if self.seq_len is None:
Y
Yu Yang 已提交
713
            self.seq_len = x.shape[0]
714
        elif x.shape[0] != -1 and self.seq_len != x.shape[0]:
Y
Yu Yang 已提交
715 716
            raise ValueError("Static RNN only take fix seq_len input")

717 718 719
        ipt = self.helper.create_variable(
            name=x.name, dtype=x.dtype, shape=list(x.shape[1:]), type=x.type
        )
Y
Yu Yang 已提交
720 721 722 723
        self.inputs.append(ipt)
        return ipt

    def step_output(self, o):
C
chengduo 已提交
724 725 726 727 728 729 730 731
        """
        Mark a sequence as a StaticRNN output.

        Args:
            o(Variable): The output sequence.

        Returns:
            None.
732 733 734 735

        Examples:
            .. code-block:: python

736
                import paddle
737 738 739 740
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
741
                paddle.enable_static()
742 743 744 745 746 747 748 749
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
750
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
751 752 753 754 755 756 757 758 759 760 761 762 763

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
                        rnn.step_output(hidden)

                result = rnn()
764

C
chengduo 已提交
765
        """
Y
Yu Yang 已提交
766
        self._assert_in_rnn_block_('step_output')
767
        check_type(o, "o", Variable, "fluid.layers.StaticRNN.step_output")
Y
Yu Yang 已提交
768

X
Xin Pan 已提交
769
        tmp_o = self.helper.create_variable_for_type_inference(dtype=o.dtype)
770 771 772 773 774 775
        self.helper.append_op(
            type='rnn_memory_helper',
            inputs={'X': [o]},
            outputs={'Out': tmp_o},
            attrs={'dtype': o.dtype},
        )
Y
Yu Yang 已提交
776

777 778 779 780 781
        out_var = self._parent_block().create_var(
            name=tmp_o.name,
            shape=[self.seq_len] + list(tmp_o.shape),
            dtype=tmp_o.dtype,
        )
Y
Yu Yang 已提交
782 783 784 785

        self.outputs.append(out_var)

    def output(self, *outputs):
C
chengduo 已提交
786 787 788 789
        """
        Mark the StaticRNN output variables.

        Args:
790
            outputs: The output Tensor, can mark multiple variables as output
C
chengduo 已提交
791 792 793

        Returns:
            None
794 795 796 797

        Examples:
            .. code-block:: python

798
                import paddle
799 800 801 802
                import paddle.fluid as fluid
                import paddle.fluid.layers as layers

                vocab_size, hidden_size=10000, 200
803
                paddle.enable_static()
804 805 806 807 808 809 810 811
                x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
                # create word sequence
                x_emb = layers.embedding(
                        input=x,
                        size=[vocab_size, hidden_size],
                        dtype='float32',
                        is_sparse=False)
                # transform batch size to dim 1
812
                x_emb = paddle.transpose(x_emb, perm=[1, 0, 2])
813 814 815 816 817 818 819 820 821 822 823 824 825 826

                rnn = fluid.layers.StaticRNN()
                with rnn.step():
                        # mark created x_emb as input, each step process a word
                        word = rnn.step_input(x_emb)
                        # create prev memory parameter, batch size comes from word
                        prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
                        hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
                        # use hidden to update prev
                        rnn.update_memory(prev, hidden)
                        # mark each step's hidden and word as output
                        rnn.output(hidden, word)

                result = rnn()
C
chengduo 已提交
827
        """
Y
Yu Yang 已提交
828 829 830 831
        for each in outputs:
            self.step_output(each)

    def update_memory(self, mem, var):
C
chengduo 已提交
832
        """
833
        Update the memory from :code:`mem` to :code:`var`.
C
chengduo 已提交
834 835 836

        Args:
            mem(Variable): the memory variable.
837
            var(Variable): the plain variable generated in RNN block, used to update memory.
T
tianshuo78520a 已提交
838
                           var and mem should have same dims and data type.
C
chengduo 已提交
839 840 841

        Returns:
            None
842

C
chengduo 已提交
843
        """
844 845
        check_type(mem, "mem", Variable, "fluid.layers.StaticRNN.update_memory")
        check_type(var, "var", Variable, "fluid.layers.StaticRNN.update_memory")
Y
Yu Yang 已提交
846 847
        self.memories[mem.name].mem = var

848
    def _parent_block(self):
849
        prog = self.helper.main_program
Y
Yu Yang 已提交
850 851 852 853 854 855 856 857 858 859 860 861 862 863 864
        parent_idx = prog.current_block().parent_idx
        assert parent_idx >= 0
        parent_block = prog.block(parent_idx)
        return parent_block

    def __call__(self, *args, **kwargs):
        if self.status != StaticRNN.AFTER_RNN_BLOCK:
            raise ValueError("RNN output can only be retrieved after rnn block")
        if len(self.outputs) == 0:
            raise ValueError("RNN has no output")
        elif len(self.outputs) == 1:
            return self.outputs[0]
        else:
            return self.outputs

865
    def _complete_op(self):
866 867
        main_program = self.helper.main_program
        rnn_block = main_program.current_block()
868
        parent_block = self._parent_block()
Y
Yu Yang 已提交
869 870 871 872 873 874 875 876 877 878 879 880 881 882

        local_inputs = set()

        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for oname in op.output_names:
                for out_var_name in op.output(oname):
                    local_inputs.add(out_var_name)

        for var in self.inputs:
            local_inputs.add(var.name)
        for m in self.memories:
            local_inputs.add(m)

C
chengduo 已提交
883 884 885
        # NOTE(zcd): the params have two categories of variables.
        #   - the variables that are the out of StaticRnn.
        #   - the variables that are the parameters of some layers, for example, conv2d.
Y
Yu Yang 已提交
886 887 888 889 890 891 892 893
        params = list()
        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for iname in op.input_names:
                for in_var_name in op.input(iname):
                    if in_var_name not in local_inputs:
                        params.append(in_var_name)

894 895 896
        parameters = [
            parent_block._find_var_recursive(name) for name in set(params)
        ]
Y
Yu Yang 已提交
897 898

        step_scope = parent_block.create_var(
899 900
            type=core.VarDesc.VarType.STEP_SCOPES
        )
Y
Yu Yang 已提交
901 902 903 904

        inlinks = [parent_block.var(i.name) for i in self.inputs]
        outlinks = self.outputs

C
chengduo 已提交
905
        # NOTE(zcd): the states maybe empty in some case.
Y
Yu Yang 已提交
906 907 908
        boot_memories = []
        pre_memories = []
        memories = []
909
        for _, mem in self.memories.items():
Y
Yu Yang 已提交
910 911
            boot_memories.append(mem.init)
            pre_memories.append(mem.pre_mem.name)
912 913 914
            assert (
                mem.mem is not None
            ), "%s should be updated in every step." % (mem.init.name)
Y
Yu Yang 已提交
915 916
            mem_var = rnn_block.var(mem.mem.name)
            assert isinstance(mem_var, Variable)
X
Xin Pan 已提交
917
            new_mem = self.helper.create_variable_for_type_inference(
918 919 920 921 922 923 924 925
                dtype=mem_var.dtype
            )
            rnn_block.append_op(
                type='rnn_memory_helper',
                inputs={'X': [mem_var]},
                outputs={'Out': [new_mem]},
                attrs={'dtype': mem_var.dtype},
            )
Y
Yu Yang 已提交
926 927 928

            memories.append(new_mem.name)

929 930 931 932 933 934 935 936 937 938 939 940 941 942 943
        parent_block.append_op(
            type='recurrent',
            inputs={
                'inputs': inlinks,
                'initial_states': boot_memories,
                'parameters': parameters,
            },
            outputs={'outputs': outlinks, 'step_scopes': [step_scope]},
            attrs={
                'has_states': len(pre_memories) > 0,
                'ex_states': pre_memories,
                'states': memories,
                'sub_block': rnn_block,
            },
        )
Y
Yu Yang 已提交
944 945


946
# (TODO: Mine) There exists dependency. It will be removed later.
Y
Yang Yang(Tony) 已提交
947 948 949 950
class WhileGuard(BlockGuard):
    def __init__(self, while_op):
        if not isinstance(while_op, While):
            raise TypeError("WhileGuard takes a while op")
951
        super().__init__(while_op.helper.main_program)
Y
Yang Yang(Tony) 已提交
952 953 954 955
        self.while_op = while_op

    def __enter__(self):
        self.while_op.status = While.IN_WHILE_BLOCK
956
        return super().__enter__()
Y
Yang Yang(Tony) 已提交
957 958 959 960 961

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None:
            return False
        self.while_op.status = While.AFTER_WHILE_BLOCK
962
        self.while_op._complete()
963
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yang Yang(Tony) 已提交
964 965


966
# (TODO: Mine) There exists dependency. It will be removed later.
967 968 969
def get_inputs_outputs_in_block(
    current_block, inner_inputs, inner_outputs, helper
):
970 971 972 973 974 975 976 977
    """
    Find inputs and outputs in current control flow block.
    :param current_block: Current control flow block.
    :param inner_inputs: Input var name of ops in current block.
    :param inner_outputs: Output var name of ops in current block.
    :return: inner_inputs, inner_outputs
    """

978 979 980 981 982 983 984 985 986 987 988 989 990
    def is_ignore_vars(op, var_name):
        # NOTE(dev): There are some persistable var created in some non-standard API
        # such as "contrib.layers.shuffle_batch". It create a "Seed" used both in
        # Input and Output. This var shall not be considered as a loop_var in
        # control_flow.
        IGNORE_VAR_NAMES = {"shuffle_batch": ["shuffle_batch_seed"]}
        if op.type in IGNORE_VAR_NAMES:
            var_names = IGNORE_VAR_NAMES[op.type]
            for name in var_names:
                if name in var_name:
                    return True
        return False

991 992 993 994 995 996 997 998
    # Step1: update inner_inputs and inner_outputs
    # NOTE: Here assumes that all variables are input or output of Ops,
    # but some variables are created without appendding a real op.
    # For example, in `arr = create_array(dtype)`, `arr` is not a output of a op.
    for op in current_block.ops:
        assert isinstance(op, Operator)
        for iname in op.input_names:
            for in_var_name in op.input(iname):
999
                if in_var_name not in inner_outputs and not is_ignore_vars(
1000 1001
                    op, in_var_name
                ):
1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
                    inner_inputs.add(in_var_name)

        for oname in op.output_names:
            for out_var_name in op.output(oname):
                inner_outputs.add(out_var_name)

    # Step2: Remove LOD_TENSOR_ARRAY created in current control flow block.
    remove_inner_inputs = set()
    parent_block = helper.main_program.block(current_block.parent_idx)

    for in_var_name in inner_inputs:
        parent_block_var = parent_block._find_var_recursive(in_var_name)
        current_block_var = None
        if current_block.has_var(in_var_name):
            current_block_var = current_block.var(in_var_name)
1017 1018 1019 1020 1021
        if (
            not parent_block_var
            and current_block_var
            and current_block_var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
        ):
1022 1023 1024 1025 1026 1027 1028
            remove_inner_inputs.add(in_var_name)

    inner_inputs = inner_inputs - remove_inner_inputs

    return inner_inputs, inner_outputs


1029
# (TODO: Mine) There exists dependency. It will be removed later.
1030
class While:
X
Xin Pan 已提交
1031
    """
1032
    :api_attr: Static Graph
1033

1034
    while loop control flow. Repeat while body until cond is False.
X
Xin Pan 已提交
1035

1036 1037 1038 1039
    Note:
        A new OP :ref:`api_fluid_layers_while_loop` is highly recommended instead of ``While`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_fluid_layers_while_loop` is easier to use and is called with less code but does the same thing as ``While`` .

1040 1041 1042 1043 1044 1045
    Notice:
        Local variables created in ``While`` are similar to that created in while of C++, and cannot be referenced externally.
        As a result, they cannot be obtained through ``fetch_list`` of ``Executor``. If you would like to access the variable
        out of ``while`` , PaddlePaddle provides ``assign`` API to assign local variables to external. Please refer to example
        code 2 or refer to `issue#22724 <https://github.com/PaddlePaddle/Paddle/issues/22724>`_.

X
Xin Pan 已提交
1046
    Args:
1047
        cond(Variable): A Tensor whose data type is bool controlling whether to continue looping.
G
guofei 已提交
1048
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
1049
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
X
Xin Pan 已提交
1050

1051
    Examples 1:
X
Xin Pan 已提交
1052
          .. code-block:: python
1053

1054
            import paddle.fluid as fluid
1055 1056 1057 1058 1059
            import numpy as np

            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)           # loop counter

            loop_len = fluid.layers.fill_constant(shape=[1],dtype='int64', value=10)    # loop length
1060

L
LiYuRio 已提交
1061
            cond = paddle.less_than(x=i, y=loop_len)
1062
            while_op = fluid.layers.While(cond=cond)
1063
            with while_op.block():
1064
                i = paddle.increment(x=i, value=1)
L
LiYuRio 已提交
1065
                paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
1066 1067 1068 1069 1070

            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[i])
1071 1072 1073 1074 1075 1076
            print(res) # [array([10])]


    Examples 2:
          .. code-block:: python

L
LiYuRio 已提交
1077
            import paddle
1078 1079 1080
            import paddle.fluid as fluid
            import numpy as np

1081
            paddle.enable_static()
1082 1083 1084 1085 1086 1087
            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
            loop_len = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
            one = fluid.layers.fill_constant(shape=[1], dtype='float32', value=1)
            data = fluid.data(name='data', shape=[1], dtype='float32')
            sums = fluid.layers.fill_constant(shape=[1], dtype='float32', value=0)  # Define the variable to be obtained ouside of While, which name should be different from the variable inside the While to be obtained

L
LiYuRio 已提交
1088
            cond = paddle.less_than(x=i, y=loop_len)
1089 1090 1091 1092
            while_op = fluid.layers.While(cond=cond)
            with while_op.block():
                sums_tensor = fluid.layers.elementwise_add(x=data, y=data)
                fluid.layers.assign(sums_tensor, sums)  # Update the value of sums_tensor defined in While to the sums which defined outside of While through layers.assign
1093
                i = paddle.increment(x=i, value=1)
1094
                data = fluid.layers.elementwise_add(x=data, y=one)
L
LiYuRio 已提交
1095
                paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
1096 1097 1098 1099 1100 1101

            feed_data = np.ones(1).astype('float32')
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())
            res = exe.run(fluid.default_main_program(), feed={'data': feed_data}, fetch_list=sums)
            print(res[0])  # [2.]    # Because the data in While does not update the value outside the While, the value of sums is [2.] after the loop
X
Xin Pan 已提交
1102 1103
    """

Y
Yang Yang(Tony) 已提交
1104 1105 1106 1107
    BEFORE_WHILE_BLOCK = 0
    IN_WHILE_BLOCK = 1
    AFTER_WHILE_BLOCK = 2

C
chengduo 已提交
1108
    def __init__(self, cond, is_test=False, name=None):
1109
        self.helper = LayerHelper("while", name=name)
Y
Yang Yang(Tony) 已提交
1110
        self.status = While.BEFORE_WHILE_BLOCK
1111
        check_variable_and_dtype(cond, 'cond', ['bool'], 'fluid.layers.While')
Y
Yang Yang(Tony) 已提交
1112
        if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
1113
            raise TypeError(
1114 1115 1116 1117
                "condition expected shape as [1], but given shape as {0}.".format(
                    list(cond.shape)
                )
            )
Y
Yang Yang(Tony) 已提交
1118
        self.cond_var = cond
C
chengduo 已提交
1119
        self.is_test = is_test
Y
Yang Yang(Tony) 已提交
1120 1121 1122 1123

    def block(self):
        return WhileGuard(self)

1124
    def _complete(self):
Y
Yang Yang(Tony) 已提交
1125 1126
        main_program = self.helper.main_program
        while_block = main_program.current_block()
1127
        parent_block = main_program.block(
1128 1129
            main_program.current_block().parent_idx
        )
Y
Yang Yang(Tony) 已提交
1130 1131 1132

        inner_outputs = {self.cond_var.name}
        x_name_list = set()
1133
        x_name_list, inner_outputs = get_inputs_outputs_in_block(
1134 1135
            while_block, x_name_list, inner_outputs, self.helper
        )
Y
Yang Yang(Tony) 已提交
1136 1137 1138

        out_vars = []
        for inner_out_name in inner_outputs:
X
Xin Pan 已提交
1139 1140 1141
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_vars.append(inner_var)
Y
Yang Yang(Tony) 已提交
1142

1143
        x_name_list |= set(map(lambda x: x.name, out_vars))
1144 1145 1146
        # NOTE(dev): cond_var has been contained in Input('Condition'), so
        # we remove it from Input('X')
        x_name_list -= {self.cond_var.name}
1147

Y
Yang Yang(Tony) 已提交
1148
        step_scope = parent_block.create_var(
1149 1150
            type=core.VarDesc.VarType.STEP_SCOPES
        )
Y
Yang Yang(Tony) 已提交
1151 1152 1153 1154

        parent_block.append_op(
            type='while',
            inputs={
1155 1156 1157 1158 1159
                'X': [
                    parent_block._var_recursive(x_name)
                    for x_name in x_name_list
                ],
                'Condition': [self.cond_var],
1160
            },
1161 1162 1163
            outputs={'Out': out_vars, 'StepScopes': [step_scope]},
            attrs={'sub_block': while_block, "is_test": self.is_test},
        )
Y
Yang Yang(Tony) 已提交
1164 1165


1166
support_ret_buildin_type = (bool, float, int)
1167 1168


1169
# (TODO: Mine) There exists dependency. It will be removed later.
1170
def assign_skip_lod_tensor_array(input, output):
1171
    """
1172
    Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
1173
    """
1174 1175

    def has_shape_diff(x_var, y_var):
1176 1177
        if len(x_var.shape) != len(y_var.shape):
            return True
1178
        for x_dim, y_dim in zip(x_var.shape, y_var.shape):
1179 1180
            if x_dim != y_dim and -1 not in [x_dim, y_dim]:
                return True
1181 1182
        return False

1183
    if not isinstance(input, (Variable, core.VarBase)):
1184
        if isinstance(output, Variable) and isinstance(
1185 1186
            input, support_ret_buildin_type
        ):
1187 1188 1189
            assign(input, output)
        else:
            output = input
1190 1191
        return

1192 1193
    if input.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        main_program = input.block.program
1194
        parent_block = main_program.block(
1195 1196
            main_program.current_block().parent_idx
        )
1197 1198 1199
        if parent_block and not parent_block._find_var_recursive(input.name):
            assign(input, output)
    else:
1200 1201 1202 1203 1204
        if (
            isinstance(output, Variable)
            and isinstance(input, Variable)
            and has_shape_diff(input, output)
        ):
1205
            warnings.warn(
1206 1207 1208 1209
                "In dy2static mode, we attemp to assign a variable with shape {} into a variable with shape{}, which is not always right.".format(
                    input.shape, output.shape
                )
            )
1210
        assign(input, output)
1211 1212


1213
# (TODO: Mine) There exists dependency (jit.dy2static.convert_operators). It will be removed later.
G
guofei 已提交
1214
def while_loop(cond, body, loop_vars, is_test=False, name=None):
G
guofei 已提交
1215
    """
1216 1217
    :api_attr: Static Graph

G
guofei 已提交
1218 1219
    while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.

1220 1221 1222 1223
    Notice:
        Local variables defined in ``body`` cannot be obtained through ``fetch_list`` of ``Executor`` , variables should
        be defined outside ``body`` and placed in ``loop_vars`` for looping, then these variables can be fetched by ``fetch_list`` .

G
guofei 已提交
1224
    Args:
1225
        cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
1226
            as many arguments as ``loop_vars`` .
1227 1228 1229
        body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
            (length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
        loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
G
guofei 已提交
1230
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
G
guofei 已提交
1231 1232
        name(str, optional): Normally there is no need for users to set this property. For more information, please
            refer to :ref:`api_guide_Name`. Default is None.
1233

G
guofei 已提交
1234
    Returns:
C
Chen Long 已提交
1235
        A list or tuple of Tensors or LoDTensorArrays which returned by ``body`` .
G
guofei 已提交
1236 1237 1238 1239

    Examples:
        .. code-block:: python

1240 1241 1242
            import paddle
            paddle.enable_static()

1243 1244
            def cond(i, ten):
                return i < ten
G
guofei 已提交
1245

1246 1247 1248
            def body(i, ten):
                i = i + 1
                return [i, ten]
G
guofei 已提交
1249

C
Chen Long 已提交
1250 1251 1252 1253 1254 1255
            main_program = paddle.static.default_main_program()
            startup_program = paddle.static.default_startup_program()
            with paddle.static.program_guard(main_program, startup_program):
                i = paddle.full(shape=[1], fill_value=0, dtype='int64')     # loop counter
                ten = paddle.full(shape=[1], fill_value=10, dtype='int64')  # loop length
                i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
1256

C
Chen Long 已提交
1257
                exe = paddle.static.Executor(paddle.CPUPlace())
1258
                res = exe.run(main_program, feed={}, fetch_list=[i])
G
guofei 已提交
1259 1260 1261 1262 1263 1264 1265 1266
                print(res) # [array([10])]
    """
    helper = LayerHelper('while_loop', **locals())

    if not callable(cond):
        raise TypeError("cond in while_loop should be callable")
    if not callable(body):
        raise TypeError("body in while_loop should be callable")
1267
    check_type(loop_vars, 'loop_vars', (list, tuple), 'fluid.layers.while_loop')
G
guofei 已提交
1268 1269 1270 1271
    if len(loop_vars) == 0:
        raise ValueError("loop_vars in while_loop should not be empty")

    pre_cond = cond(*loop_vars)
1272 1273 1274
    check_variable_and_dtype(
        pre_cond, 'var of cond returned', ['bool'], 'fluid.layers.while_loop'
    )
G
guofei 已提交
1275 1276
    if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
        raise TypeError(
1277
            "the shape of the variable returned by cond should be [1],"
1278 1279
            "but given shape as {0}.".format(list(pre_cond.shape))
        )
G
guofei 已提交
1280

J
Jiabin Yang 已提交
1281
    if _non_static_mode():
1282
        now_cond = pre_cond.numpy()[0]
1283
        while now_cond:
1284 1285 1286 1287 1288 1289
            output_vars = body(*loop_vars)
            if not isinstance(output_vars, (list, tuple)):
                output_vars = [output_vars]
            if len(output_vars) != len(loop_vars):
                raise ValueError(
                    "body in while_loop should return the same arity "
1290 1291
                    "(length and structure) and types as loop_vars"
                )
1292
            now_cond = cond(*output_vars).numpy()[0]
1293
            map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
1294 1295
        return loop_vars

G
guofei 已提交
1296
    while_loop_block = While(pre_cond, is_test, name)
1297
    has_mutable_vars_in_loop = hold_mutable_vars(loop_vars)
G
guofei 已提交
1298
    with while_loop_block.block():
1299 1300 1301 1302 1303 1304 1305 1306 1307
        # If a variable with mutable type is included in loop_vars, like `dict/list`,
        # modifying it in the body function will cause origin variable to be modified
        # synchronously. This will raise an assignment error out of while block.
        # Here we make a copy of the mutable vars to avoid this problem.
        if has_mutable_vars_in_loop:
            new_loop_vars = copy_mutable_vars(loop_vars)
            output_vars = body(*new_loop_vars)
        else:
            output_vars = body(*loop_vars)
1308 1309
        if not isinstance(output_vars, (list, tuple)):
            output_vars = [output_vars]
1310
        try:
1311
            loop_vars = _deal_with_undefined_var(output_vars, loop_vars)
1312 1313
            assert_same_structure(output_vars, loop_vars, check_types=False)
        except ValueError as e:
1314 1315
            raise ValueError(
                "body in while_loop should return the same arity "
1316 1317
                "(length and structure) as loop_vars: {0}".format(e)
            )
1318
        now_cond = cond(*output_vars)
1319
        map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
G
guofei 已提交
1320 1321 1322 1323
        assign(now_cond, pre_cond)
    return loop_vars


1324
# (TODO: Mine) There exists dependency. It will be removed later.
1325
def _deal_with_undefined_var(output_vars, loop_vars):
1326 1327 1328 1329 1330 1331 1332
    """Deal with undefined var cases, We create undefined variable based on the results of body().
    In Dy2Static, we use undefined var to represent the var created in control flow. This function
    expand the loop_vars and replace original loop_vars.
    1. UndefinedVar = Variable      # create a variable
    2. UndefinedVar = None          # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
    3. UndefinedVar = List(int)     # create a list of variable
    4. UndefinedVar = value         # create a variable
1333
    """
1334
    from paddle.jit.dy2static.utils import (
1335 1336 1337
        UndefinedVar,
        create_undefined_variable,
    )
1338 1339

    def create_var_like(o_var):
1340 1341 1342 1343
        if (
            isinstance(o_var, (Variable,) + support_ret_buildin_type)
            or o_var is None
        ):
1344
            return create_undefined_variable()
1345
        if is_sequence(o_var):
1346
            """
1347 1348 1349
            Create a complex container class inside the body of while, including Python list and python Dict
            """
            return map_structure(lambda x: create_undefined_variable(), o_var)
1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362

    if len(output_vars) != len(loop_vars):
        raise ValueError("The length of loop_vars should be the same.")

    results = []
    for o_var, l_var in zip(output_vars, loop_vars):
        if isinstance(l_var, UndefinedVar) or l_var is None:
            results.append(create_var_like(o_var))
        else:
            results.append(l_var)
    return results


Y
Yu Yang 已提交
1363
class ConditionalBlockGuard(BlockGuard):
F
fengjiayi 已提交
1364
    """
1365 1366 1367
    ConditionalBlockGuard is derived from BlockGuard. It is dedicated for
    holding a ConditionalBlock, and helping users entering and exiting the
    ConditionalBlock via Python's 'with' keyword. However, ConditionalBlockGuard
F
fengjiayi 已提交
1368 1369 1370
    is generally an internal component of IfElse, users should not use it directly.
    """

Y
Yu Yang 已提交
1371
    def __init__(self, block):
1372
        check_type(block, "block", ConditionalBlock, "ConditionalBlockGuard")
1373
        super().__init__(block.helper.main_program)
Y
Yu Yang 已提交
1374 1375 1376
        self.block = block

    def __enter__(self):
1377
        return super().__enter__()
Y
Yu Yang 已提交
1378 1379 1380

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.block.complete()
1381
        return super().__exit__(exc_type, exc_val, exc_tb)
Y
Yu Yang 已提交
1382 1383


1384
class ConditionalBlock:
Y
Yan Chunwei 已提交
1385 1386 1387 1388 1389 1390 1391 1392
    '''
    **ConditionalBlock**

    ConditionalBlock is an operator that bind a block to a specific condition,
    if the condition matches, the corresponding block will be executed.

    Args:
        inputs (Variable): bool conditions.
T
tianshuo78520a 已提交
1393
        is_scalar_condition (bool): whether the branch is controlled by a scalar.
Y
Yan Chunwei 已提交
1394 1395 1396 1397 1398
        name(str): name of this ConditionalBlock.

    Examples:
        .. code-block:: python

L
LiYuRio 已提交
1399
             import paddle
1400
             import paddle.fluid as fluid
L
LiYuRio 已提交
1401
             cond = paddle.less_than(x=label, y=limit)
Y
Yan Chunwei 已提交
1402 1403 1404 1405 1406 1407 1408 1409 1410 1411
             true_image, false_image = layers.split_lod_tensor(
                 input=image, mask=cond)
             true_cond = layers.ConditionalBlock([true_image])

             with true_cond.block():
                 ...
             with false_cond.block():
                 ...
    '''

1412
    def __init__(self, inputs, is_scalar_condition=False, name=None):
Y
Yu Yang 已提交
1413
        for each_input in inputs:
1414
            check_type(each_input, "input", Variable, "ConditionalBlock")
Y
Yu Yang 已提交
1415
        self.inputs = inputs
1416
        self.is_scalar_condition = is_scalar_condition
1417
        self.helper = LayerHelper('conditional_block', name=name)
Y
Yu Yang 已提交
1418 1419 1420 1421 1422 1423 1424 1425 1426 1427

    def block(self):
        return ConditionalBlockGuard(self)

    def complete(self):
        inside_block = self.helper.main_program.current_block()
        parent_block = self.helper.main_program.block(inside_block.parent_idx)

        intermediate = set()
        params = set()
1428 1429 1430
        params, intermediate = get_inputs_outputs_in_block(
            inside_block, params, intermediate, helper=self.helper
        )
Y
Yu Yang 已提交
1431

1432 1433 1434
        # Todo(liym27) Here assume that all params are in recursive parent block
        # but when minimize() called in control flow, some params may be in
        # conditional grad block
Y
Yu Yang 已提交
1435
        param_list = [
W
Wu Yi 已提交
1436
            parent_block._var_recursive(each_name) for each_name in params
Y
Yu Yang 已提交
1437 1438
        ]

X
Xin Pan 已提交
1439 1440 1441 1442 1443
        out_list = []
        for inner_out_name in intermediate:
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_list.append(inner_var)
Y
Yu Yang 已提交
1444 1445

        step_scope = parent_block.create_var(
1446 1447
            type=core.VarDesc.VarType.STEP_SCOPES
        )
1448
        conditional_block_op = parent_block.append_op(
Y
Yu Yang 已提交
1449 1450
            type='conditional_block',
            inputs={
1451 1452
                'Cond': self.inputs,
                'Input': param_list,
Y
Yu Yang 已提交
1453
            },
1454
            outputs={'Out': out_list, 'Scope': [step_scope]},
1455 1456
            attrs={
                'sub_block': inside_block,
1457 1458 1459
                'is_scalar_condition': self.is_scalar_condition,
            },
        )
1460

1461
        if self.need_append_conditional_block_grad(inside_block):
1462 1463 1464
            self.append_conditional_block_grad(
                parent_block, inside_block, conditional_block_op
            )
1465 1466 1467

    def need_append_conditional_block_grad(self, inside_block):
        grad_sub_block_idx = inside_block.backward_block_idx
1468
        inside_block_idx = inside_block.idx
1469

1470 1471
        # if inside_block have grad_block and grad_block is not itself,
        # we will append conditional block grad.
1472 1473 1474
        return (
            grad_sub_block_idx != -1 and grad_sub_block_idx != inside_block_idx
        )
1475

1476 1477 1478
    def append_conditional_block_grad(
        self, parent_block, inside_block, conditional_block_op
    ):
1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513
        '''
        Append op `conditional_block_grad` manually.
        When `optimizer.minimize/append_backward` is called in Paddle control flow,
        grad ops will be appended before appending op `conditional_block` so that
        op `conditional_block_grad` can't be appended when calling
        `optimizer.minimize/append_backward`. After appending op `conditional_block`,
        `conditional_block_grad` is appended manually.

        Args:
            parent_block (Block): The block that `conditional_block_op` blongs to.
            inside_block (Block): The sub block of `conditional_block_op`.
            conditional_block_op (Operator): The forward op conditional_block.
        '''

        grad_sub_block_idx = inside_block.backward_block_idx
        grad_sub_block = self.helper.main_program.block(grad_sub_block_idx)

        intermediate = set()
        params = set()

        for each_op in grad_sub_block.ops:
            assert isinstance(each_op, Operator)
            for iname in each_op.input_names:
                for in_var_name in each_op.input(iname):
                    if in_var_name not in intermediate:
                        params.add(in_var_name)

            for oname in each_op.output_names:
                for out_var_name in each_op.output(oname):
                    intermediate.add(out_var_name)

        param_list = []
        for inner_input_name in params:
            inner_var = parent_block._find_var_recursive(inner_input_name)
            if inner_var:
1514
                param_list.append(inner_var.name)
1515 1516

        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
1517 1518
            conditional_block_op.desc, set(), [grad_sub_block.desc]
        )
1519 1520 1521 1522 1523 1524 1525 1526 1527

        # append op_desc in grad_op_descs to target_block
        op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
        backward = core.op_proto_and_checker_maker.OpRole.Backward
        new_op_desc = parent_block.desc.append_op()
        new_op_desc.copy_from(grad_op_desc[0])
        new_op_desc._set_attr(op_role_attr_name, backward)
        # set input and output manually
        new_op_desc.set_input('Input', param_list)
1528 1529 1530
        new_op_desc.set_output(
            'Input@GRAD', [param + "@GRAD" for param in param_list]
        )
1531 1532 1533

        new_vars = set()
        for grad_var_name in new_op_desc.output_arg_names():
1534 1535 1536 1537
            if (
                grad_sub_block.desc.has_var_recursive(grad_var_name.encode())
                or grad_var_name == core.empty_var_name()
            ):
1538
                continue
1539
            grad_sub_block.desc.var(grad_var_name.encode())
1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553
            new_vars.add(grad_var_name)
            if grad_var_name not in op_grad_to_var:
                continue

        # infer_shape and infer_type
        new_op_desc.infer_var_type(grad_sub_block.desc)
        new_op_desc.infer_shape(grad_sub_block.desc)

        for arg in new_op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_shape_(arg, grad_sub_block)

        self.helper.main_program._sync_with_cpp()

1554

1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572
def _to_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return [x]
    return to_sequence(x)


def _is_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return False
    return is_sequence(x)


1573
def expand_undefined_var(nest1, nest2, names):
1574 1575 1576 1577
    """TODO: make this function recursively.
    nest1: Var1, (UndefinedVar, [1,2,3])
    nest2: Var2, ([1,2,3,4], UndefinedVar)
    In this case, we should not expand recursively.
1578
    """
1579
    from paddle.jit.dy2static.utils import UndefinedVar
1580
    from paddle.jit.dy2static.return_transformer import (
1581 1582
        RETURN_VALUE_PREFIX,
    )
1583 1584

    def pack_undefined_var_as(seq):
1585 1586 1587
        return pack_sequence_as(
            seq, [UndefinedVar("padding") for i in flatten(seq)]
        )
1588

1589
    def map_fn(n1, n2, name, order):
1590 1591 1592
        if not name.startswith(RETURN_VALUE_PREFIX) and (
            isinstance(n1, UndefinedVar) or n1 is None
        ):
1593 1594 1595 1596 1597 1598
            if n1 is None and n2 is not None:
                if order == 0:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
1599 1600 1601
                            name, type(n1), n1, type(n2), n2
                        )
                    )
1602 1603 1604 1605 1606
                else:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
1607 1608 1609
                            name, type(n2), n2, type(n1), n1
                        )
                    )
1610 1611 1612 1613
            return pack_undefined_var_as(n2)
        return n1

    nest1_out = list(
1614 1615
        map(
            map_fn,
1616 1617 1618 1619
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(names),
            [0 for i in _to_sequence_except_dict(names)],
1620 1621
        )
    )
1622
    nest2_out = list(
1623 1624
        map(
            map_fn,
1625 1626 1627 1628
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(names),
            [1 for i in _to_sequence_except_dict(names)],
1629 1630
        )
    )
1631
    if not _is_sequence_except_dict(nest1):
1632
        nest1_out = nest1_out[0]
1633
    if not _is_sequence_except_dict(nest2):
1634
        nest2_out = nest2_out[0]
1635 1636 1637
    return nest1_out, nest2_out


1638
class Switch:
Q
qiaolongfei 已提交
1639
    """
1640
    :api_attr: Static Graph
Q
qiaolongfei 已提交
1641

1642 1643 1644 1645 1646
    This class is used to implement Switch branch control function.
    Switch branch contains several case branches and one default branch.
    Switch control flow checks whether the case branch conditions are satisfied in turn,
    and only executes the statement after the first case branch that satisfies the conditions.
    If there is no case branch that satisfies the condition,
1647 1648
    only the statement following the default branch is executed.

1649 1650 1651 1652
    Note:
        A new OP :ref:`api_fluid_layers_case` is highly recommended instead of ``Switch`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_fluid_layers_case` is easier to use and is called with less code but does the same thing as ``Switch`` .

1653
    Member Functions:
1654
        case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.
1655

1656 1657 1658 1659 1660
        default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed.

    Case and default functions can only be used inside the scope of Switch, as shown below:

    .. code-block:: python
1661

1662 1663 1664 1665 1666 1667 1668 1669 1670
        '''
        with fluid.layers.Switch() as switch:
            with switch.case(cond1):
                i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=1)
            with switch.case(cond2):
                i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=2)
            with switch.default():
                i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
        '''
Q
qiaolongfei 已提交
1671

1672 1673
    Args:
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .
Q
qiaolongfei 已提交
1674 1675 1676

    Examples:
        .. code-block:: python
1677

1678
            import paddle
1679
            import paddle.fluid as fluid
Q
qiaolongfei 已提交
1680

1681
            lr = paddle.static.create_global_var(
Q
qiaolongfei 已提交
1682 1683 1684 1685 1686
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate")
1687
            zero_var = fluid.layers.fill_constant(
1688
                shape=[1], dtype='float32', value=0.0)
1689
            one_var = fluid.layers.fill_constant(
Q
qiaolongfei 已提交
1690
                shape=[1], dtype='float32', value=1.0)
1691
            two_var = fluid.layers.fill_constant(
1692
                shape=[1], dtype='float32', value=2.0)
1693

1694
            global_step = fluid.layers.autoincreased_step_counter(counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)
Q
qiaolongfei 已提交
1695 1696

            with fluid.layers.control_flow.Switch() as switch:
Q
qiaolongfei 已提交
1697
                with switch.case(global_step == zero_var):
1698
                    fluid.layers.assign(input=one_var, output=lr)
Q
qiaolongfei 已提交
1699
                with switch.default():
1700
                    fluid.layers.assign(input=two_var, output=lr)
Q
qiaolongfei 已提交
1701

1702 1703 1704 1705 1706
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[lr])
            print(res) # [array([1.], dtype=float32)]
Q
qiaolongfei 已提交
1707 1708
    """

1709 1710 1711 1712 1713 1714 1715 1716 1717
    def __init__(self, name=None):
        self.helper = LayerHelper('switch', name=name)
        self.inside_scope = False
        self.pre_not_conditions = []

    def case(self, condition):
        if not self.inside_scope:
            raise ValueError("case should be called inside with")

1718
        check_variable_and_dtype(
1719 1720 1721 1722 1723
            condition,
            'condition',
            ['bool'],
            'the member function case of fluid.layers.Switch',
        )
1724

1725 1726
        if len(self.pre_not_conditions) == 0:
            cond_block = ConditionalBlock([condition], is_scalar_condition=True)
2
201716010711 已提交
1727
            not_cond = paddle.logical_not(x=condition)
1728 1729 1730 1731
            self.pre_not_conditions.append(not_cond)
        else:
            pre_cond_num = len(self.pre_not_conditions)
            pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
1732
            new_not_cond = paddle.logical_and(
2
201716010711 已提交
1733
                x=pre_not_cond, y=paddle.logical_not(x=condition)
1734
            )
1735 1736
            self.pre_not_conditions.append(new_not_cond)
            cond_block = ConditionalBlock(
1737
                [paddle.logical_and(x=pre_not_cond, y=condition)],
1738 1739
                is_scalar_condition=True,
            )
1740 1741 1742 1743 1744 1745 1746 1747 1748

        return ConditionalBlockGuard(cond_block)

    def default(self):
        pre_cond_num = len(self.pre_not_conditions)
        if pre_cond_num == 0:
            raise ValueError("there should be at least one condition")
        cond_block = ConditionalBlock(
            [self.pre_not_conditions[pre_cond_num - 1]],
1749 1750
            is_scalar_condition=True,
        )
1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766
        return ConditionalBlockGuard(cond_block)

    def __enter__(self):
        """
        set flag that now is inside switch.block {}
        :return:
        """
        self.inside_scope = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.inside_scope = False
        if exc_type is not None:
            return False  # re-raise exception

        return True